Skip to content

Commit

Permalink
Release/0.6.0
Browse files Browse the repository at this point in the history
  • Loading branch information
benhorsburgh authored Apr 27, 2020
1 parent e03ecb3 commit 6dc52cc
Show file tree
Hide file tree
Showing 8 changed files with 107 additions and 140 deletions.
5 changes: 5 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
# Release 0.6.0

* support for newer versions of scikit-learn
* classification report now returns dict in line with scikit-learn

# Release 0.5.0

* Plotting now backed by pygraphviz. This allows:
Expand Down
2 changes: 1 addition & 1 deletion causalnex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@
causalnex toolkit for causal reasoning (Bayesian Networks / Inference)
"""

__version__ = "0.5.0"
__version__ = "0.6.0"

__all__ = ["structure", "discretiser", "evaluation", "inference", "network", "plots"]
10 changes: 4 additions & 6 deletions causalnex/evaluation/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# limitations under the License.
"""Evaluation metrics for causal models."""

from typing import List, Tuple
from typing import Dict, List, Tuple

import pandas as pd
from sklearn import metrics
Expand Down Expand Up @@ -112,9 +112,7 @@ def roc_auc(
return roc, auc


def classification_report(
bn: BayesianNetwork, data: pd.DataFrame, node: str
) -> pd.DataFrame:
def classification_report(bn: BayesianNetwork, data: pd.DataFrame, node: str) -> Dict:
"""
Build a report showing the main classification metrics.
Expand Down Expand Up @@ -160,7 +158,7 @@ def classification_report(
>>> 'traffic': ['light', 'heavy', 'heavy', 'light']
>>> })
>>> from causalnex.evaluation import classification_report
>>> classification_report(bn, test_data, "traffic").to_dict()
>>> classification_report(bn, test_data, "traffic")
{'precision': {
'macro avg': 0.8333333333333333, 'micro avg': 0.75,
'traffic_heavy': 0.6666666666666666,
Expand Down Expand Up @@ -204,4 +202,4 @@ def classification_report(
output_dict=True,
)

return pd.DataFrame.from_dict(report, orient="index")
return report
5 changes: 5 additions & 0 deletions causalnex/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,11 @@ def _do(self, observation: str, state: Dict[Hashable, float]) -> None:
if sum(state.values()) != 1.0:
raise ValueError("The cpd for the provided observation must sum to 1")

if max(state.values()) > 1.0 or min(state.values()) < 0:
raise ValueError(
"The cpd for the provided observation must be between 0 and 1"
)

if not set(state.keys()) == set(self._cpds_original[observation]):
raise ValueError(
"The cpd states do not match expected states: expected {expected}, found {found}".format(
Expand Down
Loading

0 comments on commit 6dc52cc

Please sign in to comment.