From f053bee25108a3e62c5ade6e0e7f5b142d1652b6 Mon Sep 17 00:00:00 2001 From: Ben Horsburgh Date: Wed, 25 Mar 2020 09:14:30 +0000 Subject: [PATCH 1/3] Merge master back to develop --- RELEASE.md | 10 ++++++++-- causalnex/__init__.py | 2 +- docs/source/05_resources/05_faq.md | 2 +- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index eb14d22..43e8353 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,5 +1,11 @@ -# Upcoming release - +# Release 0.5.0 + +* Plotting now backed by pygraphviz. This allows: + * More powerful layout manager + * Cleaner fully customisable theme + * Out-the-box styling for different node and edge types +* Can now get subgraphs from StructureModel containing a specific node +* Bugfix to resolve issue when fitting CPDs with some missing states in data * Minor documentation fixes and improvements # Release 0.4.3: diff --git a/causalnex/__init__.py b/causalnex/__init__.py index 9b33b2c..85b8230 100644 --- a/causalnex/__init__.py +++ b/causalnex/__init__.py @@ -30,6 +30,6 @@ causalnex toolkit for causal reasoning (Bayesian Networks / Inference) """ -__version__ = "0.4.3" +__version__ = "0.5.0" __all__ = ["structure", "discretiser", "evaluation", "inference", "network", "plots"] diff --git a/docs/source/05_resources/05_faq.md b/docs/source/05_resources/05_faq.md index 3e1f2f2..d097a9c 100644 --- a/docs/source/05_resources/05_faq.md +++ b/docs/source/05_resources/05_faq.md @@ -1,6 +1,6 @@ # Frequently asked questions -> *Note:* This documentation is based on `CausalNex 0.4.0`, if you spot anything that is incorrect then please create an [issue](https://github.com/quantumblacklabs/causalnex/issues) or pull request. +> *Note:* This documentation is based on `CausalNex 0.5.0`, if you spot anything that is incorrect then please create an [issue](https://github.com/quantumblacklabs/causalnex/issues) or pull request. ## What is CausalNex? From e0760ad46b71ec2526ef19d2fe940910704bdec7 Mon Sep 17 00:00:00 2001 From: Ben Horsburgh Date: Sun, 26 Apr 2020 17:32:41 +0100 Subject: [PATCH 2/3] Unpin Scikit Learn version --- RELEASE.md | 5 + causalnex/evaluation/evaluation.py | 10 +- docs/source/03_tutorial/03_tutorial.ipynb | 193 ++++++++-------------- requirements.txt | 2 +- tests/test_metrics.py | 14 +- 5 files changed, 85 insertions(+), 139 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 43e8353..2fa0806 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -1,3 +1,8 @@ +# Release future + +* unpinned scikit version +* classification report now returns dict in line with scikit-learn + # Release 0.5.0 * Plotting now backed by pygraphviz. This allows: diff --git a/causalnex/evaluation/evaluation.py b/causalnex/evaluation/evaluation.py index 2636c0b..172a7ad 100644 --- a/causalnex/evaluation/evaluation.py +++ b/causalnex/evaluation/evaluation.py @@ -27,7 +27,7 @@ # limitations under the License. """Evaluation metrics for causal models.""" -from typing import List, Tuple +from typing import Dict, List, Tuple import pandas as pd from sklearn import metrics @@ -112,9 +112,7 @@ def roc_auc( return roc, auc -def classification_report( - bn: BayesianNetwork, data: pd.DataFrame, node: str -) -> pd.DataFrame: +def classification_report(bn: BayesianNetwork, data: pd.DataFrame, node: str) -> Dict: """ Build a report showing the main classification metrics. @@ -160,7 +158,7 @@ def classification_report( >>> 'traffic': ['light', 'heavy', 'heavy', 'light'] >>> }) >>> from causalnex.evaluation import classification_report - >>> classification_report(bn, test_data, "traffic").to_dict() + >>> classification_report(bn, test_data, "traffic") {'precision': { 'macro avg': 0.8333333333333333, 'micro avg': 0.75, 'traffic_heavy': 0.6666666666666666, @@ -204,4 +202,4 @@ def classification_report( output_dict=True, ) - return pd.DataFrame.from_dict(report, orient="index") + return report diff --git a/docs/source/03_tutorial/03_tutorial.ipynb b/docs/source/03_tutorial/03_tutorial.ipynb index ca698ac..fbd20c0 100644 --- a/docs/source/03_tutorial/03_tutorial.ipynb +++ b/docs/source/03_tutorial/03_tutorial.ipynb @@ -55,6 +55,10 @@ "metadata": {}, "outputs": [], "source": [ + "# silence warnings\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", "from causalnex.structure import StructureModel\n", "sm = StructureModel()" ] @@ -126,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": { "scrolled": false }, @@ -138,7 +142,7 @@ "" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -184,7 +188,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -375,7 +379,7 @@ "[5 rows x 33 columns]" ] }, - "execution_count": 6, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -395,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 6, "metadata": {}, "outputs": [ { @@ -593,7 +597,7 @@ "[5 rows x 26 columns]" ] }, - "execution_count": 7, + "execution_count": 6, "metadata": {}, "output_type": "execute_result" } @@ -613,7 +617,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -634,7 +638,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 8, "metadata": {}, "outputs": [ { @@ -832,7 +836,7 @@ "[5 rows x 26 columns]" ] }, - "execution_count": 9, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -856,7 +860,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 9, "metadata": {}, "outputs": [], "source": [ @@ -873,7 +877,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -883,7 +887,7 @@ "" ] }, - "execution_count": 12, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" } @@ -908,7 +912,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 11, "metadata": { "scrolled": false }, @@ -920,7 +924,7 @@ "" ] }, - "execution_count": 14, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } @@ -956,7 +960,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 12, "metadata": {}, "outputs": [], "source": [ @@ -965,7 +969,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -975,7 +979,7 @@ "" ] }, - "execution_count": 16, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -1007,7 +1011,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 14, "metadata": { "scrolled": true }, @@ -1027,7 +1031,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 15, "metadata": { "scrolled": false }, @@ -1039,7 +1043,7 @@ "" ] }, - "execution_count": 19, + "execution_count": 15, "metadata": {}, "output_type": "execute_result" } @@ -1064,7 +1068,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -1074,7 +1078,7 @@ "" ] }, - "execution_count": 21, + "execution_count": 16, "metadata": {}, "output_type": "execute_result" } @@ -1101,7 +1105,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -1158,7 +1162,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -1182,7 +1186,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ @@ -1206,7 +1210,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": 20, "metadata": {}, "outputs": [], "source": [ @@ -1241,7 +1245,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ @@ -1273,7 +1277,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ @@ -1299,7 +1303,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": 23, "metadata": {}, "outputs": [], "source": [ @@ -1322,10 +1326,9 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": 24, "metadata": {}, - "outputs": [ - ], + "outputs": [], "source": [ "bn = bn.fit_cpds(train, method=\"BayesianEstimator\", bayes_prior=\"K2\")" ] @@ -1339,7 +1342,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": 25, "metadata": {}, "outputs": [ { @@ -1504,7 +1507,7 @@ "Pass 0.967033 0.84984 0.888889 0.744186 " ] }, - "execution_count": 30, + "execution_count": 25, "metadata": {}, "output_type": "execute_result" } @@ -1536,7 +1539,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 26, "metadata": {}, "outputs": [ { @@ -1570,7 +1573,7 @@ "Name: 18, dtype: object" ] }, - "execution_count": 31, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -1588,7 +1591,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": 27, "metadata": {}, "outputs": [], "source": [ @@ -1597,7 +1600,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": 28, "metadata": { "scrolled": true }, @@ -1623,7 +1626,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": 29, "metadata": {}, "outputs": [ { @@ -1675,86 +1678,32 @@ }, { "cell_type": "code", - "execution_count": 35, + "execution_count": 30, "metadata": {}, "outputs": [ { "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
precisionrecallf1-scoresupport
G1_Fail0.7777780.5833330.66666712
G1_Pass0.9107140.9622640.93578053
macro avg0.8442460.7727990.80122365
micro avg0.8923080.8923080.89230865
weighted avg0.8861720.8923080.88609765
\n", - "
" - ], "text/plain": [ - " precision recall f1-score support\n", - "G1_Fail 0.777778 0.583333 0.666667 12\n", - "G1_Pass 0.910714 0.962264 0.935780 53\n", - "macro avg 0.844246 0.772799 0.801223 65\n", - "micro avg 0.892308 0.892308 0.892308 65\n", - "weighted avg 0.886172 0.892308 0.886097 65" + "{'G1_Fail': {'precision': 0.7777777777777778,\n", + " 'recall': 0.5833333333333334,\n", + " 'f1-score': 0.6666666666666666,\n", + " 'support': 12},\n", + " 'G1_Pass': {'precision': 0.9107142857142857,\n", + " 'recall': 0.9622641509433962,\n", + " 'f1-score': 0.9357798165137615,\n", + " 'support': 53},\n", + " 'accuracy': 0.8923076923076924,\n", + " 'macro avg': {'precision': 0.8442460317460317,\n", + " 'recall': 0.7727987421383649,\n", + " 'f1-score': 0.8012232415902141,\n", + " 'support': 65},\n", + " 'weighted avg': {'precision': 0.8861721611721611,\n", + " 'recall': 0.8923076923076924,\n", + " 'f1-score': 0.8860973888496825,\n", + " 'support': 65}}" ] }, - "execution_count": 35, + "execution_count": 30, "metadata": {}, "output_type": "execute_result" } @@ -1789,7 +1738,7 @@ }, { "cell_type": "code", - "execution_count": 36, + "execution_count": 31, "metadata": {}, "outputs": [ { @@ -1843,7 +1792,7 @@ }, { "cell_type": "code", - "execution_count": 37, + "execution_count": 32, "metadata": {}, "outputs": [ { @@ -1881,7 +1830,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": 33, "metadata": {}, "outputs": [ { @@ -1890,7 +1839,7 @@ "{'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277}" ] }, - "execution_count": 38, + "execution_count": 33, "metadata": {}, "output_type": "execute_result" } @@ -1912,7 +1861,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": 34, "metadata": {}, "outputs": [ { @@ -1921,7 +1870,7 @@ "[('Fail', 157), ('Pass', 492)]" ] }, - "execution_count": 39, + "execution_count": 34, "metadata": {}, "output_type": "execute_result" } @@ -1957,7 +1906,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": 35, "metadata": {}, "outputs": [ { @@ -2013,7 +1962,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": 36, "metadata": {}, "outputs": [ { @@ -2049,7 +1998,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": 37, "metadata": {}, "outputs": [], "source": [ @@ -2072,7 +2021,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": 38, "metadata": {}, "outputs": [ { diff --git a/requirements.txt b/requirements.txt index 1f4ae9b..fa8bdbd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,6 @@ numpy>=1.14.2, <2.0 pandas==0.24.0 pgmpy==0.1.6 prettytable==0.7.2 -scikit-learn==0.20.2 +scikit-learn>=0.20.2, <0.23.0 scipy>=1.2.0, <1.3 wrapt>=1.11.0, <1.12 diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 9fee6b3..e4b9429 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -362,13 +362,6 @@ def test_auc_for_nonnumeric_features(self): class TestClassificationReport: """Test behaviour of classification_report""" - def test_contains_expected_columns(self, test_data_c_discrete, bn): - """Check that the report contains all of the required data""" - - report = classification_report(bn, test_data_c_discrete, "c") - - assert set(report.columns) == {"recall", "precision", "support", "f1-score"} - def test_contains_all_class_data( self, test_data_c_discrete, bn, test_data_c_likelihood ): @@ -376,7 +369,7 @@ def test_contains_all_class_data( report = classification_report(bn, test_data_c_discrete, "c") - assert (label in report.index for label in test_data_c_likelihood.columns) + assert (label in report for label in test_data_c_likelihood.columns) def test_report_ignores_unrequired_columns_in_data( self, train_data_idx, train_data_discrete, test_data_c_discrete @@ -396,5 +389,6 @@ def test_report_on_node_with_no_parents_based_on_modal_state( """Classification Report on a node with no parents should reflect that predictions are on modal state""" report = classification_report(bn, train_data_discrete, "d") - assert report.loc["d_False", "recall"] == 1 # always predicts most likely class - assert report.loc["d_True", "recall"] == 0 + + assert report["d_False"]["recall"] == 1 # always predicts most likely class + assert report["d_True"]["recall"] == 0 From 207a9e25853d786ed28762433bcc1c87a983a88b Mon Sep 17 00:00:00 2001 From: Tim Herfurth Date: Sun, 26 Apr 2020 20:31:43 +0200 Subject: [PATCH 3/3] restrict do intervention to CPD between 0 and 1 (#32) --- causalnex/inference/inference.py | 5 +++++ tests/test_inference.py | 16 ++++++++++++++++ 2 files changed, 21 insertions(+) diff --git a/causalnex/inference/inference.py b/causalnex/inference/inference.py index 352dda4..d653502 100644 --- a/causalnex/inference/inference.py +++ b/causalnex/inference/inference.py @@ -163,6 +163,11 @@ def _do(self, observation: str, state: Dict[Hashable, float]) -> None: if sum(state.values()) != 1.0: raise ValueError("The cpd for the provided observation must sum to 1") + if max(state.values()) > 1.0 or min(state.values()) < 0: + raise ValueError( + "The cpd for the provided observation must be between 0 and 1" + ) + if not set(state.keys()) == set(self._cpds_original[observation]): raise ValueError( "The cpd states do not match expected states: expected {expected}, found {found}".format( diff --git a/tests/test_inference.py b/tests/test_inference.py index 99bd049..731d1de 100644 --- a/tests/test_inference.py +++ b/tests/test_inference.py @@ -175,6 +175,22 @@ def test_do_expects_all_state_probabilities_sum_to_one( ): ie.do_intervention("d", {0: 0.7, 1: 0.4}) + def test_do_expects_all_state_probabilities_within_0_and_1( + self, train_model, train_data_idx + ): + """Do should accept only state probabilities where the full distribution is provided""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, + match="The cpd for the provided observation must be between 0 and 1", + ): + ie.do_intervention("d", {0: -1.0, 1: 2.0}) + def test_do_expects_all_states_have_a_probability( self, train_model, train_data_idx ):