diff --git a/.circleci/config.yml b/.circleci/config.yml new file mode 100644 index 0000000..c6ac140 --- /dev/null +++ b/.circleci/config.yml @@ -0,0 +1,134 @@ +version: 2 + +images: + python: &python + - image: circleci/buildpack-deps:stretch-browsers + +############################################################################### +utils: + prepare_container: &prepare_container + name: Prepare build container + command: | + sudo apt-get update + sudo apt-get install curl pandoc + sudo apt-get clean + curl https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh > miniconda.sh + bash miniconda.sh -b -p /home/circleci/miniconda + sudo rm -rf ~/.pyenv/ /opt/circleci/.pyenv/ + source /home/circleci/miniconda/etc/profile.d/conda.sh + conda create --name=causalnex_env python=${PYTHON_VERSION} -y + conda activate causalnex_env + conda install -y virtualenv + pip install -U pip setuptools wheel + activate_conda: &activate_conda + name: Activate conda environment + command: | + echo ". /home/circleci/miniconda/etc/profile.d/conda.sh" >> $BASH_ENV + echo "conda deactivate; conda activate causalnex_env" >> $BASH_ENV + + setup_requirements: &setup_requirements + name: Install PIP dependencies + command: | + echo "Python version: $(python --version 2>&1)" + pip install -r requirements.txt -U + pip install -r test_requirements.txt -U + conda install -y virtualenv + setup_pre_commit: &setup_pre_commit + name: Install pre-commit hooks + command: | + pre-commit install --install-hooks + pre-commit install --hook-type pre-push + linters: &linters + name: Run pylint and flake8 + command: make lint + + unit_tests: &unit_tests + name: Run tests + command: make test + + build_docs: &build_docs + # NOTE: doesn't work on python 3.5 + name: Build documentation + command: make build-docs + + install_package: &install_package + name: Install the package + command: make install + + unit_test_steps: &unit_test_steps + steps: + - checkout + - run: *prepare_container + - run: *activate_conda + - run: *setup_requirements + - run: *unit_tests + +############################################################################### +jobs: + unit_tests_35: + docker: *python + environment: + PYTHON_VERSION: '3.5' + <<: *unit_test_steps + + unit_tests_36: + docker: *python + environment: + PYTHON_VERSION: '3.6' + <<: *unit_test_steps + + unit_tests_37: + environment: + PYTHON_VERSION: '3.7' + docker: *python + <<: *unit_test_steps + + linters_37: + docker: *python + environment: + PYTHON_VERSION: '3.7' + steps: + - checkout + - run: *prepare_container + - run: *activate_conda + - run: *setup_requirements + - run: *setup_pre_commit + - run: *linters + - run: *install_package + + docs: + docker: *python + environment: + PYTHON_VERSION: '3.7' + steps: + - checkout + - run: *prepare_container + - run: *activate_conda + - run: *setup_requirements + - run: *build_docs + + all_circleci_checks_succeeded: + docker: + - image: circleci/python # any light-weight image + steps: + - run: + name: Success! + command: echo "All checks passed" + +############################################################################### +workflows: + version: 2 + regular: + jobs: + - unit_tests_35 + - unit_tests_36 + - unit_tests_37 + - linters_37 + - docs + - all_circleci_checks_succeeded: + requires: + - unit_tests_35 + - unit_tests_36 + - unit_tests_37 + - linters_37 + - docs diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..5a0426a --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[report] +fail_under=100 +show_missing=True +exclude_lines = + pragma: no cover + raise NotImplementedError diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..8783059 --- /dev/null +++ b/.flake8 @@ -0,0 +1,8 @@ +# copied from black + +[flake8] +ignore = E203, E266, E501, W503 +exclude = causalnex/bbn +max-line-length = 80 +max-complexity = 18 +select = B,C,E,F,W,T4,B9 diff --git a/.github/ISSUE_TEMPLATE/bug-report.md b/.github/ISSUE_TEMPLATE/bug-report.md new file mode 100644 index 0000000..ba480fe --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug-report.md @@ -0,0 +1,40 @@ +--- +name: Bug report +about: If something isn't working +title: '' +labels: 'Issue: Bug Report' +assignees: '' + +--- + +## Description +Short description of the problem here. + +## Context +How has this bug affected you? What were you trying to accomplish? + +## Steps to Reproduce +1. [First Step] +2. [Second Step] +3. [And so on...] + +## Expected Result +Tell us what should happen. + +## Actual Result +Tell us what happens instead. + +``` +-- If you received an error, place it here. +``` + +``` +-- Separate them if you have more than one. +``` + +## Your Environment +Include as many relevant details about the environment in which you experienced the bug: + +* CausalNex version used (`pip show causalnex`): +* Python version used (`python -V`): +* Operating system and version: diff --git a/.github/ISSUE_TEMPLATE/feature-request.md b/.github/ISSUE_TEMPLATE/feature-request.md new file mode 100644 index 0000000..a7911c2 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature-request.md @@ -0,0 +1,20 @@ +--- +name: Feature request +about: Let us know if you have a feature request or enhancement +title: '<Title>' +labels: 'Issue: Feature Request' +assignees: '' + +--- + +## Description +Is your feature request related to a problem? A clear and concise description of what the problem is: "I'm always frustrated when ..." + +## Context +Why is this change important to you? How would you use it? How can it benefit other users? + +## Possible Implementation +(Optional) Suggest an idea for implementing the addition or change. + +## Possible Alternatives +(Optional) Describe any alternative solutions or features you've considered. diff --git a/.github/ISSUE_TEMPLATE/thank-you.md b/.github/ISSUE_TEMPLATE/thank-you.md new file mode 100644 index 0000000..b959a8c --- /dev/null +++ b/.github/ISSUE_TEMPLATE/thank-you.md @@ -0,0 +1,21 @@ +--- +name: Say thank you +about: Tell us how you use CausalNex and help us grow a community +title: '<Title>' +labels: 'Issue: Thank You' +assignees: '' + +--- + +## Let us know +If you (or your company) are using CausalNex - please let us know. We'd love to hear from you! + +## Making CausalNex even better +If you would like to help CausalNex - any of the following is greatly appreciated. + +- [ ] Give the repository a star +- [ ] Help out with issues +- [ ] Review pull requests +- [ ] Blog about CausalNex +- [ ] Make tutorials +- [ ] Give talks diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 0000000..40e82b4 --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,22 @@ +## Notice + +- [ ] I acknowledge and agree that, by checking this box and clicking "Submit Pull Request": + +- I submit this contribution under the [Apache 2.0 license](https://www.apache.org/licenses/LICENSE-2.0.txt) and represent that I am entitled to do so on behalf of myself, my employer, or relevant third parties, as applicable. +- I certify that (a) this contribution is my original creation and / or (b) to the extent it is not my original creation, I am authorised to submit this contribution on behalf of the original creator(s) or their licensees. +- I certify that the use of this contribution as authorised by the Apache 2.0 license does not violate the intellectual property rights of anyone else. + +## Motivation and Context +Why was this PR created? + +## How has this been tested? +What testing strategies have you used? + +## Checklist + +- [ ] Read the [contributing](/CONTRIBUTING.md) guidelines +- [ ] Opened this PR as a 'Draft Pull Request' if it is work-in-progress +- [ ] Updated the documentation to reflect the code changes +- [ ] Added a description of this change and added my name to the list of supporting contributions in the [`RELEASE.md`](/RELEASE.md) file +- [ ] Added tests to cover my changes +- [ ] Assigned myself to the PR diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2890d45 --- /dev/null +++ b/.gitignore @@ -0,0 +1,132 @@ +########################## +# Common files + +# IntelliJ +.idea/ +*.iml +out/ +.idea_modules/ + +### macOS +*.DS_Store +.AppleDouble +.LSOverride +.Trashes + +# Vim +*~ +.*.swo +.*.swp + +# emacs +*~ +\#*\# +/.emacs.desktop +/.emacs.desktop.lock +*.elc + +# vscode +.vscode/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# C extensions +*.so + +### Python template +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +.static_storage/ +.media/ +local_settings.py + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/tmp-build-artifacts +docs/build + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# pyenv +.python-version + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000..d8abb4f --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,9 @@ +# copied from black +[settings] +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +use_parentheses=True +line_length=88 +known_first_party=causalnex,tests +default_section=THIRDPARTY diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..4100c8c --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,78 @@ +# See https://pre-commit.com for more information +# See https://pre-commit.com/hooks.html for more hooks + +default_stages: [commit, manual] + +repos: +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v2.2.3 + hooks: + - id: trailing-whitespace + stages: [commit, manual] + - id: end-of-file-fixer + stages: [commit, manual] + - id: check-yaml # Checks yaml files for parseable syntax. +# exclude: + - id: check-json # Checks json files for parseable syntax. + - id: check-added-large-files + - id: check-case-conflict # Check for files that would conflict in case-insensitive filesystems + - id: check-merge-conflict # Check for files that contain merge conflict strings. + - id: debug-statements # Check for debugger imports and py37+ `breakpoint()` calls in python source. +# exclude: + - id: detect-private-key # Detects the presence of private keys + - id: requirements-txt-fixer # Sorts entries in requirements.txt + - id: flake8 + exclude: ^causalnex/ebaybbn + +- repo: https://github.com/pre-commit/mirrors-isort + rev: v4.3.21 + hooks: + - id: isort + exclude: ^causalnex/ebaybbn + +- repo: local + hooks: + # It's impossible to specify per-directory configuration, so we just run it many times. + # https://github.com/PyCQA/pylint/issues/618 + # The first set of pylint checks if for local pre-commit, it only runs on the files changed. + - id: pylint-quick-causalnex + name: "Quick PyLint on causalnex/*" + language: system + types: [file, python] + files: ^causalnex/ + exclude: ^causalnex/ebaybbn + entry: pylint -j0 --disable=unnecessary-pass,cyclic-import --ignore=ebaybbn + stages: [commit] + - id: pylint-quick-tests + name: "Quick PyLint on tests/*" + language: system + types: [file, python] + files: ^tests/ + entry: pylint -j0 --disable=missing-docstring,redefined-outer-name,duplicate-code,no-self-use,invalid-name,cyclic-import + stages: [commit] + + # The same pylint checks, but running on all files. It's for manual run with `make lint` + - id: pylint-causalnex + name: "PyLint on causalnex/*" + language: system + pass_filenames: false + stages: [manual] + entry: pylint -j0 --disable=unnecessary-pass,cyclic-import --ignore=ebaybbn causalnex + exclude: ^causalnex/ebaybbn + - id: pylint-tests + name: "PyLint on tests/*" + language: system + pass_filenames: false + stages: [manual] + entry: pylint -j0 --disable=missing-docstring,redefined-outer-name,duplicate-code,no-self-use,invalid-name,cyclic-import tests + # We need to make some exceptions for 3.5, that's why it's a custom runner. + - id: black + name: "Black" + language: system + pass_filenames: false + entry: python -m tools.min_version 3.6 "pip install black" "black causalnex tests" + - id: legal + name: "Licence check" + language: system + pass_filenames: false + entry: make legal diff --git a/.pylintrc b/.pylintrc new file mode 100644 index 0000000..bfb1093 --- /dev/null +++ b/.pylintrc @@ -0,0 +1,425 @@ +[MASTER] + +# A comma-separated list of package or module names from where C extensions may +# be loaded. Extensions are loading into the active Python interpreter and may +# run arbitrary code +extension-pkg-whitelist=numpy + +# Add files or directories to the blacklist. They should be base names, not +# paths. +ignore=CVS + +# Add files or directories matching the regex patterns to the blacklist. The +# regex matches against base names, not paths. +# ignore-patterns= + +# Python code to execute, usually for sys.path manipulation such as +# pygtk.require(). +#init-hook= + +# Use multiple processes to speed up Pylint. +jobs=1 + +# List of plugins (as comma separated values of python modules names) to load, +# usually to register additional checkers. +load-plugins=pylint.extensions.docparams + +# Pickle collected data for later comparisons. +persistent=yes + +# Specify a configuration file. +#rcfile= + +# Allow loading of arbitrary C extensions. Extensions are imported into the +# active Python interpreter and may run arbitrary code. +unsafe-load-any-extension=no + + +[MESSAGES CONTROL] + +# Only show warnings with the listed confidence levels. Leave empty to show +# all. Valid levels: HIGH, INFERENCE, INFERENCE_FAILURE, UNDEFINED +confidence= + +# Disable the message, report, category or checker with the given id(s). You +# can either give multiple identifiers separated by comma (,) or put this +# option multiple times (only on the command line, not in the configuration +# file where it should appear only once).You can also use "--disable=all" to +# disable everything first and then reenable specific checks. For example, if +# you want to run only the similarities checker, you can use "--disable=all +# --enable=similarities". If you want to run only the classes checker, but have +# no Warning level messages displayed, use"--disable=all --enable=classes +# --disable=W" +disable=ungrouped-imports,bad-continuation,c-extension-no-member + +# Enable the message, report, category or checker with the given id(s). You can +# either give multiple identifier separated by comma (,) or put this option +# multiple time (only on the command line, not in the configuration file where +# it should appear only once). See also the "--disable" option for examples. +enable=useless-suppression + + +[REPORTS] + +# Python expression which should return a note less than 10 (10 is the highest +# note). You have access to the variables errors warning, statement which +# respectively contain the number of errors / warnings messages and the total +# number of statements analyzed. This is used by the global evaluation report +# (RP0004). +evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / statement) * 10) + +# Template used to display messages. This is a python new-style format string +# used to format the message information. See doc for all details +#msg-template= + +# Set the output format. Available formats are text, parseable, colorized, json +# and msvs (visual studio).You can also give a reporter class, eg +# mypackage.mymodule.MyReporterClass. +output-format=text + +# Tells whether to display a full report or only the messages +reports=no + +# Activate the evaluation score. +score=yes + + +[REFACTORING] + +# Maximum number of nested blocks for function / method body +max-nested-blocks=5 + + +[BASIC] + +# Naming hint for argument names +argument-name-hint=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Regular expression matching correct argument names +argument-rgx=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Naming hint for attribute names +attr-name-hint=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Regular expression matching correct attribute names +attr-rgx=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Bad variable names which should always be refused, separated by a comma +bad-names=foo,bar,baz,toto,tutu,tata + +# Naming hint for class attribute names +class-attribute-name-hint=([A-Za-zX_][A-Za-zX0-9_]{2,30}|(__.*__))$ + +# Regular expression matching correct class attribute names +class-attribute-rgx=([A-Za-zX_][A-Za-zX0-9_]{2,30}|(__.*__))$ + +# Naming hint for class names +class-name-hint=[A-Z_][a-zXA-Z0-9]+$ + +# Regular expression matching correct class names +class-rgx=[A-Z_][a-zXA-Z0-9]+$ + +# Naming hint for constant names +const-name-hint=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Regular expression matching correct constant names +const-rgx=(([A-Z_][A-Z0-9_]*)|(__.*__))$ + +# Minimum line length for functions/classes that require docstrings, shorter +# ones are exempt. +docstring-min-length=-1 + +# Naming hint for function names +function-name-hint=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Regular expression matching correct function names +function-rgx=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Good variable names which should always be accepted, separated by a comma +good-names=ex,Run,_,io,df,ds,bn,sm,ax,X,W,E,a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p,q,r,s,t,u,v,w,x,y,z + +# Include a hint for the correct naming format with invalid-name +include-naming-hint=no + +# Naming hint for inline iteration names +inlinevar-name-hint=[A-Za-zX_][A-Za-zX0-9_]*$ + +# Regular expression matching correct inline iteration names +inlinevar-rgx=[A-Za-zX_][A-Za-zX0-9_]*$ + +# Naming hint for method names +method-name-hint=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Regular expression matching correct method names +method-rgx=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Naming hint for module names +module-name-hint=(([a-zX_][a-zX0-9_]*)|([A-Z][a-zXA-Z0-9]+))$ + +# Regular expression matching correct module names +module-rgx=(([a-zX_][a-zX0-9_]*)|([A-Z][a-zXA-Z0-9]+))$ + +# Colon-delimited sets of names that determine each other's naming style when +# the name regexes allow several styles. +name-group= + +# Regular expression which should only match function or class names that do +# not require a docstring. +no-docstring-rgx=^_ + +# List of decorators that produce properties, such as abc.abstractproperty. Add +# to this list to register other decorators that produce valid properties. +property-classes=abc.abstractproperty + +# Naming hint for variable names +variable-name-hint=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + +# Regular expression matching correct variable names +variable-rgx=(([a-zX][a-zX0-9_]{2,30})|(_[a-zX0-9_]*))$ + + +[FORMAT] + +# Expected format of line ending, e.g. empty (any line ending), LF or CRLF. +expected-line-ending-format= + +# Regexp for a line that is allowed to be longer than the limit. +ignore-long-lines=^\s*(# )?<?https?://\S+>?$ + +# Number of spaces of indent required inside a hanging or continued line. +indent-after-paren=4 + +# String used as indentation unit. This is usually " " (4 spaces) or "\t" (1 +# tab). +indent-string=' ' + +# Maximum number of characters on a single line. +max-line-length=120 + +# Maximum number of lines in a module +max-module-lines=1000 + +# List of optional constructs for which whitespace checking is disabled. `dict- +# separator` is used to allow tabulation in dicts, etc.: {1 : 1,\n222: 2}. +# `trailing-comma` allows a space between comma and closing bracket: (a, ). +# `empty-line` allows space-only lines. +no-space-check=trailing-comma,dict-separator + +# Allow the body of a class to be on the same line as the declaration if body +# contains single statement. +single-line-class-stmt=no + +# Allow the body of an if to be on the same line as the test if there is no +# else. +single-line-if-stmt=no + + +[LOGGING] + +# Logging modules to check that the string format arguments are in logging +# function parameter format +logging-modules=logging + + +[MISCELLANEOUS] + +# List of note tags to take in consideration, separated by a comma. +notes=FIXME,XXX,TODO + + +[SIMILARITIES] + +# Ignore comments when computing similarities. +ignore-comments=yes + +# Ignore docstrings when computing similarities. +ignore-docstrings=yes + +# Ignore imports when computing similarities. +ignore-imports=no + +# Minimum lines number of a similarity. +min-similarity-lines=20 + + +[SPELLING] + +# Spelling dictionary name. Available dictionaries: none. To make it working +# install python-enchant package. +spelling-dict= + +# List of comma separated words that should not be checked. +spelling-ignore-words= + +# A path to a file that contains private dictionary; one word per line. +spelling-private-dict-file= + +# Tells whether to store unknown words to indicated private dictionary in +# --spelling-private-dict-file option instead of raising a message. +spelling-store-unknown-words=no + + +[TYPECHECK] + +# List of decorators that produce context managers, such as +# contextlib.contextmanager. Add to this list to register other decorators that +# produce valid context managers. +contextmanager-decorators=contextlib.contextmanager + +# List of members which are set dynamically and missed by pylint inference +# system, and so shouldn't trigger E1101 when accessed. Python regular +# expressions are accepted. +generated-members= + +# Tells whether missing members accessed in mixin class should be ignored. A +# mixin class is detected if its name ends with "mixin" (case insensitive). +ignore-mixin-members=yes + +# This flag controls whether pylint should warn about no-member and similar +# checks whenever an opaque object is returned when inferring. The inference +# can return multiple potential results while evaluating a Python object, but +# some branches might not be evaluated, which results in partial inference. In +# that case, it might be useful to still emit no-member and other checks for +# the rest of the inferred objects. +ignore-on-opaque-inference=yes + +# List of class names for which member attributes should not be checked (useful +# for classes with dynamically set attributes). This supports the use of +# qualified names. +ignored-classes=optparse.Values,thread._local,_thread._local + +# List of module names for which member attributes should not be checked +# (useful for modules/projects where namespaces are manipulated during runtime +# and thus existing member attributes cannot be deduced by static analysis. It +# supports qualified module names, as well as Unix pattern matching. +ignored-modules= + +# Show a hint with possible names when a member name was not found. The aspect +# of finding the hint is based on edit distance. +missing-member-hint=yes + +# The minimum edit distance a name should have in order to be considered a +# similar match for a missing member name. +missing-member-hint-distance=1 + +# The total number of similar names that should be taken in consideration when +# showing a hint for a missing member. +missing-member-max-choices=1 + + +[VARIABLES] + +# List of additional names supposed to be defined in builtins. Remember that +# you should avoid to define new builtins when possible. +additional-builtins= + +# Tells whether unused global variables should be treated as a violation. +allow-global-unused-variables=yes + +# List of strings which can identify a callback function by name. A callback +# name must start or end with one of those strings. +callbacks=cb_,_cb + +# A regular expression matching the name of dummy variables (i.e. expectedly +# not used). +dummy-variables-rgx=_+$|(_[a-zXA-Z0-9_]*[a-zXA-Z0-9]+?$)|dummy|^ignored_|^unused_ + +# Argument names that match this expression will be ignored. Default to name +# with leading underscore +ignored-argument-names=_.*|^ignored_|^unused_ + +# Tells whether we should check for unused import in __init__ files. +init-import=no + +# List of qualified module names which can have objects that can redefine +# builtins. +redefining-builtins-modules=six.moves,future.builtins + + +[CLASSES] + +# List of method names used to declare (i.e. assign) instance attributes. +defining-attr-methods=__init__,__new__,setUp,fit,_init + +# List of member names, which should be excluded from the protected access +# warning. +exclude-protected=_asdict,_fields,_replace,_source,_make + +# List of valid names for the first argument in a class method. +valid-classmethod-first-arg=cls + +# List of valid names for the first argument in a metaclass class method. +valid-metaclass-classmethod-first-arg=mcs + + +[DESIGN] + +# Maximum number of arguments for function / method +max-args=8 + +# Maximum number of attributes for a class (see R0902). +max-attributes=10 + +# Maximum number of boolean expressions in a if statement +max-bool-expr=5 + +# Maximum number of branch for function / method body +max-branches=12 + +# Maximum number of locals for function / method body +max-locals=20 + +# Maximum number of parents for a class (see R0901). +max-parents=7 + +# Maximum number of public methods for a class (see R0904). +max-public-methods=20 + +# Maximum number of return / yield for function / method body +max-returns=6 + +# Maximum number of statements in function / method body +max-statements=50 + +# Minimum number of public methods for a class (see R0903). +min-public-methods=1 + + +[IMPORTS] + +# Allow wildcard imports from modules that define __all__. +allow-wildcard-with-all=no + +# Analyse import fallback blocks. This can be used to support both Python 2 and +# 3 compatible code, which means that the block might have code that exists +# only in one or another interpreter, leading to false positives when analysed. +analyse-fallback-blocks=no + +# Deprecated modules which should not be used, separated by a comma +deprecated-modules=optparse,tkinter.tix + +# Create a graph of external dependencies in the given file (report RP0402 must +# not be disabled) +ext-import-graph= + +# Create a graph of every (i.e. internal and external) dependencies in the +# given file (report RP0402 must not be disabled) +import-graph= + +# Create a graph of internal dependencies in the given file (report RP0402 must +# not be disabled) +int-import-graph= + +# Force import order to recognize a module as part of the standard +# compatibility libraries. +known-standard-library= + +# Force import order to recognize a module as part of a third party library. +known-third-party=enchant + + +[EXCEPTIONS] + +# Exceptions that will emit a warning when being caught. Defaults to +# "Exception" +overgeneral-exceptions=Exception diff --git a/.readthedocs.yml b/.readthedocs.yml new file mode 100644 index 0000000..25d1e46 --- /dev/null +++ b/.readthedocs.yml @@ -0,0 +1,31 @@ +# .readthedocs.yml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Build documentation in the docs/ directory with Sphinx +sphinx: + builder: html + configuration: docs/conf.py + fail_on_warning: true + +# Build documentation with MkDocs +# mkdocs: +# configuration: mkdocs.yml + +# Optionally build your docs in additional formats such as PDF and ePub +#formats: all + +# Optionally set the version of Python and requirements required to build your docs +python: + version: 3.6 + install: + - method: pip + path: . + extra_requirements: + - docs + - requirements: + - test_requirements.txt + - doc_requirements.txt diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000..01e82f6 --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,78 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +In the interest of fostering an open and welcoming environment, we as +contributors and maintainers pledge to making participation in our project and +our community a harassment-free experience for everyone, regardless of age, body +size, disability, ethnicity, sex characteristics, gender identity and expression, +level of experience, education, socio-economic status, nationality, personal +appearance, race, religion, or sexual identity and orientation. + +## Our Standards + +Examples of behaviour that contributes to creating a positive environment +include: + +* Using welcoming and inclusive language +* Being respectful of differing viewpoints and experiences +* Gracefully accepting constructive criticism +* Focusing on what is best for the community +* Showing empathy towards other community members + +Examples of unacceptable behaviour by participants include: + +* The use of sexualised language or imagery and unwelcome sexual attention or + advances +* Trolling, insulting/derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or electronic + address, without explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Our Responsibilities + +Project maintainers are responsible for clarifying the standards of acceptable +behaviour and are expected to take appropriate and fair corrective action in +response to any instances of unacceptable behaviour. + +Project maintainers have the right and responsibility to remove, edit, or +reject comments, commits, code, wiki edits, issues, and other contributions +that are not aligned to this Code of Conduct, or to ban temporarily or +permanently any contributor for other behaviours that they deem inappropriate, +threatening, offensive, or harmful. + +## Scope + +This Code of Conduct applies both within project spaces and in public spaces +when an individual is representing the project or its community. Examples of +representing a project or community include using an official project e-mail +address, posting via an official social media account, or acting as an appointed +representative at an online or offline event. Representation of a project may be +further defined and clarified by project maintainers. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behaviour may be +reported by contacting the project team at causalnex@quantumblack.com. All +complaints will be reviewed and investigated and will result in a response that +is deemed necessary and appropriate to the circumstances. The project team is +obligated to maintain confidentiality with regard to the reporter of an incident. +Further details of specific enforcement policies may be posted separately. + +Project maintainers who do not follow or enforce the Code of Conduct in good +faith may face temporary or permanent repercussions as determined by other +members of the project's leadership. + +**Investigation Timeline:** The project team will make all reasonable efforts to initiate and conclude the investigation in a timely fashion. Depending on the type of investigation the steps and timeline for each investigation will vary. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4, +available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see +https://www.contributor-covenant.org/faq diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000..8b175ca --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,174 @@ +# Introduction + +Thank you for considering contributing to CausalNex! CausalNex would not be possible without the generous sharing from leading researchers in causal inference and we hope to maintain the spirit of open discourse by welcoming contributions in the form of pull requests (PRs), issues or code reviews. You can add to code, [documentation](https://causalnex.readthedocs.io), or simply send us spelling and grammar fixes or extra tests. Contribute anything that you think improves the community for us all! + +The following sections describe our vision and contribution process. + +## Vision + +Identifying causation from data remains a field of active research and CausalNex aims to become the leading library for causal reasoning and counterfactual analysis using Bayesian Networks. + +## Code of conduct + +The CausalNex team pledges to foster and maintain a welcoming and friendly community in all of our spaces. All members of our community are expected to follow our [Code of Conduct](/CODE_OF_CONDUCT.md) and we will do our best to enforce those principles and build a happy environment where everyone is treated with respect and dignity. + +# Get started + +We use [GitHub Issues](https://github.com/quantumblacklabs/causalnex/issues) to keep track of known bugs. We keep a close eye on them and try to make it clear when we have an internal fix in progress. Before reporting a new issue, please do your best to ensure your problem hasn't already been reported. If so, it's often better to just leave a comment on an existing issue, rather than create a new one. Old issues also can often include helpful tips and solutions to common problems. + +If you are looking for help with your code in our documentation haven't helped you, please consider posting a question on [Stack Overflow](https://stackoverflow.com/questions/tagged/causalnex). If you tag it `causalnex` and `python`, more people will see it and may be able to help. We are unable to provide individual support via email. In the interest of community engagement we also believe that help is much more valuable if it's shared publicly, so that more people can benefit from it. + +If you're over on Stack Overflow and want to boost your points, take a look at the `causalnex` tag and see if you can help others out by sharing your knowledge. It's another great way to contribute. + +If you have already checked the existing issues in [GitHub issues](https://github.com/quantumblacklabs/causalnex/issues) and are still convinced that you have found odd or erroneous behaviour then please file an [issue](https://github.com/quantumblacklabs/causalnex). We have a template that helps you provide the necessary information we'll need in order to address your query. + +## Feature requests + +### Suggest a new feature + +If you have new ideas for CausalNex functionality then please open a [GitHub issue](https://github.com/quantumblacklabs/causalnex/issues) with the label `Type: Enhancement`. You can submit an issue [here](https://github.com/quantumblacklabs/causalnex/issues) which describes the feature you would like to see, why you need it, and how it should work. + +### Contribute a new feature + +If you're unsure where to begin contributing to CausalNex, please start by looking through the `good first issues` on [GitHub](https://github.com/quantumblacklabs/causalnex/issues). + +We focus on two areas for contribution: `core` and [`contrib`](/causalnex/contrib/): +- `core` refers to the primary CausalNex library +- [`contrib`](/causalNex/contrib/) refers to features that could be added to `core` that do not introduce too many dependencies e.g. adding a new type of causal network to network module. + +Typically, we only accept small contributions for the `core` CausalNex library but accept new features as `plugins` or additions to the [`contrib`](/causalnex/contrib/) module. We regularly review [`contrib`](/causalnex/contrib/) and may migrate modules to `core` if they prove to be essential for the functioning of the framework or if we believe that they are used by most projects. + +## Your first contribution + +Working on your first pull request? You can learn how from these resources: +* [First timers only](https://www.firsttimersonly.com/) +* [How to contribute to an open source project on GitHub](https://egghead.io/courses/how-to-contribute-to-an-open-source-project-on-github) + + +### Guidelines + + - Aim for cross-platform compatibility on Windows, macOS and Linux + - We use [Anaconda](https://www.anaconda.com/distribution/) as a preferred virtual environment + - We use [SemVer](https://semver.org/) for versioning + +Our code is designed to be compatible with Python 3.5 onwards and our style guidelines are (in cascading order): + +* [PEP 8 conventions](https://www.python.org/dev/peps/pep-0008/) for all Python code +* [Google docstrings](https://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings) for code comments +* [PEP 484 type hints](https://www.python.org/dev/peps/pep-0484/) for all user-facing functions / class methods e.g. + +``` +def count_truthy(elements: List[Any]) -> int: + return sum(1 for elem in elements if element) +``` + +> *Note:* We only accept contributions under the Apache 2.0 license and you should have permission to share the submitted code. + +Please note that each code file should have a licence header, include the content of [`legal_header.txt`](https://github.com/quantumblacklabs/causalnex/blob/master/legal_header.txt). +There is an automated check to verify that it exists. The check will highlight any issues and suggest a solution. + +### Branching conventions +We use a branching model that helps us keep track of branches in a logical, consistent way. All branches should have the hyphen-separated convention of: `<type-of-change>/<short-description-of-change>` e.g. `contrib/structure` + +| Types of changes | Description | +| ---------------- | ---------------------------------------------------------------------------- | +| `contrib` | Changes under `contrib/` and has no side-effects to other `contrib/` modules | +| `docs` | Changes to the documentation under `docs/source/` | +| `feature` | Non-breaking change which adds functionality | +| `fix` | Non-breaking change which fixes an issue | +| `tests` | Changes to project unit `tests/` and / or integration `features/` tests | + +## `core` contribution process + +Small contributions are accepted for the `core` library: + + 1. Fork the project by clicking **Fork** in the top-right corner of the [CausalNex GitHub repository](https://github.com/quantumblacklabs/causalnex) and then choosing the target account the repository will be forked to. + 2. Create a feature branch on your forked repository and push all your local changes to that feature branch. + 3. Before submitting a pull request (PR), please ensure that unit tests and linting are passing for your changes by running `make test` and `make lint` locally, have a look at the section [Running checks locally](/CONTRIBUTING.md#running-checks-locally) below. + 4. Open a PR against the `quantumblacklabs:develop` branch from your feature branch. + 5. Update the PR according to the reviewer's comments. + 6. Your PR will be merged by the CausalNex team once all the comments are addressed. + + > _Note:_ We will work with you to complete your contribution but we reserve the right to takeover abandoned PRs. + +## `contrib` contribution process + +You can also add new work to `contrib`: + + 1. Create an [issue](https://github.com/quantumblacklabs/causalnex/issues) describing your contribution. + 2. Fork the project by clicking **Fork** in the top-right corner of the [CausalNex GitHub repository](https://github.com/quantumblacklabs/causalnex) and then choosing the target account the repository will be forked to. + 3. Work in [`contrib`](/causalnex/contrib/) and create a feature branch on your forked repository and push all your local changes to that feature branch. + 4. Before submitting a pull request, please ensure that unit tests and linting are passing for your changes by running `make test` and `make lint` locally, have a look at the section [Running checks locally](/CONTRIBUTING.md#running-checks-locally) below. + 5. Include a `README.md` with instructions on how to use your contribution. + 6. Open a PR against the `quantumblacklabs:develop` branch from your feature branch and reference your issue in the PR description (e.g., `Resolves #<issue-number>`). + 7. Update the PR according to the reviewer's comments. + 8. Your PR will be merged by the CausalNex team once all the comments are addressed. + + > _Note:_ We will work with you to complete your contribution but we reserve the right to takeover abandoned PRs. + +## CI / CD and running checks locally +To run tests you need to install the test requirements. +Also we use [pre-commit](https://pre-commit.com) hooks for the repository to run the checks automatically. +It can all be installed using the following command: + +```bash +make install-test-requirements +make install-pre-commit +``` + +### Running checks locally + +All checks run by our CI / CD servers can be run locally on your computer. + +#### PEP-8 Standards (`pylint` and `flake8`) + +```bash +make lint +``` + +#### Unit tests, 100% coverage (`pytest`, `pytest-cov`) + +```bash +make test +``` + +> Note: We place [conftest.py](https://docs.pytest.org/en/latest/fixture.html#conftest-py-sharing-fixture-functions) files in some test directories to make fixtures reusable by any tests in that directory. If you need to see which test fixtures are available and where they come from, you can issue: + +```bash +pytest --fixtures path/to/the/test/location.py +``` + +#### Others + +Our CI / CD also checks that `causalnex` installs cleanly on a fresh Python virtual environment, a task which depends on successfully building the docs: + +```bash +make build-docs +``` + +This command will only work on Unix-like systems and requires `pandoc` to be installed. + +> ❗ Running `make build-docs` in a Python 3.5 environment may sometimes yield multiple warning messages like the following: `WARNING: toctree contains reference to nonexisting document '04_user_guide/04_user_guide'`. You can simply ignore them or switch to Python 3.6+ when building documentation. + +## Hints on pre-commit usage +The checks will automatically run on all the changed files on each commit. +Even more extensive set of checks (including the heavy set of `pylint` checks) +will run before the push. + +The pre-commit/pre-push checks can be omitted by running with `--no-verify` flag, as per below: + +```bash +git commit --no-verify <...> +git push --no-verify <...> +``` +(`-n` alias works for `git commit`, but not for `git push`) + +All checks will run during CI build, so skipping checks on push will +not allow you to merge your code with failing checks. + +You can uninstall the pre-commit hooks by running: + +```bash +make uninstall-pre-commit +``` +`pre-commit` will still be used by `make lint`, but will not install the git hooks. diff --git a/LICENSE.md b/LICENSE.md new file mode 100644 index 0000000..a9ec8fd --- /dev/null +++ b/LICENSE.md @@ -0,0 +1,27 @@ +Copyright 2019-2020 QuantumBlack Visual Analytics Limited + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +(either separately or in combination, "QuantumBlack Trademarks") are +trademarks of QuantumBlack. The License does not grant you any right or +license to the QuantumBlack Trademarks. You may not use the QuantumBlack +Trademarks or any confusingly similar mark as a trademark for your product, +or use the QuantumBlack Trademarks in any other manner that might cause +confusion in the marketplace, including but not limited to in advertising, +on websites, or on software. + +See the License for the specific language governing permissions and +limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..33eb411 --- /dev/null +++ b/Makefile @@ -0,0 +1,38 @@ +install: + pip install . -U + +clean: + rm -rf build dist docs/build pip-wheel-metadata .mypy_cache .pytest_cache + find . -regex ".*/__pycache__" -exec rm -rf {} + + find . -regex ".*\.egg-info" -exec rm -rf {} + + pre-commit clean || true + +legal: + python tools/license_and_headers.py + +lint: + pre-commit run -a --hook-stage manual + +test: + pytest tests + +package: clean install + python setup.py sdist bdist_wheel + +SPHINXPROJ = causalnex + +install-doc-requirements: + pip install -r doc_requirements.txt -U + +build-docs: install install-doc-requirements + ./docs/build-docs.sh + +install-test-requirements: + pip install -r test_requirements.txt -U + +install-pre-commit: install-test-requirements + pre-commit install --install-hooks + +uninstall-pre-commit: + pre-commit uninstall + pre-commit uninstall --hook-type pre-push diff --git a/README.md b/README.md index a172bd4..bd91fbb 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,80 @@ -# CausalNex +![CausalNex](docs/source/causalnex_banner.png) +----------------- + +| Theme | Status | +|------------------------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| +| Latest Release | [![PyPI version](https://badge.fury.io/py/causalnex.svg)](https://pypi.org/project/causalnex/) | +| Python Version | [![Python Version](https://img.shields.io/badge/python-3.5%20%7C%203.6%20%7C%203.7-blue.svg)](https://pypi.org/project/causalnex/) | +| `master` Branch Build | [![CircleCI](https://circleci.com/gh/quantumblacklabs/causalnex/tree/master.svg?style=shield&circle-token=92ab70f03f3183655473dad16be641959cd31b83)](https://circleci.com/gh/quantumblacklabs/causalnex/tree/master) | +| `develop` Branch Build | [![CircleCI](https://circleci.com/gh/quantumblacklabs/causalnex/tree/develop.svg?style=shield&circle-token=92ab70f03f3183655473dad16be641959cd31b83)](https://circleci.com/gh/quantumblacklabs/causalnex/tree/develop) | +| Documentation Build | [![Documentation](https://readthedocs.org/projects/causalnex/badge/?version=latest)](https://causalnex.readthedocs.io/) | +| License | [![License](https://img.shields.io/badge/license-Apache%202.0-blue.svg)](https://opensource.org/licenses/Apache-2.0) | +| Code Style | [![Code Style: Black](https://img.shields.io/badge/code%20style-black-black.svg)](https://github.com/ambv/black) | + + +## What is CausalNex? + +> "A toolkit for causal reasoning with Bayesian Networks." + +CausalNex aims to become one of the leading library for causal reasoning and "what-if" analysis using Bayesian Networks. It helps to simplify the steps: + - To learn causal structures, + - To allow domain experts to augment the relationships, + - To estimate the effects of potential interventions using data. + +## Why CausalNex? + +CausalNex is built on our collective experience to leverage Bayesian Networks to identify causal relationships in data so that we can develop the right interventions from analytics. We developed CausalNex because: + +- We believe **leveraging Bayesian Networks** is more intuitive to describe causality compared to traditional machine learning methodology that are built on pattern recognition and correlation analysis. +- Causal relationships are more accurate if we can easily **encode or augment domain expertise** in the graph model +- We can then use the graph model to **assess the impact** from changes to underlying features, i.e. counterfactual analysis, and **identify the right intervention**. + +In our experience, a data scientist generally has to use at least 3-4 different open-source libraries before arriving at the final step of finding the right intervention. CausalNex aims to simplify this end-to-end process for causality and counterfactual analysis. + +## What are the main features of CausalNex? + +The main features of this library are: + +- Use state-of-the-art structure learning methods to understand conditional dependencies between variables +- Allow domain knowledge to augment model relationship +- Build predictive models based on structural relationships +- Fit probability distribution of the Bayesian Networks +- Evaluate model quality with standard statistical checks. +- Visualisation that simplifies how causality is understood in Bayesian Networks +- Analyse the impact of interventions using Do-calculus + +## How do I install CausalNex? + +CausalNex is a Python package. To install it, simply run: + +```bash +pip install causalnex +``` + +See more detailed installation instructions, including how to setup Python virtual environments, in our [installation guide](https://causalnex.readthedocs.io/en/latest/02_getting_started/02_install.html) and get started with our [tutorial](https://causalnex.readthedocs.io/en/latest/03_tutorial/03_tutorial.html). + +## How do I use CausalNex? + +You can find the documentation for the latest stable release [here](https://causalnex.readthedocs.io/en/latest/). It explains: + +- An end-to-end [tutorial on how to use CausalNex](https://causalnex.readthedocs.io/en/latest/03_tutorial/03_tutorial.htm) +- The [main concepts and methods](https://causalnex.readthedocs.io/en/latest/04_user_guide/04_user_guide.htm) in using Bayesian Networks for Causal Inference + +> Note: You can find the notebook and markdown files used to build the docs in [`docs/source`](docs/source). + +## Can I contribute? + +Yes! We'd love you to join us and help us build CausalNex. Check out our [contributing](CONTRIBUTING.md) documentation. + +## How do I upgrade CausalNex? + +We use [SemVer](http://semver.org/) for versioning. The best way to upgrade safely is to check our [release notes](RELEASE.md) for any notable breaking changes. + +## What licence do you use? + +See our [LICENSE](LICENSE.md) for more detail. + +## We're hiring! + +Do you want to be part of the team that builds CausalNex and [other great products](https://quantumblack.com/labs) at QuantumBlack? If so, you're in luck! QuantumBlack is currently hiring Machine Learning Engineers who love using data to drive their decisions. Take a look at [our open positions](https://www.quantumblack.com/careers/current-openings#content) and see if you're a fit. diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000..3601855 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,9 @@ +# Release 0.4.0: + +The initial release of CausalNex. + +## Thanks for supporting contributions +CausalNex was originally designed by [Paul Beaumont](https://www.linkedin.com/in/pbeaumont/) and [Ben Horsburgh](https://www.linkedin.com/in/benhorsburgh/) to solve challenges they faced in inferencing causality in their project work. This work was later turned into a product thanks to the following contributors: +[Yetunde Dada](https://github.com/yetudada), [Wesley Leong](https://www.linkedin.com/in/wesleyleong/), [Steve Ler](https://www.linkedin.com/in/song-lim-steve-ler-380366106/), [Viktoriia Oliinyk](https://www.linkedin.com/in/victoria-oleynik/), [Roxana Pamfil](https://www.linkedin.com/in/roxana-pamfil-1192053b/), [Nisara Sriwattanaworachai](https://www.linkedin.com/in/nisara-sriwattanaworachai-795b357/) and [Nikolaos Tsaousis](https://www.linkedin.com/in/ntsaousis/). + +CausalNex would also not be possible without the generous sharing from leading researches in the field of causal inference and we are grateful to everyone who advised and supported us, filed issues or helped resolve them, asked and answered questions or simply be part of inspiring discussions. diff --git a/causalnex/__init__.py b/causalnex/__init__.py new file mode 100644 index 0000000..a3b94b8 --- /dev/null +++ b/causalnex/__init__.py @@ -0,0 +1,35 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +causalnex toolkit for causal reasoning (Bayesian Networks / Inference) +""" + +__version__ = "0.4.0" + +__all__ = ["structure", "discretiser", "evaluation", "inference", "network", "plots"] diff --git a/causalnex/contrib/README.md b/causalnex/contrib/README.md new file mode 100644 index 0000000..4c2447f --- /dev/null +++ b/causalnex/contrib/README.md @@ -0,0 +1,39 @@ +# CausalNex contrib + +The contrib directory is meant to contain user contributions, these +contributions might get merged into core CausalNex at some point in the future. + +When create a new module in `contrib`, place it exactly where it would be if it +was merged into core CausalNex. + +For example, functions to plot network diagrams are under the core package `causalnex.plotting`. If you are +contributing a new visualisation or plot you should have the following directory: +`causalnex/contrib/my_project/plotting/` - i.e., the name of your project before the +`causalnex` package path. + +This is how a module would look like under `causalnex/contrib`: +``` +causalnex/contrib/my_project/plotting/ + my_module.py + README.md +``` + +You should put you test files in `tests/contrib/my_project`: +``` +tests/contrib/my_project + test_my_module.py +``` + +## Requirements + +If your project has any requirements that are not in the core `requirements.txt` +file. Please add them in `setup.py` like so: +``` +... +extras_require={ + 'my_project': ['requirement1==1.0.1', 'requirement2==2.0.1'], + }, +``` + +Please notice that a readme with instructions about how to use your module +and 100% test coverage are required to accept a PR. diff --git a/causalnex/contrib/__init__.py b/causalnex/contrib/__init__.py new file mode 100644 index 0000000..5da8261 --- /dev/null +++ b/causalnex/contrib/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/causalnex/discretiser/__init__.py b/causalnex/discretiser/__init__.py new file mode 100644 index 0000000..f8dd21d --- /dev/null +++ b/causalnex/discretiser/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``causalnex.discretiser`` provides functionality to discretise data. +""" + +__version__ = "0.4.0" + +__all__ = ["Discretiser"] + +from .discretiser import Discretiser diff --git a/causalnex/discretiser/discretiser.py b/causalnex/discretiser/discretiser.py new file mode 100644 index 0000000..763a38e --- /dev/null +++ b/causalnex/discretiser/discretiser.py @@ -0,0 +1,214 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +"""Tools to help discretise data.""" + +from typing import List + +import numpy as np +from sklearn.base import BaseEstimator, TransformerMixin + + +class Discretiser(BaseEstimator, TransformerMixin): + """Allows the discretisation of numeric data. + + Example: + :: + >>> import causalnex + >>> import pandas as pd + >>> + >>> df = pd.DataFrame({'Age': [12, 13, 18, 19, 22, 60]}) + >>> + >>> from causalnex.discretiser import Discretiser + >>> df["Transformed_Age_1"] = Discretiser(method="fixed", + >>> numeric_split_points=[11,18,50]).transform(df["Age"]) + >>> df.to_dict() + {'Age': {0: 7, 1: 12, 2: 13, 3: 18, 4: 19, 5: 22, 6: 60}, + 'Transformed_Age': {0: 0, 1: 1, 2: 1, 3: 2, 4: 2, 5: 2, 6: 3}} + """ + + def __init__( + self, + method: str = "uniform", + num_buckets: int = None, + outlier_percentile: float = None, + numeric_split_points: List[float] = None, + percentile_split_points: List[float] = None, + ): + """ + Creates a new Discretiser, that provides fit, fit_transform, and transform function to discretise data. + + Args: + method (str): can be one of: + - uniform: discretise data into uniformly spaced buckets. Note, complete uniformity + cannot be guaranteed under all circumstances, for example, if 5 data points are to split + into 2 buckets, then one will contain 2 points, and the other will contain 3. + Provide num_buckets. + - quantile: discretise data according to the distribution of values. For example, providing + num_buckets=4 will discretise data into 4 buckets, [0-25th, 25th-50th, 50th-75th, 75th-100th] + percentiles. Provide num_buckets. + - outlier: discretise data into 3 buckets - [low_outliers, normal, high_outliers] based on + outliers being below outlier_percentile, or above 1-outlier_percentile. Provide outlier_percentile. + - fixed: discretise according to pre-defined split points. Provide numeric_split_points + - percentiles: discretise data according to the distribution of percentiles values. + Provide percentile_split_points. + num_buckets: (int): used by method=uniform and method=quantile. + outlier_percentile: used by method=outlier. + numeric_split_points: used by method=fixed. to split such that values below 10 go into bucket 0, + 10 to 20 go into bucket 1, and above 20 go into bucket 2, provide [10, 21]. Note that split_point + values are non-inclusive. + percentile_split_points: used by method=percentiles. to split such that values below 10th percentiles + go into bucket 0, 10th to below 75th percentiles go into bucket 1, and 75th percentiles and above go into + bucket 2, provide [0.1, 0.75]. + + Raises: + ValueError: If an incorrect argument is passed. + """ + + self.numeric_split_points = [] + + self.method = method + self.num_buckets = num_buckets + self.outlier_percentile = outlier_percentile + self.numeric_split_points = numeric_split_points + self.percentile_split_points = percentile_split_points + + allowed_methods = ["uniform", "quantile", "outlier", "fixed", "percentiles"] + + if self.method not in allowed_methods: + raise ValueError( + "{0} is not a recognised method. Use one of: {1}".format( + self.method, " ".join(allowed_methods) + ) + ) + if self.method in {"uniform", "quantile"} and num_buckets is None: + raise ValueError( + "{0} method expects {1}".format(self.method, "num_buckets") + ) + + if self.method == "outlier" and outlier_percentile is None: + raise ValueError( + "{0} method expects {1}".format(self.method, "outlier_percentile") + ) + + if outlier_percentile is not None and not 0 <= outlier_percentile < 0.5: + raise ValueError( + "{0} must be between 0 and 0.5".format("outlier_percentile") + ) + + if self.method == "fixed" and numeric_split_points is None: + raise ValueError( + "{0} method expects {1}".format(self.method, "numeric_split_points") + ) + + if ( + numeric_split_points is not None + and sorted(numeric_split_points) != numeric_split_points + ): + raise ValueError( + "{0} must be monotonically increasing".format("numeric_split_points") + ) + + if self.method == "percentiles" and percentile_split_points is None: + raise ValueError( + "{0} method expects {1}".format(self.method, "percentile_split_points") + ) + + if percentile_split_points is not None and not all( + 0 <= p <= 1 for p in percentile_split_points + ): + raise ValueError( + "{0} must be between 0 and 1".format("percentile_split_points") + ) + + if ( + percentile_split_points is not None + and sorted(percentile_split_points) != percentile_split_points + ): + raise ValueError( + "{0} must be monotonically increasing".format("percentile_split_points") + ) + + if self.method == "fixed": + self.numeric_split_points = numeric_split_points + + def fit(self, data: np.ndarray) -> "Discretiser": + """ + Fit where split points are based on the input data. + + Args: + data (np.ndarray): values used to learn where split points exist. + + Returns: + self + + Raises: + RuntimeError: If an attempt to fit fixed numeric_split_points is made. + """ + + x = data.flatten() + x.sort() + + if self.method == "uniform": + bucket_width = len(x) / self.num_buckets + self.numeric_split_points = [ + x[int(np.floor((n + 1) * bucket_width))] + for n in range(self.num_buckets - 1) + ] + + elif self.method == "quantile": + bucket_width = 1.0 / self.num_buckets + quantiles = [bucket_width * (n + 1) for n in range(self.num_buckets - 1)] + self.numeric_split_points = np.quantile(x, quantiles) + + elif self.method == "outlier": + self.numeric_split_points = np.quantile( + x, [self.outlier_percentile, 1 - self.outlier_percentile] + ) + + elif self.method == "percentiles": + percentiles = [p * 100 for p in self.percentile_split_points] + self.numeric_split_points = np.percentile(x, percentiles) + + else: + raise RuntimeError("cannot call fit using method=fixed") + + return self + + def transform(self, data: np.ndarray) -> np.ndarray: + """ + Transform the input data into discretised digits, based on the numeric_split_points that were either + learned through using fit(), or from initialisation if method="fixed". + + Args: + data (np.ndarray): values that will be transformed into discretised digits. + + Returns: + input data transformed into discretised digits. + """ + + return np.digitize(data, self.numeric_split_points, right=False) diff --git a/causalnex/ebaybbn/__init__.py b/causalnex/ebaybbn/__init__.py new file mode 100644 index 0000000..8265862 --- /dev/null +++ b/causalnex/ebaybbn/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from .bbn import * diff --git a/causalnex/ebaybbn/bbn.py b/causalnex/ebaybbn/bbn.py new file mode 100644 index 0000000..ee53e5c --- /dev/null +++ b/causalnex/ebaybbn/bbn.py @@ -0,0 +1,1009 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Data Structures to represent a BBN as a DAG.""" + +import copy +import heapq +import logging +from collections import defaultdict +from io import StringIO +from itertools import combinations, product +from random import choice, random + +from .exceptions import VariableNotInGraphError, VariableValueNotInDomainError +from .graph import Node, UndirectedGraph, UndirectedNode, connect +from .utils import get_args, get_original_factors + +# from .bayesian import GREEN, NORMAL +GREEN = "\033[92m" +NORMAL = "\033[0m" + + +class BBNNode(Node): + def __init__(self, factor): + super(BBNNode, self).__init__(factor.__name__) + self.func = factor + self.argspec = get_args(factor) + + def __repr__(self): + return "<BBNNode %s (%s)>" % (self.name, self.argspec) + + +class BBN: + """A Directed Acyclic Graph""" + + def __init__(self, nodes_dict, name=None, domains={}): + self.nodes = list(nodes_dict.values()) + self.vars_to_nodes = nodes_dict + self.domains = domains + # For each node we want + # to explicitly record which + # variable it 'introduced'. + # Note that we cannot record + # this duing Node instantiation + # becuase at that point we do + # not yet know *which* of the + # variables in the argument + # list is the one being modeled + # by the function. (Unless there + # is only one argument) + for variable_name, node in list(nodes_dict.items()): + node.variable_name = variable_name + + def get_graphviz_source(self): + fh = StringIO() + fh.write("digraph G {\n") + fh.write(' graph [ dpi = 300 bgcolor="transparent" rankdir="LR"];\n') + edges = set() + for node in sorted(self.nodes, key=lambda x: x.name): + fh.write(' %s [ shape="ellipse" color="blue"];\n' % node.name) + for child in node.children: + edge = (node.name, child.name) + edges.add(edge) + for source, target in sorted(edges, key=lambda x: (x[0], x[1])): + fh.write(" %s -> %s;\n" % (source, target)) + fh.write("}\n") + return fh.getvalue() + + def build_join_tree(self): + jt = build_join_tree(self) + return jt + + def validate_keyvals(self, **kwds): + """ + When evidence in the form of + keyvals are provided to the .query() method + validate that all keys match a variable name + and that all vals are in the domain of + the variable + """ + vars = set([n.variable_name for n in self.nodes]) + for k, v in list(kwds.items()): + if k not in vars: + raise VariableNotInGraphError(k) + domain = self.domains.get(k, (True, False)) + if v not in domain: + s = "{}={}".format(k, v) + raise VariableValueNotInDomainError(s) + return True + + def query(self, **kwds): + # First check that the keyvals + # provided are valid for this graph... + self.validate_keyvals(**kwds) + jt = self.build_join_tree() + assignments = jt.assign_clusters(self) + jt.initialize_potentials(assignments, self, kwds) + + jt.propagate() + marginals = dict() + normalizers = defaultdict(float) + + for node in self.nodes: + for k, v in list(jt.marginal(node).items()): + # For a single node the + # key for the marginal tt always + # has just one argument so we + # will unpack it here + marginals[k[0]] = v + # If we had any evidence then we + # need to normalize all the variables + # not evidenced. + if kwds: + normalizers[k[0][0]] += v + + if kwds: + for k, v in marginals.items(): + if normalizers[k[0]] != 0: + marginals[k] /= normalizers[k[0]] + + return marginals + + def draw_samples(self, query={}, n=1): + """query is a dict of currently evidenced + variables and is none by default.""" + samples = [] + result_cache = dict() + # We need to add evidence variables to the sample... + while len(samples) < n: + sample = dict(query) + while len(sample) < len(self.nodes): + next_node = choice( + [node for node in self.nodes if node.variable_name not in sample] + ) + key = tuple(sorted(sample.items())) + if key not in result_cache: + result_cache[key] = self.query(**sample) + result = result_cache[key] + var_density = [ + r + for r in list(result.items()) + if r[0][0] == next_node.variable_name + ] + cumulative_density = var_density[:1] + for key, mass in var_density[1:]: + cumulative_density.append((key, cumulative_density[-1][1] + mass)) + r = random() + i = 0 + while r > cumulative_density[i][1]: + i += 1 + sample[next_node.variable_name] = cumulative_density[i][0][1] + samples.append(sample) + return samples + + +class JoinTree(UndirectedGraph): + def __init__(self, nodes, name=None): + super(JoinTree, self).__init__(nodes, name) + self._sensitivity_flag = False + + @property + def sepset_nodes(self): + return [n for n in self.nodes if isinstance(n, JoinTreeSepSetNode)] + + @property + def clique_nodes(self): + return [n for n in self.nodes if isinstance(n, JoinTreeCliqueNode)] + + def initialize_potentials(self, assignments, bbn, evidence={}): + # Step 1, assign 1 to each cluster and sepset + for node in self.nodes: + tt = dict() + + vals = [] + variables = node.variable_names + # Lets sort the variables here so that + # the variable names in the keys in + # the tt are always sorted. + variables.sort() + for variable in variables: + domain = bbn.domains.get(variable, [True, False]) + vals.append(list(product([variable], domain))) + permutations = product(*vals) + for permutation in permutations: + tt[permutation] = 1 + node.potential_tt = tt + + # Step 2: Note that in H&D the assignments are + # done as part of step 2 however we have + # seperated the assignment algorithm out and + # done these prior to step 1. + # Now for each assignment we want to + # generate a truth-table from the + # values of the bbn truth-tables that are + # assigned to the clusters... + + for clique, bbn_nodes in assignments.items(): + tt = dict() + vals = [] + variables = list(clique.variable_names) + variables.sort() + for variable in variables: + domain = bbn.domains.get(variable, [True, False]) + vals.append(list(product([variable], domain))) + permutations = product(*vals) + for permutation in permutations: + argvals = dict(permutation) + potential = 1 + for bbn_node in bbn_nodes: + bbn_node.clique = clique + # We could handle evidence here + # by altering the potential_tt. + # This is slightly different to + # the way that H&D do it. + + arg_list = [] + for arg_name in get_args(bbn_node.func): + arg_list.append(argvals[arg_name]) + + potential *= bbn_node.func(*arg_list) + tt[permutation] = potential + clique.potential_tt = tt + + if not evidence: + # We dont need to deal with likelihoods + # if we dont have any evidence. + return + + # Step 2b: Set each liklihood element ^V(v) to 1 + self.initial_likelihoods(assignments, bbn) + for clique, bbn_nodes in assignments.items(): + for node in bbn_nodes: + if node.variable_name in evidence: + for k, v in list(clique.potential_tt.items()): + # Encode the evidence in + # the clique potential... + for variable, value in k: + if variable == node.variable_name: + if value != evidence[variable]: + clique.potential_tt[k] = 0 + + def initial_likelihoods(self, assignments, bbn): + # TODO: Since this is the same every time we should probably + # cache it. + likelihood = defaultdict(dict) + for clique, bbn_nodes in assignments.items(): + for node in bbn_nodes: + for value in bbn.domains.get(node.variable_name, [True, False]): + likelihood[(node.variable_name, value)] = 1 + return likelihood + + def assign_clusters(self, bbn): + assignments_by_family = dict() + assignments_by_clique = defaultdict(list) + assigned = set() + for node in bbn.nodes: + args = get_args(node.func) + if len(args) == 1: + # If the func has only 1 arg + # it means that it does not + # specify a conditional probability + # This is where H&D is a bit vague + # but it seems to imply that we + # do not assign it to any + # clique. + # Revising this for now as I dont + # think its correct, I think + # all CPTs need to be assigned + # once and once only. The example + # in H&D just happens to be a clique + # that f_a could have been assigned + # to but wasnt presumably because + # it got assigned somewhere else. + pass + # continue + # Now we need to find a cluster that + # is a superset of the Family(v) + # Family(v) is defined by D&H to + # be the union of v and parents(v) + family = set(args) + # At this point we need to know which *variable* + # a BBN node represents. Up to now we have + # not *explicitely* specified this, however + # we have been following some conventions + # so we could just use this convention for + # now. Need to come back to this to + # perhaps establish the variable at + # build bbn time... + + containing_cliques = [ + clique_node + for clique_node in self.clique_nodes + if (set(clique_node.variable_names).issuperset(family)) + ] + assert len(containing_cliques) >= 1 + for clique in containing_cliques: + if node in assigned: + # Make sure we assign all original + # PMFs only once each + break + assignments_by_clique[clique].append(node) + assigned.add(node) + assignments_by_family[tuple(family)] = containing_cliques + return assignments_by_clique + + def propagate(self, starting_clique=None): + """Refer to H&D pg. 20""" + + # Step 1 is to choose an arbitrary clique cluster + # as starting cluster + if starting_clique is None: + starting_clique = self.clique_nodes[0] + logging.debug("Starting propagating messages from: %s", starting_clique.name) + # Step 2: Unmark all clusters, call collect_evidence(X) + for node in self.clique_nodes: + node.marked = False + logging.debug("Marking node as not visited Node: %s", node.name) + self.collect_evidence(sender=starting_clique) + + # Step 3: Unmark all clusters, call distribute_evidence(X) + for node in self.clique_nodes: + node.marked = False + + self.distribute_evidence(starting_clique) + + def collect_evidence(self, sender=None, receiver=None): + + logging.debug("Collect evidence from %s", sender.name) + # Step 1, Mark X + sender.marked = True + + # Step 2, call collect_evidence on Xs unmarked + # neighbouring clusters. + for neighbouring_clique in sender.neighbouring_cliques: + if not neighbouring_clique.marked: + logging.debug( + "Collect evidence from %s to %s", + neighbouring_clique.name, + sender.name, + ) + self.collect_evidence(sender=neighbouring_clique, receiver=sender) + # Step 3, pass message from sender to receiver + if receiver is not None: + sender.pass_message(receiver) + + def distribute_evidence(self, sender=None, receiver=None): + logging.debug("Distribute evidence from: %s", sender.name) + # Step 1, Mark X + sender.marked = True + + # Step 2, pass a messagee from X to each of its + # unmarked neighbouring clusters + for neighbouring_clique in sender.neighbouring_cliques: + if not neighbouring_clique.marked: + logging.debug( + "Pass message from: %s to %s", sender.name, neighbouring_clique.name + ) + sender.pass_message(neighbouring_clique) + + # Step 3, call distribute_evidence on Xs unmarked neighbours + for neighbouring_clique in sender.neighbouring_cliques: + if not neighbouring_clique.marked: + logging.debug( + "Distribute evidence from: %s to %s", + neighbouring_clique.name, + sender.name, + ) + self.distribute_evidence(sender=neighbouring_clique, receiver=sender) + + def marginal(self, bbn_node): + """Remember that the original + variables that we are interested in + are actually in the bbn. However + when we constructed the JT we did + it out of the moralized graph. + This means the cliques refer to + the nodes in the moralized graph + and not the nodes in the BBN. + For efficiency we should come back + to this and add some pointers + or an index. + """ + + # First we will find the JT nodes that + # contain the bbn_node ie all the nodes + # that are either cliques or sepsets + # that contain the bbn_node + # Note that for efficiency we + # should probably have an index + # cached in the bbn and/or the jt. + containing_nodes = [] + + for node in self.clique_nodes: + if bbn_node.name in [n.name for n in node.clique.nodes]: + containing_nodes.append(node) + # In theory it doesnt matter which one we + # use so we could bale out after we + # find the first one + # TODO: With some better indexing we could + # avoid searching for this node every time... + clique_node = containing_nodes[0] + tt = defaultdict(float) + for k, v in list(clique_node.potential_tt.items()): + entry = transform(k, clique_node.variable_names, [bbn_node.variable_name]) + tt[entry] += v + + # Now if this node was evidenced we need to normalize + # over the values... + # TODO: It will be safer to copy the defaultdict to a regular dict + return tt + + +class Clique(object): + def __init__(self, cluster): + self.nodes = cluster + + def __repr__(self): + vars = sorted([n.variable_name for n in self.nodes]) + return "Clique_%s" % "".join([v.upper() for v in vars]) + + +def transform(x, X, R): + """Transform a Potential Truth Table + Entry into a different variable space. + For example if we have the + entry [True, True, False] representing + values of variable [A, B, C] in X + and we want to transform into + R which has variables [C, A] we + will return the entry [False, True]. + Here X represents the argument list + for the clique set X and R represents + the argument list for the sepset. + This implies that R is always a subset + of X""" + entry = [] + for r in R: + pos = X.index(r) + entry.append(x[pos]) + return tuple(entry) + + +class JoinTreeCliqueNode(UndirectedNode): + def __init__(self, clique): + super(JoinTreeCliqueNode, self).__init__(clique.__repr__()) + self.clique = clique + self.potential_psi = None + + # Now we create a pointer to + # this clique node as the "parent" clique + # node of each node in the cluster. + # for node in self.clique.nodes: + # node.parent_clique = self + # This is not quite correct, the + # parent cluster as defined by H&D + # is *a* cluster than is a superset + # of Family(v) + + @property + def variable_names(self): + """Return the set of variable names + that this clique represents""" + var_names = [] + for node in self.clique.nodes: + var_names.append(node.variable_name) + return sorted(var_names) + + @property + def neighbouring_cliques(self): + """Return the neighbouring cliques + this is used during the propagation algorithm. + + """ + neighbours = set() + for sepset_node in self.neighbours: + # All *immediate* neighbours will + # be sepset nodes, its the neighbours of + # these sepsets that form the nodes + # clique neighbours (excluding itself) + for clique_node in sepset_node.neighbours: + if clique_node is not self: + neighbours.add(clique_node) + return neighbours + + def pass_message(self, target): + """Pass a message from this node to the + recipient node during propagation. + + NB: It may turnout at this point that + after initializing the potential + Truth table on the JT we could quite + simply construct a factor graph + from the JT and use the factor + graph sum product propagation. + In theory this should be the same + and since the semantics are already + worked out it would be easier.""" + + # Find the sepset node between the + # source and target nodes. + sepset_node = list(set(self.neighbours).intersection(target.neighbours))[0] + + logging.debug("Pass message from: %s to: %s", self.name, target.name) + # Step 1: projection + logging.debug("Project into the Sepset node: %s", str(sepset_node)) + self.project(sepset_node) + + logging.debug(" Send the summed marginals to the target: %s ", str(sepset_node)) + + # Step 2 absorbtion + self.absorb(sepset_node, target) + + def project(self, sepset_node): + """See page 20 of PPTC. + We assign a new potential tt to + the sepset which consists of the + potential of the source node + with all variables not in R marginalized. + """ + assert sepset_node in self.neighbours + # First we make a copy of the + # old potential tt + + # Now we assign a new potential tt + # to the sepset by marginalizing + # out the variables from X that are not + # in the sepset + # ToDO test and check this function + # Todo check on the old sepset potentials and when they will be set + + sepset_node.potential_tt_old = copy.deepcopy(sepset_node.potential_tt) + tt = defaultdict(float) + for k, v in self.potential_tt.items(): + entry = transform(k, self.variable_names, sepset_node.variable_names) + tt[entry] += v + sepset_node.potential_tt = tt + + def absorb(self, sepset, target): + # Assign a new potential tt to + # Y (the target) + logging.debug( + "Absorb potentails from sepset node %s into clique %s", + sepset.name, + target.name, + ) + tt = dict() + + for k, v in list(target.potential_tt.items()): + # For each entry we multiply by + # sepsets new value and divide + # by sepsets old value... + # Note that nowhere in H&D is + # division on potentials defined. + # However in Barber page 12 + # an equation implies that + # the the division is equivalent + # to the original assignment. + # For now we will assume entry-wise + # division which seems logical. + entry = transform(k, target.variable_names, sepset.variable_names) + if target.potential_tt[k] == 0: + tt[k] = 0 + else: + tt[k] = target.potential_tt[k] * ( + sepset.potential_tt[entry] / sepset.potential_tt_old[entry] + ) + # assign the new potentials to the node + target.potential_tt = tt + + def __repr__(self): + return "<JoinTreeCliqueNode: %s>" % self.clique + + +class SepSet(object): + def __init__(self, X, Y): + """X and Y are cliques represented as sets.""" + self.X = X + self.Y = Y + self.label = list(X.nodes.intersection(Y.nodes)) + + @property + def mass(self): + return len(self.label) + + @property + def cost(self): + """Since cost is used as a tie-breaker + and is an optimization for inference time + we will punt on it for now. Instead we + will just use the assumption that all + variables in X and Y are binary and thus + use a weight of 2. + TODO: come back to this and compute + actual weights + """ + return 2 ** len(self.X.nodes) + 2 ** len(self.Y.nodes) + + def insertable(self, forest): + """A sepset can only be inserted + into the JT if the cliques it + separates are NOT already on + the same tree. + NOTE: For efficiency we should + add an index that indexes cliques + into the trees in the forest.""" + X_trees = [t for t in forest if self.X in [n.clique for n in t.clique_nodes]] + Y_trees = [t for t in forest if self.Y in [n.clique for n in t.clique_nodes]] + assert len(X_trees) == 1 + assert len(Y_trees) == 1 + if X_trees[0] is not Y_trees[0]: + return True + return False + + def insert(self, forest): + """Inserting this sepset into + a forest, providing the two + cliques are in different trees, + means that effectively we are + collapsing the two trees into + one. We will explicitely perform + this collapse by adding the + sepset node into the tree + and adding edges between itself + and its clique node neighbours. + Finally we must remove the + second tree from the forest + as it is now joined to the + first. + """ + X_tree = [t for t in forest if self.X in [n.clique for n in t.clique_nodes]][0] + Y_tree = [t for t in forest if self.Y in [n.clique for n in t.clique_nodes]][0] + + # Now create and insert a sepset node into the Xtree + ss_node = JoinTreeSepSetNode(self, self) + X_tree.nodes.append(ss_node) + + # And connect them + self.X.node.neighbours.append(ss_node) + ss_node.neighbours.append(self.X.node) + + # Now lets keep the X_tree and drop the Y_tree + # this means we need to copy all the nodes + # in the Y_tree that are not already in the X_tree + for node in Y_tree.nodes: + if node in X_tree.nodes: + continue + X_tree.nodes.append(node) + + # Now connect the sepset node to the + # Y_node (now residing in the X_tree) + self.Y.node.neighbours.append(ss_node) + ss_node.neighbours.append(self.Y.node) + + # And finally we must remove the Y_tree from + # the forest... + forest.remove(Y_tree) + + def __repr__(self): + return "SepSet_%s" % "".join( + # [x.name[2:].upper() for x in list(self.label)]) + [x.variable_name.upper() for x in list(self.label)] + ) + + +class JoinTreeSepSetNode(UndirectedNode): + def __init__(self, name, sepset): + super(JoinTreeSepSetNode, self).__init__(name) + self.sepset = sepset + self.potential_psi = None + + @property + def variable_names(self): + """Return the set of variable names + that this sepset represents""" + # TODO: we are assuming here + # that X and Y are each separate + # variables from the BBN which means + # we are assuming that the sepsets + # always contain only 2 nodes. + # Need to check whether this is + # the case. + return sorted([x.variable_name for x in self.sepset.label]) + + def __repr__(self): + return "<JoinTreeSepSetNode: %s>" % self.sepset + + +def build_bbn(*args, **kwds): + """Builds a BBN Graph from + a list of functions and domains""" + variables = set() + domains = kwds.get("domains", {}) + name = kwds.get("name") + factor_nodes = dict() + + if isinstance(args[0], list): + # Assume the functions were all + # passed in a list in the first + # argument. This makes it possible + # to build very large graphs with + # more than 255 functions, since + # Python functions are limited to + # 255 arguments. + args = args[0] + + for factor in args: + factor_args = get_args(factor) + variables.update(factor_args) + bbn_node = BBNNode(factor) + factor_nodes[factor.__name__] = bbn_node + + # Now lets create the connections + # To do this we need to find the + # factor node representing the variables + # in a child factors argument and connect + # it to the child node. + + # Note that calling original_factors + # here can break build_bbn if the + # factors do not correctly represent + # a BBN. + original_factors = get_original_factors(list(factor_nodes.values())) + for factor_node in list(factor_nodes.values()): + factor_args = get_args(factor_node) + parents = [ + original_factors[arg] + for arg in factor_args + if original_factors[arg] != factor_node + ] + for parent in parents: + connect(parent, factor_node) + bbn = BBN(original_factors, name=name) + bbn.domains = domains + + return bbn + + +def make_node_func(variable_name, conditions): + # We will enforce the following + # convention. + # The ordering of arguments will + # be firstly the parent variables + # in alphabetical order, followed + # always by the child variable + tt = dict() + domain = set() + for givens, conditionals in conditions: + key = [] + for parent_name, val in sorted(givens): + key.append((parent_name, val)) + # Now we will sort the + # key before we add the child + # node. + # key.sort(key=lambda x: x[0]) + + # Now for each value in + # the conditional probabilities + # we will add a new key + for value, prob in list(conditionals.items()): + key_ = tuple(key + [(variable_name, value)]) + domain.add(value) + tt[key_] = prob + + argspec = [k[0] for k in key_] + + def node_func(*args): + key = [] + for arg, val in zip(argspec, args): + key.append((arg, val)) + return tt[tuple(key)] + + node_func.argspec = argspec + node_func._domain = domain + node_func.__name__ = "f_" + variable_name + return node_func + + +def build_bbn_from_conditionals(conds): + node_funcs = [] + domains = dict() + for variable_name, cond_tt in list(conds.items()): + node_func = make_node_func(variable_name, cond_tt) + node_funcs.append(node_func) + domains[variable_name] = node_func._domain + return build_bbn(*node_funcs, domains=domains) + + +def make_undirected_copy(dag): + """Returns an exact copy of the dag + except that direction of edges are dropped.""" + nodes = dict() + for node in dag.nodes: + undirected_node = UndirectedNode(name=node.name) + undirected_node.func = node.func + undirected_node.argspec = node.argspec + undirected_node.variable_name = node.variable_name + nodes[node.name] = undirected_node + # Now we need to traverse the original + # nodes once more and add any parents + # or children as neighbours. + for node in dag.nodes: + for parent in node.parents: + nodes[node.name].neighbours.append(nodes[parent.name]) + nodes[parent.name].neighbours.append(nodes[node.name]) + + g = UndirectedGraph(list(nodes.values())) + return g + + +def make_moralized_copy(gu, dag): + """gu is an undirected graph being + a copy of dag.""" + gm = copy.deepcopy(gu) + gm_nodes = dict([(node.name, node) for node in gm.nodes]) + for node in dag.nodes: + for parent_1, parent_2 in combinations(node.parents, 2): + if gm_nodes[parent_1.name] not in gm_nodes[parent_2.name].neighbours: + gm_nodes[parent_2.name].neighbours.append(gm_nodes[parent_1.name]) + if gm_nodes[parent_2.name] not in gm_nodes[parent_1.name].neighbours: + gm_nodes[parent_1.name].neighbours.append(gm_nodes[parent_2.name]) + return gm + + +def priority_func(node): + """Specify the rules for computing + priority of a node. See Harwiche and Wang pg 12. + """ + # We need to calculate the number of edges + # that would be added. + # For each node, we need to connect all + # of the nodes in itself and its neighbours + # (the "cluster") which are not already + # connected. This will be the primary + # key value in the heap. + # We need to fix the secondary key, right + # now its just 2 (because mostly the variables + # will be discrete binary) + introduced_arcs = 0 + cluster = [node] + node.neighbours + for node_a, node_b in combinations(cluster, 2): + if node_a not in node_b.neighbours: + assert node_b not in node_a.neighbours + introduced_arcs += 1 + return [introduced_arcs, 2] # TODO: Fix this to look at domains + + +def construct_priority_queue(nodes, priority_func=priority_func): + pq = [] + for node_name, node in nodes.items(): + entry = priority_func(node) + [node.name] + heapq.heappush(pq, entry) + return pq + + +def record_cliques(cliques, cluster): + """We only want to save the cluster + if it is not a subset of any clique + already saved. + Argument cluster must be a set""" + if any([cluster.issubset(c.nodes) for c in cliques]): + return + cliques.append(Clique(cluster)) + + +def triangulate(gm, priority_func=priority_func): + """Triangulate the moralized Graph. (in Place) + and return the cliques of the triangulated + graph as well as the elimination ordering.""" + + # First we will make a copy of gm... + gm_ = copy.deepcopy(gm) + + # Now we will construct a priority q using + # the standard library heapq module. + # See docs for example of priority q tie + # breaking. We will use a 3 element list + # with entries as follows: + # - Number of edges added if V were selected + # - Weight of V (or cluster) + # - Pointer to node in gm_ + # Note that its unclear from Huang and Darwiche + # what is meant by the "number of values of V" + gmnodes = dict([(node.name, node) for node in gm.nodes]) + elimination_ordering = [] + cliques = [] + while True: + gm_nodes = dict([(node.name, node) for node in gm_.nodes]) + if not gm_nodes: + break + pq = construct_priority_queue(gm_nodes, priority_func) + # Now we select the first node in + # the priority q and any arcs that + # should be added in order to fully connect + # the cluster should be added to both + # gm and gm_ + v = gm_nodes[pq[0][2]] + cluster = [v] + v.neighbours + for node_a, node_b in combinations(cluster, 2): + if node_a not in node_b.neighbours: + node_b.neighbours.append(node_a) + node_a.neighbours.append(node_b) + # Now also add this new arc to gm... + gmnodes[node_b.name].neighbours.append(gmnodes[node_a.name]) + gmnodes[node_a.name].neighbours.append(gmnodes[node_b.name]) + gmcluster = set([gmnodes[c.name] for c in cluster]) + record_cliques(cliques, gmcluster) + # Now we need to remove v from gm_... + # This means we also have to remove it from all + # of its neighbours that reference it... + for neighbour in v.neighbours: + neighbour.neighbours.remove(v) + gm_.nodes.remove(v) + elimination_ordering.append(v.name) + return cliques, elimination_ordering + + +def build_join_tree(dag, clique_priority_func=priority_func): + # First we will create an undirected copy + # of the dag + gu = make_undirected_copy(dag) + + # Now we create a copy of the undirected graph + # and connect all pairs of parents that are + # not already parents called the 'moralized' graph. + gm = make_moralized_copy(gu, dag) + + # Now we triangulate the moralized graph... + cliques, elimination_ordering = triangulate(gm, clique_priority_func) + + # Now we initialize the forest and sepsets + # Its unclear from Darwiche Huang whether we + # track a sepset for each tree or whether its + # a global list???? + # We will implement the Join Tree as an undirected + # graph for now... + + # First initialize a set of graphs where + # each graph initially consists of just one + # node for the clique. As these graphs get + # populated with sepsets connecting them + # they should collapse into a single tree. + forest = set() + for clique in cliques: + jt_node = JoinTreeCliqueNode(clique) + # Track a reference from the clique + # itself to the node, this will be + # handy later... (alternately we + # could just collapse clique and clique + # node into one class... + clique.node = jt_node + tree = JoinTree([jt_node]) + forest.add(tree) + + # Initialize the SepSets + S = set() # track the sepsets + for X, Y in combinations(cliques, 2): + if X.nodes.intersection(Y.nodes): + S.add(SepSet(X, Y)) + sepsets_inserted = 0 + while sepsets_inserted < (len(cliques) - 1): + # Adding in name to make this sort deterministic + deco = [(s, -1 * s.mass, s.cost, s.__repr__()) for s in S] + deco.sort(key=lambda x: x[1:]) + candidate_sepset = deco[0][0] + for candidate_sepset, _, _, _ in deco: + if candidate_sepset.insertable(forest): + # Insert into forest and remove the sepset + candidate_sepset.insert(forest) + S.remove(candidate_sepset) + sepsets_inserted += 1 + break + + assert len(forest) == 1 + jt = list(forest)[0] + return jt diff --git a/causalnex/ebaybbn/exceptions.py b/causalnex/ebaybbn/exceptions.py new file mode 100644 index 0000000..b515d3a --- /dev/null +++ b/causalnex/ebaybbn/exceptions.py @@ -0,0 +1,100 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +class InvalidGraphException(Exception): + """ + Raised if the graph verification + method fails. + """ + + pass + + +class InvalidSampleException(Exception): + """Should be raised if a + sample is invalid.""" + + pass + + +class InvalidInferenceMethod(Exception): + """Raise if the user tries to set + the inference method to an unknown string.""" + + pass + + +class InsufficientSamplesException(Exception): + """Raised when the inference method + is 'sample_db' and there are less + pre-generated samples than the + graphs n_samples attribute.""" + + pass + + +class NoSamplesInDB(Warning): + pass + + +class VariableNotInGraphError(Exception): + """Exception raised when + a graph is queried with + a variable that is not part of + the graph. + """ + + pass + + +class VariableValueNotInDomainError(Exception): + """Raised when a BBN is queried with + a value for a variable that is not within + that variables domain.""" + + pass + + +class IncorrectInferenceMethodError(Exception): + """Raise when attempt is made to + generate samples when the inference + method is not 'sample_db' + """ + + pass diff --git a/causalnex/ebaybbn/graph.py b/causalnex/ebaybbn/graph.py new file mode 100644 index 0000000..df7ac01 --- /dev/null +++ b/causalnex/ebaybbn/graph.py @@ -0,0 +1,74 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Generic Graph Classes""" + + +class Node(object): + def __init__(self, name, parents=[], children=[]): + self.name = name + self.parents = parents[:] + self.children = children[:] + + def __repr__(self): + return "<Node %s>" % self.name + + +class UndirectedNode(object): + def __init__(self, name, neighbours=[]): + self.name = name + self.neighbours = neighbours[:] + + def __repr__(self): + return "<UndirectedNode %s>" % self.name + + +class UndirectedGraph(object): + def __init__(self, nodes, name=None): + self.nodes = nodes + self.name = name + + +def connect(parent, child): + """ + Make an edge between a parent + node and a child node. + a - parent + b - child + """ + parent.children.append(child) + child.parents.append(parent) diff --git a/causalnex/ebaybbn/utils.py b/causalnex/ebaybbn/utils.py new file mode 100644 index 0000000..16265f0 --- /dev/null +++ b/causalnex/ebaybbn/utils.py @@ -0,0 +1,94 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Some Useful Helper Functions""" +import inspect + +# TODO: Find a better location for get_args + + +def get_args(func): + """ + Return the names of the arguments + of a function as a list of strings. + This is so that we can omit certain + variables when we marginalize. + Note that functions created by + make_product_func do not return + an argspec, so we add a argspec + attribute at creation time. + """ + + if hasattr(func, "argspec"): + return func.argspec + # return inspect.getargspec(func).args + return [p for p in inspect.signature(func).parameters] + + +def make_key(*args): + """Handy for short truth table keys""" + key = "" + for a in args: + if hasattr(a, "value"): + raise ValueError("Unexpected type") + else: + key += str(a).lower()[0] + return key + + +def get_original_factors(factors): + """ + For a set of factors, we want to + get a mapping of the variables to + the factor which first introduces the + variable to the set. + To do this without enforcing a special + naming convention such as 'f_' for factors, + or a special ordering, such as the last + argument is always the new variable, + we will have to discover the 'original' + factor that introduces the variable + iteratively. + """ + original_factors = dict() + while len(original_factors) < len(factors): + for factor in factors: + args = get_args(factor) + unaccounted_args = [a for a in args if a not in original_factors] + if len(unaccounted_args) == 1: + original_factors[unaccounted_args[0]] = factor + return original_factors diff --git a/causalnex/evaluation/__init__.py b/causalnex/evaluation/__init__.py new file mode 100644 index 0000000..e450aea --- /dev/null +++ b/causalnex/evaluation/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``causalnex.evaluation`` provides functionality to evaluate causal models using standard metrics. +""" + +__version__ = "0.4.0" + +__all__ = ["roc_auc", "classification_report"] + +from .evaluation import classification_report, roc_auc diff --git a/causalnex/evaluation/evaluation.py b/causalnex/evaluation/evaluation.py new file mode 100644 index 0000000..2636c0b --- /dev/null +++ b/causalnex/evaluation/evaluation.py @@ -0,0 +1,207 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +"""Evaluation metrics for causal models.""" + +from typing import List, Tuple + +import pandas as pd +from sklearn import metrics + +from causalnex.network import BayesianNetwork + + +def _build_ground_truth( + bn: BayesianNetwork, data: pd.DataFrame, node: str +) -> pd.DataFrame: + + ground_truth = pd.get_dummies(data[node]) + + # it's possible that not all states are present in the test set, so we need to add them to ground truth + for dummy in bn.node_states[node]: + if dummy not in ground_truth.columns: + ground_truth[dummy] = [0 for _ in range(len(ground_truth))] + + # update ground truth column names to be correct, since we may have added missing columns + return ground_truth[sorted(ground_truth.columns)] + + +def roc_auc( + bn: BayesianNetwork, data: pd.DataFrame, node: str +) -> Tuple[List[Tuple[float, float]], float]: + """ + Build a report of the micro-average Receiver-Operating Characteristics (ROC), and the Area Under the ROC curve + Micro-average computes roc_auc over all predictions for all states of node. + + Args: + bn (BayesianNetwork): model to compute roc_auc. + data (pd.DataFrame): test data that will be used to calculate ROC. + node (str): name of the variable to generate the report for. + + Returns: + roc - auc tuple + - roc (List[Tuple[float, float]]): list of [(fpr, tpr)] observations. + - auc float: auc for the node predictions. + + Example: + :: + >>> from causalnex.structure import StructureModel + >>> from causalnex.network import BayesianNetwork + >>> + >>> sm = StructureModel() + >>> sm.add_edges_from([ + >>> ('rush_hour', 'traffic'), + >>> ('weather', 'traffic') + >>> ]) + >>> bn = BayesianNetwork(sm) + >>> import pandas as pd + >>> data = pd.DataFrame({ + >>> 'rush_hour': [True, False, False, False, True, False, True], + >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], + >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] + >>> } + >>> bn = bn.fit_node_states_and_cpds(data) + >>> test_data = pd.DataFrame({ + >>> 'rush_hour': [False, False, True, True], + >>> 'weather': ['Good', 'Bad', 'Good', 'Bad'], + >>> 'traffic': ['light', 'heavy', 'heavy', 'light'] + >>> }) + >>> from causalnex.evaluation import roc_auc + >>> roc, auc = roc_auc(bn, test_data, "traffic") + >>> print(auc) + 0.75 + """ + + ground_truth = _build_ground_truth(bn, data, node) + predictions = bn.predict_probability(data, node) + + # update column names to match those of ground_truth + predictions.rename(columns=lambda x: x.lstrip(node + "_"), inplace=True) + predictions = predictions[sorted(predictions.columns)] + + fpr, tpr, _ = metrics.roc_curve( + ground_truth.values.ravel(), predictions.values.ravel() + ) + roc = list(zip(fpr, tpr)) + auc = metrics.auc(fpr, tpr) + + return roc, auc + + +def classification_report( + bn: BayesianNetwork, data: pd.DataFrame, node: str +) -> pd.DataFrame: + """ + Build a report showing the main classification metrics. + + Args: + bn (BayesianNetwork): model to compute classification report using. + data (pd.DataFrame): test data that will be used for predictions. + node (str): name of the variable to generate report for. + + Returns: + Text summary of the precision, recall, F1 score for each class. + + The reported averages include micro average (averaging the + total true positives, false negatives and false positives), macro + average (averaging the unweighted mean per label), weighted average + (averaging the support-weighted mean per label) and sample average + (only for multilabel classification). + + Note that in binary classification, recall of the positive class + is also known as "sensitivity"; recall of the negative class is + "specificity". + + Example: + :: + >>> from causalnex.structure import StructureModel + >>> from causalnex.network import BayesianNetwork + >>> + >>> sm = StructureModel() + >>> sm.add_edges_from([ + >>> ('rush_hour', 'traffic'), + >>> ('weather', 'traffic') + >>> ]) + >>> bn = BayesianNetwork(sm) + >>> import pandas as pd + >>> data = pd.DataFrame({ + >>> 'rush_hour': [True, False, False, False, True, False, True], + >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], + >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] + >>> } + >>> bn = bn.fit_node_states_and_cpds(data) + >>> test_data = pd.DataFrame({ + >>> 'rush_hour': [False, False, True, True], + >>> 'weather': ['Good', 'Bad', 'Good', 'Bad'], + >>> 'traffic': ['light', 'heavy', 'heavy', 'light'] + >>> }) + >>> from causalnex.evaluation import classification_report + >>> classification_report(bn, test_data, "traffic").to_dict() + {'precision': { + 'macro avg': 0.8333333333333333, 'micro avg': 0.75, + 'traffic_heavy': 0.6666666666666666, + 'traffic_light': 1.0, + 'weighted avg': 0.8333333333333333 + }, + 'recall': { + 'macro avg': 0.75, + 'micro avg': 0.75, + 'traffic_heavy': 1.0, + 'traffic_light': 0.5, + 'weighted avg': 0.75 + }, + 'f1-score': { + 'macro avg': 0.7333333333333334, + 'micro avg': 0.75, + 'traffic_heavy': 0.8, + 'traffic_light': 0.6666666666666666, + 'weighted avg': 0.7333333333333334 + }, + 'support': { + 'macro avg': 4, + 'micro avg': 4, + 'traffic_heavy': 2, + 'traffic_light': 2, + 'weighted avg': 4 + }} + """ + + predictions = bn.predict(data, node) + + labels = sorted(list(bn.node_states[node])) + target_names = [ + "{0}_{1}".format(node, str(v)) for v in sorted(bn.node_states[node]) + ] + report = metrics.classification_report( + y_true=data[node], + y_pred=predictions, + labels=labels, + target_names=target_names, + output_dict=True, + ) + + return pd.DataFrame.from_dict(report, orient="index") diff --git a/causalnex/inference/__init__.py b/causalnex/inference/__init__.py new file mode 100644 index 0000000..4628b6e --- /dev/null +++ b/causalnex/inference/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``causalnex.inference`` provides functionality to make inferences based on interventions and observations. +""" + +__version__ = "0.4.0" + +__all__ = ["InferenceEngine"] + +from .inference import InferenceEngine diff --git a/causalnex/inference/inference.py b/causalnex/inference/inference.py new file mode 100644 index 0000000..352dda4 --- /dev/null +++ b/causalnex/inference/inference.py @@ -0,0 +1,333 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the implementation of ``InferenceEngine``. + +``InferenceEngine`` provides tools to make inferences based on interventions and observations. +""" + +import copy +import inspect +import re +import types +from typing import Callable, Dict, Hashable, Tuple, Union + +import pandas as pd + +from causalnex.ebaybbn import build_bbn +from causalnex.network import BayesianNetwork + + +class InferenceEngine: + """ + An ``InferenceEngine`` provides methods to query marginals based on observations and + make interventions (Do-Calculus) on a ``BayesianNetwork``. + + Example: + :: + >>> # Create a Bayesian Network with a manually defined DAG + >>> from causalnex.structure.structuremodel import StructureModel + >>> from causalnex.network import BayesianNetwork + >>> from causalnex.inference import InferenceEngine + >>> + >>> sm = StructureModel() + >>> sm.add_edges_from([ + >>> ('rush_hour', 'traffic'), + >>> ('weather', 'traffic') + >>> ]) + >>> data = pd.DataFrame({ + >>> 'rush_hour': [True, False, False, False, True, False, True], + >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], + >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] + >>> }) + >>> bn = BayesianNetwork(sm) + >>> # Inference can only be performed on the `BayesianNetwork` with learned nodes states and CPDs + >>> bn = bn.fit_node_states_and_cpds(data) + >>> + >>> # Create an `InferenceEngine` to query marginals and make interventions + >>> ie = InferenceEngine(bn) + >>> # Query the marginals as learned from data + >>> ie.query()['traffic'] + {'heavy': 0.7142857142857142, 'light': 0.2857142857142857} + >>> # Query the marginals given observations + >>> ie.query({'rush_hour': True, 'weather': 'Terrible'})['traffic'] + {'heavy': 1.0, 'light': 0.0} + >>> # Make an intervention on the `BayesianNetwork` + >>> ie.do_intervention('rush_hour', False) + >>> # Query marginals on the intervened `BayesianNetwork` + >>> ie.query()['traffic'] + {'heavy': 0.5, 'light': 0.5} + >>> # Reset interventions + >>> ie.reset_do('rush_hour') + >>> ie.query()['traffic'] + {'heavy': 0.7142857142857142, 'light': 0.2857142857142857} + """ + + def __init__(self, bn: BayesianNetwork): + """ + Create a new ``InferenceEngine`` from an existing ``BayesianNetwork``. + + It is expected that structure and probability distribution has already been learned + for the ``BayesianNetwork`` that is to be used for inference. + This Bayesian Network cannot contain any isolated nodes. + + Args: + bn: Bayesian Network that inference will act on. + + Raises: + ValueError: if the Bayesian Network contains isolates, or if a variable name is invalid, + or if the CPDs have not been learned yet. + """ + + bad_nodes = [node for node in bn.nodes if not re.match("^[0-9a-zA-Z_]+$", node)] + if bad_nodes: + raise ValueError( + "Variable names must match ^[0-9a-zA-Z_]+$ - please fix the " + "following nodes: {0}".format(bad_nodes) + ) + + if not bn.cpds: + raise ValueError( + "Bayesian Network does not contain any CPDs. You should fit CPDs " + "before doing inference (see `BayesianNetwork.fit_cpds`)." + ) + + self._cpds = None + + self._create_cpds_dict_bn(bn) + self._generate_domains_bn(bn) + self._generate_bbn() + + def query( + self, observations: Dict[str, Hashable] = None + ) -> Dict[str, Dict[Hashable, float]]: + """ + Query the ``BayesianNetwork`` for marginals given some observations. + + Args: + observations: observed states of nodes in the Bayesian Network. + For instance, query({"node_a": 1, "node_b": 3}) + If None or {}, the marginals for all nodes in the ``BayesianNetwork`` are returned. + + Returns: + A dictionary of marginal probabilities of the network. + For instance, :math:`P(a=1) = 0.3, P(a=2) = 0.7` -> {a: {1: 0.3, 2: 0.7}} + """ + bbn_results = ( + self._bbn.query(**observations) if observations else self._bbn.query() + ) + + results = {node: dict() for node in self._cpds} + for (node, state), prob in bbn_results.items(): + results[node][state] = prob + + return results + + def _do(self, observation: str, state: Dict[Hashable, float]) -> None: + """ + Makes an intervention on the Bayesian Network. + + Args: + observation: observation that the intervention is on. + state: mapping of state -> probability. + + Raises: + ValueError: if states do not match original states of the node, or probabilities do not sum to 1. + """ + + if sum(state.values()) != 1.0: + raise ValueError("The cpd for the provided observation must sum to 1") + + if not set(state.keys()) == set(self._cpds_original[observation]): + raise ValueError( + "The cpd states do not match expected states: expected {expected}, found {found}".format( + expected=set(self._cpds_original[observation]), + found=set(state.keys()), + ) + ) + + self._cpds[observation] = {s: {(): p} for s, p in state.items()} + + def do_intervention( + self, node: str, state: Union[Hashable, Dict[Hashable, float]] = None + ) -> None: + """ + Make an intervention on the Bayesian Network. + + For instance, + `do_intervention('X', 'x')` will set :math:`P(X=x)` to 1, and :math:`P(X=y)` to 0 + `do_intervention('X', {'x': 0.2, 'y': 0.8})` will set :math:`P(X=x)` to 0.2, and :math:`P(X=y)` to 0.8 + + Args: + node: the node that the intervention acts upon. + state: state to update node it. + - if Hashable: the intervention updates the state to 1, and all other states to 0; + - if Dict[Hashable, float]: update states to all state -> probabilitiy in the dict. + + Raises: + ValueError: if performing intervention would create an isolated node. + """ + if not any( + [ + node in inspect.getargs(f.__code__)[0][1:] + for _, f in self._node_functions.items() + ] + ): + raise ValueError( + "Do calculus cannot be applied because it would result in an isolate" + ) + + if isinstance(state, int): + state = {s: float(s == state) for s in self._cpds[node]} + + self._do(node, state) + self._generate_bbn() + + def reset_do(self, observation: str) -> None: + """ + Resets any do_interventions that have been applied to the observation. + + Args: + observation: observation that will be reset. + """ + + self._cpds[observation] = self._cpds_original[observation] + self._generate_bbn() + + def _generate_bbn(self): + """Re-create the _bbn.""" + self._node_functions = self._create_node_functions() + + self._bbn = build_bbn( + list(self._node_functions.values()), domains=self._domains + ) + + def _generate_domains_bn(self, bn): + + self._domains = { + variable: list(cpd.index.values) for variable, cpd in bn.cpds.items() + } + + def _create_cpds_dict_bn(self, bn: BayesianNetwork) -> None: + """ + Map CPDs in the ``BayesianNetwork`` to required format: + + >>> {"observation": + >>> {"state": + >>> {(("condition1_observation", "condition1_state"), ("conditionN_observation", "conditionN_state")): + >>> "probability" + >>> } + >>> } + + For example, :math:`P( Colour=red | Make=fender, Model=stratocaster) = 0.4`: + >>> {"colour": + >>> {"red": + >>> {(("make", "fender"), ("model", "stratocaster")): + >>> 0.4 + >>> } + >>> } + >>> } + """ + + lookup = { + variable: { + state: { + tuple(zip(cpd.columns.names, parent_value)): cpd.loc[state][ + parent_value + ] + for parent_value in pd.MultiIndex.from_frame(cpd).names + } + for state in cpd.index.values + } + for variable, cpd in bn.cpds.items() + } + + self._cpds = lookup + self._cpds_original = copy.deepcopy(self._cpds) + + def _create_node_function(self, name: str, args: Tuple[str]): + """Creates a new function that describes a node in the ``BayesianNetwork``.""" + + def template() -> float: + """Template node function.""" + # use inspection to determine arguments to the function + # initially there are none present, but caller will add appropriate arguments to the function + # getargvalues was "inadvertently marked as deprecated in Python 3.5" + # https://docs.python.org/3/library/inspect.html#inspect.getfullargspec + arg_spec = inspect.getargvalues( # pylint: disable=deprecated-method + inspect.currentframe() + ) + + return self._cpds[arg_spec.args[0]][ # target name + arg_spec.locals[arg_spec.args[0]] + ][ # target state + tuple([(arg, arg_spec.locals[arg]) for arg in arg_spec.args[1:]]) + ] # conditions + + code = template.__code__ + template.__code__ = types.CodeType( + len(args), + code.co_kwonlyargcount, + len(args), + code.co_stacksize, + code.co_flags, + code.co_code, + code.co_consts, + code.co_names, + args, + code.co_filename, + name, + code.co_firstlineno, + code.co_lnotab, + code.co_freevars, + code.co_cellvars, + ) + template.__name__ = name + + return template + + def _create_node_functions(self) -> Dict[str, Callable]: + """Creates all functions required to create a ``BayesianNetwork``.""" + + node_functions = dict() + + for node, states in self._cpds.items(): + # since we only need condition names, which are consistent across all states, + # then we can inspect the 0th element + states_conditions = list(states.values())[0] + + # take any state, and get its conditions + state_conditions = list(states_conditions.items())[0] + condition_nodes = [n for n, v in state_conditions[0]] + + node_args = tuple([node] + condition_nodes) # type: Tuple[str] + function_name = "f_{node}".format(node=node) + node_function = self._create_node_function(function_name, node_args) + node_functions[node] = node_function + + return node_functions diff --git a/causalnex/network/__init__.py b/causalnex/network/__init__.py new file mode 100644 index 0000000..bcd9ff2 --- /dev/null +++ b/causalnex/network/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``causalnex.network`` provides functionality to learn joint probability distribution of networks. +""" + +__version__ = "0.4.0" + +__all__ = ["BayesianNetwork"] + +from .network import BayesianNetwork diff --git a/causalnex/network/network.py b/causalnex/network/network.py new file mode 100644 index 0000000..84fe5fa --- /dev/null +++ b/causalnex/network/network.py @@ -0,0 +1,572 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the implementation of ``BayesianNetwork``. + +``BayesianNetwork`` is a class that represents a probabilistic, weighted, directed acyclic graph (DAG) +describing causal relationships between variables and their distribution in a factorised way. +""" + +import re +from typing import Dict, Hashable, List, Set, Tuple + +import networkx as nx +import pandas as pd +from pgmpy.estimators import BayesianEstimator, MaximumLikelihoodEstimator +from pgmpy.models import BayesianModel + +from causalnex.structure import StructureModel + + +class BayesianNetwork: + """ + Base class for Bayesian Network (BN), a probabilistic weighted DAG where nodes represent variables, + edges represent the causal relationships between variables. + + ``BayesianNetwork`` stores nodes with their possible states, edges and + conditional probability distributions (CPDs) of each node. + + ``BayesianNetwork`` is built on top of the ``StructureModel``, which is an extension of ``networkx.DiGraph`` + (see :func:`causalnex.structure.structuremodel.StructureModel`). + + In order to define the ``BayesianNetwork``, users should provide a relevant ``StructureModel``. + Once ``BayesianNetwork`` is initialised, no changes to the ``StructureModel`` can be made + and CPDs can be learned from the data. + + The learned CPDs can be then used for likelihood estimation and predictions. + + Example: + :: + >>> # Create a Bayesian Network with a manually defined DAG. + >>> from causalnex.structure import StructureModel + >>> from causalnex.network import BayesianNetwork + >>> + >>> sm = StructureModel() + >>> sm.add_edges_from([ + >>> ('rush_hour', 'traffic'), + >>> ('weather', 'traffic') + >>> ]) + >>> bn = BayesianNetwork(sm) + >>> # A created ``BayesianNetwork`` stores nodes and edges defined by the ``StructureModel`` + >>> bn.nodes + ['rush_hour', 'traffic', 'weather'] + >>> + >>> bn.edges + [('rush_hour', 'traffic'), ('weather', 'traffic')] + >>> # A ``BayesianNetwork`` doesn't store any CPDs yet + >>> bn.cpds + >>> {} + >>> + >>> # Learn the nodes' states from the data + >>> import pandas as pd + >>> data = pd.DataFrame({ + >>> 'rush_hour': [True, False, False, False, True, False, True], + >>> 'weather': ['Terrible', 'Good', 'Bad', 'Good', 'Bad', 'Bad', 'Good'], + >>> 'traffic': ['heavy', 'light', 'heavy', 'light', 'heavy', 'heavy', 'heavy'] + >>> }) + >>> bn = bn.fit_node_states(data) + >>> bn.node_states + {'rush_hour': {False, True}, 'weather': {'Bad', 'Good', 'Terrible'}, 'traffic': {'heavy', 'light'}} + >>> # Learn the CPDs from the data + >>> bn = bn.fit_cpds(data) + >>> # Use the learned CPDs to make predictions on the unseen data + >>> test_data = pd.DataFrame({ + >>> 'rush_hour': [False, False, True, True], + >>> 'weather': ['Good', 'Bad', 'Good', 'Bad'] + >>> }) + >>> bn.predict(test_data, "traffic").to_dict() + >>> {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} + >>> bn.predict_probability(test_data, "traffic").to_dict() + {'traffic_prediction': {0: 'light', 1: 'heavy', 2: 'heavy', 3: 'heavy'}} + {'traffic_light': {0: 0.75, 1: 0.25, 2: 0.3333333333333333, 3: 0.3333333333333333}, + 'traffic_heavy': {0: 0.25, 1: 0.75, 2: 0.6666666666666666, 3: 0.6666666666666666}} + """ + + def __init__(self, structure: StructureModel): + """ + Create a ``BayesianNetwork`` with a DAG defined by ``StructureModel``. + + Args: + structure: a graph representing a causal relationship between variables. + In the structure + - cycles are not allowed; + - multiple (parallel) edges are not allowed; + - isolated nodes and multiple components are not allowed. + + Raises: + ValueError: If the structure is not a connected DAG. + """ + n_components = nx.number_weakly_connected_components(structure) + + if n_components > 1: + raise ValueError( + "The given structure has {n_components} separated graph components. " + "Please make sure it has only one.".format(n_components=n_components) + ) + + if not nx.is_directed_acyclic_graph(structure): + cycle = nx.find_cycle(structure) + raise ValueError( + "The given structure is not acyclic. Please review the following cycle: {cycle}".format( + cycle=cycle + ) + ) + + # _node_states is a Dict in the form `dict: {node: dict: {state: index}}`. + # Underlying libraries expect all states to be integers from zero, and + # thus this dict is used to convert from state -> idx, and then back from idx -> state as required + self._node_states = None # type: Dict[str: Dict[Hashable, int]] + self._structure = structure + + # _model is a pgmpy Bayesian Model. + # It is used for: + # - probability fitting + # - predictions + self._model = BayesianModel() + self._model.add_edges_from(structure.edges) + + @property + def structure(self) -> StructureModel: + """ + ``StructureModel`` defining the DAG of the Bayesian Network. + + Returns: + A ``StructureModel`` of the Bayesian Network. + """ + return self._structure + + @property + def nodes(self) -> List[str]: + """ + List of all nodes contained within the Bayesian Network. + + Returns: + A list of node names. + """ + return list(self._model.nodes) + + @property + def node_states(self) -> Dict[str, Set[Hashable]]: + """ + Dictionary of all states that each node can take. + + Returns: + A dictionary of node and its possible states, in format of `dict: {node: state}`. + """ + return {node: set(states.keys()) for node, states in self._node_states.items()} + + @node_states.setter + def node_states(self, nodes: Dict[str, Set[Hashable]]): + """ + Set the list of nodes that are contained within the Bayesian Network. + The states of all nodes must be provided. + + Args: + nodes: A dictionary of node and its possible states, in format of `dict: {node: state}`. + + Raises: + ValueError: if a node contains a None state. + KeyError: if a node is missing. + """ + missing_feature = set(self.nodes).difference(set(nodes.keys())) + if missing_feature: + raise KeyError( + "The data does not cover all the features found in the Bayesian Network. " + "Please check the following features: {nodes}".format( + nodes=missing_feature + ) + ) + + for node, states in nodes.items(): + if any(pd.isnull(list(states))): + raise ValueError("node '{node}' contains None state".format(node=node)) + self._node_states = { + n: {v: k for k, v in enumerate(sorted(nodes[n]))} for n in nodes + } + + @property + def edges(self) -> List[Tuple[str, str]]: + """ + List of all edges contained within the Bayesian Network, as a Tuple(from_node, to_node). + + Returns: + A list of all edges. + """ + return list(self._model.edges) + + @property + def cpds(self) -> Dict[str, pd.DataFrame]: + """ + Conditional Probability Distributions of each node within the Bayesian Network. + + The row-index of each dataframe is all possible states for the node. + The col-index of each dataframe is a MultiIndex that describes all possible permutations of parent states. + + For example, for a node :math:`P(A | B, D)`, where + .. math:: + - A \\in \\text{{"a", "b", "c", "d"}} + - B \\in \\text{{"x", "y", "z"}} + - C \\in \\text{{False, True}} + + >>> b x y z + >>> d False True False True False True + >>> a + >>> a 0.265306 0.214286 0.066667 0.25 0.444444 0.000000 + >>> b 0.183673 0.214286 0.200000 0.25 0.222222 0.666667 + >>> c 0.285714 0.285714 0.400000 0.25 0.333333 0.333333 + >>> d 0.265306 0.285714 0.333333 0.25 0.000000 0.000000 + + Returns: + Conditional Probability Distributions of each node within the Bayesian Network. + """ + cpds = dict() + for cpd in self._model.cpds: + + iterables = [ + sorted(self._node_states[var].keys()) for var in cpd.variables[1:] + ] + cols = [""] + if iterables: + cols = pd.MultiIndex.from_product(iterables, names=cpd.variables[1:]) + + cpds[cpd.variable] = pd.DataFrame( + cpd.values.reshape( + len(self._node_states[cpd.variable]), max(1, len(cols)) + ) + ) + cpds[cpd.variable][cpd.variable] = sorted( + self._node_states[cpd.variable].keys() + ) + cpds[cpd.variable].set_index([cpd.variable], inplace=True) + cpds[cpd.variable].columns = cols + + return cpds + + def fit_node_states(self, df: pd.DataFrame) -> "BayesianNetwork": + """ + Fit all states of nodes that can appear in the data. + The dataframe provided should contain every possible state (values that can be taken) for every column. + + Args: + df: data to fit node states from. Each column indicates a node and each row + an observed combination of states. + + Returns: + self + + Raises: + ValueError: if dataframe contains any missing data. + """ + self.node_states = {c: set(df[c].unique()) for c in df.columns} + + return self + + def _state_to_index( + self, df: pd.DataFrame, nodes: List[str] = None + ) -> pd.DataFrame: + """ + Transforms all values in df to an integer, as defined by the mapping from fit_node_states. + + Args: + df: data to transform + nodes: list of nodes to map to index. None means all. + + Returns: + The transformed dataframe. + + Raises: + ValueError: if nodes have not been fit, or if column names do not match node names. + """ + + df.is_copy = False + cols = nodes if nodes else df.columns + for col in cols: + df[col] = df[col].map(self._node_states[col]) + df.is_copy = True + return df + + def fit_cpds( + self, + data: pd.DataFrame, + method: str = "MaximumLikelihoodEstimator", + bayes_prior: str = None, + equivalent_sample_size: int = None, + ) -> "BayesianNetwork": + """ + Learn conditional probability distributions for all nodes in the Bayesian Network, conditioned on + their incoming edges (parents). + + Args: + data: dataframe containing one column per node in the Bayesian Network. + method: how to fit probabilities. One of: + - "MaximumLikelihoodEstimator": fit probabilities using Maximum Likelihood Estimation; + - "BayesianEstimator": fit probabilities using Bayesian Parameter Estimation. Use bayes_prior. + bayes_prior: how to construct the Bayesian prior used by method="BayesianEstimator". One of: + - "K2": shorthand for dirichlet where all pseudo_counts are 1 + regardless of variable cardinality; + - "BDeu": equivalent of using Dirichlet and using uniform 'pseudo_counts' of + `equivalent_sample_size / (node_cardinality * np.prod(parents_cardinalities))` + for each node. Use equivelant_sample_size. + equivalent_sample_size: used by BDeu bayes_prior to compute pseudo_counts. + + Returns: + self + + Raises: + ValueError: if an invalid method or bayes_prior is specified. + + """ + + transformed_data = data.copy(deep=True) # type: pd.DataFrame + transformed_data = self._state_to_index(transformed_data[self.nodes]) + + if method == "MaximumLikelihoodEstimator": + self._model.fit(data=transformed_data, estimator=MaximumLikelihoodEstimator) + + elif method == "BayesianEstimator": + valid_bayes_priors = ["BDeu", "K2"] + if bayes_prior not in valid_bayes_priors: + raise ValueError( + "unrecognised bayes_prior, please use on of %s" + % " ".join(valid_bayes_priors) + ) + + self._model.fit( + data=transformed_data, + estimator=BayesianEstimator, + prior_type=bayes_prior, + equivalent_sample_size=equivalent_sample_size, + ) + else: + valid_methods = ["MaximumLikelihoodEstimator", "BayesianEstimator"] + raise ValueError( + "unrecognised method, please use on of %s" % " ".join(valid_methods) + ) + + return self + + def fit_node_states_and_cpds( + self, + data: pd.DataFrame, + method: str = "MaximumLikelihoodEstimator", + bayes_prior: str = None, + equivalent_sample_size: int = None, + ) -> "BayesianNetwork": + """ + Call `fit_node_states` and then `fit_cpds`. + + Args: + data: dataframe containing one column per node in the Bayesian Network. + method: how to fit probabilities. One of: + - "MaximumLikelihoodEstimator": fit probabilities using Maximum Likelihood Estimation; + - "BayesianEstimator": fit probabilities using Bayesian Parameter Estimation. Use bayes_prior. + bayes_prior: how to construct the Bayesian prior used by method="BayesianEstimator". One of: + - "K2": shorthand for dirichlet where all pseudo_counts are 1 + regardless of variable cardinality; + - "BDeu": equivalent of using dirichlet and using uniform 'pseudo_counts' of + `equivalent_sample_size / (node_cardinality * np.prod(parents_cardinalities))` + for each node. Use equivelant_sample_size. + equivalent_sample_size: used by BDeu bayes_prior to compute pseudo_counts. + + Returns: + self + """ + + return self.fit_node_states(data).fit_cpds( + data, method, bayes_prior, equivalent_sample_size + ) + + def predict(self, data: pd.DataFrame, node: str) -> pd.DataFrame: + """ + Predict the state of a node based on some input data, using the Bayesian Network. + + Args: + data: data to make prediction. + node: the node to predict. + + Returns: + A dataframe of predictions, containing a single column name {node}_prediction. + """ + + if all(parent in data.columns for parent in self._model.get_parents(node)): + return self._predict_from_complete_data(data, node) + + return self._predict_from_incomplete_data(data, node) + + def _predict_from_complete_data( + self, data: pd.DataFrame, node: str + ) -> pd.DataFrame: + """ + Predicts state of node given all parents of node exist within data. + This method inspects the CPD of node directly, since all parent states are known. + This avoids traversing the full network to compute marginals. + This method is fast. + + Args: + data: data to make prediction. + node: the node to predict. + + Returns: + A dataframe of predictions, containing a single column named {node}_prediction. + """ + transformed_data = data.copy(deep=True) # type: pd.DataFrame + + parents = sorted(self._model.get_parents(node)) + cpd = self.cpds[node] + + transformed_data[ + "{node}_prediction".format(node=node) + ] = transformed_data.apply( + lambda row: cpd[tuple([row[parent] for parent in parents])].idxmax() + if parents + else cpd[""].idxmax(), + axis=1, + ) + return transformed_data[[node + "_prediction"]] + + def _predict_from_incomplete_data( + self, data: pd.DataFrame, node: str + ) -> pd.DataFrame: + """ + Predicts state of node when some parents of node do not exist within data. + This method uses the pgmpy predict function, which predicts the most likely state for every node + that is not contained within data. + With incomplete data, pgmpy goes beyond parents in the network to determine the most likely predictions. + This method is slow. + + Args: + data: data to make prediction. + node: the node to predict. + + Returns: + A dataframe of predictions, containing a single column name {node}_prediction. + """ + + transformed_data = data.copy(deep=True) # type: pd.DataFrame + self._state_to_index(transformed_data) + + # pgmpy will predict all missing data, so drop column we want to predict + transformed_data.drop(node, axis=1, inplace=True) + + predictions = self._model.predict(transformed_data)[[node]] + + return predictions.rename(columns={node: node + "_prediction"}) + + def predict_probability(self, data: pd.DataFrame, node: str) -> pd.DataFrame: + """ + Predict the probability of each possible state of a node, based on some input data. + + Args: + data: data to make prediction. + node: the node to predict probabilities. + + Returns: + A dataframe of predicted probabilities, contained one column per possible state, named {node}_{state}. + """ + + if all(parent in data.columns for parent in self._model.get_parents(node)): + return self._predict_probability_from_complete_data(data, node) + + return self._predict_probability_from_incomplete_data(data, node) + + def _predict_probability_from_complete_data( + self, data: pd.DataFrame, node: str + ) -> pd.DataFrame: + """ + Predict the probability of each possible state of a node, based on some input data. + This method inspects the CPD of node directly, since all parent states are known. + This avoids traversing the full network to compute marginals. + This method is fast. + + Args: + data: data to make prediction. + node: the node to predict probabilities. + + Returns: + A dataframe of predicted probabilities, contained one column per possible state, named {node}_{state}. + """ + transformed_data = data.copy(deep=True) # type: pd.DataFrame + + parents = sorted(self._model.get_parents(node)) + cpd = self.cpds[node] + + def lookup_probability(row, s): + """Retrieve probability from CPD""" + if parents: + return cpd[tuple([row[parent] for parent in parents])].loc[s] + return cpd.at[s, ""] + + for state in self.node_states[node]: + transformed_data[ + "{n}_{s}".format(n=node, s=state) + ] = transformed_data.apply( + lambda row, st=state: lookup_probability(row, st), axis=1 + ) + + return transformed_data[ + ["{n}_{s}".format(n=node, s=state) for state in self.node_states[node]] + ] + + def _predict_probability_from_incomplete_data( + self, data: pd.DataFrame, node: str + ) -> pd.DataFrame: + """ + Predict the probability of each possible state of a node, based on some input data. + This method uses the pgmpy predict_probability function, which predicts the probability + of every state for every node that is not contained within data. + With incomplete data, pgmpy goes beyond parents in the network to determine the most likely predictions. + This method is slow. + + Args: + data: data to make prediction. + node: the node to predict probabilities. + + Returns: + A dataframe of predicted probabilities, contained one column per possible state, named {node}_{state}. + """ + transformed_data = data.copy(deep=True) # type: pd.DataFrame + self._state_to_index(transformed_data) + + # pgmpy will predict all missing data, so drop column we want to predict + transformed_data.drop(node, axis=1, inplace=True) + + probability = self._model.predict_probability( + transformed_data + ) # type: pd.DataFrame + + # keep only probabilities for the node we are interested in + cols = [] + pattern = re.compile("^{node}_[0-9]+$".format(node=node)) + # disabled open pylint issue (https://github.com/PyCQA/pylint/issues/2962) + for col in probability.columns: # pylint: disable=E1133 + if pattern.match(col): + cols.append(col) + probability = probability[cols] + probability.columns = cols + + return probability diff --git a/causalnex/plots/__init__.py b/causalnex/plots/__init__.py new file mode 100644 index 0000000..0657d97 --- /dev/null +++ b/causalnex/plots/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``causalnex.plots`` provides functionality to visualise structure models. +""" + +__version__ = "0.4.0" + +__all__ = ["plot_structure"] + +from .plots import plot_structure diff --git a/causalnex/plots/plots.py b/causalnex/plots/plots.py new file mode 100644 index 0000000..29bc478 --- /dev/null +++ b/causalnex/plots/plots.py @@ -0,0 +1,116 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +"""Plot Methods.""" +from typing import Dict, List, Tuple + +import matplotlib.pyplot as plt +import networkx as nx +from matplotlib.axes import Axes +from matplotlib.figure import Figure + +from causalnex.structure.structuremodel import StructureModel + + +def _setup_plot(ax: plt.Axes = None, title: str = None) -> (plt.Figure, plt.Axes): + """Initial setup of fig and ax to plot to.""" + + if not ax: + fig = plt.figure() # type: plt.Figure + ax = fig.add_subplot(1, 1, 1) # type: plt.Axes + + if title: + ax.set_title(title) + + return ax.get_figure(), ax + + +def plot_structure( + g: StructureModel, + ax: plt.Axes = None, + title: str = None, + show_labels: bool = True, + node_color: str = "r", + edge_color: str = "k", + label_color: str = "k", + node_positions: Dict[str, List[float]] = None, +) -> Tuple[Figure, Axes, Dict[str, List[float]]]: + """Plot the structure model to visualise the relationships between nodes. + + Args: + g: the structure model to plot. + ax: if provided then figure will be drawn to this Axes, otherwise a new Axes will be created. + title: if provided then the title will be drawn on the plot. + show_labels: if True then node labels will be drawn. + node_color: a single color format string, for example 'r' or '#ff0000'. default "r". + edge_color: a single color format string, for example 'r' or '#ff0000'. default "k". + label_color: a single color format string, for example 'r' or '#ff0000'. default "k". + node_positions: coordinates for node positions, ie {"node_a": [0, 0]}. + + Returns: + fig, ax, node_positions. + + Example: + :: + >>> # Create a Bayesian Network with a manually defined DAG. + >>> from causalnex.structure import StructureModel + >>> from causalnex.network import BayesianNetwork + >>> + >>> sm = StructureModel() + >>> sm.add_edges_from([ + >>> ('rush_hour', 'traffic'), + >>> ('weather', 'traffic') + >>> ]) + >>> from causalnex.plots import plot_structure + >>> plot_structure(sm) + """ + + fig, ax = _setup_plot(ax, title) + + if not node_positions: + node_positions = nx.circular_layout(g) + + node_color = node_color if node_color else "r" + edge_color = edge_color if edge_color else "k" + label_color = label_color if label_color else "k" + + nx.draw_networkx_nodes( + g, node_positions, ax=ax, nodelist=g.nodes, node_color=node_color + ) + + for u, v in g.edges: + nx.draw_networkx_edges( + g, node_positions, ax=ax, edgelist=[(u, v)], edge_color=edge_color + ) + + if show_labels: + nx.draw_networkx_labels(g, node_positions, ax=ax, font_color=label_color) + + ax.set_axis_off() + plt.tight_layout() + + return fig, ax, node_positions diff --git a/causalnex/structure/__init__.py b/causalnex/structure/__init__.py new file mode 100644 index 0000000..b30fc19 --- /dev/null +++ b/causalnex/structure/__init__.py @@ -0,0 +1,37 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +``causalnex.structure`` provides functionality to define or learn structure. +""" + +__version__ = "0.4.0" + +__all__ = ["StructureModel", "notears"] + +from .structuremodel import StructureModel diff --git a/causalnex/structure/notears.py b/causalnex/structure/notears.py new file mode 100644 index 0000000..febc99f --- /dev/null +++ b/causalnex/structure/notears.py @@ -0,0 +1,553 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are derived from a repository under Apache 2.0: +# DAGs with NO TEARS. +# @inproceedings{zheng2018dags, +# author = {Zheng, Xun and Aragam, Bryon and Ravikumar, Pradeep and Xing, Eric P.}, +# booktitle = {Advances in Neural Information Processing Systems}, +# title = {{DAGs with NO TEARS: Continuous Optimization for Structure Learning}}, +# year = {2018}, +# codebase = {https://github.com/xunzheng/notears} +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tools to learn a ``StructureModel`` which describes the conditional dependencies between variables in a dataset. +""" + +import logging +import warnings +from copy import deepcopy +from typing import List, Tuple + +import numpy as np +import pandas as pd +import scipy.linalg as slin +import scipy.optimize as sopt + +from causalnex.structure.structuremodel import StructureModel + +__all__ = ["from_numpy", "from_pandas", "from_numpy_lasso", "from_pandas_lasso"] + + +def from_numpy( + X: np.ndarray, + max_iter: int = 100, + h_tol: float = 1e-8, + w_threshold: float = 0.0, + tabu_edges: List[Tuple[int, int]] = None, + tabu_parent_nodes: List[int] = None, + tabu_child_nodes: List[int] = None, +) -> StructureModel: + """ + Learn the `StructureModel`, the graph structure describing conditional dependencies between variables + in data presented as a numpy array. + + The optimisation is to minimise a score function :math:`F(W)` over the graph's + weighted adjacency matrix, :math:`W`, subject to the a constraint function :math:`h(W)`, + where :math:`h(W) == 0` characterises an acyclic graph. + :math:`h(W) > 0` is a continuous, differentiable function that encapsulated how acyclic the graph is + (less == more acyclic). + Full details of this approach to structure learning are provided in the publication: + + Based on DAGs with NO TEARS. + @inproceedings{zheng2018dags, + author = {Zheng, Xun and Aragam, Bryon and Ravikumar, Pradeep and Xing, Eric P.}, + booktitle = {Advances in Neural Information Processing Systems}, + title = {{DAGs with NO TEARS: Continuous Optimization for Structure Learning}}, + year = {2018}, + codebase = {https://github.com/xunzheng/notears} + } + + Args: + X: 2d input data, axis=0 is data rows, axis=1 is data columns. Data must be row oriented. + max_iter: max number of dual ascent steps during optimisation. + h_tol: exit if h(W) < h_tol (as opposed to strict definition of 0). + w_threshold: fixed threshold for absolute edge weights. + tabu_edges: list of edges(from, to) not to be included in the graph. + tabu_parent_nodes: list of nodes banned from being a parent of any other nodes. + tabu_child_nodes: list of nodes banned from being a child of any other nodes. + + Returns: + StructureModel: a graph of conditional dependencies between data variables. + + Raises: + ValueError: If X does not contain data. + """ + + # n examples, d properties + _, d = X.shape + + bnds = [ + (0, 0) + if i == j + else (0, 0) + if tabu_edges is not None and (i, j) in tabu_edges + else (0, 0) + if tabu_parent_nodes is not None and i in tabu_parent_nodes + else (0, 0) + if tabu_child_nodes is not None and j in tabu_child_nodes + else (None, None) + for i in range(d) + for j in range(d) + ] + + return _learn_structure(X, bnds, max_iter, h_tol, w_threshold) + + +def from_numpy_lasso( + X: np.ndarray, + beta: float, + max_iter: int = 100, + h_tol: float = 1e-8, + w_threshold: float = 0.0, + tabu_edges: List[Tuple[int, int]] = None, + tabu_parent_nodes: List[int] = None, + tabu_child_nodes: List[int] = None, +) -> StructureModel: + """ + Learn the `StructureModel`, the graph structure with lasso regularisation + describing conditional dependencies between variables in data presented as a numpy array. + + Based on DAGs with NO TEARS. + @inproceedings{zheng2018dags, + author = {Zheng, Xun and Aragam, Bryon and Ravikumar, Pradeep and Xing, Eric P.}, + booktitle = {Advances in Neural Information Processing Systems}, + title = {{DAGs with NO TEARS: Continuous Optimization for Structure Learning}}, + year = {2018}, + codebase = {https://github.com/xunzheng/notears} + } + + Args: + X: 2d input data, axis=0 is data rows, axis=1 is data columns. Data must be row oriented. + beta: Constant that multiplies the lasso term. + max_iter: max number of dual ascent steps during optimisation. + h_tol: exit if h(W) < h_tol (as opposed to strict definition of 0). + w_threshold: fixed threshold for absolute edge weights. + tabu_edges: list of edges(from, to) not to be included in the graph. + tabu_parent_nodes: list of nodes banned from being a parent of any other nodes. + tabu_child_nodes: list of nodes banned from being a child of any other nodes. + + Returns: + StructureModel: a graph of conditional dependencies between data variables. + + Raises: + ValueError: If X does not contain data. + """ + + # n examples, d properties + _, d = X.shape + + bnds = [ + (0, 0) + if i == j + else (0, 0) + if tabu_edges is not None and (i, j) in tabu_edges + else (0, 0) + if tabu_parent_nodes is not None and i in tabu_parent_nodes + else (0, 0) + if tabu_child_nodes is not None and j in tabu_child_nodes + else (None, None) + for i in range(d) + for j in range(d) + ] * 2 + + return _learn_structure_lasso(X, beta, bnds, max_iter, h_tol, w_threshold) + + +def from_pandas( + X: pd.DataFrame, + max_iter: int = 100, + h_tol: float = 1e-8, + w_threshold: float = 0.0, + tabu_edges: List[Tuple[str, str]] = None, + tabu_parent_nodes: List[str] = None, + tabu_child_nodes: List[str] = None, +) -> StructureModel: + """ + Learn the `StructureModel`, the graph structure describing conditional dependencies between variables + in data presented as a pandas dataframe. + + The optimisation is to minimise a score function :math:`F(W)` over the graph's + weighted adjacency matrix, :math:`W`, subject to the a constraint function :math:`h(W)`, + where :math:`h(W) == 0` characterises an acyclic graph. + :math:`h(W) > 0` is a continuous, differentiable function that encapsulated how acyclic the graph is + (less == more acyclic). + Full details of this approach to structure learning are provided in the publication: + + Based on DAGs with NO TEARS. + @inproceedings{zheng2018dags, + author = {Zheng, Xun and Aragam, Bryon and Ravikumar, Pradeep and Xing, Eric P.}, + booktitle = {Advances in Neural Information Processing Systems}, + title = {{DAGs with NO TEARS: Continuous Optimization for Structure Learning}}, + year = {2018}, + codebase = {https://github.com/xunzheng/notears} + } + + Args: + X: input data. + max_iter: max number of dual ascent steps during optimisation. + h_tol: exit if h(W) < h_tol (as opposed to strict definition of 0). + w_threshold: fixed threshold for absolute edge weights. + tabu_edges: list of edges(from, to) not to be included in the graph. + tabu_parent_nodes: list of nodes banned from being a parent of any other nodes. + tabu_child_nodes: list of nodes banned from being a child of any other nodes. + + Returns: + StructureModel: graph of conditional dependencies between data variables. + + Raises: + ValueError: If X does not contain data. + """ + + data = deepcopy(X) + + non_numeric_cols = data.select_dtypes(exclude="number").columns + + if len(non_numeric_cols) > 0: + raise ValueError( + "All columns must have numeric data. " + "Consider mapping the following columns to int {non_numeric_cols}".format( + non_numeric_cols=non_numeric_cols + ) + ) + + col_idx = {c: i for i, c in enumerate(data.columns)} + idx_col = {i: c for c, i in col_idx.items()} + + if tabu_edges: + tabu_edges = [(col_idx[u], col_idx[v]) for u, v in tabu_edges] + if tabu_parent_nodes: + tabu_parent_nodes = [col_idx[n] for n in tabu_parent_nodes] + if tabu_child_nodes: + tabu_child_nodes = [col_idx[n] for n in tabu_child_nodes] + + g = from_numpy( + data.values, + max_iter, + h_tol, + w_threshold, + tabu_edges, + tabu_parent_nodes, + tabu_child_nodes, + ) + + sm = StructureModel() + sm.add_nodes_from(data.columns) + sm.add_weighted_edges_from( + [(idx_col[u], idx_col[v], w) for u, v, w in g.edges.data("weight")], + origin="learned", + ) + + return sm + + +def from_pandas_lasso( + X: pd.DataFrame, + beta: float, + max_iter: int = 100, + h_tol: float = 1e-8, + w_threshold: float = 0.0, + tabu_edges: List[Tuple[str, str]] = None, + tabu_parent_nodes: List[str] = None, + tabu_child_nodes: List[str] = None, +) -> StructureModel: + """ + Learn the `StructureModel`, the graph structure with lasso regularisation + describing conditional dependencies between variables in data presented as a pandas dataframe. + + Based on DAGs with NO TEARS. + @inproceedings{zheng2018dags, + author = {Zheng, Xun and Aragam, Bryon and Ravikumar, Pradeep and Xing, Eric P.}, + booktitle = {Advances in Neural Information Processing Systems}, + title = {{DAGs with NO TEARS: Continuous Optimization for Structure Learning}}, + year = {2018}, + codebase = {https://github.com/xunzheng/notears} + } + + Args: + X: input data. + beta: Constant that multiplies the lasso term. + max_iter: max number of dual ascent steps during optimisation. + h_tol: exit if h(W) < h_tol (as opposed to strict definition of 0). + w_threshold: fixed threshold for absolute edge weights. + tabu_edges: list of edges(from, to) not to be included in the graph. + tabu_parent_nodes: list of nodes banned from being a parent of any other nodes. + tabu_child_nodes: list of nodes banned from being a child of any other nodes. + + Returns: + StructureModel: graph of conditional dependencies between data variables. + + Raises: + ValueError: If X does not contain data. + """ + + data = deepcopy(X) + + non_numeric_cols = data.select_dtypes(exclude="number").columns + + if not non_numeric_cols.empty: + raise ValueError( + "All columns must have numeric data. " + "Consider mapping the following columns to int {non_numeric_cols}".format( + non_numeric_cols=non_numeric_cols + ) + ) + + col_idx = {c: i for i, c in enumerate(data.columns)} + idx_col = {i: c for c, i in col_idx.items()} + + if tabu_edges: + tabu_edges = [(col_idx[u], col_idx[v]) for u, v in tabu_edges] + if tabu_parent_nodes: + tabu_parent_nodes = [col_idx[n] for n in tabu_parent_nodes] + if tabu_child_nodes: + tabu_child_nodes = [col_idx[n] for n in tabu_child_nodes] + + g = from_numpy_lasso( + data.values, + beta, + max_iter, + h_tol, + w_threshold, + tabu_edges, + tabu_parent_nodes, + tabu_child_nodes, + ) + + sm = StructureModel() + sm.add_nodes_from(data.columns) + sm.add_weighted_edges_from( + [(idx_col[u], idx_col[v], w) for u, v, w in g.edges.data("weight")], + origin="learned", + ) + + return sm + + +def _learn_structure( + X: np.ndarray, + bnds, + max_iter: int = 100, + h_tol: float = 1e-8, + w_threshold: float = 0.0, +) -> StructureModel: + """ + Based on initial implementation at https://github.com/xunzheng/notears + """ + + def _h(w: np.ndarray) -> float: + """ + Constraint function of the NOTEARS algorithm. + + Args: + w: current adjacency matrix. + + Returns: + float: DAGness of the adjacency matrix (0 == DAG, >0 == cyclic). + """ + + W = w.reshape([d, d]) + return np.trace(slin.expm(W * W)) - d + + def _func(w: np.ndarray) -> float: + """ + Objective function that the NOTEARS algorithm tries to minimise. + + Args: + w: current adjacency matrix. + + Returns: + float: objective. + """ + + W = w.reshape([d, d]) + loss = 0.5 / n * np.square(np.linalg.norm(X.dot(np.eye(d, d) - W), "fro")) + h = _h(W) + return loss + 0.5 * rho * h * h + alpha * h + + def _grad(w: np.ndarray) -> np.ndarray: + """ + Gradient function used to compute next step in NOTEARS algorithm. + + Args: + w: the current adjacency matrix. + + Returns: + np.ndarray: gradient vector. + """ + + W = w.reshape([d, d]) + loss_grad = -1.0 / n * X.T.dot(X).dot(np.eye(d, d) - W) + E = slin.expm(W * W) + obj_grad = loss_grad + (rho * (np.trace(E) - d) + alpha) * E.T * W * 2 + return obj_grad.flatten() + + if X.size == 0: + raise ValueError("Input data X is empty, cannot learn any structure") + logging.info("Learning structure using 'NOTEARS' optimisation.") + + # n examples, d properties + n, d = X.shape + # initialise matrix to zeros + w_est, w_new = np.zeros(d * d), np.zeros(d * d) + + # initialise weights and constraints + rho, alpha, h, h_new = 1.0, 0.0, np.inf, np.inf + + # start optimisation + for n_iter in range(max_iter): + while rho < 1e20: + sol = sopt.minimize(_func, w_est, method="L-BFGS-B", jac=_grad, bounds=bnds) + w_new = sol.x + h_new = _h(w_new) + if h_new > 0.25 * h: + rho *= 10 + else: + break + w_est, h = w_new, h_new + alpha += rho * h + if h <= h_tol: + break + if h > h_tol and n_iter == max_iter - 1: + warnings.warn("Failed to converge. Consider increasing max_iter.") + + w_est[np.abs(w_est) <= w_threshold] = 0 + return StructureModel(w_est.reshape([d, d])) + + +def _learn_structure_lasso( + X: np.ndarray, + beta: float, + bnds, + max_iter: int = 100, + h_tol: float = 1e-8, + w_threshold: float = 0.0, +) -> StructureModel: + """ + Based on initial implementation at https://github.com/xunzheng/notears + """ + + def _h(w_vec: np.ndarray) -> float: + """ + Constraint function of the NOTEARS algorithm with lasso regularisation. + + Args: + w_vec: weight vector (wpos and wneg). + + Returns: + float: DAGness of the adjacency matrix (0 == DAG, >0 == cyclic). + """ + + W = w_vec.reshape([d, d]) + return np.trace(slin.expm(W * W)) - d + + def _func(w_vec: np.ndarray) -> float: + """ + Objective function that the NOTEARS algorithm with lasso regularisation tries to minimise. + + Args: + w_vec: weight vector (wpos and wneg). + + Returns: + float: objective. + """ + + w_pos = w_vec[: d ** 2] + w_neg = w_vec[d ** 2 :] + + wmat_pos = w_pos.reshape([d, d]) + wmat_neg = w_neg.reshape([d, d]) + + wmat = wmat_pos - wmat_neg + loss = 0.5 / n * np.square(np.linalg.norm(X.dot(np.eye(d, d) - wmat), "fro")) + h_val = _h(wmat) + return loss + 0.5 * rho * h_val * h_val + alpha * h_val + beta * w_vec.sum() + + def _grad(w_vec: np.ndarray) -> np.ndarray: + """ + Gradient function used to compute next step in NOTEARS algorithm with lasso regularisation. + + Args: + w_vec: weight vector (wpos and wneg). + + Returns: + np.ndarray: gradient vector. + """ + + w_pos = w_vec[: d ** 2] + w_neg = w_vec[d ** 2 :] + + grad_vec = np.zeros(2 * d ** 2) + wmat_pos = w_pos.reshape([d, d]) + wmat_neg = w_neg.reshape([d, d]) + + wmat = wmat_pos - wmat_neg + + loss_grad = -1.0 / n * X.T.dot(X).dot(np.eye(d, d) - wmat) + exp_hdmrd = slin.expm(wmat * wmat) + obj_grad = ( + loss_grad + + (rho * (np.trace(exp_hdmrd) - d) + alpha) * exp_hdmrd.T * wmat * 2 + ) + lbd_grad = beta * np.ones(d * d) + grad_vec[: d ** 2] = obj_grad.flatten() + lbd_grad + grad_vec[d ** 2 :] = -obj_grad.flatten() + lbd_grad + + return grad_vec + + if X.size == 0: + raise ValueError("Input data X is empty, cannot learn any structure") + logging.info( + "Learning structure using 'NOTEARS' optimisation with lasso regularisation." + ) + + n, d = X.shape + w_est, w_new = np.zeros(2 * d * d), np.zeros(2 * d * d) + rho, alpha, h_val, h_new = 1.0, 0.0, np.inf, np.inf + for n_iter in range(max_iter): + while rho < 1e20: + sol = sopt.minimize(_func, w_est, method="L-BFGS-B", jac=_grad, bounds=bnds) + w_new = sol.x + + h_new = _h( + w_new[: d ** 2].reshape([d, d]) - w_new[d ** 2 :].reshape([d, d]) + ) + if h_new > 0.25 * h_val: + rho *= 10 + else: + break + w_est, h_val = w_new, h_new + alpha += rho * h_val + if h_val <= h_tol: + break + if h_val > h_tol and n_iter == max_iter - 1: + warnings.warn("Failed to converge. Consider increasing max_iter.") + + w_new = w_est[: d ** 2].reshape([d, d]) - w_est[d ** 2 :].reshape([d, d]) + w_new[np.abs(w_new) < w_threshold] = 0 + return StructureModel(w_new.reshape([d, d])) diff --git a/causalnex/structure/structuremodel.py b/causalnex/structure/structuremodel.py new file mode 100644 index 0000000..4cf0d23 --- /dev/null +++ b/causalnex/structure/structuremodel.py @@ -0,0 +1,269 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +""" +This module contains the implementation of ``StructureModel``. + +``StructureModel`` is a class that describes relationships between variables as a graph. +""" + +from typing import List, Set, Union + +import networkx as nx +import numpy as np + + +def _validate_origin(origin: str) -> None: + """ + Checks that origin has a valid value. One of: + - unknown: edge exists for an unknown reason; + - learned: edge was created as the output of a machine-learning process; + - expert: edge was created by a domain expert. + + Args: + origin: the value to validate. + + Raises: + ValueError: if origin is not valid. + """ + + allowed = {"unknown", "learned", "expert"} + + if origin not in allowed: + raise ValueError( + "Unknown origin: must be one of {allowed} - got `{origin}`.".format( + allowed=allowed, origin=origin + ) + ) + + +class StructureModel(nx.DiGraph): + """ + Base class for structure models, which are an extension of ``networkx.DiGraph``. + + A ``StructureModel`` stores nodes and edges with optional data, or attributes. + + Edges have one required attribute, "origin", which describes how the edge was created. + Origin can be one of either unknown, learned, or expert. + + StructureModel hold directed edges, describing a cause -> effect relationship. + Cycles are permitted within a ``StructureModel``. + + Nodes can be arbitrary (hashable) Python objects with optional key/value attributes. + By convention None is not used as a node. + + Edges are represented as links between nodes with optional key/value attributes. + """ + + def __init__(self, incoming_graph_data=None, origin="unknown", **attr): + """ + Create a ``StructureModel`` with incoming_graph_data, which has come from some origin. + + Args: + incoming_graph_data (Optional): input graph (optional, default: None) + Data to initialize graph. If None (default) an empty graph is created. + The data can be any format that is supported by the to_networkx_graph() + function, currently including edge list, dict of dicts, dict of lists, + NetworkX graph, NumPy matrix or 2d ndarray, SciPy sparse matrix, or PyGraphviz graph. + + origin (str): label for how the edges were created. Can be one of: + - unknown: edges exist for an unknown reason; + - learned: edges were created as the output of a machine-learning process; + - expert: edges were created by a domain expert. + + attr : Attributes to add to graph as key/value pairs (no attributes by default). + """ + + _validate_origin(origin) + super().__init__(incoming_graph_data, **attr) + for u_of_edge, v_of_edge in self.edges: + self[u_of_edge][v_of_edge]["origin"] = origin + + def to_directed_class(self): + """ + Returns the class to use for directed copies. + See :func:`networkx.DiGraph.to_directed()`. + """ + return StructureModel + + def to_undirected_class(self): + """ + Returns the class to use for undirected copies. + See :func:`networkx.DiGraph.to_undirected()`. + """ + return nx.Graph + + # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) + # this has been disabled because origin tracking is required for CausalGraphs + # implementing it in this way allows all 3rd party libraries and applications to + # integrate seamlessly, where edges will be given origin="unknown" where not provided + def add_edge( + self, u_of_edge: str, v_of_edge: str, origin: str = "unknown", **attr + ): # pylint: disable=W0221 + """ + Adds a causal relationship from u to v. + + If u or v do not currently exists in the ``StructureModel`` then they will be created. + + By default a relationship will be given origin="unknown", but + may also be given "learned" or "expert" origin. + + Adding an edge that already exists will replace the existing edge. + See :func:`networkx.DiGraph.add_edge`. + + Args: + u_of_edge: causal node. + v_of_edge: effect node. + origin: label for how the edge was created. Can be one of: + - unknown: edge exists for an unknown reason; + - learned: edge was created as the output of a machine-learning process; + - expert: edge was created by a domain expert. + **attr: Attributes to add to edge as key/value pairs (no attributes by default). + """ + _validate_origin(origin) + + attr.update({"origin": origin}) + super().add_edge(u_of_edge, v_of_edge, **attr) + + # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) + # this has been disabled because origin tracking is required for CausalGraphs + # implementing it in this way allows all 3rd party libraries and applications to + # integrate seamlessly, where edges will be given origin="unknown" where not provided + def add_edges_from( + self, + ebunch_to_add: Union[Set[tuple], List[tuple]], + origin: str = "unknown", + **attr + ): # pylint: disable=W0221 + """ + Adds a bunch of causal relationships, u -> v. + + If u or v do not currently exists in the ``StructureModel`` then they will be created. + + By default relationships will be given origin="unknown", + but may also be given "learned" or "expert" origin. + + Notes: + Adding an edge that already exists will replace the existing edge. + See :func:`networkx.DiGraph.add_edges_from`. + + Args: + ebunch_to_add: container of edges. + Each edge given in the container will be added to the graph. + The edges must be given as 2-tuples (u, v) or + 3-tuples (u, v, d) where d is a dictionary containing edge data. + origin: label for how the edges were created. One of: + - unknown: edges exist for an unknown reason. + - learned: edges were created as the output of a machine-learning process. + - expert: edges were created by a domain expert. + **attr: Attributes to add to edge as key/value pairs (no attributes by default). + """ + + _validate_origin(origin) + + attr.update({"origin": origin}) + super().add_edges_from(ebunch_to_add, **attr) + + # disabled: W0221: Parameters differ from overridden 'add_edge' method (arguments-differ) + # this has been disabled because origin tracking is required for CausalGraphs + # implementing it in this way allows all 3rd party libraries and applications to + # integrate seamlessly, where edges will be given origin="unknown" where not provided + def add_weighted_edges_from( + self, + ebunch_to_add: Union[Set[tuple], List[tuple]], + weight: str = "weight", + origin: str = "unknown", + **attr + ): # pylint: disable=W0221 + """ + Adds a bunch of weighted causal relationships, u -> v. + + If u or v do not currently exists in the ``StructureModel`` then they will be created. + + By default relationships will be given origin="unknown", + but may also be given "learned" or "expert" origin. + + Notes: + Adding an edge that already exists will replace the existing edge. + See :func:`networkx.DiGraph.add_edges_from`. + + Args: + ebunch_to_add: container of edges. + Each edge given in the container will be added to the graph. + The edges must be given as 2-tuples (u, v) or + 3-tuples (u, v, d) where d is a dictionary containing edge data. + weight : string, optional (default='weight'). + The attribute name for the edge weights to be added. + origin: label for how the edges were created. One of: + - unknown: edges exist for an unknown reason; + - learned: edges were created as the output of a machine-learning process; + - expert: edges were created by a domain expert. + **attr: Attributes to add to edge as key/value pairs (no attributes by default). + """ + _validate_origin(origin) + + attr.update({"origin": origin}) + super().add_weighted_edges_from(ebunch_to_add, weight=weight, **attr) + + def edges_with_origin(self, origin) -> list: + """ + List of edges created with given origin attribute. + + Returns: + A list of edges with the given origin. + """ + + return [(u, v) for u, v in self.edges if self[u][v]["origin"] == origin] + + def remove_edges_below_threshold(self, threshold: float): + """ + Remove edges whose absolute weights are less than a defined threshold. + + Args: + threshold: edges whose absolute weight is less than this value are removed. + """ + + self.remove_edges_from( + [(u, v) for u, v, w in self.edges(data="weight") if np.abs(w) < threshold] + ) + + def get_largest_subgraph(self) -> "StructureModel": + """ + Get the largest subgraph of the Structure Model. + + Returns: + The largest subgraph of the Structure Model. If no subgraph exists, None is returned. + """ + largest_n_edges = 0 + largest_subgraph = None + + for subgraph in nx.weakly_connected_component_subgraphs(self): + if len(subgraph.edges) > largest_n_edges: + largest_n_edges = len(subgraph.edges) + largest_subgraph = subgraph + + return largest_subgraph diff --git a/doc_requirements.txt b/doc_requirements.txt new file mode 100644 index 0000000..f60e9a6 --- /dev/null +++ b/doc_requirements.txt @@ -0,0 +1,12 @@ +click>=7.0, <8.0 +ipykernel>=4.8.1, <5.0 +jupyter_client>=5.1.0, <6.0 +nbsphinx==0.4.2 +nbstripout==0.3.3 +patchy>=1.5, <2.0 +recommonmark==0.5.0 +sphinx-autodoc-typehints>=1.6.0, < 2.0 +sphinx-markdown-tables==0.0.9 +sphinx>=1.8.4, <2.0 +sphinx_copybutton==0.2.5 +sphinx_rtd_theme==0.4.3 diff --git a/docs/_templates/autosummary/base.rst b/docs/_templates/autosummary/base.rst new file mode 100644 index 0000000..b7556eb --- /dev/null +++ b/docs/_templates/autosummary/base.rst @@ -0,0 +1,5 @@ +{{ fullname | escape | underline}} + +.. currentmodule:: {{ module }} + +.. auto{{ objtype }}:: {{ objname }} diff --git a/docs/_templates/autosummary/class.rst b/docs/_templates/autosummary/class.rst new file mode 100644 index 0000000..aa25df8 --- /dev/null +++ b/docs/_templates/autosummary/class.rst @@ -0,0 +1,33 @@ +{{ fullname | escape | underline }} + +.. currentmodule:: {{ module }} + +.. autoclass:: {{ objname }} + :members: + :undoc-members: + :show-inheritance: + :inherited-members: + + {% block attributes %} + {% if attributes %} + .. rubric:: Attributes + + .. autosummary:: + {% for item in all_attributes %} + {%- if not item.startswith('_') %} + {{ name }}.{{ item }} + {%- endif -%} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block methods %} + {% if methods %} + .. rubric:: Methods + + .. autosummary:: + {% for item in methods %} + {{ name }}.{{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/_templates/autosummary/module.rst b/docs/_templates/autosummary/module.rst new file mode 100644 index 0000000..68f7527 --- /dev/null +++ b/docs/_templates/autosummary/module.rst @@ -0,0 +1,58 @@ +{{ fullname | escape | underline }} + +.. rubric:: Description + +.. automodule:: {{ fullname }} + + {% block public_modules %} + {% if public_modules %} + .. rubric:: Modules + + .. autosummary:: + :toctree: + :template: autosummary/module.rst + {% for item in public_modules %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block functions %} + {% if functions %} + .. rubric:: Functions + + .. autosummary:: + :toctree: + {% for item in functions %} + {%- if not item.startswith('_') %} + {{ item }} + {% endif %} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block classes %} + {% if classes %} + .. rubric:: Classes + + .. autosummary:: + :toctree: + :template: autosummary/class.rst + {% for item in classes %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} + + {% block exceptions %} + {% if exceptions %} + .. rubric:: Exceptions + + .. autosummary:: + :toctree: + :template: autosummary/class.rst + {% for item in exceptions %} + {{ item }} + {%- endfor %} + {% endif %} + {% endblock %} diff --git a/docs/build-docs.sh b/docs/build-docs.sh new file mode 100755 index 0000000..e96218c --- /dev/null +++ b/docs/build-docs.sh @@ -0,0 +1,44 @@ +#!/usr/bin/env bash + +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +set -e + +python -m ipykernel install --user --name=causalnex --display-name=causalnex + +# Move some files around. We need a separate build directory, which would +# have all the files, build scripts would shuffle the files, +# we don't want that happening on the actual code locally. +# When running on ReadTheDocs, sphinx-build would run directly on the original files, +# but we don't care about the code state there. +rm -rf docs/build +mkdir docs/build/ +cp -r docs/_templates docs/conf.py docs/build/ + +sphinx-build -c docs/ -Ea -j auto -D language=en docs/build/ docs/build/html diff --git a/docs/conf.py b/docs/conf.py new file mode 100644 index 0000000..506e091 --- /dev/null +++ b/docs/conf.py @@ -0,0 +1,503 @@ +#!/usr/bin/env python3 +# -*- coding: utf-8 -*- +# +# causalnex documentation build configuration file, +# created by, sphinx-quickstart on Mon Dec 18 11:31:24 2017. +# +# This file is execfile()d with the current directory set to its +# containing dir. +# +# Note that not all possible configuration values are present in this +# autogenerated file. +# +# All configuration values have a default; values that are commented out +# serve to show the default. + +import importlib +import re +import shutil +import sys +from distutils.dir_util import copy_tree +from inspect import getmembers, isclass, isfunction +from pathlib import Path +from typing import List + +import patchy +from click import secho, style +from sphinx.ext.autosummary.generate import generate_autosummary_docs + +from causalnex import __version__ as release + +# -- Project information ----------------------------------------------------- + +project = "causalnex" +copyright = "2020, QuantumBlack" +author = "QuantumBlack" + +# The short X.Y version. +version = re.match(r"^([0-9]+\.[0-9]+).*", release).group(1) + +# -- General configuration --------------------------------------------------- + +# If your documentation needs a minimal Sphinx version, state it here. +# +# needs_sphinx = '1.0' + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + "sphinx.ext.autodoc", + "sphinx.ext.autosummary", + "sphinx.ext.napoleon", + "sphinx_autodoc_typehints", + "sphinx.ext.doctest", + "sphinx.ext.todo", + "sphinx.ext.coverage", + "sphinx.ext.mathjax", + "sphinx.ext.ifconfig", + "sphinx.ext.viewcode", + "nbsphinx", + "recommonmark", + "sphinx_markdown_tables", + "sphinx_copybutton", +] + +# enable autosummary plugin (table of contents for modules/classes/class +# methods) +autosummary_generate = True +autosummary_imported_members = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ["_templates"] + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = {".rst": "restructuredtext", ".md": "markdown"} + +# The master toctree document. +master_doc = "index" + +# The language for content autogenerated by Sphinx. Refer to documentation +# for a list of supported languages. +# +# This is also used if you do content translation via gettext catalogs. +# Usually you set "language" from the command line for these cases. +language = None + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path . +exclude_patterns = ["**cli*", "_build", "**.ipynb_checkpoints", "_templates"] + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = "sphinx" + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = "sphinx_rtd_theme" +here = Path(__file__).parent.absolute() +# html_logo = str(here / "causalnex_logo.svg") + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + "collapse_navigation": False, + "style_external_links": True, + # "logo_only": True + # "github_url": "https://github.com/quantumblacklabs/causalnex" +} + +html_context = { + "display_github": True, + "github_url": "https://github.com/quantumblacklabs/causalnex/tree/develop/docs/source", +} + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] + +# Custom sidebar templates, must be a dictionary that maps document names +# to template names. + +# The default sidebars (for documents that don't match any pattern) are +# defined by theme itself. Builtin themes are using these templates by +# default: ``['localtoc.html', 'relations.html', 'sourcelink.html', +# 'searchbox.html']``. +# +# html_sidebars = {} + +html_show_sourcelink = False + +# -- Options for HTMLHelp output --------------------------------------------- + +# Output file base name for HTML help builder. +htmlhelp_basename = "causalnexdoc" + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, "causalnex.tex", "causalnex Documentation", "QuantumBlack", "manual") +] + +# -- Options for manual page output ------------------------------------------ + +# One entry per manual page. List of tuples +# (source start file, name, description, authors, manual section). +man_pages = [(master_doc, "causalnex", "causalnex Documentation", [author], 1)] + +# -- Options for Texinfo output ---------------------------------------------- + +# Grouping the document tree into Texinfo files. List of tuples +# (source start file, target name, title, author, +# dir menu entry, description, category) +texinfo_documents = [ + ( + master_doc, + "causalnex", + "causalnex Documentation", + author, + "causalnex", + "Toolkit for causal reasoning (Bayesian Networks / Inference)", + "Data-Science", + ) +] + +# -- Options for todo extension ---------------------------------------------- + +# If true, `todo` and `todoList` produce output, else they produce nothing. +todo_include_todos = False + +# -- Extension configuration ------------------------------------------------- + +# nbsphinx_prolog = """ +# see here for prolog/epilog details: +# https://nbsphinx.readthedocs.io/en/0.4.0/prolog-and-epilog.html +# """ + +nbsphinx_epilog = """ +.. note:: + + Found a bug, or didn't find what you were looking for? `🙏Please file a + ticket <https://github.com/quantumblacklabs/causalnex/issues/new/choose>`_ +""" + +# -- NBconvert kernel config ------------------------------------------------- +nbsphinx_kernel_name = "causalnex" + + +# -- causalnex specific configuration ------------------ +MODULES = [] + + +def get_classes(module): + importlib.import_module(module) + return [obj[0] for obj in getmembers(sys.modules[module], lambda obj: isclass(obj))] + + +def get_functions(module): + importlib.import_module(module) + return [ + obj[0] for obj in getmembers(sys.modules[module], lambda obj: isfunction(obj)) + ] + + +def remove_arrows_in_examples(lines): + for i, line in enumerate(lines): + lines[i] = line.replace(">>>", "") + + +def autolink_replacements(what): + """ + Create a list containing replacement tuples of the form: + (``regex``, ``replacement``, ``obj``) for all classes and methods which are + imported in ``MODULES`` ``__init__.py`` files. The ``replacement`` + is a reStructuredText link to their documentation. + For example, if the docstring reads: + This DataSet loads and saves ... + Then the word ``DataSet``, will be replaced by + :class:`~causalnex.io.DataSet` + Works for plural as well, e.g: + These ``DataSet``s load and save + Will convert to: + These :class:`causalnex.io.DataSet` s load and + save + Args: + what (str) : The objects to create replacement tuples for. Possible + values ["class", "func"] + Returns: + List[Tuple[regex, str, str]]: A list of tuples: (regex, replacement, + obj), for all "what" objects imported in __init__.py files of + ``MODULES`` + """ + replacements = [] + suggestions = [] + for module in MODULES: + if what == "class": + objects = get_classes(module) + elif what == "func": + objects = get_functions(module) + + # Look for recognised class names/function names which are + # surrounded by double back-ticks + if what == "class": + # first do plural only for classes + replacements += [ + ( + r"``{}``s".format(obj), + ":{}:`~{}.{}`\\\\s".format(what, module, obj), + obj, + ) + for obj in objects + ] + + # singular + replacements += [ + (r"``{}``".format(obj), ":{}:`~{}.{}`".format(what, module, obj), obj) + for obj in objects + ] + + # Look for recognised class names/function names which are NOT + # surrounded by double back-ticks, so that we can log these in the + # terminal + if what == "class": + # first do plural only for classes + suggestions += [ + (r"(?<!\w|`){}s(?!\w|`{{2}})".format(obj), "``{}``s".format(obj), obj) + for obj in objects + ] + + # then singular + suggestions += [ + (r"(?<!\w|`){}(?!\w|`{{2}})".format(obj), "``{}``".format(obj), obj) + for obj in objects + ] + + return replacements, suggestions + + +def log_suggestions(lines: List[str], name: str): + """Use the ``suggestions`` list to log in the terminal places where the + developer has forgotten to surround with double back-ticks class + name/function name references. + + Args: + lines: The docstring lines. + name: The name of the object whose docstring is contained in lines. + """ + title_printed = False + + for i in range(len(lines)): + if ">>>" in lines[i]: + continue + + for existing, replacement, obj in suggestions: + new = re.sub(existing, r"{}".format(replacement), lines[i]) + if new == lines[i]: + continue + if ":rtype:" in lines[i] or ":type " in lines[i]: + continue + + if not title_printed: + secho("-" * 50 + "\n" + name + ":\n" + "-" * 50, fg="blue") + title_printed = True + + print( + "[" + + str(i) + + "] " + + re.sub(existing, r"{}".format(style(obj, fg="magenta")), lines[i]) + ) + print( + "[" + + str(i) + + "] " + + re.sub(existing, r"``{}``".format(style(obj, fg="green")), lines[i]) + ) + + if title_printed: + print("\n") + + +def autolink_classes_and_methods(lines): + for i in range(len(lines)): + if ">>>" in lines[i]: + continue + + for existing, replacement, obj in replacements: + lines[i] = re.sub(existing, r"{}".format(replacement), lines[i]) + + +# Sphinx build passes six arguments +def autodoc_process_docstring(app, what, name, obj, options, lines): + try: + # guarded method to make sure build never fails + log_suggestions(lines, name) + autolink_classes_and_methods(lines) + except Exception as e: + print( + style( + "Failed to check for class name mentions that can be " + "converted to reStructuredText links in docstring of {}. " + "Error is: \n{}".format(name, str(e)), + fg="red", + ) + ) + + remove_arrows_in_examples(lines) + + +# Sphinx build method passes six arguments +def skip(app, what, name, obj, skip, options): + if name == "__init__": + return False + return skip + + +def _prepare_build_dir(app, config): + """Get current working directory to the state expected + by the ReadTheDocs builder. Shortly, it does the same as + ./build-docs.sh script except not running `sphinx-build` step.""" + build_root = Path(app.srcdir) + build_out = Path(app.outdir) + copy_tree(str(here / "source"), str(build_root)) + copy_tree(str(build_root / "api_docs"), str(build_root)) + shutil.rmtree(str(build_root / "api_docs")) + shutil.rmtree(str(build_out), ignore_errors=True) + copy_tree(str(build_root / "css"), str(build_out / "_static" / "css")) + copy_tree(str(build_root / "04_user_guide/images"), str(build_out / "04_user_guide")) + shutil.rmtree(str(build_root / "css")) + + +def setup(app): + app.connect("config-inited", _prepare_build_dir) + app.connect("autodoc-process-docstring", autodoc_process_docstring) + app.connect("autodoc-skip-member", skip) + app.add_stylesheet("css/qb1-sphinx-rtd.css") + # fix a bug with table wraps in Read the Docs Sphinx theme: + # https://rackerlabs.github.io/docs-rackspace/tools/rtd-tables.html + app.add_stylesheet("css/theme-overrides.css") + # add "Copy" button to code snippets + app.add_stylesheet("css/copybutton.css") + app.add_stylesheet("css/causalnex.css") + + # when using nbsphinx, to allow mathjax render properly + app.config._raw_config.pop('mathjax_config') + + +def fix_module_paths(): + """ + This method fixes the module paths of all class/functions we import in the + __init__.py file of the various causalnex submodules. + """ + for module in MODULES: + mod = importlib.import_module(module) + if not hasattr(mod, "__all__"): + mod.__all__ = get_classes(module) + get_functions(module) + + +# (regex, restructuredText link replacement, object) list +replacements = [] + +# (regex, class/function name surrounded with back-ticks, object) list +suggestions = [] + +try: + # guarded code to make sure build never fails + replacements_f, suggestions_f = autolink_replacements("func") + replacements_c, suggestions_c = autolink_replacements("class") + replacements = replacements_f + replacements_c + suggestions = suggestions_f + suggestions_c +except Exception as e: + print( + style( + "Failed to create list of (regex, reStructuredText link " + "replacement) for class names and method names in docstrings. " + "Error is: \n{}".format(str(e)), + fg="red", + ) + ) + +fix_module_paths() + +patchy.patch( + generate_autosummary_docs, + """\ +@@ -3,7 +3,7 @@ def generate_autosummary_docs(sources, output_dir=None, suffix='.rst', + base_path=None, builder=None, template_dir=None, + imported_members=False, app=None): + # type: (List[unicode], unicode, unicode, Callable, Callable, unicode, Builder, unicode, bool, Any) -> None # NOQA +- ++ imported_members = True + showed_sources = list(sorted(sources)) + if len(showed_sources) > 20: + showed_sources = showed_sources[:10] + ['...'] + showed_sources[-10:] +""", +) + +patchy.patch( + generate_autosummary_docs, + """\ +@@ -96,6 +96,21 @@ def generate_autosummary_docs(sources, output_dir=None, suffix='.rst', + if x in include_public or not x.startswith('_')] + return public, items + ++ import importlib ++ def get_public_modules(obj, typ): ++ # type: (Any, str) -> List[str] ++ items = [] # type: List[str] ++ for item in getattr(obj, '__all__', []): ++ try: ++ importlib.import_module(name + '.' + item) ++ except ImportError: ++ continue ++ finally: ++ if item in sys.modules: ++ sys.modules.pop(name + '.' + item) ++ items.append(name + '.' + item) ++ return items ++ + ns = {} # type: Dict[unicode, Any] +""", +) + +patchy.patch( + generate_autosummary_docs, + """\ +@@ -106,6 +106,9 @@ def generate_autosummary_docs(sources, output_dir=None, suffix='.rst', + get_members(obj, 'class', imported=imported_members) + ns['exceptions'], ns['all_exceptions'] = \\ + get_members(obj, 'exception', imported=imported_members) ++ ns['public_modules'] = get_public_modules(obj, 'module') ++ ns['functions'] = [m for m in ns['functions'] if not hasattr(obj, '__all__') or m in obj.__all__] ++ ns['classes'] = [m for m in ns['classes'] if not hasattr(obj, '__all__') or m in obj.__all__] + elif doc.objtype == 'class': + ns['members'] = dir(obj) + ns['inherited_members'] = \\ +""", +) diff --git a/docs/source/01_introduction/01_introduction.md b/docs/source/01_introduction/01_introduction.md new file mode 100644 index 0000000..e7e85e1 --- /dev/null +++ b/docs/source/01_introduction/01_introduction.md @@ -0,0 +1,42 @@ +# Introduction + + +CausalNex is a Python library that uses Bayesian Networks to combine machine learning and domain expertise for causal reasoning. +You can use CausalNex to uncover structural relationships in your data, learn complex distributions, +and observe the effect of potential interventions. + +## Main features of CausalNex + +The CausalNex library has the following features: + +- Deploys state-of-the-art structure learning method, [DAG with NO TEARS](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf), to understand conditional dependencies between variables +- Allows domain knowledge to augment model relationships +- Builds predictive models based on structural relationships +- Understands model probability +- Evaluates model quality with standard statistical checks +- Visualisation which simplifies how causality is understood +- Analyses the impact of interventions using Do-calculus + +## Learning About CausalNex + +In the next few chapters, you will learn how to install and set up CausalNex, and how to use it on your own projects. +Once you are set up, to get a feel for CausalNex, we suggest working through our example tutorial project. +Advanced users looking for in-depth information should consult the User Guide. +You can also check out the resources section for answers to frequently asked questions and the API reference documentation for further, detailed information. + +## Assumptions + +We have designed the documentation in general, and the tutorial in particular, for beginners to get started using Bayesian Networks on their projects. If you an have elementary knowledge of Python and Bayesian Networks then you may find the CausalNex learning curve more challenging. However, we have simplified the tutorial by providing all the Python functions necessary to create your first CausalNex project. + +Note: There are a number of excellent online resources for learning Python, but be aware that +you should choose those that reference Python 3, as CausalNex is built for Python 3.5+. +There are many curated lists of online resources, such as: + +- [Official Python programming language website](https://www.python.org/) +- [List of free programming books and tutorials](https://github.com/EbookFoundation/free-programming-books/blob/master/free-programming-books.md#python) + +There are also several excellent online resources for learning about Bayesian Networks, such as: + +- [Lecture notes](https://ermongroup.github.io/cs228-notes/) on Probabilistic graphical models based on Stanford CS228; +- [An Introduction to Bayesian Network Theory and Usage](http://infoscience.epfl.ch/record/82584) by T. Stephenson; +- [PGMPY tutorial](https://github.com/pgmpy/pgmpy_notebook/blob/master/notebooks/2.%20Bayesian%20Networks.ipynb). diff --git a/docs/source/02_getting_started/01_prerequisites.md b/docs/source/02_getting_started/01_prerequisites.md new file mode 100644 index 0000000..a133417 --- /dev/null +++ b/docs/source/02_getting_started/01_prerequisites.md @@ -0,0 +1,71 @@ +# Installation prerequisites + +CausalNex supports macOS, Linux and Windows (7 / 8 / 10 and Windows Server 2016+). If you encounter any problems on +these platforms, please check the FAQ, and / or the Alchemy community support on Slack. + +## macOS / Linux + +In order to work effectively with CausalNex projects, we highly recommend you download and install +[Anaconda](https://www.anaconda.com/download/#macos) (Python 3.x version). + +## Windows + +You will require admin rights to complete the installation of the following tools on your machine: + +* [Anaconda](https://www.anaconda.com/download/#windows) (Python 3.x version) + +## Python virtual environments + +Python's virtual environments can be used to isolate the dependencies of different individual projects, +avoiding Python version conflicts. They also prevent permission issues for non-administrator users. +For more information, please refer to this +[guide](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html). + +### Using `conda` + +We recommend creating your virtual environment using [`conda`](https://conda.io/docs/), a package and environment +manager program bundled with Anaconda. + +#### Create an environment with `conda` + +Use [`conda create`](https://conda.io/docs/user-guide/tasks/manage-environments.html#id1) to create a python 3.6 +environment called `environment_name` by running: + +```bash +conda create --name environment_name python=3.6 +``` + +#### Activate an environment with `conda` + +Use [`conda activate`](https://conda.io/docs/user-guide/tasks/manage-environments.html#activating-an-environment) +to activate an environment called `environment_name` by running: + +```bash +conda activate environment_name +``` + +When you want to deactivate the environment you are using with CausalNex, you can use +[`conda deactivate`](https://conda.io/docs/user-guide/tasks/manage-environments.html#id6): + +```bash +conda deactivate +``` + +#### Other `conda` commands + +To list all existing `conda` environments: + +```bash +conda env list +``` + +To delete an environment: + +```bash +conda remove --name environment_name --all +``` + +### Alternatives to `conda` + +If you prefer an alternative environment manager such as [`venv`](https://docs.python.org/3/library/venv.html), +[`pyenv`](https://github.com/pyenv/pyenv), etc, please read their respective documentation. diff --git a/docs/source/02_getting_started/02_install.md b/docs/source/02_getting_started/02_install.md new file mode 100644 index 0000000..de3cddc --- /dev/null +++ b/docs/source/02_getting_started/02_install.md @@ -0,0 +1,21 @@ +## Installation guide + +We recommend installing CausalNex in a new virtual environment for *each* of your projects. To install CausalNex: + +```bash +pip install causalnex +``` + +To check your installation: + +```bash +python -c "import causalnex" +``` + +If CausalNex is not installed correctly you will see an error message similar to the following: + +```bash +ModuleNotFoundError: No module named 'causalnex' +``` + +You should not see any output if CausalNex is correctly installed. diff --git a/docs/source/03_tutorial/03_tutorial.ipynb b/docs/source/03_tutorial/03_tutorial.ipynb new file mode 100644 index 0000000..2a2e5af --- /dev/null +++ b/docs/source/03_tutorial/03_tutorial.ipynb @@ -0,0 +1,2112 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# A first CausalNex tutorial" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This tutorial will walk you through an example workflow using CausalNex to estimate whether a student will pass or fail an exam, by looking at various influences like school support, relationship between family members, and others. We will use the [Student Performance Data Set](https://archive.ics.uci.edu/ml/datasets/Student+Performance) published in the [UCI Machine Learning Repository](http://archive.ics.uci.edu/ml).\n", + "\n", + "\n", + "To work through this tutorial, you first need to create a new Python 3 notebook and download the [student.zip](https://archive.ics.uci.edu/ml/machine-learning-databases/00320/student.zip) file and extract `student-por.csv` from the zip file into the same directory, then copy and paste the code cells from this tutorial into your notebook." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Structure Learning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Defining the structure of a Bayesian Network (BN) model can be done based on machine learning, domain knowledge, or a combination of both, where experts and algorithms contribute as equal partners.\n", + "\n", + "Regardless of the approach, it is important to validate the structure by evaluating the BN - this will be covered later in the tutorial. In this section, we will focus on how to define a structure." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Structure from Domain Knowledge" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can manually define a structure model by specifying the relationships between different features.\n", + "\n", + "First, we must create an empty structure model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from causalnex.structure import StructureModel\n", + "sm = StructureModel()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we can specify the relationships between features. For example, let's assume that experts tell us the following causal relationships are known (G1 is grade in semester 1):\n", + "* `health` -> `absences`\n", + "* `health` -> `G1`\n", + "\n", + "We can add these relationships into our structure model:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "sm.add_edges_from([\n", + " ('health', 'absences'),\n", + " ('health', 'G1')\n", + "])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Visualising the Structure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can examine a StructureModel by looking at the output of `sm.edges`" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "OutEdgeView([('health', 'absences'), ('health', 'G1')])" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sm.edges" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "but it can often be more intuitive to visualise it. CausalNex provides a plotting module that allows us to do this." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ben_horsburgh/opt/anaconda3/envs/causal-test/lib/python3.7/site-packages/networkx/drawing/nx_pylab.py:563: MatplotlibDeprecationWarning: \n", + "The iterable function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use np.iterable instead.\n", + " if not cb.iterable(width):\n", + "/Users/ben_horsburgh/opt/anaconda3/envs/causal-test/lib/python3.7/site-packages/networkx/drawing/nx_pylab.py:660: MatplotlibDeprecationWarning: \n", + "The iterable function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use np.iterable instead.\n", + " if cb.iterable(node_size): # many node sizes\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAWgElEQVR4nO3de5ScdX3H8fc3QJINigEBEfDCzQpCihBEFAgK5S6IQDZYjyBajb1oj9Sj2NODpbX0VKFqtaXWKnraHpZLALkErFig5SaIGi5aASHK/WYQSCCR/fWP37Od2WV2s7szs8/zzLxf53CS7O7MPskfvM8z+8zziZQSkiRVzayyD0CSpFYMlCSpkgyUJKmSDJQkqZIMlCSpkgyUJKmSDJQkqZIMlCSpkgyUJKmSDJQkqZIMlCSpkgyUJKmSDJQkqZIMlCSpkjYs+wDWK2JL4ERgATAfWAWsAM4hpcfLPDRJUvdEZfegIvYCTgUOAxIw0PTZNUAAy4EzSOmWmT9ASVI3VTNQEUuBM4G5TPwy5DDwPHAKKZ09E4cmSZoZ1QtUI07zpvCo1RgpSeopXQtURJwEfCiltO8UHrQXcA1Ti9OI1cAiUrp1Go+VJFVM1a7iO5X8st50zC0eL0nqAdUJVL5a7zCmf0yzgMOJ2KJzByVJKkvbgYqIT0fEvRHxTETcFRHHjP50fCUino6In0XEgU2fOCkiflE87r4/gy+Sr9bjG8DOwKbAIcDK5icEzgZ2Il9z/kcjD8rSR+EfI+KnTcezR/H9to6ICyPi8Yi4LyI+1nQsb4mIWyPiNxHxaESc1e6/iySpPZ04g7oX2A94BfCXwL9FxKuLz+1dfH5z4DRgWURsFhEbA18GDkspvRx421HwMmDgEuBvgGXA48UTnzDmG14G3EJ+M9R5wFXFx8+HgSE4Ang/sAlwFPBkRMwCLgV+AmwDHAj8aUQcUjz0S8CXUkqbADsUTytJKlHbgUopnZ9SeiilNJxSGgLuBt5SfPox4IsppXXF5/6XHBDIl4jvGhEDKaWH988nR5xN/kHSzuR3EX8G+DGjz6I+TT57ei3wjuLzAF8HlsK9KaVbUnZPSmklsBewRUrp9JTS2pTSL4B/AZYUD10H7BgRm6eUnk0p3dTuv4skqT2deInv/RHx44hYFRGrgF3JZ0wAD6bRlwmuBLZOKT0HDAJLgYcj4vIbcrBYCXycHKD5wGbkl/AebHqSrZp+Pw94tvj9r4DX5SiO9Tpg65FjLI7zM8Cris9/EHgD8LOIuCUijpz6v4QkqZPautVRRLyOfCZyIHBjSunFiPgxxdkQsE1ERFOkXgt8ByCldBVwVUQMAH/9XjjmfljzGhj4c+D3p3E828DwNZCWRhxKfqnujcAc4JvAfSmlnVo9LqV0N3BC8VLge4ALIuKVRUglSSVo9wxqY/IJzuMAEfEB8hnUiC2Bj0XERhFxPPmVuysi4lURcXTxs6gXgGcfhkeAWAqcAdxZPMHTwPmTPJiTIZ2XY3kpcBbwx8BBwA+AZyLiUxExEBEbRMSukd93RUS8LyK2SCkNk+/1B8UZnSSpHG0FKqV0F/muDzcCjwK7Adc3fcnN5AvungA+BxyXUnqy+L6fAB4CngIWrYUPAcuPgeFPkX84tAm5dssndzjDS+CS4fw9NwBmFx9/DjgAOBrYHbivOJ6vky/sADgUuDMiniVfMLEkpbRmCv8UkqQOq9atjjpwJ4mAHwJfA95LDtVZ5KvVtwUuAIaA/ynOliRJFVWtQEFH7sUXEQH8K7AgpbSw+NiO5AszBoFXkl85HAJuSpX7R5AkVS9Q0LG7mUfE7JTS2hYf35lGrOaR3/c0BPzQWElSNVQzUAARC8lviTqc8fegriDvQU3rBrHFmdZuNGIFjVitMFaSVJ7qBmpEvrdeq0Xdb3VyUbeI1R7kUC0mn5kNAUPFxSCSpBlU/UCVoIjV3jRi9RSNWN1d5rFJUr8wUOtRvHn37eRYHUe+NH4kVveXeGiS1NMM1BRExAbAInKsjgXuIcfq/JTSA2UemyT1GgM1TRGxEfmuFYPkNwHfSY7VBSmlR8o8NknqBQaqAyJiDnAwOVZHAreRY3VhSumJMo9NkurKQHVYcfPbw8ixOhS4iRyri1JKvy7z2CSpTgxUFxU3wz2SHKuDgOvIsbokpfSbMo9NkqrOQM2QiBhZ+B0kX2hxNTlWlzrrIUkvZaBKEBGbAu8mx2of8mr9EHCFd1GXpMxAlSwiNiePJA4CewKXA+cC300pvVDmsUlSmQxUhUTEVuT3Vw2Sp7AuIZ9ZXZ1SWlfmsUnSTDNQFRUR2wLHk2O1A3AROVbXpJReLPPYJGkmGKgaiIjXk+8JOAhsA1yIw4uSepyBqhmHFyX1CwNVYw4vSuplBqoHOLwoqRcZqB7TYnhxDY15kJ+WeWySNBUGqoe1GF58kkas7inz2CRpfQxUn2gxvPggOVbnObwoqYoMVB8aM7z4HuBeHF6UVDEGqs8Vw4vvBJbg8KKkCjFQ+n8OL0qqEgOllhxelFQ2A6X1GjO8eCCN4cXvOLwoqVsMlKZkzPDi/jSGFy9zeFFSJxkoTVuL4cUrybFa7vCipHYZKHXEmOHFPcjDi0M4vChpmgyUOs7hRUmdYKDUVQ4vSpouA6UZM87w4rnA9Q4vShrLQKkUY4YXN6MxvHiz8yCSwECpAsYMLw7Q2LK6zVhJ/ctAqTLGGV4cKv673VhJ/cVAqZIcXpRkoFR5Di9K/clAqVYcXpT6h4FSbY0zvHgueXjxwTKPTVL7DJR6QtPw4iAvHV58tMxjkzQ9Bko9pxhe/D3ySvARNIYXlzm8KNWHgVJPazG8eCM5Vhc7vChVm4FS33B4UaoXA6W+5PCiVH0GSn3P4UWpmgyU1KTF8OJlNIYX15Z5bFK/MVDSOMYML76JxvDi9x1elLrPQEmT0GJ4cRk5Vtc6vCh1h4GSpqjF8OIF5Fg5vCh1kIGS2uDwotQ9BkrqEIcXpc4yUFKHtRheTDRi5fCiNEkGSuoihxel6TNQ0gxxeFGaGgMllaDF8OIDNIYXV5Z5bFJVGCipZC2GF+8hx8rhRfU1AyVVSDG8eCCN4cU7cHhRfcpASRVVDC8eTI6Vw4vqOwZKqoEJhhcvSimtKvPYpG4xUFLNOLyofmGgpBpzeFG9zEBJPWLM8OJbgatweFE1ZqCkHtQ0vLgEeDMOL6qGDJTU44rhxePIZ1a74PCiasJASX3E4UXViYGS+pTDi6o6AyWJiNiJRqwcXlQlGChJozi8qKowUJJacnhRZTNQktZrzPDiILAahxfVZQZK0pQ0DS8uIV8R6PCiusJASZq2YnhxXxrDi7/C4UV1iIGS1BERsSGjhxfvxuFFtcFASeo4hxfVCQZKUleNGV48EvghDi9qEgyUpBnj8KKmwkBJKoXDi1ofAyWpdA4vqhUDJalSiuHFY2gML16Jw4t9yUBJqqyI2IJ8yfogDi/2HQMlqRYcXuw/BkpS7Ti82B8MlKRac3ixdxkoST3D4cXeYqAk9aSI2IUcqyXAXBxerB0DJamnFfMgC2hsWQ3j8GItGChJfaOI1Z7kUC3G4cVKM1CS+lKxZbU3OVYOL1aQgZLU9xxerCYDJUlNHF6sDgMlSeNweLFcBkqSJqFpeHEJcAQOL3adgZKkKSqGFw8nn1kdgsOLXWGgJKkNEfEyRg8vXovDix1hoCSpQ4rhxaPJsdoPhxfbYqAkqQscXmyfgZKkLhszvLgHcCkOL66XgZKkGeTw4uQZKEkqSUS8hsbw4vY4vDiKgZKkCoiI7WhsWW1Np4cXI7YETiTf2X0+sApYAZxDSo+3/fxdYKAkqWKK4cWReZBNaWd4MWIv4FTgMCABA02fXQMEsBw4g5RuafvgO8hASVKFFcOLI7Ga2vBixFLgzOJxsyb4ymHgeeAUUjq7A4fdERMdsCSpZCmlu1JKpwE7AxsB25Ij9fOI+FxELCh2rkZrxGkeE/y//hxg3/z5ecCZxeOaniYOiIgHOvTXmRIDJUk1UJwtrQO+AewInADMJl+yfldEfDYidgZGXtYbidNUzAv4p7Mi3t2xA2+DgZKkmknZrSmlTwLbASeTL3y4OiJW/ATOSfllvWk5AJau94tmgIGSpHrZPSJWRMTTETEEzE4p3Qh8D3gsYPuPwC63N/3//W+BHYCXk994ddE4T7x/8et+cMisiGcjYnDkcxFxSkQ8FhEPR8QHuvI3G8NASVK9LAYOJZ85LQBOiog3k1/6+8gaOP3DsPYo4IXiATsA/w08DZwGvA94uMUTX1f8+iN4fhg+m1IaKj60FfAKYBvgg8BXi1s5dZWBkqR6+XJK6aGU0lPknz/tDnwY+OeU0s1zYLeTYfYc4KbiAceT31g1i3wp4E7ADyb4BrPyy4MLmj60Djg9pbQupXQF8CzwOx39W7WwYbe/gSSpox5p+v1qcns2A06MiD/ZGDbeEFgLPFR80beBs4D7iz8/C0xiYXF+0++fTCn9dsz3fdk0jn1KPIOSpPr7FfC5lNL8x2DZKnJBTgBWAn8AfAV4knz7iF3J79hdj9KHFw2UJNXffwCfiIjrT4ejn4IXLweeAZ4j3ypii+ILvwncMcETvQq4J//4akU3D3gyDJQk1VCx5LsrsIj8s6ifA1v9Hax9E2xwTvF1uwCnAPuQ43M78PYJnvezwAdgziz4i4hY3KXDnxRvdSRJNRER84DDydc6HAzcQL7t0cUppVVNX7iMvOw7nZOQYeBiUjq27QNuk4GSpAqLiDnky8oHyXG6lRylZSmlJ8d50F7ANUz9ThKQf3y1iJRunc7xdpKBkqSKiYjZwEHkKB1F/nnQEHBhSunRST5J8734Jms1FbphrIGSpAqIiA2Bd5Cj9G7yz5SGgAtSSg9O80lrfTdzAyVJJYmIDYD9yFE6lnxV+BBwfkppZYe+yULyHtThjL8HdQV5D6r0l/WaGShJmkERMQt4KzlKxwOPkaN0Xkrp3i5+4y1ovaj7LRd1JalPFXtNC8lRWky+mcO55Cj9rMxjqzJvdSRJXVBE6XdpRGmYfKZ0BHDHlKfb+5CBkqQOiog30Zhon01evz0e+JFRmhpf4pOkNkXEG2hEaT45SkPAD4zS9BkoSZqGiNiORpS2Ai4gR+mGlNJwmcfWKwyUJE1SRLyG/POkQfJg4IXkKF2XUnqxzGPrRQZKkiYQEa8GjiNHaWfgYnKUvj9mI0kdZqAkaYzI7xk6lhyl3cl3Cx8C/jOltLbMY+snBkqSgIjYDDiGHKW9yXdXGAKuTCk9X+ax9SsDJalvRcQryLMUg8C+wPfIUbo8pfRcmccmAyWpzxRDf+8iR+md5FmKIeA7KaVnSjw0jWGgJPW8iBgg38Fh4qE/VYqBktSTiqG/Q4AlTHboT5VioCT1jIjYiNFDf7cz1aE/VYaBklRrxdDfAeQoHUMnhv5UCd4sVlLtFEN/+9IY+vslOUp7dmzoT6UzUJJqYYKhv7d1dehPpTFQkiprgqG/dzr01/sMlKRKcehPIwyUpEqIiF1ozFfMwaG/vudVfJJKExE70YjSpjj0pyYGStKMKob+RjaVXo1DfxqHgZLUdcXQ3/E0hv6W4dCf1sNASeoKh/7ULgMlqWMc+lMnGShJbXHoT91ioCRNmUN/mgkGStKkFEN/R9IY+rsWh/7URQZK0riKob/DyVE6BIf+NIMMlKRRmob+Bsm3F3LoT6UwUJIc+lMlGSipTzn0p6rzZrFSH3HoT3VioKQe59Cf6spAST2o2FTak8adwh36U+0YKKlHFFFaQCNKDv2p1gyUVHMO/alXeRWfVEMO/akfGCipJhz6U78xUFKFRcS2NKLk0J/6ioGSKiYitqKxPuvQn/qWgZIqICI2J79xdgkO/UmAgZJKExGbMnrobzkO/Un/z0BJMygiNqEx9LcfDv1J4zJQUpdFxMbAu3DoT5oSAyV1gUN/UvsMlNQhDv1JnWWgpDY49Cd1j4GSpsihP2lmeLNYaRKKTaX9cOhPmjEGShpHMV/RPPT3OA79STPGQElNxgz9LQaeIw/9HejQnzSzDJT63pihv8VAIp8pHYlDf1JpDJT61pihv7nkKC3GoT+pEryKT33FoT+pPgyUel5EvJ5GlBz6k2rCQKknjRn62x64EIf+pFoxUOoZxdDfceQo7UJj6O+/Ukrryjw2SVNnoFRrTUN/g8CbcehP6hkGSrXTYujvSnKUljv0J/UOA6VacOhP6j8GSpXl0J/U3wyUKsWhP0kjDJRKVwz9HQwswaE/SQUDpVIUQ38Hks+UjsahP0ljGCjNmGLobxGNob97yHcKd+hP0kt4s1h1VTH0ty+Nob8HyGdKCx36kzQRA6WOazH09wQ5SvumlO4p89gk1YeBUkeMM/Q3hEN/kqbJQGnaHPqT1E0GSlMWETvTmK8YwKE/SV3gVXyalIjYkUaUNsOhP0ldZqA0rmLob2RTaRvy0N+5OPQnaQYYKI1SDP0dT47SDsAy8pnStQ79SZpJBkqthv4uIUfp+w79SSqLgepTLYb+LiNH6bsO/UmqAgPVRxz6k1QnBqrHOfQnqa4MVA8qhv6OJM9XOPQnqZYMVI8ohv4OI58pHQrcSGPo79dlHpskTYeBqrGmob9B8tDfbTSG/p4o89gkqV0GqmZaDP3dQY7SBQ79SeolBqoGImID4ABGD/0NAec79CepV3mz2Ipy6E9SvzNQFeLQnyQ1GKiStRj6W02+IetBKaWflnlsklQmA1WCFkN/kM+U3gXc7nyFJBmoGdVi6O+84ve3GSVJGs2r+LqsxdDf+eSzpZuNkiSNz0B1wThDf0PA9Q79SdLkGKgOcehPkjrLQLXBoT9J6h4DNUUO/UnSzDBQkzBm6O+tjB76W1PmsUlSrzJQ4yiG/o4iR2l/4GpylC5z6E+Sus9ANWka+hsk3zH8OhpDf78p89gkqd9UP1ARWwInku+8MB9YBawAziGlx9t/eof+JKmKqhuoiL2AU8nxSOQ7L4xYAwSwHDiDlG6Z2lM79CdJVVfNQEUsBc4E5gKzJvjKYeB54BRSOnv0U8RSYJ+U0onFn8cb+rswpfRIx/8OkqS2VC9QjTjNm8KjVtMUqSJOZ5HjdgL5LMyhP0mqkWoFKr+sd825MO/vyac4GwPbkX8I9VHgGuB08mtymwL3Nx69GlgU+b1J/wDMIb80+Evgq8B5Dv1JUn1M9PJZGU79Agx8HPgk8AjwKHA2cD2wlhysk4HPv/Sxc++Cr5H/m1N8LIBHU0qfN06SVC/VOYOK2HIVrNwG5n6bfKuGiXwP+BCjzqBI8MJu8Fd35j/uCrwReDGltLDzByxJ6qYq7UGdeAPEC+QrGKYjYPgOeIGUvtDJA5MkzbwqvcS34CmYszmjq/k28pufBsjvml2PAfL7pSRJNVelM6j5rwSeAH5L48BuKH7dlnxN+WSep9MHJkmaeVU6g1q1D/nqhkvafJ5OHIwkqVxVOoNaMR/WnAYDf0i+PvwQ8lV7K4CRu7MOk6/mW1d8zfPkys7On15TfLkkqeYqdRUfsBKY++/Al2i8D2p74IPASeSX/N4x5qGLyO+PIvfqtZ24R58kqVzVCRRAxDLyRXzTeelxGLiYlNZ3hbokqQaq9DMogDPIZ0HT8XzxeElSD6hWoPJdyU8h37ZoKkbuxXdr5w9KklSGKl0kkaV0NhHQ5t3MJUn1Vq2fQTWLWEjegzqc8fegriDvQXnmJEk9prqBGhGxBa0Xdb/l1XqS1LuqHyhJUl+q1kUSkiQVDJQkqZIMlCSpkgyUJKmSDJQkqZIMlCSpkgyUJKmSDJQkqZIMlCSpkgyUJKmSDJQkqZIMlCSpkgyUJKmSDJQkqZL+D8kWvkPMq7vfAAAAAElFTkSuQmCC\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "from causalnex.plots import plot_structure\n", + "\n", + "_, _, _ = plot_structure(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Learning the Structure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "As the number of variables grows, or when domain knowledge does not exist, it can be tedious to define a structure manually. We can use CausalNex to learn the structure model from data. The structure learning algorithm we are going to use here is the [NOTEARS](https://arxiv.org/abs/1803.01422) algorithm.\n", + "\n", + "When learning structure, we can use the entire dataset. Since structure should be considered as a joint effort between machine learning and domain experts, it is not always necessary to use a train / test split.\n", + "\n", + "But before we begin, we have to pre-process the data so that the NOTEARS algorithm can be used." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preparing the Data for Structure Learning" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>school</th>\n", + " <th>sex</th>\n", + " <th>age</th>\n", + " <th>address</th>\n", + " <th>famsize</th>\n", + " <th>Pstatus</th>\n", + " <th>Medu</th>\n", + " <th>Fedu</th>\n", + " <th>Mjob</th>\n", + " <th>Fjob</th>\n", + " <th>...</th>\n", + " <th>famrel</th>\n", + " <th>freetime</th>\n", + " <th>goout</th>\n", + " <th>Dalc</th>\n", + " <th>Walc</th>\n", + " <th>health</th>\n", + " <th>absences</th>\n", + " <th>G1</th>\n", + " <th>G2</th>\n", + " <th>G3</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>GP</td>\n", + " <td>F</td>\n", + " <td>18</td>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>4</td>\n", + " <td>at_home</td>\n", + " <td>teacher</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>0</td>\n", + " <td>11</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>GP</td>\n", + " <td>F</td>\n", + " <td>17</td>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>T</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>at_home</td>\n", + " <td>other</td>\n", + " <td>...</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>9</td>\n", + " <td>11</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>GP</td>\n", + " <td>F</td>\n", + " <td>15</td>\n", + " <td>U</td>\n", + " <td>LE3</td>\n", + " <td>T</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>at_home</td>\n", + " <td>other</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>6</td>\n", + " <td>12</td>\n", + " <td>13</td>\n", + " <td>12</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>GP</td>\n", + " <td>F</td>\n", + " <td>15</td>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>T</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>health</td>\n", + " <td>services</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>0</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>GP</td>\n", + " <td>F</td>\n", + " <td>16</td>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>T</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>other</td>\n", + " <td>other</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>5</td>\n", + " <td>0</td>\n", + " <td>11</td>\n", + " <td>13</td>\n", + " <td>13</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 33 columns</p>\n", + "</div>" + ], + "text/plain": [ + " school sex age address famsize Pstatus Medu Fedu Mjob Fjob ... \\\n", + "0 GP F 18 U GT3 A 4 4 at_home teacher ... \n", + "1 GP F 17 U GT3 T 1 1 at_home other ... \n", + "2 GP F 15 U LE3 T 1 1 at_home other ... \n", + "3 GP F 15 U GT3 T 4 2 health services ... \n", + "4 GP F 16 U GT3 T 3 3 other other ... \n", + "\n", + " famrel freetime goout Dalc Walc health absences G1 G2 G3 \n", + "0 4 3 4 1 1 3 4 0 11 11 \n", + "1 5 3 3 1 1 3 2 9 11 11 \n", + "2 4 3 2 2 3 3 6 12 13 12 \n", + "3 3 2 2 1 1 5 0 14 14 14 \n", + "4 4 3 2 1 2 5 0 11 13 13 \n", + "\n", + "[5 rows x 33 columns]" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import pandas as pd\n", + "data = pd.read_csv('student-por.csv', delimiter=';')\n", + "data.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Looking at the data, we can see that features consist of numeric and non-numeric columns. We can drop sensitive features such as sex that we do not want to include in our model." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>address</th>\n", + " <th>famsize</th>\n", + " <th>Pstatus</th>\n", + " <th>Medu</th>\n", + " <th>Fedu</th>\n", + " <th>traveltime</th>\n", + " <th>studytime</th>\n", + " <th>failures</th>\n", + " <th>schoolsup</th>\n", + " <th>famsup</th>\n", + " <th>...</th>\n", + " <th>famrel</th>\n", + " <th>freetime</th>\n", + " <th>goout</th>\n", + " <th>Dalc</th>\n", + " <th>Walc</th>\n", + " <th>health</th>\n", + " <th>absences</th>\n", + " <th>G1</th>\n", + " <th>G2</th>\n", + " <th>G3</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>A</td>\n", + " <td>4</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>yes</td>\n", + " <td>no</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>0</td>\n", + " <td>11</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>T</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>no</td>\n", + " <td>yes</td>\n", + " <td>...</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>9</td>\n", + " <td>11</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>U</td>\n", + " <td>LE3</td>\n", + " <td>T</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>yes</td>\n", + " <td>no</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>6</td>\n", + " <td>12</td>\n", + " <td>13</td>\n", + " <td>12</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>T</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>0</td>\n", + " <td>no</td>\n", + " <td>yes</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>0</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>U</td>\n", + " <td>GT3</td>\n", + " <td>T</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>no</td>\n", + " <td>yes</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>5</td>\n", + " <td>0</td>\n", + " <td>11</td>\n", + " <td>13</td>\n", + " <td>13</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 26 columns</p>\n", + "</div>" + ], + "text/plain": [ + " address famsize Pstatus Medu Fedu traveltime studytime failures \\\n", + "0 U GT3 A 4 4 2 2 0 \n", + "1 U GT3 T 1 1 1 2 0 \n", + "2 U LE3 T 1 1 1 2 0 \n", + "3 U GT3 T 4 2 1 3 0 \n", + "4 U GT3 T 3 3 1 2 0 \n", + "\n", + " schoolsup famsup ... famrel freetime goout Dalc Walc health absences G1 \\\n", + "0 yes no ... 4 3 4 1 1 3 4 0 \n", + "1 no yes ... 5 3 3 1 1 3 2 9 \n", + "2 yes no ... 4 3 2 2 3 3 6 12 \n", + "3 no yes ... 3 2 2 1 1 5 0 14 \n", + "4 no yes ... 4 3 2 1 2 5 0 11 \n", + "\n", + " G2 G3 \n", + "0 11 11 \n", + "1 11 11 \n", + "2 13 12 \n", + "3 14 14 \n", + "4 13 13 \n", + "\n", + "[5 rows x 26 columns]" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "drop_col = ['school','sex','age','Mjob', 'Fjob','reason','guardian']\n", + "data = data.drop(columns=drop_col)\n", + "data.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we want to make our data numeric, since this is what the NOTEARS expects. We can do this by label encoding non-numeric variables." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['address', 'famsize', 'Pstatus', 'schoolsup', 'famsup', 'paid', 'activities', 'nursery', 'higher', 'internet', 'romantic']\n" + ] + } + ], + "source": [ + "import numpy as np\n", + "struct_data = data.copy()\n", + "\n", + "non_numeric_columns = list(struct_data.select_dtypes(exclude=[np.number]).columns)\n", + "print(non_numeric_columns)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>address</th>\n", + " <th>famsize</th>\n", + " <th>Pstatus</th>\n", + " <th>Medu</th>\n", + " <th>Fedu</th>\n", + " <th>traveltime</th>\n", + " <th>studytime</th>\n", + " <th>failures</th>\n", + " <th>schoolsup</th>\n", + " <th>famsup</th>\n", + " <th>...</th>\n", + " <th>famrel</th>\n", + " <th>freetime</th>\n", + " <th>goout</th>\n", + " <th>Dalc</th>\n", + " <th>Walc</th>\n", + " <th>health</th>\n", + " <th>absences</th>\n", + " <th>G1</th>\n", + " <th>G2</th>\n", + " <th>G3</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>0</th>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>4</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>4</td>\n", + " <td>0</td>\n", + " <td>11</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <th>1</th>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>...</td>\n", + " <td>5</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>9</td>\n", + " <td>11</td>\n", + " <td>11</td>\n", + " </tr>\n", + " <tr>\n", + " <th>2</th>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>6</td>\n", + " <td>12</td>\n", + " <td>13</td>\n", + " <td>12</td>\n", + " </tr>\n", + " <tr>\n", + " <th>3</th>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>4</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>...</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>1</td>\n", + " <td>5</td>\n", + " <td>0</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " <td>14</td>\n", + " </tr>\n", + " <tr>\n", + " <th>4</th>\n", + " <td>1</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>3</td>\n", + " <td>3</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>0</td>\n", + " <td>0</td>\n", + " <td>1</td>\n", + " <td>...</td>\n", + " <td>4</td>\n", + " <td>3</td>\n", + " <td>2</td>\n", + " <td>1</td>\n", + " <td>2</td>\n", + " <td>5</td>\n", + " <td>0</td>\n", + " <td>11</td>\n", + " <td>13</td>\n", + " <td>13</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "<p>5 rows × 26 columns</p>\n", + "</div>" + ], + "text/plain": [ + " address famsize Pstatus Medu Fedu traveltime studytime failures \\\n", + "0 1 0 0 4 4 2 2 0 \n", + "1 1 0 1 1 1 1 2 0 \n", + "2 1 1 1 1 1 1 2 0 \n", + "3 1 0 1 4 2 1 3 0 \n", + "4 1 0 1 3 3 1 2 0 \n", + "\n", + " schoolsup famsup ... famrel freetime goout Dalc Walc health \\\n", + "0 1 0 ... 4 3 4 1 1 3 \n", + "1 0 1 ... 5 3 3 1 1 3 \n", + "2 1 0 ... 4 3 2 2 3 3 \n", + "3 0 1 ... 3 2 2 1 1 5 \n", + "4 0 1 ... 4 3 2 1 2 5 \n", + "\n", + " absences G1 G2 G3 \n", + "0 4 0 11 11 \n", + "1 2 9 11 11 \n", + "2 6 12 13 12 \n", + "3 0 14 14 14 \n", + "4 0 11 13 13 \n", + "\n", + "[5 rows x 26 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from sklearn.preprocessing import LabelEncoder\n", + "\n", + "le = LabelEncoder()\n", + "for col in non_numeric_columns:\n", + " struct_data[col] = le.fit_transform(struct_data[col])\n", + " \n", + "struct_data.head(5)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now apply the NOTEARS algorithm to learn the structure." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "from causalnex.structure.notears import from_pandas\n", + "sm = from_pandas(struct_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "and visualise the learned StructureModel using the plot function." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ben_horsburgh/opt/anaconda3/envs/causal-test/lib/python3.7/site-packages/networkx/drawing/nx_pylab.py:563: MatplotlibDeprecationWarning: \n", + "The iterable function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use np.iterable instead.\n", + " if not cb.iterable(width):\n", + "/Users/ben_horsburgh/opt/anaconda3/envs/causal-test/lib/python3.7/site-packages/networkx/drawing/nx_pylab.py:660: MatplotlibDeprecationWarning: \n", + "The iterable function was deprecated in Matplotlib 3.1 and will be removed in 3.3. Use np.iterable instead.\n", + " if cb.iterable(node_size): # many node sizes\n" + ] + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "_, _, _ = plot_structure(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The reason why we have a fully connected graph here is we haven't applied thresholding to the weaker edges. Thresholding can be applied either by specifying the value for the parameter `w_threshold` in `from_pandas`, or we can remove the edges by calling the structure model function, `remove_edges_below_threshold`." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sm.remove_edges_below_threshold(0.8)\n", + "_, _, _ = plot_structure(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this structure, we can see that there are some relationships that appear intuitively correct:\n", + "\n", + "* <strong>Pstatus</strong> affects <strong>famrel</strong> - if parents live apart, the quality of family relationship may be poor as a result. \n", + "* <strong>internet</strong> affects <strong>absences</strong> - The presence of internet at home may cause student to skip class.\n", + "* <strong>studytime</strong> affects <strong>G1</strong> - longer studytime should have a positive impact on a student's result. \n", + "\n", + "However, there are some relationships that are certainly incorrect:\n", + "\n", + "* <strong>higher</strong> affects <strong>Mother's education</strong> - this relationship does not make sense as students who wants to pursue higher education does not affect mother's education. It could be the other way round.\n", + "\n", + "To avoid these erroneous relationships, we can re-run structure learning with some added constraints:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sm = from_pandas(struct_data, tabu_edges=[(\"higher\", \"Medu\")], w_threshold=0.8)\n", + "_, _, _ = plot_structure(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Modifying the Structure" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To correct erroneous relationships, we can incorporate domain knowledge into the model after structure learning. We can modify the structure model through adding and deleting the edges. For example, we can add and remove edges as:" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "scrolled": true + }, + "outputs": [], + "source": [ + "sm.add_edge(\"failures\", \"G1\")\n", + "sm.remove_edge(\"Pstatus\", \"G1\")\n", + "sm.remove_edge(\"address\", \"G1\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now visualise our updated structure to confirm it looks reasonable." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "scrolled": false + }, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "_, _, _ = plot_structure(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can see there are two separate subgraphs here in the visualisation plot: `Dalc->Walc` and the other big subgraph. We can retrieve the largest subgraph easily by calling the StructureModel function `get_largest_subgraph()`." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAagAAAEYCAYAAAAJeGK1AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8li6FKAAAgAElEQVR4nOydd5icZfWG7ycQSEILXUoITZHeq0BCl957R0Kx4E8pigJSlACCAioiAalGekeRGgHpHWkivfcgJaEkz++P8y6ZLJvdmdmZ+WZm3/u65kp25itnZme/873nfd7nyDaZTCaTyTQb/YoOIJPJZDKZrsgJKpPJZDJNSU5QmUwmk2lKcoLKZDKZTFOSE1Qmk8lkmpKcoDKZTCbTlOQElclkMpmmJCeoTCaTyTQlOUFlMplMpinJCSqTyWQyTUlOUJlMJpNpSnKCymQymUxTkhNUJpPJZJqSnKAymUwm05TkBJXJZDKZpiQnqEwmk8k0JTlBZTKZTKYpyQkqk8lkMk1JTlCZTCaTaUpygspkMplMUzJ10QFkMpmENAewO7AUMBgYCzwKnIP9dpGhZTJFINtFx5DJ9G2kFYFDgQ0BAwNLXh0HCPg7MBL7vsYHmMkUQ05QmUyRSPsBJwED6L7kPhEYDxyIfXojQstkiibPQWUyvUTSIpIelvShpAO62W4+SR9JmgpgqPSfP8EpwCB6/lvsl7Y7KSW1TKbtyXNQmUzvOQS41fYy3W1k+yVgegCkFYfCQlNVfpPYkaTux76/qmgzmRYhj6Aymd4zFHi8wn0O7Vfl399EGDAh5qwymbYmJ6hMdUhzIB2MdD7SNenfg5FmLzq0RiLpFmAt4PepfPdDSQ9J+p+klyUdWbLt/JL8lDQXIYj4kiOBXUp+foFQRnyRfh4O/Bz4FjAd9HsONj5FWlDSWZJel/SqpF92lA8lLSzpn5I+kPSOpIvq8gFkMnUkl/gyldGz4uxopD6jOLO9tqQxwAW2z5Q0HNiNGFEtAdwo6WHbV3bsMzRer1iddD4h5VsEmAgTz4SrgDuBhYHpgGuBl4E/AccANxDJcxpghWrfYyZTFHkElSmfmJwfA2xOqM4GdtpiYHp+c2BMX5zMtz3G9mO2J9p+FPgrMKx0m36wJF/97HpkD2Bx4q7yfRj4FCwG/J/tj22/BfwW2CFt/jlRepzb9njbd1T9pjKZgsgJKoMkS1q4h4065NBlK86GwmkjpZNrE2VrIGllSbdKelvSB8B+wGyTbROLcCtmSMn/XwQmxOf8uqSxksYSI6c50iaHxKm4V9Ljkvaq5pyZTJHkBJXpmSjrdSSn8ncDLQ/7IvWl8tJo4GpgiO2ZgNOJRPElDoeIyZgO+KTk5ze6OHDpQYYAU8MEYDbbg9NjRtuLA9h+w/YI23MD+wKn9XgTksk0GTlBZcrhUKJ0VzGK+Y++pDibAXjP9nhJKwE7dd5gIjxGzNd9yTLAbcBLwAfAyB5OMheM+zo8BZwkaUZJ/SQtJGkYgKRtJc2bNn+fmPOa2Js3lsk0mpyg2hBJP0mqrg8lPS1pHUlTSfqZpGfT8w9IKq0arSvpmVQu+oMkAbwlzXk0bDoU+s1BzO5/ULLT1cS8yGBCafZk51jiO7bR6tL6ku5P6rY3Jf0mxTpc0iud4n9B0rrp/0dKulTSRSnuByUtXcOPq9Z8Fzha0ofAEcDFnTd4Ec6j06hqPWB7woRveWCTns+jfWAz4gbgCSIJXQrMlV5fEbhH0kfEr+mHtp+r6h1lMkVhOz/a6EGIvF4mJscB5gcWAg4m7twXIS6OSwOzpm1MKMAGA/MBbwPfts334eIFYeKz4A/BW4J3ARv8NHgQ+AbwZ+DjwQuBP02vDwXfGP//ZEgop3dN55seWCX9fzjwSqf38AKwbvr/kcSE/zZAf+Ag4Hmgf9Gfda8ecLlhgtNnVeFjguGywt9DfuRHnR95BNV+TACmBRaT1N/2C7afBfYGDrP9tINHbL9bst9xtsc63A5uJapO3AZrHAhakMgqI4ELifU5FwEbE3f/HZljHKF77sTAGUK1trCk2Wx/ZPvuCt7TA7Yvtf058Bui3LhKBfs3IyMJb71qGE/PVcBMpuXJCarNsP1f4P+Ikcdbki6UNDcxr/5sN7uWzst/QrLkGQvTDy15YSiRnN4EXks/d9AvneTVLg5+Pvwb+AbwlKT7JJVRxfqSlzv+Y3si8AowdwX7Nx+xRuxAJtdGlMMnhGFstjnKtD05QbUhtkfbXp3IHwaOJy7yC1V6rMHw0YslP79ErMOZk8gQpa85nWSeLo6zHLxqe0dCBn08cKmk6YCPKVEHJieEzm4UQ0pe7wfMS+TH1iZcyQ8EPpnY88LdiUxKTtnNPNMnyAmqzUjO2mtLmpYoBY0jLm5nAsdI+rqCpSTN2tPxVoc7fgN+HvgI+BkxmT81sB1wHXAzMUl0ElFbXO2rhxn3Y5ha0uxpBNQhs54I/AcYIGljSf2Bw9JhSlle0laSpiZGh58ClZQImxf79Gdgg2vgc8f7Gtdpi3HE7/FKYFhOTpm+RLY6aj+mBY4DFiXyxp3APkRVblrC/mY2QqK8ZU8H+wV8f3bYYk2YejywAfC79NoiwAXAD4iy3jLANYSsrBM6I6apHpc0iBh47WB7HDBO0neJBDoVcAJRwivlKiIvngv8F9gqzUe1Bd8Iq6IbDXsCuz8KW46HISuFa8ejwLnkjro9kzsStx25YWGmZ6TLCfuiakbcE4Ersbeu7tQ6EljY9i49bduqSPoXcILtq9LPewJr2t6z2MhahNyRuG3JJb5MOWTFWZ2QtDixFOC6gkNpTbI/ZFuTE1SmZ7LirJ6MAM62/UWPW2Ympwp/SHJH4pYil/gy5TPpgjCAbi4IhomKkVNWnHWDpAHEfNuKtp8veT6X+HoiOhLfcxaE5UhlfDI/XP8iPGz7mJrHlqkZeQSVKZ9INsMIRVmHQrCUcZ/BxP/Ag2TFWTlsTSxCfr7HLduMZGF1QS8Ocag62UVNiXOA1Sd/asALQE5OzU9W8WUqI8p1W6fOuV9RTA2HN++C/ZzLeuUwAvh90UG0HKHW27DH7aZMP2AjpNmzuq+5ySOoTHXYb2OfiL0b9mbp3xPvigZ9C6TJ/8wUkLQI8E3CyLWt6cK8eGPSkjpJH0l6JG33pUlw+nmyUZakXSW9OAiePzqWJABhgTIIKPXtepBY7f0Y0ZDrLsIapaMR127Qf9P4rn5pWCzpEElvSXpd0haSNpL0H0nvSfpZSRz9JP00GS+/K+liSbPU9lPLQE5QmRqTJvv/TIwOMlNmb+Bc258VHUg9SYn4+8Q82wzEUrqngGOBi2xPb7tHd3pJiwF/BHZ9G64cC/07Fst9jXAcLrWNP59oLbwk0ZBrVWKheccK8X4w1ayTmjt2HGYAYYRyBDAK2IUwl18DOFzSAmnbHwBbEOXuuQkn+T+U9YFkKiInqEw9OAvYOYkAMp1ILh+7E4uT250pmRdXyjbAtbZvmw5mPIbJL167E4vGO074V2DXHg7Yf/IeZ58Dv0oLwC8kFrOfYvtD248TLU06Eul+wM9tv2L7U8L3cpvkdJKpITlBZWpOmvR/ENiq6FialM2Bf9t+puhA6k035sWVMjeTTIPHTgeU+nRtTmSQ54EbgZmAlXo44OeTr+171/aE9P8O8c+bJa+PIxkoEx6XV6TeaWOJNmgTCIvKTA3JCSpTL0YRFkuZr7IP8fn0CaZgXtzV+pbJjIOJslsHrzPJNPjRj2Fc6ZzTAMIb8gKivFc6eupK6jcRJrwLb1X0RibxMrCh7cEljwG2uzLyz/SCnKAy9eJqYFFJ3yg6kGZC0kJEqeiKomNpBN2YF78JzJ/c6Tt4GNhBUn9JKxBlvQ4uBTaRtPrdMPpw6N+5f/1uhKT8aiZPUHMSi806T/bdGT591XA68CtJQ9N7nF3S5lUeK9MNOUFl6kKa/D+HLJbozHeA821Xax3VanSYF79DCO7mIHzzLkmvvyvpwfT/w4mWMO8DRwGjOw6S5oG+B4xeFR79HP4zb6cTfYu4oC3H5H3K1gYWJ4Zjs8VTE9+BV96r3Bmlg1OIPHiDpA8JZ/2VqzxWphuyk0Smbkj6OnAHMKTd1WrlkNqJvASsbfvJbrbLThI9EQaxY5i8JMjawE6ERLIbPiEWkue1ek1OHkFl6kYSATxOzGFnYBPgv90lp0yZdOEPeR+hzNm++z2zP2QLkRNUpt5kscQk9gHOKDqItqGkI/FuwLrAycAMXW+dOxK3IDlBZerNFcAykhYsOpAiSRPqKxGT/ZlaYZ/+Eay1HXzyPny2R+5I3FbkhWU9kbt09grb4yWdT0wL/Kyn7duYvYDRqYtwpobMEEuinjBsRNd/q7kjcYuSRRJTInfprBmSFgVuAeZrp1bt5ZIcBl4g1s48Vsb2WSRRAYqOz9fbzuXTNiOX+Loid+msKUkU8F9CJNAX+TbwSjnJKVMZkuYC1iIZv2bai5ygOlNGl84NgXPjv1/p0ilpfknOvlxfYRR9d01UFkfUjz2AS21/WHQgmdqTS3ylTGFtRZl8AgxTLEh8Huif23hPQtJAYkH/srZfKjqeRiFpHqLrwxDbH5e5Ty7xlUFyoXgG2NH2vUXHk6k9eQQ1OYcyucNxJQxI+2e6IIkDRhNOCn2JvYCLy01OmYpYG/iQWAKVaUP6dIJKDdIOlfREP2ns7rDZeOj3PjFZMjswc/r/KyX7DWdSn4QJwEHAbNBvAdhy7h7XCfZpzgD26ivlz3SH/x1yea9ejABGOZeB2pY+naASOwMbPA2/+Q/ol8SKvj2BFwlfmoFEx7WuGAVcCzwE3A/jB/TdeZYeSSKBVwnRQF9gPaKNw4M9bpmpCEmzE80P/1J0LJn6kRMU/N72y1+Hrx8O/f5KLKrYmpiImgH4OfDPKex8MdHsZkjsN/DwmH/KTJm+JJbI4oj6sTtwpe2xPW6ZaVlygprUBG3wUOA1Qu2wL+GIPCOwJrHib0IXO7/GpCY1AIvnz7QnLgLWSOKBtkXS14g5kix/rjGSRCrvFR1Lpr7ki+mk/DL2JaJt50nA08A9wP+A29IGXRW652JShgN4MiqEmSlg+yNi4NnuCrU9gMts/6/oQNqQNYEvgDuLDiRTX3KCgu9JmvcZeOYYmLg9IQsaSHilvEc0ppkS2wGnEiKK92Dc0TB/vQNuA84A9u7UrK5tSO9rb/Idfr3I4og+QlteICpkNHDDInDgQuDDiDmlcURzs1XofkZ/BDFTuzSwHAwYN0ngl5kCSTTwLiEiaEeGE5XivDanxkiahRDWnl90LJn606cX6kp6Adjb9k3picsJ+6JqEvdE4ErsrWsWYBujcN5Y1/Y2PW7cYki6ELjD9u+r3D8v1J0Ckn4IrGR756JjydSfPIKanJGENX81jE/7Z8pjNLCOpDmLDqSWJPnzt8ny55qTxBH7kEunfYacoErpoktnmeQunRWSxAOXE2KCdmI34Crb7xcdSBuyKtCfKa/6yLQZfTpB2Z7/y/LepCe/7NJJz4q83KWzd5wBjGgXsUSWP9edLI7oY7TFhaHmRLIZRnThHE+nLp2GceOB/8E/yF06e8O9RIIfXnActWIN4qblX0UH0m5IGgxsyZeNBDJ9gT7hiVYVUa7bmphTmKxLp+DRxeGbz0WPn1zWqxLblnQGcWd8S9Hx1IB8h18/dgJusP1W0YFkGkefVvH1BklLEzZ889vuymQiUwaSZibsoRa2/U7R8VRLkj8/Rw3eR1bxTU4qnT4EHGz7xqLjyTSOXOKrEtuPEE5HfcX4tC4kMcHVhLigldkF+FsrJ9kmZgXCdezmogPJNJacoHpHR3kq0zvOAPZJd8otRxZH1J0RwJm2s41YHyPPQfWOi4BfS5rb9mtFB9PC/IsQF6wO3F5wLNWwCtGwckzBcbQdkmYAtgUWKzqWwpHmoNN8OPAocA7220WGVi9yguoFtj+SdAlhfPqrouNpVZJYYhSxCLMVE1QWR9SPHYAxtl8vOpDCkFYkunVvSHhWDyx5dRxwNNLfgZFpLWfbkEt8vaetjU8byPnApkk00TJImoksf64nfbunVliCjSEs2AYweXIi/TwgvT4mbd825ItqL7H9APA+sG7RsbQySVzwN0Js0ErsBNxk+82iA2k3JC0LzAHcUHQshRDJ5iSid2pP1+p+abuT2ilJ5QRVG7JYojaMooXEEtkbru6MAP7c6ss4JL0gaZykjyS9KekcSdN3s/0eg6WHmZScuuUFQESDLCYlqRVqEXvR5ARVG0YD67ab8WkBjCHKFSsXHEe5LE9MVt/U04aZypA0HTH/9OeiY6kRm9qeHliOkM0f1t3G88K8xN9CNQwg5qxanpygakCJ8enuRcfSyiSRQYdYohXI8uf6sS1wp+2Xe9yyhbD9KvB3YAlJe0h6TtKHkp6XtLOkRYHTn4RZp4d+g9N+1wHLEovBhgBHlhxzzfTvYGB64C7odwRsNn0IuACQNL8kS5o6/fyVc9f1jVdJTlC1YxRhfNoS5akm5lxgyyQ+aFpSiWY74OyiY2lT2lIcIWkIsBHwJNGMe0PbMwCrAQ/bfvIAuGplmPgRoSMHmA44L/18HfBHwigU4Lb071jgI8Ly3cA3YcEpxDBdV+eu6RutETlB1Y57CGPZ4QXH0dIkscFNhPigmdkB+Gde/1Z7JC0BDCVEM+3ClZLGAncQ7UKOJNb+LSFpoO3XbT8OMDvM26/TtXk4sCTx5FLAjnTfc2QqmHom6E4R2+W5m42coGpEKk+dQeuUp5qZUTS/6CQ7R9SPEcDZtr8oOpAasoXtwbaH2v6u7Y+B7YH9gNclXSfpmwDTdCGMuAdYC5gdmAk4HejJU6t/9M76Ct2du9nICaq2XABsKGm2ogNpcW4CZpa0fNGBdEUyCp4buL7oWNoNSQOAnYGzio6l3tj+h+31gLmAp0g3PJ930TB1J2Az4GXgAyKzdKwK72pOYTrgw8n72X2tnHM3GzlB1ZAS49Ndi46llUmigzNp3tHoCOCsVpc/NylbAw/Yfr7oQOqJpDklbZ7mgz4lpo8mAvSHJ18Gf1ay/YfALIQ8715CNtzB7MSF/LmS55aCTx+C2STNl+Zzv1T1dXfuZiMnqNrTUmt5mpizge26Wy9SBJIGEVMA7SJ/bjbaUhzRBf2AHxMdEd4jGqTuL2nFP8DgxUBfAzpKMacBRwAzAEcT6pwOBgE/B75FKPnuBjYA94cLCa++B4jWQN2eux5vsrfkflA1JiWmJ4ARtu8oOp5WRtKVwLW2zyw6lg4k7Q5sZ3vjOp6jT/aDkrQIMfc/nz3ZAKKtkTQNMXI8gCi5/WEsrDsDrNev6wpeT0wErsTeupZxFkEeQdWYFlzL08w0o1giiyPqxwjg3L6SnFKp7XCiYec+wAnAwsBDG8Eyn0O1JeTxwMgahVkoOUHVh/OAzVrN+LQJuR6YO4kSCkfS4sTakuuKjqXdkDQt0bSyaUbL9ULS8pLOJcQJ8xHrkdYiFvCeAJxzJ+wyLfyALgQTPfAJcCD2/TUNuiBygqoDyfj077Se8WlTkUQIZ9E8o6i9Cfnz50UH0oZsAfzb9jNFB1IPJPWXtL2kfxGuM08AC9seYftRSUsB9xEJa2nbN2KfDhxIJJ2eRAwTmZScTq/fO2ksOUHVj+wsURv+DOyYxAmFkeTPu9AH7vALYgRtKI6QNLuknxNlvP0JA9iFbB9v+11J/SQdSLSzPxHY1va7Xx4gks0wwjhiPNH/qZRx6fkrgWHtlJwgNyysJ2OIXi0rEevsMlVg+yVJdxPebEX2XNoKeKjd5c9FIGkhYGngiqJjqRWSliNKdFsAlwEb236k0zZDiO/0NMBKU/xuRblua6TZ6bqj7rm5o26mImxPlNSxlicnqN4xCjiIYhPUPsDvCzx/O7M3cJ7tT4sOpDdI6k80rzyAsGr6A/D1VPLvvO0OhB/eycDxZa2piyR0Yi1jbnayzLyOpPYbTwFDk+N5pgrSH/6LwHpFeIZJ+gbRin5IIxRmfUlmnn63LwFr236y6HiqQTGyGQF8l1gveypwZVdWTZIGEzc6KwI7u03EDPUiz0HVkWR8ejPNb3za1CRRwtnEnXYR7E0fkj83mE2B/7ZicpK0jKQ/A88Q8vBNba9p+9IpJKdhhGv4/4DlcnLqmZyg6k8zruVpRc4EdklihYaRFlHuThZH1IuWEkdImlrSNpJuI9wZ/kuU8fay/dAU9plG0nHAX4HvlZjFZnogJ6j6cyMwa7Man7YKaQL5IUKs0Eg2B56w/Z8Gn7ftkTSUEBFdWnQsPSFpNkk/JUp4PwR+Byxg+1h3I1CQtBgxB70osIztvIauAnKCqjMlxqd5FNV7zqDxn2Nf8YYrgu8Af7HdWTrdNEhaOomdngEWIdpmrGH7ku7Wwyn4AdFP8LS031uNibp9aA4VnzQHXcsnz2kT+eTZwGOSDrL9UdHBtDBXA3+Q9I1GjGgkLQgsQxvJn5uF1Hp8L2DDomPpTIptc0KNtzCRYBYpN8FImptYvzczsGq7Lj5uBMWOoKQVkS4nFFpHEQshN0n/HgW8hHQ50ooFRtlrbL9KqMC2LzqWViaJFM6lcWKJ7wAX2B7foPP1JTYEXrH9WNGBdCBpVkk/Icp4PyYS0/y2f1VBctoKeJAwFV89J6feUVyCkvYjFrNuTrQ5Gdhpi4Hp+c2BMWn7ViaLJWrDmcDuSbxQN5L8eU+yMWy9aBpxhKSlJI0CngUWA7ay/S3bF5VrayVphqToOwHY0vaR2RKr99QkQUl6QdK6XTy/hqSnu9hhP8LyY1BHDGOAeacc4yDgpBZPUtcD8yTPrUyVpNLeE8SNSz3ZGHjO9hN1Pk+fQ9I8wOrARQXGMJWkLSXdSvxtvkSU8XavVP4taTVCPj6BEELcVfuI+yZ1HUHZvt32IpM9GeW6juRUCR1JaoUahddQ0rqIP5NHUbWgEWKJLI6oH3sBFxUhtZY0i6SDidHSIcTveH7bx6R1i5Ucq7+kYwjz1wOT8WueY64hRZT4DiVKd9UwgJLWxVMiTXI2I2cBOxVtfNoGXAEsm0QMNUfSfMDKtID8udWQNBUxt9fQ0qmkJST9iUhMSxKmrKva/ms1C7CTu8i/gBWAZW1fWduIM1DbBLWMpEclfSDpIkkDJA2X9ErHButI6ywDW8wA/bYlFAOHdTrIScAcRFvJs0ue/xQ4CPoNgS2nkt6SdLqkgQAd55H0E0lvdNq1abD9ErEmYpuiY2llkmjhAuJCVw/2AkbbrrQXT6Zn1gPesf1gvU+UynibS7oZuAF4FVjU9m6276vymJK0L3AncA6wke3XaxZ0ZjJqmaC2A74NLEDIxfcofVHSNA/ApbvBF+8BO/JV7e4bwAfEt+gs4HvA++m1nwL/AR6G8S/BKcA8wBElu38NmIUwaWzmbrZZLFEbRgF7JjFDzSjqDr8PUfeOxJJmTi0s/ktUXM4iynhH236jF8edk1jqsA+whu3TnM1M60otE9Sptl+z/R5wDbF+pJRV+sG0P4L+/Qk7gJU6bdCfyDj9gY2A6YGnAROF4t8Cs8LAeWLB3LHADiW7TwR+YfvTZl74R9ijLJxWmGeqJIkXniPEDLXk28Brth+t8XH7PJK+BqxNWP7U4/iLSzqd+F4sC+xgexXbo3vroyhpU0II8RixtqnlvANbkVomqNI7k0+I/FLK3HPCp6Xd+4Z02mBWJl85PAj4CHg7HXB5YhXv9DEAux6YvWTzt1thvUoTGJ+2E2dQ+9FyFkfUjz2Ay2rp7J/KeJtJuokwZn4DWMz2LrZ73eZG0nQp6Z0KbGf7Z9k0uHE0UiTx+pswbel4+OUyd5yNWBT1OGEx8RH81fZMtkuTYCsNtc8Edm208WkbcimwchI19Jokf16DAuXP7YqkfsRNWU3Ke5IGS/oxYUF0ODEfNDStP6rJnJBCcfwQcflZxvbttThupnwamaDumgCfngKffwFcBdxb5o79iML1j4A3o8Xxo5LmkbRBnWKtK7afI8oFWxYdSyuTRAyjCVFDLdgTuDhLhevCWsDHlP9n3yWSFpV0GtFCfQWilc1Kti+oVcPD5Fh+OFGOPyytjfqgFsfOVEbDEpTtz5aD7c6GqQcTEqxNgGnL3P94whRrFRjQD34B3ETMRbUqWSxRG0YB30nihqpJd/hZHFE/RgCjqhEVSOonaRNJNxBr+t8BFre9k+27aylUSO3nbwOGAcvbvrhWx85UTuM76ob33uZAv5WB/Yjb1nIwTBRcib113eJrEMmq52WyX1evkXQPcHRvWhlIWh8Yabvwtijt1lE3dZx9hmhP8X5P25fsNxNxefg+Ud0/lVjgW/PW8JKUznU8IcA6JXUiyBRIQxfqShp2CZzxOYw/l7Ar/3YF+48D/RT+WafwGkqaaD2PLJaoBbUQS+xDHj3Vi92Aq8pNTpK+Ken3RBlvZWBXYEXb59UpOc1GzGf+H9F6/rc5OTUHjXaSWGQ7OHcQTHUi+FJiQW6ZfHIn/PZ4OFDSH9rEjWEUsEe9jU/7ABcBaySRQ8Wk9S3rEPNZmRqSRiY9rn1KZbyNJF1P3IS+Dyxpe0fbd9VrvVGax36ESIYrNZO7eqbBCcr2Gbbn/Nwe8Bh8d+NQj/d0pzKR2O7Ade0DgaWBmYAHW71LbTI+fRLYrOhYWpkkariY8qvFndkDuLyW8ufMl6xB/A3/q6sXJc0o6QBiyeMviTVSQ20fntrU1AVJAyX9jkicu9o+qBWWqfQ1imu3YZ9OTEReCYwn1Hlf8hlM+AK+SK8PS9tje6ztXYAjgb9L+llvJ8gLpoguse1Ih1iiou90usOvmfw58xW6FEdIWiQliBeAbxE3CcvbPrfeiULSssADxDrKpW3fUs/zZaqn8SKJLqPQ7HTqqHsfjN0MNnkDFprS8F7SEKKB3TTEXdDzDYu5RqS1UC8T5YWWi7+ZkPQAcKjtGyrYZy3COmvpZrGtaReRhKRZCFeHhW2/k24eNiA61S5P3BT80fYr3RymlvFMBRwEHEisWhndLL/zTNc0h+t3tHU/sfSpleLOdi1gOHBr17v55dSH6kfAvZIOAs5rpS+d7fGS/kJInDt752YqYxQhdig7QaXtq5I/Z3pkFwELe1QAACAASURBVOBvwGeSfgD8gFgLdSrRFLBhlmSShhKiJAjBxYuNOnemeopt+d4N6YLRY/nL9kTbJwHrAgcDF0uatQEh1pJRwF5N3CakVRgNrJNEDz2S1FsbEsvyMjUklU6/Tyx1fIGYi/oOsJztsxuVnJL7+C7AfcB1hEovJ6cWoWkTVOICYKN0IekW248QK8tfBh6WtF69g6sVth8nVES1Nj7tUySRw+V0ctLvht2AqytZm5PpnqTG2wC4A1iQED8sbXs7RwPTho1UJc1MiC5+Bmxg+wTbExp1/kzvaeoElS4c1xDrIMrZfrztHxNqrj9LOrmjZ1QLkMUStWEUsHdPYoly5c+Z8pA0g6TvAU8Qi10nAkckc9VybTdrGc/ahHz8LUJ88VCjY8j0nqZOUIkzgH3SBaUsbN9EyNHnAe6TtHS9gqshlwCrJuFHpnruIRShw3vYbvX07x11jabNkbSwpJOBF4k5430Ide6ShClyo+OZVtKJwPnACNsHNHn7nUw3tEKC6riAfKuSnVJfqu2AE4CbJB1UqQS5kSTj0wupnfFpnySVkMrxOazaG66vk+Z11pd0LXAXsUxkGdvb2L6NMHC90fZbDY5rSWKuaQGirPiPRp4/U3ua9oLdQckFp2IrGwfnASsS/n83N/kI5QxqYHya4QJgwynNXaa5ic2YpOrKlIGk6SV9lyjjnUisURxq+6e2X0rbiAbbRqV5rx8BtxB9Tbex/U6jzp+pH02foBLnAZulC0vF2H6BKPncADwgacfahVY7ktDjDWKtSKZK0tzl1YQIoit2Af6eL2LlIWkhSb8hynjrAPsTI5Qz08i/lI6+ojc1KLZ5ib/rbYCVk0Iwj4rbhJZIUOlC8ndg514cY4LtkYSs+AhJf5E0uFYx1pB6dInti4wCRnSeu8ziiPJIZbx1JV1NzOt9TkjEt7Y9ppskMAI4sxFmq5K2Ax4kWnAMc/RZy1SCNAfSwUjnI12T/j04mScUTnM4SZRBUuWcTA1W/Cej2ROATYHdbY/pfYS1QdL0hFR+MdeoM2hfJCWiJ4B9XNIJVdIqRAnwG83qWF2kk4Sk6QjV7AHABGJR7V+6GCl1tW/Hd3dx26/VMcaZgN8BqwC72O5VE8Q+SXQLPpS4YTfRNbiDcYCIQcFI7PsaH2DQEiOoxBhgELBSbw9k+xPb3ydKFaMlnSCp3N6JdaUGxqcZuhVLdIgjmjI5FYWkBSWdBLxElJi/Byxle1Q5ySmxA/DPOienNYhu1B8Dy+bkVAXSfsT1dHNgAJMnJ9LPA9LrY9L2hdAyCSpdUKoSS3RzzL8RcvRvAPdIWrxWx+4lZa3lyfTIZHOXkmYEtgLOKTKoZiGV8daRdBXRin0isILtLW3fWkWlom7iCEnTSBpJ3LwdYHt/2x/X41ztgqT5JXkyh5pINicBg46Efrt0f4h+xKDgpKKSVKtdAM8BtkoXmprg8AHckigZjJF0QBMkhgeAD4gJ6UyVlMxddvwd7gTcbPvN4qIqHknTSdoXeIwo4f2NUOMdXK1hsaRliPZu19cu0i+PvShwN7AEIWe/ptbn6BNEWe8kIulUQkeSWqH2QXVP0RfiikgXlpuBmqrwkhz9LGBV4iJ2vaS5a3mOSuMhiyVqxRlMEkuMSD/3SdId9a8JNd6GwA+BJWz/qQajkRHAWbW0EkojvO8BtwOnA5v19ZuLXnIoUbrrkS+++tSAtH9DaakElahpma8U2/8lHAb+BTwkaet6nKdMRgPrSpqjwBjagTFETX03YBYaJH9uFtJFfi1JVxAjcxGtXbawfXMtJNlJdLQD8OfeHqvkmHMRI7vdgdUczU5bQ9FVZyT9VNKzkj6U9ISkLdPzU0k6UdI7kp6j1NtTmuM52GgY9JsBWA8oXWPxAvHFOAuYD1g7PX83sBowGPotBVsuL21eEscekp5LcTwvaef0/MKS/inpgxTLRVW/Wdst9SCS6guEv1Y9z7My8AxwNjBjQe/1z8DBRX/mrf4AfgI8BRxWdCxlxrsncHYvjzGIGNU8RqgZ9wOmr1O8uwPX1fB4WxLrAY8G+hf9+2i2B7AtMHe6Fm5PCEbmSr/jp4AhxM3YrYRCb2rDwSvDhB+Bx4P/CZ4evDPY4OdjO+8K/gj8CfgV8Czg68ATwNfD+AFxrtmB6YD/AYukmOYi1JsQBr0/T/ENAFav9r223AjKIZY4kzobq9q+B1iWWP/xsKSKrJZqRJdreTIVczEhhLmk6EDqjaShkk4g1HibEr3SFrd9ukMhWg/2oQalU4Xh7JmES8VWto+w/Xmvo2szbF9i+zVHq6GLiBvplQhrt5Ntv+ywehvZsc+/YZX7od8xRP+TNYkvR2eOJDLPQFIrifRInSanXRTGpqcgRDVLSBpo+3VHVwaIa+ZQYG6HgXfVfpctl6ASZwPbpXUXdcP2R7b3If7IL5P0S0n963nOTtwNfEqYb2aqZ23ijnytogOpB6mMN1zS5cTC1akJV4XNbN/kdFtbp3MvTnjfXdfL46xKyMdFCCHurEF4bYmk3SQ9LGmspLGEeGQ2YlRV6hz/Zd+rl2GOmYnk08HQLo5d6gP3InFHN7jk8QTMCczlmLPcnhi1vS7pOknfTLseQvwe75X0uKSq/UVbMkHZfpWYON2+Qee7CliGGFHdKWmRBp23ah/CzGSMAE6jzdqZSBokaW+ircTpxPzaUNs/tv1sg8IYQZQju5hX7xlJ/SUdRfj6HWz7O7Y/rGmEbYSiM/AoohnkrLYHA/8mEsLrTJ5j5uv4zzzw9vtEfa6Dl7o6fsn/hxArtseWPMbDaNvHAdj+h+31iPLeUykubL9he4TtuYF9gdMkLVzN+23JBJUox7G6Zth+A9iEmBe6Q9L+DSq9dTRtbLUuwU2BpKWItivHAbNIWr7gkHqNpPkkHUfc5G4BHEQ4j5xWxzJeV3EMIOzHqmqrIenrRLeClYlR0+U1DK9dmY6YL3obvnQdWSK9djFwgKR509q/n6bn+78M45YDfgF8RnzoPWn1d0nb/IOwFBkH434HH6fjzylp8+Q88inwEVHyQ9K2Co9EgPdTvFUtjG/lBHU9ME+6ADUEB38klH7fAa5Rme3Fe3HO94jvyZSMTzPdMwL4c7rDP4sWHUWlMt6aki4lSmEDCHXbJrZvcDHOGFsBD7nCtVPpvewD3En0bdrQ2darLGw/Qaxlugt4k+i79a/08iginzxClHpvTM8/ux8MuQA+v4dQTxxFzxeUIcBVwLGEKmI+GHgILEzkjX7Aj4HXgPeIaYj9064rEsYHHxGmzT90tT6JRStSeqlmOQr4XUHnngb4JTGs3qzO51qTUGKp6M+8lR6Eku1dYL708zzpj6kuarYaxv2lio+Yr96LSEpPExZEMxQdY4ptDLBthfvMkS5aDxGjvsLfR7s9gBUIF5X3idLvErYxXG6Y4KTcq/AxwXBZo99LK4+gIMptO6V1GA3F9me2DyMkn6dI+lMa7taD24k7ltXqdPx2ZRvgHqdeRW7w3GUvGSTpWKKMtzUhlV/U9h/cBHM0kr4BLErcZJe7z8ZEon2CEHE8Uafw+hzJCmpHSXcClwKPAgvZ3s/2v9NmI4nmktUwnhJVYKNo6QRl+0WiFcA2BcZwB+HnNy2xuLfXZrZdnCOLJaqjK/lz0zp0pNLX6kSpZAtivmF12xs7JqSbyeB2BHCu7c962jCJOU4D/gDs4Ghw2ON+mZ5Jc0GHA88Tv5MTiMR0omN6YBLhSn4gUK75bwefAAdi31+DkCuiZdptTIm0ivrHttdogli2If4I/wAc6yqVTVM49uzEeocFHA35Mt0gaTEmqdo+L3l+KmKh98a2Hy0ovMlIYoMdgR8A0xPGrdjuwcuzGBTO/y8Ba9j+Tw/brgD8hXhP37f9QQNCbHuS2OcAojP0JcRUx2Nl7txhGDuA7gcpE4mR04HYp/cq4Cpp6RFU4lpg4XRBKhTblwLLESKK2yUtVMNjv00IQ6pu2tjH2JuYx5lsoafDK64pxBJJDfUrooy3HbH6/puE32QzL1DdHHiiu+QkaWpJPyfsio6wvWtOTr0jSfK3l/Qv4HLgcWBh2/uUnZyAlGyGEdL+8UT/p1LGpeevBIYVlZygDUZQAKlWP8D2j4uOBUDhhv4D4DBC6vln1+CDlrQO8Ftq0LSxnUkjkpeJeY6vqIckzUdM0g9x+b2OahWbiLnEA4D1iWUEv7f9dMk2hTUsLAdJNxLf6b9O4fUFCXXeeKIh6CuNjK/dSNWTfYjS738JB/qra1KhiWPvDixFrMUdS8xfnUvcFBdKuySoBYm5qCG2q50ErDmSliAuQM8DIxztH3pzvH7Af4CdHVZMmS6QtCOwl2MR4ZS2+Rtwoe3zGhTTAEKccQDQ0RH2nK5GFc2coLr7W0vJd3fg18SE+slNNm/WUkhalvi+bEEIH37XLGXpRtEOJT7SXfLDhMlk05DUMx2ms49I2rCXx2uID2EbUI43XEPEEpLmkXQMUcbbETiCaDd/SouWvPYGzu8iOc1KzIUcCKxj+zc5OVVOKuNtK+l2QiH5FFHGG9HXkhO0SYJKNNRZolxsf2r7EGJh9umSft9LWfw5wNaqYdPGdiK5EyxGz/Ln64AF6zF3mdR4q0r6K2FDMzMwzPa3bV/XqhduhQ/lnnTqmitpfWJx6IvAin3xQtpbJM0m6VDgOcLG6GRgQdvH23632OiKo50S1JXA4ukC1XTYvpWQo88CPCBpuSqP8wZwCzVu2thG7E0Z8ucknjg7bV8TJE0raVdCsXYBUQqb3/b3bT9Vq/MUyCbAs7afBJA0UNIphOhkd9sHNlOJvRWQtIyks4gqy8LApraH2b6slirgVqVtElS6IJ1HDS84tcb2WNs7AccQXXt/mmTPldKUo8WikTQNsAfle8OdBeya5od6c965FIanLxL+mkcRZbyTW7SMNyW+7EisaPF+P/A1QrRzc5GBtRJJ4biNpNsIG7P/Al93GOU+XHB4TUXbJKjEKGD3dKFqWmyPJuxINgBulTR/hYe4EZit2lFYG7MZPcifS+nt3KWklSX9hXBGmB1Y2/b6tq91DVufNwNJ+bgycLmkg4nv4Ehi4e173e6cAb4s4/2UKOP9kBDKLGh7ZG8FVO1KWyWodGF6irhQNTXJfmcdwpfsPkm7JhVUOfs2zVqeJmMfOs2PlEFFYolUxttF0r1E59AHiMXT321z657vEN/Va4ledyvaviAvd+gZSUsrGjE+QzTO3Nz2Go7Gg8283q1wmkNmLs1B11r8cyrV4kvaGdjN9gY1j7NOpHLJBcTCu/3LuSOVNA/RznuIo3lYn0bSAsTcT0VLDdJo+2XCUuiZbrabi+htsy/xezqVaHNe85FSs8nMJU1NOGdDSMh/3W4jxFqTPrPNCJn4wkQ/slFugrVFrUSxIyhpRaIL6ItE3X4XYiJ2l/TzS0iXI61YwVEvA5ZPF6yWINWdVySc0R9JC3J72udVoq1LKxifNoK9gQsqnaRPc5fnMoW5S0krSbqAKOPNCaxre13bV/eFi7Sir9BNhDP8eraP6wvvu1okzSLpEOBZoh3FacQI+9icnCqnuAQVflBjCNuUAURbgVIGpuc3B8ak7XskXaAuIEoSLYPtcbb/j2itcK6k35QxeZ/FEnx5t/oV+XMFnAns0TF3qXCG3knS3cBFhOvEgrb3t/14TYJuASStRczRzUf4XT5YcEhNi6QlJZ1BJKbFgK1sr2774lzGq56qEpSiz/zwqs86yaxwUBkx9EvbnVRukiIuVHulC1dLYftGQo4+HzE31V1Dxr8DQyQt2ZDgmpeNgeeqnQNKc5dPEoq+Iwgz2e8QIoCFbZ/kPmTQm+bZfk3c6P2cWMd1frFRNR+SppK0haRbCJ/Ml4BFbO9h+4GCw2sLqkpQthe3Paan7SS9IGndTk+uyKTkVAkdSWqFMuJ7nLAXekhS08rOp0RamLctcCJws6QfJ5ujztt9QfTE6uujqGrEEV+i+E5ORTR3mwdY3/Y6tq/qa+WsZM91LzFvsgwwP3CRG9hKvtmRNLOkg4jR0iGE0GYB27+0/Vax0bUXRZT4DiVKd90yhSb2A9L+5XAGMHcFcTUVDs4FViJaa98oad4uNj2LaNrYuUTaJ0jy51UIm51K9uto8HZX2vdvwAfAcZ7U4K3PIKmfpB8CtxICkK2IbsR704vk305IWkLSnwiZ+FJEN+HVbF/Y08LwTHVUW+J7QdK6ko6UdLGk8yR9mEp/K6RtzifKVNdI+kjSIUhz3AkbrQb9BhN1rDElxx1O1BO+RQyXnkvPHZ6emwH6rQtbHCEtUhLLKpLulDRW0iMlpcfFCNeG36fz/76a91o0tp8nrPFvBh6UtH2n118E7qPApo0FsxfwV5fpSq7JG7ztQzR4W9j2SKKM1XIj7t6SFKH/AHYAVrV9VpKPrwe825fLVamMt7mkm4EbgFeJzsa7ORoAZuqJq+t5/wKwLnAkYam/EVEiGQnc3Xm7jp/vhWNmAV8HngC+ATwL+K3U934YeAj43+DPwZ+l5xYEPw3+BLwGTPg23JKOPw9xl7cRkWzXSz/Pnl5/FbiqmvfYjA9ice/TxIV0ppLntwJuKzq+Aj6PqQiJ+NJlbLs8odZ7nxgRLNXFNosBrwFTF/y+9iR6WTXiXNsQEvLDO79vwkF7v6J/zwX9DmYmjG+fB+4CdgKmKTquvvaoRYnvDtt/c9TqzycGRl1yLmy8EZNnkxWI2koHewCLA1MD/dNzexKr2wYCO0C/t6BjBLUL8Ld0/okOgcH96RQQF5vVq7QTajocLZeXAz4k5OhrppeuAb4uadHCgiuGbwOv2X6kqxc15QZvXTpDO0QWzxGii7ZG0oySzgWOBTazfYxLvN8kzUksJB9dVIxFIGkxSX8kvgfLANvbXtX2aOcyXsOpRYJ6o+T/nwADpqSeexNmuYRYidvxuINY/NPBkC72+1rJ/wcBn02awxoKbJvKe2MljSW62c6VXv8Y+B9hKdQW2P7Y9neB7wEXSjoOEOFy3tfKUyPoYn5E0uyKbq7PE03eTgIWsn2Ce3aGHkUD2nAUiaTVCffx8cBy7rq32B7A5bb/18jYiiCV8TZVNGK8mbimLeroAnxvweH1aeotkpjMpmJ2eG9Xwiai4/Ex0XK2g3K8fiZAx53ey0RvmsElj+lsH1dy/n/Shio329cRd3iLAncT03m7SZq2yLgahaS5gTWBC0ueW07S2URTxwWAjW0Pt325y3eGvgRYJYkv2ookDPkV8R5/aHtfd6HOS4rRLpN/OyFpsKQfEd+Xw4mbvKG2j3J0DcgUTL0T1JvAgh0/7AZ/u4aYjZ1A3L6NASrpBz0BPvskRmoQ6zQ2lbRBugsaIGl4idrtTeBtYHiyqmkrHJLWLYA/Ep/FuzRZ08Y6sidxof1U0naS7iB6QD1NOEPvPaXSX3c4xBZ/JcQXbYOkbwJ3EiX4ZWxf3c3mw4m/sbbs2ixpUUmnESPsFYCdbK9k+y+5jNdc1DtBjQQOS+W3g1aBU6+AT48lrJ+HEMZelXRvE+j15Atm+2XCaeJnRCJ6GTiYSe/rlPT6IMICqe1wMApYjRAN/L4dk3Ep6Q5/H2Ik/TyTGrwt4LDi6a0zdMdC75afu1TwXeB2wjFjU9tv9rDbCMI3rgmMOmtDktFvIukGQkr/NrC47Z2nUOLMNAGNN4sN773NqSI5GiYKrsTeurJTagXgYmKCvCW7mZaDpOmAt4i7331tX15wSDVHYax7HKEiPQ/4ne2H6nCee4CjbP+tx41rf+6amMVK+hqxkHsOYGfbT5exz2xEf6IF3AbuGZJmIkbb3ydmFU4BLrb9aaGBZcqiiIW6I4nqXsWMAx0fdv+V8gCxCLNHE9ZWxuFq/kfCAukESWdJmqHgsHqNJm/wdi2xvOBQ23vVIzklWlosIWlzwkPwQWJtU4/JKbEbcHWrJydJ30xrH58n+ljtSrQIOT8np9ah8QkqFrcdyKR5pHL55Ab4009hpKSjJfXveZeOU9pU2PenhRkFrE+4o08EHpa0WrEhVYe6bvC2MjAvqbNrHbkQGJbEGC2DpOkljQJ+C2xj+zCXaVYqScTfSL0/27qQyngbSbqemN5+D1jC9o6272qnkmVfoRg3c/t0JiWpnkpuE9N2B25h7w8sS1x8/yXpGxWcdTSwrqL3VNuS7pSfJrq7jiA+58srTepFoskbvC0CbOHU4A3YGbjCdW6lntRtFxPloZZA0irEqGlqQgjxrwoPsTrx91bpfoWS1nQdQHzvjyH+1ue3fYTt14qNLtMbimu3EUlqGHAlUfIb12mLcen5K4FhaXtsv04sxD2XSFL7pDu/Hk7nD4AriMaI7c4ZJGm97SupPqk3jFTG20rSGGLt9vOEM/SeTm0e0u95BI27wx8FfEddGPU2E+mz+wWhYjw0fWbVrF8aAZzZKiMNSYtI+h3hWLMa8be9gu3zXGFfsExz0iwddWen646659JNk6/knHABYWm0t3twEpa0KpHYFmmVP8JqUPSReoX4Y30hPSdi0epRhOVhU6i0JM1CLDD+HhHzqcQC0a+UpRT9iU4lbIrqHnv6zB4AfpJcShpCJSIJSQsTfwP/A/aodsSgaEz4PCEk6q0Ksm6km4UNiE61yxE3EafbrmS1SqZFaI47Q/tt7BOxd8PeLP17YnfJKXbzk8CqwL+JuZZNejjT3cBnxMitbUl3j3+hpGljkqOfRixu3Re4ushyp6Sl0lzJs4S71da2v2X7om7mTBoqf07naUqxRJKP7034xI0Gvt3LctYuwN+bNTmlMt4PgKeAXxGNJIemObacnNqU5khQvcD2Z7Z/RrQ+/52k05Pcuqtt+5pY4itNG0uS+mOUl9RrRlpMvaWkWwml4UvAN23v7vAZ7G7f2YjS7gUNCLWU0cB6Cm+6pkBRcbiCkE4Pt31qb5ZPNLM4QtLXJZ1ClPHWIBZQL2/7nFzGa39aPkF1YPt2wvpnINGWYsUpbHoBsJGkWRsWXAE4ehq9yCTj3NLXyk7qtUDSLJIOJkZLBwN/ItbZHFPGotEOdgWusf1eveLsijR3eTlNMncpaUPCR+9pYGXXpgX9yoS/5ZgaHKvXJDXeBpKuIwQbHxOO9dvZvqMZStOZxtA2CQriYmJ7d8JX61pJh3UxgniPcP/etYgYG0y3o8UKknpVaFKDt2eBJamywVsB4ojOjAJGlCPGqReSBkn6A9H1dyfbP6nhep59aII5SUkzSPoe8ARwPNHuY6jtnzlcYzJ9jLZKUB3Yvpjo/zMc+KekBTttMgooS/3X4lwCrKauO/EC5SX1SlB9Grx9i/ARvqPauHrJ3YSidHgRJ5e0PLHgdjAxkhhTw2PPSPg3nlurY1YRw8KSTibKeMOJm5FlbZ9tu7O6N9OHaMsEBZAmTtcn7sLukbRnSUK6nXjvLbmAtVySs0RZxqcpqS9HCEi6SurdImlmSQcSNjmHEq3o57d9tHvvDF2oN1yJWKKhrvgp2R9KzNcd5fCNG1vj0+wE3FxBqbUmJJHH+pKuJYQe44iktK3t24sezWWag+aQmdcZSUsSqrZngH1sv5supkul0UPbovCuu5qY85lQxvb9CAnvz4FDgHO6u1hIWgz4AdEu/DrCG69m5pvNIn9OcvjniL5SPfWU6u259iTmDr9GGOLubvulOp3rQeCntm+ox/G7ON/0hJ3SDwhF7anA6DxSKoNQ3Xa1HOecnhTPLYuboK1vIx7EJPCJxFqbDQhD9bHA4KJja8B7vw/YsMJ9liS+/JcBs3Z6bSpgU+BGot/kkcBcdYr9+8CFRX+GKZbzgR/V+RwiRmvjgYOAfnU81/JE8q/bOUrOtRDwG6IlzGXESF1F/05b4gErGi43jDN8YnDJ45P0/OUOr8Hi463l96boABr+hmFtQt58KjFH872iY2rAex5B2ANVut+0REeUjqQ+GPgRIXq4l7AdmraOcSslybWL/gxTPGsSE/h1ubACsxD2Sq8AVzbg/fwJ+Hmdf3/rEiP4twnhw9Cif48t9YD9DB8bJnRKTJ0fE9J2+xUecw0ffaLE15lUNvojsArwKbEWp20/CIWj+UvAYg6rqEr334Mwau1PWE/91g3ooSNpZaI0+w03QZuUNIf5JOFaUlPBhqR1gbOJOdMngNXcy3YbPZxveqJ/2uKusV9dWrKwK1HGm0i0uBjtaAaZKRdpP+Akop9duXwCHEiyhmt12lYk0R2OVgI7AocRpYdT1QbN6aaE7Q+J0eIe5e6jyRu8HQecBlxPlP4a1XW0QxxReHKC+oglFF2gf0u0G9/L9o+Iead6swNwWy2Tk6QFJJ1IrL9bn7CvWsr2mTk5VUgs+TgJGHQhsVBtOqKx18rEH6MJy/oFgRmBuYEfwaDP4SSiB17rU/QQrugHcfF9HfgnbVx+IMxin6OH+QZgJuD/CDXe/cSd8LTpNRFlvbeAnwBT1THeGYH3gTmL/uw6xTUbMXc5cw2OtTRh03UJMEvJ83sCZ9f5fdwDbFyD44jos3Yl8A5wAqHeLPx31dKPmFOacCJ4DvAl4P+BJ4IfBO8EHg/+L/j9VOZ7F7wW+ESYaLis8PdQg0efHEF14mRCQHETcJ+kndt0fdT9hKHo2l29qDIavDn4C5HsNgJukTS0TvHuCNziBsufe8KhJLyeSNRVkUanBxHfuROA7dxAhwxJSxM33Nf34hjTSdqXsMw6hXCgH2r7ECeD4kyVhFpvww+g3xHEaGkbYAbibmBZou49LVH+GZx2M1ESezY22yiZcLc0fT5BOdbo3EKMCjYg5NWj0zxV22D7Kz6EmrzB2z+JEcuS7qHBm+0XiUR3HfVL6k3pDZeoeqG3pCFEYtoCWMnRGqLR858jgLNcxrKDzkiaX9KviTLehsSShCVtn+FYd5fpPbsDvouYIN+8h41HE+WG2QgPrH3jadMk9ly9oc8nqMQoYISjffjy32PNFgAAIABJREFUhOLoEUldjjZamL8A66e5gs4N3obaPtz2q+UcyPYE2ydQh6SenBNmJWTszcitxMT1SpXsJGkHon3HjcAw28/XIbaeYhhEjE7/XME+krSWpCuIkbiI5LqF7VsKSLDtzlLAwHeIpFNq67IaMWIaCNyWntuJKI38B9gPSK7GA9NxWpqcoIIbgdkkLWd7nO0DiLvM8yWdKGnaguOrFV8DXgMep0YN3uqU1Dsa5zWFOKIzKa4zKVMsIWmwpAuI9WIb2h5ZzeilRmwL3O0yFv4m/78RhNT/90RJcKjtg2w/V+c4+zKDIe7Q3mFyxcydxATorHy1FfnXib413+10nFYmJyhiNEBY84woee4fxCT2AkQZa8mCwusVqYy3oaS/ExZP9xKikB1t31mLu99aJvUkf96OkFw3M+cAWycvuykiaThRefkAWM72A/UPrVtGEBWDKSJpqKTjiTLeJoRoZgnbf8plvIYwFqInzrREm+Ry+YJYpFh6nFYmJ6hJnA1sr5K2E2lCfBtCzXmLpB+pydt/d6CuG7zNRyjExhOLTmtKp6R+b5VJfTvg9nJLjUVRMne5Y1evS5o2XeRHA/vb/p4LllonW6oFibnDzq9J0jBJlxHGtFMT7Tw2t31zLuM1lEeBcYOBXxAjokuBD4lR08NE/xGIYXxHG/EngJGEpJLwNny0UQHXjaJlhM30IFa87zWF1xYketPcBMxbdKzdvIevE6qq9whXgtXp5HwA/BD4Sx1jEJEI3ybuvsu20iGcwzcp+nMsM9YNgPu7eH5x4CFCej17Fceti8ycuNH6VafnBhKdlx8hFiHvD0xf9Gfbpx8wR7IvssEXgFcEDwTPBl4J/Cfwp+A9kgx9EHgo+CDwuNhvnKv47jXbo/AAmulB+Mvd1c3rUxOLe98kpMGFx5zi6pcultcRN1THAkO62X4WUim7znFVlNSJSd1XgKmL/kwr+NxfIEp3HT8fQEwd7N35xqCC49Y8QRFLKd4GFkw/z0fccL9F9Edbr9p486MOj7QOqgd7oyk9JuR1UO3J34EhUypN2f7C9i+JRPZLSedJmqmhEZagKhu8OdbcXEudmzY6JtKHEZ1aH5C0bQ+7dMifG+Gk0GscYomziGaGcxPfn52AVRzuCc1UFtuKGNXNK+nS9P8BhKXSprZvbLJ4+wxpnnhRSVtIOkTSP9aIUnm1Le3HEzcfrU/RGbLZHsDRwKllbDcd4ef3PLBGg2NcmFhg/C7hQrAGFd79Eonj8Ur360XMKxKy9nOBmbp4fVB6P0OL/g5U+L7mIaYH3iKmDHo9+qPGIyiijPdk+q4+RVgQzVD0Z5cfX/5+diDWLX1A6BwM3F5iFOsKHm1lGJtHUF/lLGBnSQO728j2x7b3J0o6F0saKWmaegVVhwZvtxEly1VrHWtXOLrpLkeYWT4saY1Om2wD3ONYBNwSJAXfL4nS3hm2j3ITjf4kDZF0LFE2XZBoXbKY7T84/BkzzcHlhA3ZjEQrm/HAToTh64HE30xPSy4m0mZGsZBVfF8hXSDvB06SdL+kbldj274GWAZYArhb0qK1jEfS9JK+S5Txfg1cAcxn+1D3ooldSmijKHGWqDclSf0HRFI/tiSp9yh/biYkfYsQVH1BrCdbp9iIgnQjs7qki4n4BhEXwFNsX+cmXVvWV0luJLsAMwOfp8ef3FGij2QzjBDcjCduTEsZl56/EhjWTskJyCW+0gdRkz+N+IV/QXxZDihzXxEX+7eJEkqvSmc0oMEbBTZtJBa8X0M4K2z8/+2deZicZZX2fycrTcCELRAEA4ZN2US2AQIBEdkEJiwCDjuJsgmKBCc6iiwOfgq44CAfEtEAQ8YAIgwB5wOJQwgS0AnBURARArJNhjWQBUzu74/7KbpS6e7auyrdz++6nivp6nd56q3q97znPPc5BycQD271d6CCeQ/G1TdeAg5Lrw3CXsq2DTpH1SG+9N09CUvEn8De0prAECzq2aLV1y6PlT6zdfHDwzzcJeAsHC5eu8t9YD3BeYKpgtvTv+f1BbVet9eo1RNop4Fd7GeTYRLwFnBMlcfYAifD3kWVXWZpQYM3nB/VkqaNRUZ9Ea5S0NYqsvTZPowLo25Q8ruK1i4rPE/FBgqvgV2SjNAM4ACKZP24csR9rb52eaz0ue0PPI+jIsXdAuqukt+XRg7xFSHpTbxO8kfsQQ2mMw+u0mP8CdgD38j+KyLGl9snVYY+Dbde+A42UKMlfUnNX5P5EVah9XoFd/mvcir2WEcBd0bEBr09j3KksNlpWDJ/HW5T8VLJZhWtXTZwPntExDRcTfx9WKhzkKS7tWIYb5UKnfZ1IqIjIq7En8nxkiZpxW4Br7V2hu1FNlAlyNUj/g57QUNwwmu1x3hX0tewtPeyiLg2lfBZgWiPBm+/wqGgVjU4G4/DfDvhtb+5EfH3LZrLSkREIRQ5ERuBq5JhXYH0IDEHiz2aNZfV0proI7jU0mzce+ns9GBUuv0HcXeGW5s1p0zlRMQO+Lu+LrC9pF+1eEptTzZQXZCMwzhgyi/hVSImEXE9EXekfydV0mtF0mwsoAh8490tPf3uGxG3YS9rOS7YerikmV3d/JqJOguf9ppYooRC19xio355d0a9N4mIQ7DQ4FFgN0mPl9nlGhrYbbdoHhtGxMX4QeZY4KvAlpK+n7z+7jgVuEE1FgLONIaIGBgRXwJ+CVwit7PJnlIltDrG2LYDdk7Z3IsFi0pyDRal12+Vm/qttD/OTTqy6OdjcVX8/8GhvM8Aw1r+Pj23UbgXVK/mxuCyTC8DQ0peXxOHzP6Mk157+3oMA67GeUNjq9hvMC7E+6E6z38yDiXuBtyEvfgfAFtVOZcXsKy85d+x/jqA0bjX2kxWsRy/dhjZg+oKrzfMxL3CVsOJjsV0pNcPA2am7Yt2jw9jkcSPImLzcIO3K/EaxjO41uN9apPK0JJexD2Ouix82kQmAFMlvVMyn4WSTgXOB26LiK9HxODemFBE7EJnlYXtJc2qdF9J7+LQ24Q6zj8UG6ZPAjfg1uybSjpL5T24Yg4G/iLpD7XOJVM7KVJyHI6S3Ansq1Uox69taLWF7I0BbIlDNQspJxuvM3sbS7dfxKG7d7HXdBmdNdCE+wItwOGgtlCu4e6oD/fi+YZgqfaWZbYbhRV+vwE2b+J8BuHQ2cvAUXUcZ0z6bIdWud8o4MJ0TR7DPcoG1jGPO4ETW/296o8D5zRNw5Vadmj1fFbl0V88qPOxx7KmpO93u1XEzsDlOLmxGlYHLp/v5M1HcI5P4Jvec1q5wdsNeI3rDOwhjKzyfM3gP4CRaSG3NzgUeFzSEz1tJHt3B+JrNjsiGq44jIgxuLLGXrjw6/RajyXpKbxmVVa9mc69a0TciG9m6wL74Py3v6rGpoYR8QEs9Kn5fWRqI9yw81H8oLOT3NAzUyP9xUCNxjeAckzGoZ1aWA1XER+O13MWYqn6lhGxbunGcuhlVyxpnxsRB9d43oagLpo2NpmK5c8yP2BFo15WpFKOFIY5BXtn/wbsr8b0obqGHkQnqVfUcRExB/eLegR72GdK+mMDzn8KcJNa3H+qP5E+08tw2sQESedIKq36kKmWVrtwzR5YRr0M59q8hXsh/RcOvT0HfF0SgpF/8Tb6MWgj0AjQD0FzQNuChoPOLArtXQfaHfR50NqgyfCuHOI7BRue17BnMrpoPgI2K5njXnht6ipg9RZeq43wgnxTxRu4oeH/AqvVsO8Q4JtYAHBQHXMoZPE/irvFNvL9DcVimNLPeQMc3n0Rh/AOoYswHnUUi8W13J7D62ct+R71t4GrQMzDFV+a2sKmv40+70FJ+hhudX6WpDXwDekEYAReSD495d2cKBsPHgKexI/Un8ftaO/BLtjPsCSnwEO4CufLwD/BuxfCt4AvY7n0ejh0dFOZOf4nLq+/JvC7iGhJTpKkv2Ihx6eafKqa5c+S3pH0j7gC9FURcVVEVBWSjYgD8PfgKWAXSb+vdh5l5rgUV22fkM63S0TcgB9a1scL5vtJukM1hvF64ADgBUmPNvi4mRJSm4wv4IfgK7Bq95UWT6tv0WoL2RsDK/ImdPO77wLfEVz/tA2U/lrkJa0Nmlb08+Gg7xR5UBuXCCZ2dPmSU4uOPwCX8hmdfl7JgyqZz9HY3n2FOhbJ67hWhwKzm3j8Qfgabd2AYw0HrsctJHaqYPsOrKZ8FtinyddxG1zn8CEsV/8iFZaxoT4P6rbi718eTft8N8LPrQ+QBFB5NH70eQ+qlLQofV9ELIiIN4DTcLhnRGGb9Yu27+ji57eKft645PgLXHbmexHxekS8jkNmgWumlUXSvwE7Ah8Dfh0Rm1b2zhrGDGB0RGzTpOMfDDwjqZI1wR6R9Iak43HYbEZEfCUiBna1bUR8FGfxr4PDX/fVe/5uzrN+RHwNh3aF5fubSbpcTU7ODDdN3As7/5kmERGfwt+l+4BxWlEAlWkg/c5A4UXp23FL9OE4ITPw027VlMrJ1vHa1mcljSgaHXJViYqQQ2374TWSORFxYm/VypP7Gf2Y5oklGl4bTtI0ujHqRVn8d+Ms/k83w1BExM4RMRV7cxvi0lVnYZlxo8N43XEyMF3SW2W3zFRNRAxPn/HFwCclfUNt1P+rL9IfDdSawKuSlqSkzE+n1+ctr73FcoHFh3jxe3JEbA3vfanLtTpfCUnLJV2B+wydh/snrVPn/CqlKYVPI2JjnIT6s0YeF0Dun1Nq1DfBT7kH4hDgvzbynBExJCKOjYgHsaR7HjBG0mnyutYtwI5pHk0lIgbgNa9rmn2u/ki4weZcnGT/UbkBZ6bJ9EcDdQZwUUQsBL5G583yp7GyQ1QtcSFMwm0ypkXEm7is0YG1HlDSPNwu/Vng0Yj4RJ1zrOScz2Dp8xENPvQpwDQ1Sf5cYtQvwX2R7sWihJqbO5aSwnhfxWtLE/HnPUbSZZLeKy4si0BuwKKQZvNx4DVJv+2Fc/Ub0kPIpThsepak09UmFWD6AyH1am3S9ibiVly+qBbDvRy4DanRN/X3iIh9cSmdW4DJamKeRUQcgatujGvQ8QbiG/ohaqLCLCLWAn4IbIeN7MeAUyT9RwOOvSNwNhaSTAeulPRYmX22pjPVoKJwUEScDOwl6eQq5nYzcI/6WkfVFhLujn0jbkY5QVJVrXcy9dMfPaieuJTaw3xL0v5NQ9K9WI4+Cng4IrZv4ulux0nGWzXoePsDLzXZOO2Lw2wvATtKOgG3Y58SEd+tJWQZEYMj4uiIeACHD/8bix4+U844ASQxyHzgoGrPXcUc18deY0NDmP2VlMB9Jk4RuRp3Ts7GqQVkA1WM48pfxLLwalgEfBHpkcZPakVSCOkYnKx6T0RMSusPjT5P3YVPS5hIk9ZHwn2SLse5R6dK+nzBu6zVqEfEehHxFez1nY5LYI2R9C1Vn+vSY2WJBnAicKt6br2RqYBww8wZOFdyD0nXKIeZWkY2UKU4RFIwUsvLbL2cTuPUa6EVmRvw2tQhwL2p/lqjuRY4IVXYrpmIGAXsjQtoNpSI2BY3CtwEy8dXCuVVY9QjYoeIuA74UzrmQZL2lnRrHYqt6cDuEbFRjft3S1J35q65DSDc/Xou/j6NVRdNIDO9SzZQXWFjMw4nPS4BStd6FqfXbwPG9aZxKiaJGfbBjdAeiYiGtsuQ9GdcWbveDrcNlz+nLP5zqTCLvyejnsJ4R0XE/cAvsFR8M0kTk0ilLtKi+k1YJNJo9sbfxYeacOx+QUSsERHX4q4D4yVdkCIImRaTRRLlcFHSE/Gi+wicLzUP+CnSglZOrZiUiHojrjN4hqSa8rq6OO7RwERJH69x/wG48eCn1KAQaPJEfooL9B5fbaJkEmxMwp7yPcBY4C/A94FfNCO3JSI+gtf1Ni2XF1WNSCIibsKVP65szExXYdwVoKu/1Z9097caEbvhaiQzgS9IWtg7k81URCPLUuTR2oHbflyJF+X3btAxC4VPx9S4/37YaDak7xWuE1goBTWoxmN8BCcjv4lvYjOAEb3w+TwMHFjBdhWVOsIVUF6nwhJKfXbU0P0adxwu9N8a3/L3kEeXI4f4+hCSFkn6HC7fdGNEfKve9SO58On11C6WmAjUvdBcbxZ/RAyKiCMj4j+BO3A94A/iqg9P4RyzfeqZYwU0WixxAnC7mlxCqa2poft1RGwOzMLtbnaQ9PPem3CmKlptIfNozsCV1H+OF33rKswKbIWfNAdXud9I3HJkeJ3n3xOr6X5Ila1AsJfxjzjR+X7gqK7eB06mfh5Xo6+qG24Vc1kzXY9RZbYr60HhpPI/Anu2+rvWslFD9+t3YOnn3KvtLNqkm3Ue3Y/sQfVR5Jj74TjkNzMizqlVji7pcaxqO6TKXU8Efi7pjVrOW08Wf0Rsnxa+nwS2wLkse0qari4WwCXdhUN/mwMPNaNYrry+MR0boHrZI/07qwHHWvVI3a+nweq7AsPw09CuuKmagG/jkvJr4gZk3wYGw5DvwiDBbyTlBfg2JxuoPozMFNz++xjg7lTxuhauoYoCsvXKn1MW/2+ArYGPSLqzgn0GRcThETETuBMLH7aQdIoqaL1dYtTvq8eo98CPgAkNOO5ngB/145vs5Mug4xysdnkJL0xejftfvION1FTsst4N/ADnOQxw08vJrZh0pkpa7cLl0TsD92H6Kv47PrKG/TtwF9xNKtx+b1yHsKowCg5dnQkswAau7P7A2sD5WBwyCwspqgpHdnHMMcCDuEzR+xv4OQQWjXy8h216DPEBa2FxxLqt/l61ZMDI12Dx6qCbqwjvfQ50VufPiwXrtfy95NHjyB5UP0HS3yRdjOvI/XNEXBcR76ti/8W4lM65EfHNiOjyCTQivpd+dzZViiNSQm8hi393ST16CBGxbURcg0UOHwYOlzRW0s9UZx6LpKfw2tcs3OX4yHqOV3RcYS+qHrHEccBdkv63EXNaBTlxNsRSrHyoBOEFyK1XfOnEhs8s01CygepnSHoI+CiOgsyNiLGV7BcRB+NmeJ/D7T+62+8ALN8dDxwWEXt0s13p8cdjz6KQxf9kN9sNjIi/j4hf4cjNs8CWkk5Sgyt5J6N+ETbql0bET6ox6j1wI/CJcN5OVeTKEQBs9yoMXReHBQrsjpOfOnARvWK+jsu+FC3+deB8qUwbkw1UP0TSW5I+C3wemB4Rl0TE4DK7XYrXnAEGYsPQFS/jHBNwNY7zejpoRKyZxAzfpocs/ohYKyLOw97S+XhNbFNJl6jJhTyTUd8BWEoVRr2H472BFZa1PMHvim+uM+uZw6pARIyOiPGpGG4xI9bB8ebiPIPZOO65DivWKPsBXou6Eyf1FR+nwVPONJhsoPoxkm7HyrWPALMjYsseNt8LS9YLxuOv3Wz3Qvp3CfDvWJzRJSmLvyBe2EHSg11ss01E/F8seNgOOErS7pKmSXqnh/k2lC6M+jciYkgdhyyIJartQTYRuFZSuTqRfYH9cL+2+RHxYkTcHRHnA6/vho3NL8oc4Me4AOO9QBeFEBtSbSXTPLKB6udIehnLx6cAsyLi9NRuYNuIuKpwA5VLJ43Ff+vgEGFXFMr4TMFrQkvhPcn41IjYINW+uxB7EZMkTVBRiZkUxjssIu7FIoXngQ9JOkEt7mSajPoOdBr1WtuRPIiN/V6V7pDCi4fjKvN9mogYhh2h5dgWbYBbthwJzBsBiy/A3UdvxolNy+lseQuOo34Zt7j+4MqnWIzLIGXamFyLL/MeyYO6EZc22hpXWThM0oyibQYCt+4NN94Hoympe7Yd6DEXWj2t5Njn4JYVj2Al26u4keCLRdushQuqnoVDhd8Hbu5NT6lSkuH+LK5s8TXg6moEIekY5wC7SPqHkte7rMUXrprwcUkNEWy0E2k9bg8sTBmLv39zcXHfwdigfAu4UE5Cnw+sdiPwPSwXHYYN0anAScCW2M0vDusdh6Xo2MP/AG1UTzOzMtlAZVYgrUXNBnZKLz0NbK5CgVMnSE7GlRfEiqVlFmPjcxdwKcnbSU/+zwHvS/vMAsYVbugR8WEsvjgGhwWvlDSnee+ycRQZ9ZexwX25in3XxqHLMfh++g6W2G8BrI9TeqZhT+t+/Ll8WdIvG/keeptk3D9IpzEaiz2k2fi7MQt4WNLiJIYZCxwr6Zaig7R19+tMg2i1zj2P9hr4prEcGxLhm+NnJRWXlllWJudkWdrutHTM7+LQn4pGoe3F/wNeBC4ANmj1+6/xmg0GLknv49Aq9lsP+B32Ppdj8UfptT8//f9t4A1gq1a/3xquz0CsHD0brym9gMO207C3vD0wsJt9twe2Xel3LhBbVZmjovG2YKdWX5c8yo/sQWVWICK2Br4AbAtshpNgF8mtKS7HFdMrZdEyOG8Q/Av2rJbgLH7wzfbPOEIzXW0YxquWJKm/HrfwOFc99L9KXuM8bLiH4OtxLBZBFDyD/8Ge1QKsqF6Or+NEuUJIWxIRqwO7YM9nT1zJ5HnsBRY8pGdU783HIc+qv5P0coPRTO1kA9UPiYiTgAmSysqlI6LjZtjvCDfcq+ZGUGDR9XDcCZacn0ln6spCYB31scZwKZz5PXxzPk6WqHe13QBsuE/A1/VtfI3exoYrcKv5KVjePhgb+N8D+6lB/b4aQUSsw4rrR9vh91AwRg+oWUnFnUZqNXoO9y3H1y8bp1WIbKD6IdUYqLRDXfH+ZfCLQfBJbJyW4bWWDtwbaZVeT+mOVHniX3Dt0m9I+lu4g68kPZe2CeBLwEU4DLaJpOci4re4Msb7JL0bEQuxEZuBZfZLWvCWKJrzaDqN0Z7A+3HdxFnYS5ojaVEvTmonvC56EN2vi87A66INaZqZ6R2ygeqHVGWgrK6aj59Qa2UJ8IFwbuUawHAsmPiTmtC9tl1IhXl/gt/rBHyTfBXnfKlou9OxMVtd0pIkvNhC0h3p969gWfqh6uX8p6Ta3IZOMcOe2JgWjNEsYF5bfI6rSPfrTBW0ehEsj+YN3AfpKRxO+wOpcyhW4T6Ak+zfAB4H9i3a7ySsLls4Al6dCksLC8xTQFuBRoA+AXqmaPEZ0A9Bm4GGg84ALe/sanoeXl/5Y9F8PprOtyFwC15reRo4u2guu2Bp+ptYKXdFq69rlZ/BACwOWIw9x7foQkgBDBGMFEwSXC+4I/07aQ5s2IvzXQ0boS9jg/o68ARwbfpejCH3Ucqjl0bLJ5BHEz9cN+fbMN0kj8brG6PSjeZvWAwxOP3uDSyIGJaMwZaSeAJu/n0yQLeBxoD+AHoXdDFotxIDdTDoNdB80Lqgu9LvroBf44XynXHIZTMcKhoA/BbnEg3B8uO/APun9/AgcHz6/xrA37X6utbwORyUDFRBnfccxe3qa2hZ3sC5rQUcjAsuzEoGdA5wBa6nOLLV1y+P/jtaPoE8evHDduLjYclAvVD8JJxuSscnA/U6cATQkZ7kJdABoGuLbp7LQB1FXhSg+4t+fxTo0vT/na1IO6eLOe0KPFvy2mRSuwlc9/NCVuHWEsAN6YHgTSx4EPDPkmqW7tcxl42BT+O1scewN3sPlvnvC6zR6uuVRx6FUVwMONPHiIgTgHOBTdJLa+AW6MuA5yUVL0DOx6GktyPiaBySm7ILvDUV93yfD5yD9eYFhN2i0ennDYp+tzp+HAd40d7RU11MczSwYUQUq9IG4vUNcGGAi4DHI+Jp4EJJ/17B228bJB0XEROx17gFLtmzoEqZ9IC03eVEQAVKtKQU/DCda0djsYCgsHZ0HTBXfUxJmek7ZAPVR4mI0bgg6b7Ag5KWRcRcHF4DeH9ERJGR+gBwO4CsrPtlRHSMgLtPhVEPwICNga8A/0DVLO7w+tGYLn73HPC0pM272lFuu3FsutkeDtwcEeuowtbv7YLcT+uxNG5JFTlmToPVv0NnqZ5N8Sr/6bhc+UU4k3ct4BkfqmCkHqFEkRYRQ4Ed6TRGe2BRxv3Ar7An+mTJg0km07bkYrF9l2HYwVkA79V326bo9yOBs1Ph1qOADwEzImL9VKh1GLD0dzBnYNrhNNxz47/Tz28A0yubS8h11M6LiB1TMdrNkhGdAyyMiC9FREcqFLtN+AZORBwXEevJ6rWCl9UXKnmXbVk+DBcm/PbK+64GTI6I4RFxYKqs/mvgFdyufkPgp8DWkjaTdLKkKZL+lI1TZlUie1B9FEl/iIjLschgOW6J80DRJg8Bm2Ppd6EN/Cvhrrbnpu31Csz9pp++PzYeBryFC+bNx1rx/bASo9t52EjOeFKakur8/SvOm3kGix/mR8QncajraVzb8wngn9IhDgCuSNUJ5gPHJG9k1SVi5Otw4AUQU/FiX4EdcGE/sHxxF7xAVMKApTB+Pdh/ATyMPaRvAL+R9GZT557J9CI5DypTnhSOosZKEsC40nBUvyZi0gy4+FAYuoTyT4n34CSqZ4peWwZL/wYXDJX+T7Ommcm0mhziy5THVcm/iI1NNRTqnmXjtCJVtywvZSAMHeqWFJlMnyWH+DKVIV2Nexfmumf1s0LL8sIf4ez070ZUvMiWW5Zn+jTZg8pUjo3NOOA2bIBK14IWp9dvw2G9bJy6puKW5eWO04jJZDLtSvagMtXhcN0Rue5ZXRRalnecgVUk+2PV3jw6W5Yvx2q+d9M2S/ATZepXkluWZ/o8WSSRyfQ2RQV4e2pZPhvYp2TXcVitQm5ZnukHZAOVybSC3LI8kylLXoPKZFrDpdgLqoUlaf9Mpk+TDVQm0wqydD+TKUsWSWQyrSJL9zOZHslrUJlMq8ktyzOZLskGKpNpF7J0P5NZgWygMplMJtOWZJFEJpPJZNqSbKAymUwm05ZkA5XJZDKZtiQbqEwmk8m0JdlAZTKZTKYtyQYqk8lkMm1JNlCZTCaTaUuygcpkMplMW5INVCaTyWTakmygMplMJtOWZAOVyWQymbYkG6hMJpPJtCXZQGUymUymLckGKpPJZDIcRqZcAAAAMklEQVRtSTZQmUwmk2lLsoHKZDKZTFuSDVQmk8lk2pJsoDKZTCbTlmQDlclkMpm25P8DwQR4OJKs6XAAAAAASUVORK5CYII=\n", + "text/plain": [ + "<Figure size 432x288 with 1 Axes>" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "sm = sm.get_largest_subgraph()\n", + "\n", + "_, _, _ = plot_structure(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After deciding on how the final structure model should look, we can instantiate a `BayesianNetwork`." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "from causalnex.network import BayesianNetwork\n", + "\n", + "bn = BayesianNetwork(sm)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We are now ready to move on to learning the conditional probability distribution of different features in the `BayesianNetwork`." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Fitting the Conditional Distribution of the Bayesian Network" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Preparing the Discretised Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Bayesian Networks in CausalNex support only discrete distributions. Any continuous features, or features with a large number of categories, should be discretised prior to fitting the Bayesian Network. Models containing variables with many possible values will typically be badly fit, and exhibit poor performance.\n", + "\n", + "For example, consider P(G2 | G1), where G1 and G2 have possible values 0 to 20. The discrete conditional probability distribution is therefore specified using 21x21 (441) possible combinations - most of which we will be unlikely to observe.\n", + "\n", + "CausalNex provides a few helper methods to make discretisation easier. Let's start by reducing the number of categories in some of the categorical features by combining similar values. We will make numeric features categorical by discretisation, and then give the buckets meaningful labels." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Cardinality of Categorical Features" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To reduce the cardinality of categorical features we can define a map `{old_value: new_value}`, and use this to update the feature. For example, in the `studytime` feature, we make the studytime which is more than 2 (2 means 2 to 5 hours here, see https://archive.ics.uci.edu/ml/datasets/Student+Performance) into `long-studytime`, and the rest into `short-studytime`." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "discretised_data = data.copy()\n", + "\n", + "data_vals = {col: data[col].unique() for col in data.columns}\n", + "\n", + "failures_map = {v: 'no-failure' if v == [0]\n", + " else 'have-failure' for v in data_vals['failures']}\n", + "\n", + "studytime_map = {v: 'short-studytime' if v in [1,2]\n", + " else 'long-studytime' for v in data_vals['studytime']}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have defined our maps `{old_value: new_value}` we can update each feature, applying the mapping transformation." + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "discretised_data[\"failures\"] = discretised_data[\"failures\"].map(failures_map)\n", + "discretised_data[\"studytime\"] = discretised_data[\"studytime\"].map(studytime_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Discretising Numeric Features" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To make numeric features categorical, they must first be discretised. CausalNex provides a helper class `causalnex.discretiser.Discretiser`, which supports several discretisation methods. For our data the `fixed` method will be applied, providing static values that define the bucket boundaries. For example, `absences` will be discretised into the buckets < 1, 1 to 9, and >=10. Each bucket will be labelled as an integer from zero." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "from causalnex.discretiser import Discretiser\n", + "\n", + "discretised_data[\"absences\"] = Discretiser(method=\"fixed\", \n", + " numeric_split_points=[1, 10]).transform(discretised_data[\"absences\"].values)\n", + "\n", + "discretised_data[\"G1\"] = Discretiser(method=\"fixed\", \n", + " numeric_split_points=[10]).transform(discretised_data[\"G1\"].values)\n", + "\n", + "discretised_data[\"G2\"] = Discretiser(method=\"fixed\", \n", + " numeric_split_points=[10]).transform(discretised_data[\"G2\"].values)\n", + "\n", + "discretised_data[\"G3\"] = Discretiser(method=\"fixed\", \n", + " numeric_split_points=[10]).transform(discretised_data[\"G3\"].values)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Create Labels for Numeric Features" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To make the discretised categories more readable, we can map the category labels onto something more meaningful in the same way that we mapped category feature values." + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "absences_map = {0: \"No-absence\", 1: \"Low-absence\", 2: \"High-absence\"}\n", + "\n", + "G1_map = {0: \"Fail\", 1: \"Pass\"}\n", + "G2_map = {0: \"Fail\", 1: \"Pass\"}\n", + "G3_map = {0: \"Fail\", 1: \"Pass\"}\n", + "\n", + "discretised_data[\"absences\"] = discretised_data[\"absences\"].map(absences_map)\n", + "discretised_data[\"G1\"] = discretised_data[\"G1\"].map(G1_map)\n", + "discretised_data[\"G2\"] = discretised_data[\"G2\"].map(G2_map)\n", + "discretised_data[\"G3\"] = discretised_data[\"G3\"].map(G3_map)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Train / Test Split" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Like many other machine learning models, we will use a train and test split to help us validate our findings." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "# Split 90% train and 10% test\n", + "from sklearn.model_selection import train_test_split\n", + "\n", + "train, test = train_test_split(discretised_data, train_size=0.9, test_size=0.1, random_state=7)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Probability" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With the learnt structure model from earlier and the discretised data, we can now fit the probability distrbution of the Bayesian Network. The first step in this is specifying all of the states that each node can take. This can be done either from data, or providing a dictionary of node values. We use the full dataset here to avoid cases where states in our test set do not exist in the training set. For real-world applications, these states may need to be provided using the dictionary method." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "bn = bn.fit_node_states(discretised_data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Fit Conditional Probability Distributions" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `fit_cpds` method of `BayesianNetwork` accepts a dataset to learn the conditional probablilty distributions (CPDs) of each node, along with a method of how to do this fit." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/ben_horsburgh/opt/anaconda3/envs/causal-test/lib/python3.7/site-packages/pandas/core/generic.py:5069: FutureWarning: Attribute 'is_copy' is deprecated and will be removed in a future version.\n", + " object.__getattribute__(self, name)\n", + "/Users/ben_horsburgh/opt/anaconda3/envs/causal-test/lib/python3.7/site-packages/pandas/core/generic.py:5070: FutureWarning: Attribute 'is_copy' is deprecated and will be removed in a future version.\n", + " return object.__setattr__(self, name, value)\n" + ] + } + ], + "source": [ + "bn = bn.fit_cpds(train, method=\"BayesianEstimator\", bayes_prior=\"K2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Once we have the the CPDs, we can inspect them through the `cpds` property, which is a dictionary of node->cpd." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead tr th {\n", + " text-align: left;\n", + " }\n", + "\n", + " .dataframe thead tr:last-of-type th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr>\n", + " <th>failures</th>\n", + " <th colspan=\"8\" halign=\"left\">have-failure</th>\n", + " <th colspan=\"8\" halign=\"left\">no-failure</th>\n", + " </tr>\n", + " <tr>\n", + " <th>higher</th>\n", + " <th colspan=\"4\" halign=\"left\">no</th>\n", + " <th colspan=\"4\" halign=\"left\">yes</th>\n", + " <th colspan=\"4\" halign=\"left\">no</th>\n", + " <th colspan=\"4\" halign=\"left\">yes</th>\n", + " </tr>\n", + " <tr>\n", + " <th>schoolsup</th>\n", + " <th colspan=\"2\" halign=\"left\">no</th>\n", + " <th colspan=\"2\" halign=\"left\">yes</th>\n", + " <th colspan=\"2\" halign=\"left\">no</th>\n", + " <th colspan=\"2\" halign=\"left\">yes</th>\n", + " <th colspan=\"2\" halign=\"left\">no</th>\n", + " <th colspan=\"2\" halign=\"left\">yes</th>\n", + " <th colspan=\"2\" halign=\"left\">no</th>\n", + " <th colspan=\"2\" halign=\"left\">yes</th>\n", + " </tr>\n", + " <tr>\n", + " <th>studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " <th>long-studytime</th>\n", + " <th>short-studytime</th>\n", + " </tr>\n", + " <tr>\n", + " <th>G1</th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " <th></th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>Fail</th>\n", + " <td>0.75</td>\n", + " <td>0.806452</td>\n", + " <td>0.5</td>\n", + " <td>0.75</td>\n", + " <td>0.5</td>\n", + " <td>0.612245</td>\n", + " <td>0.5</td>\n", + " <td>0.75</td>\n", + " <td>0.5</td>\n", + " <td>0.612903</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.032967</td>\n", + " <td>0.15016</td>\n", + " <td>0.111111</td>\n", + " <td>0.255814</td>\n", + " </tr>\n", + " <tr>\n", + " <th>Pass</th>\n", + " <td>0.25</td>\n", + " <td>0.193548</td>\n", + " <td>0.5</td>\n", + " <td>0.25</td>\n", + " <td>0.5</td>\n", + " <td>0.387755</td>\n", + " <td>0.5</td>\n", + " <td>0.25</td>\n", + " <td>0.5</td>\n", + " <td>0.387097</td>\n", + " <td>0.5</td>\n", + " <td>0.5</td>\n", + " <td>0.967033</td>\n", + " <td>0.84984</td>\n", + " <td>0.888889</td>\n", + " <td>0.744186</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + "failures have-failure \\\n", + "higher no \n", + "schoolsup no yes \n", + "studytime long-studytime short-studytime long-studytime short-studytime \n", + "G1 \n", + "Fail 0.75 0.806452 0.5 0.75 \n", + "Pass 0.25 0.193548 0.5 0.25 \n", + "\n", + "failures \\\n", + "higher yes \n", + "schoolsup no yes \n", + "studytime long-studytime short-studytime long-studytime short-studytime \n", + "G1 \n", + "Fail 0.5 0.612245 0.5 0.75 \n", + "Pass 0.5 0.387755 0.5 0.25 \n", + "\n", + "failures no-failure \\\n", + "higher no \n", + "schoolsup no yes \n", + "studytime long-studytime short-studytime long-studytime short-studytime \n", + "G1 \n", + "Fail 0.5 0.612903 0.5 0.5 \n", + "Pass 0.5 0.387097 0.5 0.5 \n", + "\n", + "failures \n", + "higher yes \n", + "schoolsup no yes \n", + "studytime long-studytime short-studytime long-studytime short-studytime \n", + "G1 \n", + "Fail 0.032967 0.15016 0.111111 0.255814 \n", + "Pass 0.967033 0.84984 0.888889 0.744186 " + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "bn.cpds[\"G1\"]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The CPD dictionaries are multi-indexed, and so the `loc` function can be a useful way to interact with them:" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Predict the State given the Input Data" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The `predict` method of `BayesianNetwork` allows us to make predictions based on the data using the learnt Bayesian Network. For example, we want to predict if a student fails or passes their exam based on the input data. Imagine we have an incoming student data that looks like this:" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "address U\n", + "famsize GT3\n", + "Pstatus T\n", + "Medu 3\n", + "Fedu 2\n", + "traveltime 1\n", + "studytime short-studytime\n", + "failures have-failure\n", + "schoolsup no\n", + "famsup yes\n", + "paid yes\n", + "activities yes\n", + "nursery yes\n", + "higher yes\n", + "internet yes\n", + "romantic no\n", + "famrel 5\n", + "freetime 5\n", + "goout 5\n", + "Dalc 2\n", + "Walc 4\n", + "health 5\n", + "absences Low-absence\n", + "G2 Fail\n", + "G3 Fail\n", + "Name: 18, dtype: object" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "discretised_data.loc[18, discretised_data.columns != 'G1']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Based on these data, we want to predict if this student fails their exam. Intuitively, we would expect this student to fail because they spend a shorter amount of time on their study and have failed in the past. Let's see how our Bayesian Network performs on this:" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "predictions = bn.predict(discretised_data, \"G1\")" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The prediction is 'Fail'\n" + ] + } + ], + "source": [ + "print('The prediction is \\'{prediction}\\''.format(prediction=predictions.loc[18, 'G1_prediction']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The prediction by the Bayesian Network turns out to be a `Fail`. Let's compare this to the ground truth:" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The ground truth is 'Fail'\n" + ] + } + ], + "source": [ + "print('The ground truth is \\'{truth}\\''.format(truth=discretised_data.loc[18, 'G1']))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "which turns out to be the same." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model Quality" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To evaluate the quality of the model that has been learned, CausalNex supports two main approaches: Classification Report and Reciever Operating Characteristics (ROC) / Area Under the ROC Curve (AUC). In this section each will be discussed." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Classification Report" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To obtain a classification report using a BN, we need to provide a test set, and the node we are trying to classify. The report will predict the target node for all rows in the test set, and evaluate how well those predictions are made." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "<div>\n", + "<style scoped>\n", + " .dataframe tbody tr th:only-of-type {\n", + " vertical-align: middle;\n", + " }\n", + "\n", + " .dataframe tbody tr th {\n", + " vertical-align: top;\n", + " }\n", + "\n", + " .dataframe thead th {\n", + " text-align: right;\n", + " }\n", + "</style>\n", + "<table border=\"1\" class=\"dataframe\">\n", + " <thead>\n", + " <tr style=\"text-align: right;\">\n", + " <th></th>\n", + " <th>precision</th>\n", + " <th>recall</th>\n", + " <th>f1-score</th>\n", + " <th>support</th>\n", + " </tr>\n", + " </thead>\n", + " <tbody>\n", + " <tr>\n", + " <th>G1_Fail</th>\n", + " <td>0.777778</td>\n", + " <td>0.583333</td>\n", + " <td>0.666667</td>\n", + " <td>12</td>\n", + " </tr>\n", + " <tr>\n", + " <th>G1_Pass</th>\n", + " <td>0.910714</td>\n", + " <td>0.962264</td>\n", + " <td>0.935780</td>\n", + " <td>53</td>\n", + " </tr>\n", + " <tr>\n", + " <th>macro avg</th>\n", + " <td>0.844246</td>\n", + " <td>0.772799</td>\n", + " <td>0.801223</td>\n", + " <td>65</td>\n", + " </tr>\n", + " <tr>\n", + " <th>micro avg</th>\n", + " <td>0.892308</td>\n", + " <td>0.892308</td>\n", + " <td>0.892308</td>\n", + " <td>65</td>\n", + " </tr>\n", + " <tr>\n", + " <th>weighted avg</th>\n", + " <td>0.886172</td>\n", + " <td>0.892308</td>\n", + " <td>0.886097</td>\n", + " <td>65</td>\n", + " </tr>\n", + " </tbody>\n", + "</table>\n", + "</div>" + ], + "text/plain": [ + " precision recall f1-score support\n", + "G1_Fail 0.777778 0.583333 0.666667 12\n", + "G1_Pass 0.910714 0.962264 0.935780 53\n", + "macro avg 0.844246 0.772799 0.801223 65\n", + "micro avg 0.892308 0.892308 0.892308 65\n", + "weighted avg 0.886172 0.892308 0.886097 65" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from causalnex.evaluation import classification_report\n", + "classification_report(bn, test, \"G1\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This report shows that the model we have defined is able to classify whether a student passes their exam reasonably well.\n", + "\n", + "For the predictions where the student fails, the precision is good, but recall is bad. This implies that we can rely on predictions for this class when they are made, but we are likely to miss some of the predictions we should have made. Perhaps these missing predictions are as a result of something missing in our structure - this could be an interesting area to explore." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### ROC / AUC" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Reciever Operating Characteristics (ROC), and the Area Under the ROC Curve (AUC) can be obtained using the `roc_auc` method within the CausalNex metrics module. Again, a test set and target node must be provided. The ROC curve is computed by micro-averaging predictions made across all states (classes) of the target node." + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.9181065088757396\n" + ] + } + ], + "source": [ + "from causalnex.evaluation import roc_auc\n", + "roc, auc = roc_auc(bn, test, \"G1\")\n", + "print(auc)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The AUC value for our model is high, giving us confidence in the performance." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Querying Marginals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "After iterating over our model structure, CPDs, and validating our model quality, we can query our model under defferent observation to gain insights." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Baseline Marginals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To query the model for baseline marginals that reflect the population as a whole, a `query` method can be used. First let's update our model using the complete dataset, since the one we currently have was only built from training data." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:root:Replacing existing CPD for address\n", + "WARNING:root:Replacing existing CPD for absences\n", + "WARNING:root:Replacing existing CPD for Pstatus\n", + "WARNING:root:Replacing existing CPD for famrel\n", + "WARNING:root:Replacing existing CPD for studytime\n", + "WARNING:root:Replacing existing CPD for G1\n", + "WARNING:root:Replacing existing CPD for failures\n", + "WARNING:root:Replacing existing CPD for schoolsup\n", + "WARNING:root:Replacing existing CPD for paid\n", + "WARNING:root:Replacing existing CPD for higher\n", + "WARNING:root:Replacing existing CPD for internet\n", + "WARNING:root:Replacing existing CPD for G2\n", + "WARNING:root:Replacing existing CPD for G3\n" + ] + } + ], + "source": [ + "bn = bn.fit_cpds(discretised_data, method=\"BayesianEstimator\", bayes_prior=\"K2\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can safely ignore these warnings, which let us know we are replacing the previously existing CPDs. \n", + "\n", + "For inference, we must create a new InferenceEngine from our BayesianNetwork, which lets us query the model. The query method will compute the marginal likelihood of all states for all nodes." + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277}" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from causalnex.inference import InferenceEngine\n", + "\n", + "ie = InferenceEngine(bn)\n", + "marginals = ie.query()\n", + "marginals[\"G1\"] " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The output observed tells us that `P(G1=Fail) = 0.25`, and the `P(G1=Pass) = 0.75`. As a quick sanity check, we can compute what proportion of our dataset are `Fail`, which should be approximately the same." + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[('Fail', 157), ('Pass', 492)]" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import numpy as np\n", + "labels, counts = np.unique(discretised_data[\"G1\"], return_counts=True)\n", + "list(zip(labels, counts))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The proportion of the students who fail is `157 / (157+492) = 0.242` - which is close to our computed marginal likelihood." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Marginals after Observations" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can also query the marginal likelihood of states in our network given some observations. These observations can be made anywhere in the network, and their impact will be propagated through to the node of interest.\n", + "\n", + "Let's look at the difference in the likelihood of `G1` based on `studytime`." + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Marginal G1 | Short Studtyime {'Fail': 0.2776556433482524, 'Pass': 0.7223443566517477}\n", + "Marginal G1 | Long Studytime {'Fail': 0.15504850337837614, 'Pass': 0.8449514966216239}\n" + ] + } + ], + "source": [ + "marginals_short = ie.query({\"studytime\": \"short-studytime\"})\n", + "marginals_long = ie.query({\"studytime\": \"long-studytime\"})\n", + "print(\"Marginal G1 | Short Studtyime\", marginals_short[\"G1\"])\n", + "print(\"Marginal G1 | Long Studytime\", marginals_long[\"G1\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Based on our data we can see that students who study longer are more likely to pass their exam." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Do Calculus" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "CausalNex also supports simple Do-Calculus, allowing as to specify interventions. In this section we will take a look at the supported methods, and what they mean." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Updating a Node Distribution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can apply an intervention to any node in our data, updating its distribution using a `do` operator. This can be thought of as asking our model \"What if\" something were different. For example, we could ask what would happen if 100% of students wanted to go on to do higher education." + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "distribution before do {'no': 0.10752688172043011, 'yes': 0.8924731182795698}\n", + "distribution after do {'no': 0.0, 'yes': 0.9999999999999998}\n" + ] + } + ], + "source": [ + "print(\"distribution before do\", ie.query()[\"higher\"])\n", + "ie.do_intervention(\"higher\", \n", + " {'yes': 1.0, \n", + " 'no': 0.0})\n", + "print(\"distribution after do\", ie.query()[\"higher\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Resetting a Node Distribution" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can reset any interventions that we make by using the `reset_intervention` method, and providing the node that we want to reset." + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "ie.reset_do(\"higher\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Effect of Do on Marginals" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can again use `query` to examine the effect that an intervention has on our marginal likelihoods. In this case, we can look at how the likelihood of achieving a pass changes if 100% of students wanted to do higher education." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "marginal G1 {'Fail': 0.25260687281677224, 'Pass': 0.7473931271832277}\n", + "updated marginal G1 {'Fail': 0.20682952942551894, 'Pass': 0.7931704705744809}\n" + ] + } + ], + "source": [ + "print(\"marginal G1\", ie.query()[\"G1\"])\n", + "ie.do_intervention(\"higher\", \n", + " {'yes': 1.0, \n", + " 'no': 0.0})\n", + "print(\"updated marginal G1\", ie.query()[\"G1\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In this case, we can see that if 100% of students wanted to do higher education (as opposed to 90% in our data population), then we estimate that pass rate would increase from 74.7% to 79.3%." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python [conda env:causal-test] *", + "language": "python", + "name": "conda-env-causal-test-py" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/source/04_user_guide/04_user_guide.md b/docs/source/04_user_guide/04_user_guide.md new file mode 100644 index 0000000..347434e --- /dev/null +++ b/docs/source/04_user_guide/04_user_guide.md @@ -0,0 +1,304 @@ +# Causal Inference with Bayesian Networks. Main Concepts and Methods + +## 1. Causality + +### 1.1 Why is causality important? + +Experts and practitioners in various domains are commonly interested in discovering causal relationships to answer questions like + +> "What drives economical prosperity?", "What fraction of patients can a given drug save?", +"How much would a power failure cost to a given manufacturing plant?". + +The ability to identify truly causal relationships is fundamental to developing impactful interventions in medicine, policy, business, and other domains. + +Often, in the absence of randomised control trials, there is a need for causal inference purely from observational data. +However, in this case the commonly known fact that + +> correlation does not imply causation + +comes to life. Therefore, it is crucial to distinguish between events that _cause_ specific outcomes and those that merely _correlate_. +One possible explanation for correlation between variables where neither causes the other is the presence of _confounding_ variables +that influence both the target and a driver of that target. Unobserved confounding variables are severe +threats when doing causal inference on observational data. +The research community has made significant contributions to develop methods and techniques for this type of analysis. +[Potential outcomes framework (Rubin causal model)](https://5harad.com/mse331/papers/rubin_causal_inference.pdf), +[propensity score matching](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC3144483/) and +[structural causal models](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC2836213/) are, arguably, the most popular frameworks for observational causal inference. + +Here, we focus on the structural causal models and one particular type, Bayesian Networks. + +Interested users can find more details in the references below. +- [Causal inference using potential outcomes: Design, +modeling, decisions. Journal of the American Statistical Association](https://5harad.com/mse331/papers/rubin_causal_inference.pdf) by D. Rubin; +- [Lecture notes on potential outcomes approach](http://statweb.stanford.edu/~rag/stat209/jorogosa06.pdf), Dept of Psychiatry & Behavioral Sciences, Stanford University by Booil Jo; +- [Probabilistic graphical models: principles and techniques](https://mitpress.mit.edu/books/probabilistic-graphical-models) by D. Koller and N. Friedman. + +### 1.2 Structural Causal Models (SCMs) + +*Structural causal models* represent causal dependencies using graphical models that provide an intuitive visualisation by +representing variables as nodes and relationships between variables as edges in a graph. + +SCMs serve as a comprehensive framework unifying graphical models, structural equations, and counterfactual +and interventional logic. + +Graphical models serve as a language for structuring and visualising knowledge about the world and can incorporate both data-driven and human inputs. + +Counterfactuals enable the articulation of something there is a desire to know, and structural equations serve to tie the two together. + +SCMs had a transformative impact on multiple data-intensive disciplines (e.g. epidemiology, economics, etc.), enabling the codification of the existing knowledge in diagrammatic and algebraic forms and consequently leveraging data to estimate the answers to interventional and counterfacutal +questions. + +Bayesian Networks are one of the most widely used SCMs and are at the core of this library. + +More on SCMs: [Causality: Models, Reasoning, and Inference](http://bayes.cs.ucla.edu/BOOK-2K/) by J. Pearl. + + +## 2. Bayesian Networks (BNs) + +### 2.1 Directed Acyclic Graph (DAG) +A *graph* is a collection of *nodes* and *edges*, where the *nodes* are some objects, and *edges* between them represent some connection between these objects. +A *directed graph*, is a graph in which each edge is orientated from one node to another node. +In a directed graph, an edge goes from a *parent* node to a *child* node. +A *path* in a directed graph is a sequence of edges such that the ending node of each edge is the starting node of the next edge in the sequence. +A *cycle* is a path in which the starting node of its first edge equals the ending node of its last edge. +A *directed acyclic graph* is a directed graph that has no cycles. + +<figure> + <img src="graph.png" width="210"/> + <figcaption>Figure 1: A simple directed acyclic graph.</figcaption> +</figure> + + +<figure> + <img src="graph_definitions.png" width="350"/> + <figcaption>Figure 2: A more complex graph with a cycle and an isolated node. + This graph can be turned into a DAG by removing one of the edges forming a cycle: (F, G), (E, F) or (G, E).</figcaption> +</figure> + +### 2.2 What Bayesian Networks are and are not +**What are Bayesian Networks?** + +*Bayesian Networks* are probabilistic graphical models that represent the dependency structure of a set of variables and their joint distribution efficiently in a factorised way. + +Bayesian Network consists of a DAG, a causal graph where nodes represents random variables and edges represent the the relationship between them, and a conditional probability distribution (CPDs) associated with +each of the random variables. + +If a random variable has parents in the BN then the CPD represents \\(P(\text{variable|parents}) \\) i.e. the +probability of that variable given its parents. In the case, when +the random variable has no parents it simply represents \\(P(\text{variable}) \\) i.e. the probability of that variable. + +Even though we are interested in the joint distribution of the variables in the graph, Bayes' rule requires to only specify the conditional distributions of each variable given its parents. + +> The links between variables in BNs encode dependency not necessarily causality. In this package we are mostly interested in the case where BNs are causal. Hence, the edge between nodes should be seen as *cause -> effect* relationship. + +Let's consider an example of a simple Bayesian network shown in figure below. It shows how the actions of customer relationship managers (emails sent and meetings held) affect the bank's income. + +<figure> + <img src="BN.png" width="700"/> + <figcaption>Figure 3: A Bayesian Network describing a banking case study. Tables attributed to the nodes show the CPDs of the corresponding variables given their parents (if present).</figcaption> +</figure> + +New sales and the number of meetings with a customer directly affect the bank's income. However, these +two drivers are not independent but the number of meetings also influences +whether a new sale takes place. In addition, system prompts indirectly influence +the bank's income through the generation of new sales. This example +shows that BNs are able to capture complex relationships between variables and represent dependencies between +drivers and include drivers that do not affect the target directly. + + +**Steps for working with a Bayesian Network** + +BN models are built in a multi-step process before they can be used for analysis. + +1. **Structure Learning**. The structure of a network describing the relationships between variables can be learned from data, or built from expert knowledge. +2. **Structure Review**. Each relationship should be validated, so that it can be asserted to be causal. This may involve flipping / removing / adding learned edges, or confirming expert knowledge from trusted literature or empirical beliefs. +3. **Likelihood Estimation**. The conditional probability distribution of each variable given its parents can be learned from data. +4. **Prediction & Inference**. The given structure and likelihoods can be used to make predictions, or perform observational and counterfactual inference. +CausalNex supports structure learning from continuous data, and expert opinion. CausalNex supports likelihood estimation and prediction/inference from discrete data. A `Discretiser` class is provided to help discretising continuous data in a meaningful way. + + +> Since BNs themselves are not inherently causal models, the structure learning algorithms on their own merely learn that there are dependencies between variables. A useful approach to the problem is to fi rst group the features into themes and constrain the search space to respect how themes of variables relate. If there is further domain knowledge available, it can be used as additional constraints before learning a graph algorithmically. + +**What can we use Bayesian Networks for?** + +The probabilities of variables in Bayesian Networks update as observations are added to the model. +This is useful for inference or counterfactuals, and for predictive analytics. +Metrics can help us understand the strength of relationships between variables. + +- The sensitivity of nodes to changes in observations of other events can be used to assess what changes could lead to what effects; +- The active trail of a target node identifies which other variables have any effect on the target. + +### 2.3 Advantages and Drawbacks of Bayesian Networks + +**Advantages** + +- Bayesian Networks offer a graphical representation that is reasonably interpretable and easily explainable; +- Relationships captured between variables in a Bayesian Network are more complex yet hopefully more informative than a conventional model; +- Models can reflect both statistically significant information (learned from the data) and domain expertise simultaneously; +- Multiple metrics can used to measure the significance of relationships and help identify the effect of specific actions; +- Offer a mechanism of suggesting counterfactual actions and combine actions without aggressive independence assumptions. + +**Drawbacks** + +- Granularity of modelling may have to be lower. However, this may either not be necessary, or can be run in tangent to other techniques that provide accuracy +but are less interpretable; +- Computational complexity is higher. However, this can be offset with careful feature selection and a less granular discretisation policy, but at the expense of predictive power; +- This is (unfortunately) not a way of fully automating Causal Inference. + +## 3. `BayesianNetwork` + +The `BayesianNetwork` class is the central class for the causal inference analysis in the package. +It is built on top of the `StructureModel`, which is an extension of `networkx.DiGraph` + +`StructureModel` represents a causal graph, a DAG of the respective BN and holds directed edges, describing +a _cause -> effect_ relationship. In order to define the `BayesianNetwork`, users should provide a relevant `StructureModel`. + +> Cycles are permitted within a `StructureModel`. However, only **acyclic connected** `StructureModel` are allowed in the construction of `BayesianNetwork`; isolated nodes are not allowed. + +### 3.1 Defining the DAG with `StructureModel` + +Our package enables a _hybrid way_ to learn structure of the model. + +For instance, users can define a causal model **fully manually**, e.g., using the domain expertise: + +```python + from causalnex.structure import StructureModel + # Encoding the causal graph suggested by an expert + # d + # ↙ ↓ ↘ + # a ← b → c + # ↑ ↗ + # e + sm_manual = StructureModel() + sm_manual.add_edges_from( + [ + ("b", "a", origin="expert"), + ("b", "c", origin="expert"), + ("d", "a", origin="expert"), + ("d", "c", origin="expert"), + ("d", "b", origin="expert"), + ("e", "c", origin="expert"), + ("e", "b", origin="expert"), + ] + ) +``` +Or, users can learn the network structure **automatically** from the data using the [`NOTEARS`](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf) algorithm. Moreover, if there is domain knowledge available, +it can be used as **additional constraints** before learning a graph algorithmically. + +> Recently published [NOTEARS](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf) algorithm for learning DAGs from data based on a continuous optimisation problem +allowed to overcome the challenges of combinatorial optimisation giving a new impulse to the usage of BNs in the machine learning applications. + +```python + from causalnex.structure.notears import from_pandas + from causalnex.network import BayesianNetwork + + # Unconstrained learning of the structure from data + sm = from_pandas(data) + # Imposing edges that are not allowed in the causal model + sm_with_tabu_edges = from_pandas(data, tabu_edges=[("e", "a")]) + # Imposing parent nodes that are not allowed in the causal model + sm_with_tabu_parents = from_pandas(data, tabu_parent_nodes=["a", "c"]) + # Imposing child nodes that are not allowed in the causal model + sm_with_tabu_parents = from_pandas(data, tabu_child_nodes=["d", "e"]) +``` + +Finally, the output of the algorithm should be **inspected** and **adjusted**, if required, +by a domain expert. This is a targeted effort to encode important domain knowledge in models, and avoid spurious relationships. + +```python + # Removing the learned edge from the model + sm.remove_edge("a", "c") + # Changing the direction of the learned edge + sm.remove_edge("c", "d") + sm.add_edge("d", "c", origin="learned") + # Adding the edge that was not learned by the algorithm + sm.add_edge("a", "e", origin="expert") +``` + +> When defining the structure model, we recommend to use the **entire** dataset **without** discretisation of continuous variables. + +### 3.2 Likelihood Estimation and Predictions with `BayesianNetwork` + +Once the graph has been determined, the `BayesianNetwork` can be initialised and the conditional probability distributions of the variables can be learned from the data. + +Maximum likelihood or Bayesian parameter estimation can be used for CPDs learning. +> When learning CPDs of the BN, +> - The dicscretised data should be used (either high or low granularity of features and target variables can be used); +> - Overfitting and underfitting of CPDs can be avoided with an appropriate train/test split of the data. + +```python + from causalnex.network import BayesianNetwork + from causalnex.discretiser import Discretiser + + # Inititalise BN with defined structure model + bn = BayesianNetwork(sm) + # First, learn all the possible states of the nodes using the whole dataset + bn.fit_node_states(data_discrete) + # Fit CPDs using the training dataset with the discretised continuous variable "c" + train_data_discrete = train_data.copy() + train_data_discrete["c"] = Discretiser(method="uniform").transform(discretised_data["c"].values) + bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") +``` + +Once the CPDs are learned, they can be used to predict the state of a node as well as probability of each possible state of a node, based on some input data (e.g., previously unseen test data) and learned CPDs: + +```python + predictions = bn.predict(test_data_discrete, "c") + predicted_probabilities = bn.predict_probability(test_data_discrete, "c") +``` +> When all parents of a given node exist within input data, the method inspects the CPDs directly and avoids traversing the full network. When some parents do not exist within input data, the most likely state for every node that is not contained within data is computed, and the predictions are made accordingly. + +## 4. Querying model and making interventions with `InferenceEngine` + +After iterating over the model structure, CPDs, and validating the model quality, we can +undertake inference on a BN to examine expected behaviour and gain insights. + +`InferenceEngine` class provides methods to query marginals based on observations and to make interventions (a.k.a. DO-calculus) on a Bayesian Network. + +### 4.1 Querying marginals with `InferenceEngine.query` + +Inference and observation of evidence are done on the fly, following a deterministic [Junction Tree Algorithm (JTA)](https://ermongroup.github.io/cs228-notes/inference/jt/). + +To query the model for baseline marginals that reflect the population as a whole, a `query` method can be used. + +> We recommend to update the model using the complete dataset for this type of queries. + +```python + from causalnex.inference import InferenceEngine + + # Updating the model on the whole dataset + bn.fit_cpds(data_discrete, method="BayesianEstimator", bayes_prior="K2") + ie = InferenceEngine(bn) + # Querying all the marginal probabilities of the model's distribution + marginals = ie.query({}) +``` + +Users can also query the marginals of states in a BN given some _observations_. +These observations can be made anywhere in the network; the marginal distributions of nodes (including the target variable) will be updated and their impact will be propagated through to the node of interest: + +```python + # Querying the marginal probabilities of the model's distribution + # after an observed state of the node "b" + marginals_after_observations = ie.query({"b": True}) +``` + +> - For complex networks, the JTA may take an hour to update the probabilities throughout the network; +> - This process can not be parallelised, but multiple queries can be run in parallel; + +### 4.2 Making interventions (Do-calculus) with `InferenceEngine.do_intervention` + +Finally, users can use the insights from the inference and observation of evidence to encode taking _actions_ and observe the effect of these actions on the target variable. + +Our package supports simple Do-Calculus, allowing as to Make an intervention on the Bayesian Network. + +Users can apply an intervention to any node in the data, updating its distribution using a _do_ operator, +examining the effect of that intervention by querying marginals and resetting any interventions: + +```python + # Doing an intervention to the node "d" + ie.do_intervention("d", True) + # Querying all the updated marginal probabilities of the model's distribution + marginals_after_interventions = ie.query({}) + # Re-introducing the original conditional dependencies + ie.reset_do("d") +``` diff --git a/docs/source/04_user_guide/images/BN.png b/docs/source/04_user_guide/images/BN.png new file mode 100644 index 0000000..d0659d3 Binary files /dev/null and b/docs/source/04_user_guide/images/BN.png differ diff --git a/docs/source/04_user_guide/images/graph.png b/docs/source/04_user_guide/images/graph.png new file mode 100644 index 0000000..b37389c Binary files /dev/null and b/docs/source/04_user_guide/images/graph.png differ diff --git a/docs/source/04_user_guide/images/graph_definitions.png b/docs/source/04_user_guide/images/graph_definitions.png new file mode 100644 index 0000000..3ba7230 Binary files /dev/null and b/docs/source/04_user_guide/images/graph_definitions.png differ diff --git a/docs/source/05_resources/05_faq.md b/docs/source/05_resources/05_faq.md new file mode 100644 index 0000000..2242776 --- /dev/null +++ b/docs/source/05_resources/05_faq.md @@ -0,0 +1,103 @@ +# Frequently asked questions + +> *Note:* This documentation is based on `CausalNex 0.4.0`, if you spot anything that is incorrect then please create an [issue](https://github.com/quantumblacklabs/causalnex/issues) or pull request. + +## What is CausalNex? + +[CausalNex](https://github.com/quantumblacklabs/causalnex) is a python library that allows data scientists and domain experts to co-develop models which go beyond correlation to consider causal relationships. It was originally designed by [Paul Beaumont](https://www.linkedin.com/in/pbeaumont/) and [Ben Horsburgh](https://www.linkedin.com/in/benhorsburgh/) to solve challenges they faced in inferencing causality in their project work. + +This work was later turned into a product thanks to the following contributors: [Ivan Danov](https://github.com/idanov), [Dmitrii Deriabin](https://github.com/DmitryDeryabin), [Yetunde Dada](https://github.com/yetudada), [Wesley Leong](https://www.linkedin.com/in/wesleyleong/), [Steve Ler](https://www.linkedin.com/in/song-lim-steve-ler-380366106/), [Viktoriia Oliinyk](https://www.linkedin.com/in/victoria-oleynik/), [Roxana Pamfil](https://www.linkedin.com/in/roxana-pamfil-1192053b/), [Fabian Peter](https://www.linkedin.com/in/fabian-peters-6291ab105/), [Nisara Sriwattanaworachai](https://www.linkedin.com/in/nisara-sriwattanaworachai-795b357/) and [Nikolaos Tsaousis](https://www.linkedin.com/in/ntsaousis/). + +## What are the benefits of using CausalNex? + +It is important to consider the primary benefits of CausalNex in the context of an end-to-end causality and counterfactual analysis. + +As we see it, CausalNex: + +- **Generates transparency and trust in models** it creates by allowing users to collaborate with domain experts during the modelling process. +- Uses an **optimised structure learning algorithm**, [NOTEARS](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf) where the runtime to learn structure is no longer exponential but scales cubically with number of nodes. +- **Add known relationships or remove spurious correlations** so that your model can better consider causal relationships in data +- **Visualise networks using common tools** built upon [NetworkX](https://networkx.github.io/), allowing users to understand relationships in their data more intuitively, and work with experts to encode their knowledge +- **Streamlines the use of Bayesian Networks** for an end-to-end counterfactual analysis, which in the past was a complicated process involving the use of at least three separate open source libraries, each with its own interface. + +## When should you consider using CausalNex? + +CausalNex is created specifically for data scientists who would like an efficient and intuitive process to identify causal relationships and the right intervention through data and collaboration with domain experts. + +## Why NOTEARS algorithm over other structure learning methods? + +Historically, structure learning has been a very **hard** problem. We are interested in looking for the optimal directed acyclic graph (DAGs) that describes the conditional dependencies between variables. However, the search space for this is **combinatorial** and scales **super-exponentially** with the number of nodes. [NOTEARS](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf) algorithm cleverly introduces a new optimisation heuristic and approach to solving this problem, where the runtime for this is no longer exponential but scales **cubically** with the number of nodes. + +## What is the recommended type of dataset to be used in NOTEARS? + +[NOTEARS](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf) works by detecting if a small increase in the value of the node will result in an increase in another node. If there is, NOTEARS will be able to capture this and assert that this is a causal relationship. Therefore, we highly recommend that the dataset to be used is **continuous**. + +**Categorical variables** like blood type **won’t be able to work** in this case. Nonetheless, after learning the structure using NOTEARS, one can still manually add the relationships for these features to the structure based on their domain knowledge. + +## What is the recommended number of samples for satisfactory performance? + +According to the benchmarking done on **synthetic dataset** in-house, it is highly recommended that **at least 1000 samples** is used to get a satisfactory performance. We also discovered that any further increase than 1000 samples **does not help improve the accuracy** regardless of number of nodes, and it takes a **longer time** to run. + +## Why can my StructureModel be cyclic, but not my BayesianNetwork? + +StructureModel is used when discovering the causal structure of a dataset. Part of this process is adding, removing, and flipping edges, until the appropriate structure is completed. As edges are modified, cycles can be temporarily introduced into the structure, which would raise an Exception within a BayesianNetwork, which is a specialised **directed acyclic graph**. + +Once the structure is finalised, and is acyclic, then it can be used to create a [BayesianNetwork](https://causalnex.readthedocs.io/en/latest/04_user_guide/04_user_guide.html) + + +## Why a separate data pre-processing process for probability fitting than structure learning? / Why discretise data in probability fitting? + +We treat Bayesian Network probability fitting and Structure Learning as two separate problems. The data for Structure Learning should be continuous for the causal relationships to be learnt. **Once we already knew the causal relationship between all the nodes**, we can start doing probability fitting. At the moment, we are **only supporting discrete Bayesian Network model**, and this requires the continuous features to be discretised. + +## Why call fit_node_states before fit_cpds? + +Before fitting, the model first has to know how many states each node has to carry out the computations. Alternatively, one can also call **fit_node_states_and_cpds**. However, there is a chance that this might not work if one were to do train/test splitting as the model might not see all the possible states. + +For example, rare blood type like AB-negative might not appear in the training data but in the test data. Therefore, we strongly encourage users to do **fit_node_states using all data** and **fit_cpds using training data** to test the model quality, so that the model knows all the possible states that each node can have. + +## What is Do-intervention and when to use it? + +[Do-intervention](https://causalnex.readthedocs.io/en/latest/04_user_guide/04_user_guide.html) is symbolically described as p(y|do(x)). It asks the question of what is the probability distribution of Y if we were to **set** the value of X to x **arbitrarily**. + +For example, we have 50% of males and 50% of females in the world, but we might be interested to learn about the probability distribution of happiness index if we had 80% of females and 20% males in the world. + +Do-intervention is very useful in **counterfactual analysis**, where we are interested to know if the outcomes would have been different if we had taken a different action/intervention. + +## How can I make inference faster? + +At the moment, the algorithm calculates the probability of **every node** in a Bayesian Network. If users are interested in making inference of the target node faster, user can remove nodes that are independent from the target node, and also children of the target node. For example, if we have C<-A->B->D and we want to learn P(B|A), we can remove C and D to make the inference faster. + +## How does CausalNex compare to other projects, e.g. CausalML, DoWhy? + +The following points describe how we are unique comparing to the others: +1) We are one of the very few causal packages that use **Bayesian Networks** to model the problems. Most of the causal packages use statistical matching technique like **propensity score matching** to approach these problems. +2) One of the main hurdle to applying Bayesian Network is to find the optimal graph structure. In CausalNex, We **simplify** this process by providing the ability for the users to learn the graph structure through: i) **encoding domain expertise** by manually adding the edges, and ii) **leveraging the data** using the state-of-the-art [structure learning algorithm](https://papers.nips.cc/paper/8157-dags-with-no-tears-continuous-optimization-for-structure-learning.pdf). +3) We provide the ability for the users to do **counterfactual analysis** using Bayesian Network by introducing **Do-Calculus**, which is not commonly found in Bayesian Network packages. + +## What version of Python does CausalNex use? + +CausalNex is built for Python 3.5, 3.6 and 3.7. + +## How do I upgrade CausalNex? + +We use [SemVer](http://semver.org/) for versioning. The best way to upgrade safely is to check our [release notes](RELEASE.md) for any notable breaking changes. + +Once CausalNex is installed, you can check your version as follows: + +``` +pip show causalnex +``` + +To later upgrade CausalNex to a different version, simply run: + +``` +pip install causalnex -U +``` + +## How can I find out more CausalNex? + +CausalNex is on GitHub, and our preferred community channel for feedback is through [GitHub issues](https://github.com/quantumblacklabs/causalnex/issues). You can find news about updates and new features introduced by heading over to [RELEASE.md](https://github.com/quantumblacklabs/causalnex/blob/develop/RELEASE.md). + +## Where can I learn more about Bayesian Networks? + +You can read our [documentation](https://causalnex.readthedocs.io/en/latest/04_user_guide/04_user_guide.htm) to know more about the concepts and other useful references with regards to using Bayesian Networks for Causal Inference. diff --git a/docs/source/api_docs/causalnex.rst b/docs/source/api_docs/causalnex.rst new file mode 100644 index 0000000..b98ba91 --- /dev/null +++ b/docs/source/api_docs/causalnex.rst @@ -0,0 +1,21 @@ +causalnex +======== + +.. rubric:: Description + +.. automodule:: causalnex + + + + .. rubric:: Modules + + .. autosummary:: + :toctree: + :template: autosummary/module.rst + + causalnex.structure + causalnex.plots + causalnex.discretiser + causalnex.network + causalnex.evaluation + causalnex.inference diff --git a/docs/source/api_docs/index.rst b/docs/source/api_docs/index.rst new file mode 100644 index 0000000..70b133c --- /dev/null +++ b/docs/source/api_docs/index.rst @@ -0,0 +1,99 @@ +.. causalnex documentation master file, created by + sphinx-quickstart on Mon Dec 18 11:31:24 2017. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +.. image:: causalnex_banner.png + :alt: CausalNex logo + :class: causalnex-logo + +Welcome to CausalNex's API docs and tutorials! +============================================= + +.. image:: https://circleci.com/gh/quantumblacklabs/causalnex/tree/master.svg?style=shield + :target: https://circleci.com/gh/quantumblacklabs/causalnex/tree/master + :alt: CircleCI build status + +.. image:: https://img.shields.io/badge/coverage-100%25-brightgreen.svg + :target: https://github.com/quantumblacklabs/causalnex + :alt: Test coverage + +.. image:: https://img.shields.io/badge/python-3.5%20%7C%203.6%20%7C%203.7-blue.svg + :target: https://pypi.org/project/causalnex/ + :alt: Python version 3.5, 3.6, 3.7 + +.. image:: https://badge.fury.io/py/causalnex.svg + :target: https://pypi.org/project/causalnex/ + :alt: PyPI package version + +.. image:: https://readthedocs.org/projects/causalnex/badge/?version=latest + :target: https://causalnex.readthedocs.io/ + :alt: Docs build status + +.. image:: https://img.shields.io/badge/code%20style-black-000000.svg + :target: https://github.com/ambv/black + :alt: Code style is Black + +.. image:: https://img.shields.io/badge/license-Apache%202.0-blue.svg + :target: https://opensource.org/licenses/Apache-2.0 + :alt: License is Apache 2.0 + +.. image:: https://pepy.tech/badge/causalnex + :target: https://pepy.tech/project/causalnex + :alt: Downloads + +.. toctree:: + :maxdepth: 2 + :caption: Introduction + + 01_introduction/01_introduction + +.. toctree:: + :maxdepth: 2 + :caption: Getting Started + + 02_getting_started/01_prerequisites + 02_getting_started/02_install + +.. toctree:: + :maxdepth: 2 + :caption: Tutorial + + 03_tutorial/03_tutorial.md + + +.. toctree:: + :maxdepth: 2 + :caption: User guide + + 04_user_guide/04_user_guide + 04_user_guide/04_reference + +.. toctree:: + :maxdepth: 2 + :caption: Resources + + 05_resources/05_faq + + +API Docs +======== + +.. toctree:: + :maxdepth: 0 + :caption: API Docs + :hidden: + + causalnex + +.. autosummary:: + :template: autosummary/module.rst + + causalnex + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/causalnex_banner.png b/docs/source/causalnex_banner.png new file mode 100644 index 0000000..06a5584 Binary files /dev/null and b/docs/source/causalnex_banner.png differ diff --git a/docs/source/css/causalnex.css b/docs/source/css/causalnex.css new file mode 100644 index 0000000..2c33866 --- /dev/null +++ b/docs/source/css/causalnex.css @@ -0,0 +1,22 @@ +table { + font-family: "Trebuchet MS", Arial, Helvetica, sans-serif; + display: block; + overflow: scroll; +} + +td, th { + border: 1px solid #ddd; + padding: 8px; +} + +tr:nth-child(even){background-color: #f2f2f2;} + +tr:hover {background-color: #ddd;} + +th { + padding-top: 12px; + padding-bottom: 12px; + text-align: left; + background-color: #34BB54; + color: white; +} diff --git a/docs/source/css/copybutton.css b/docs/source/css/copybutton.css new file mode 100644 index 0000000..6c01d71 --- /dev/null +++ b/docs/source/css/copybutton.css @@ -0,0 +1,60 @@ +/* Copied from: https://github.com/raw/choldgraf/sphinx-copybutton/master/sphinx_copybutton/_static/copybutton.css */ + + /* Copy buttons */ +a.copybtn { + position: absolute; + top: 2px; + right: 2px; + width: 1.7em; + height: 1.7em; + padding: .3em; + opacity: .6; + transition: opacity 0.5s; +} + + div.highlight { + position: relative; + background: #f5f5f5; +} + + .highlight:hover .copybtn { + opacity: 1; +} + + /** + * A minimal CSS-only tooltip copied from: + * https://codepen.io/mildrenben/pen/rVBrpK + * + * To use, write HTML like the following: + * + * <p class="o-tooltip--left" data-tooltip="Hey">Short</p> + */ + .o-tooltip--left { + position: relative; + } + + .o-tooltip--left:after { + opacity: 0; + visibility: hidden; + position: absolute; + content: attr(data-tooltip); + padding: 2px; + top: 0; + left: 0; + background: grey; + font-size: 1rem; + color: white; + white-space: nowrap; + z-index: 2; + border-radius: 2px; + transform: translateX(-102%) translateY(0); + transition: opacity 0.2s cubic-bezier(0.64, 0.09, 0.08, 1), transform 0.2s cubic-bezier(0.64, 0.09, 0.08, 1); +} + + .o-tooltip--left:hover:after { + display: block; + opacity: 1; + visibility: visible; + transform: translateX(-100%) translateY(0); + transition: opacity 0.2s cubic-bezier(0.64, 0.09, 0.08, 1), transform 0.2s cubic-bezier(0.64, 0.09, 0.08, 1); +} diff --git a/docs/source/css/qb1-sphinx-rtd.css b/docs/source/css/qb1-sphinx-rtd.css new file mode 100644 index 0000000..8dfcbe3 --- /dev/null +++ b/docs/source/css/qb1-sphinx-rtd.css @@ -0,0 +1,405 @@ +@import url("https://fonts.googleapis.com/css?family=Titillium+Web:300,400,600"); + +html, body.wy-body-for-nav { + margin: 0; + padding: 0; + -webkit-font-smoothing: antialiased; + font-family: 'Titillium Web', sans-serif; + font-weight: 400; + line-height: 2rem; +} + +html { + font-size: 62.5%; +} + +body.wy-body-for-nav { + font-size: 1.6rem; + background: rgb(250, 250, 250) !important; + color: black; +} + +.wy-side-nav-search { + text-align: left; +} + +.wy-side-nav-search input[type=text] { + display: block; + box-sizing: border-box; + width: 100%; + padding: 6px 12px; + color: #666; + background-color: #fff; + font-family: inherit; + font-size: 1.6rem; + border: 1px #ccc solid; + border-radius: 2px; + transition: all ease 0.15s; + box-shadow: none; +} + +.wy-side-nav-search input[type=text]:focus { + border-color: #888; + color: #333; +} + +.wy-body-for-nav .wy-nav-side { + position: fixed; + top: 0; + bottom: 0; + left: 0; + padding-bottom: 2em; + width: 300px; + overflow-x: hidden; + overflow-y: auto; + min-height: 100%; + background: white; + z-index: 200; + box-shadow: 0 2px 5px 0 rgba(0, 0, 0, .05); +} + +.wy-body-for-nav .wy-side-scroll { + width: 100%; + position: relative; + overflow-x: initial; + overflow-y: initial; + height: initial; +} + +.wy-body-for-nav .wy-side-nav-search { + width: 100%; + background: none; + margin: 0; + padding: 0 20px; +} +.wy-body-for-nav .wy-side-nav-search a { + display: inline-flex; + flex-direction: row-reverse; + align-items: center; + font-size: 4rem; + margin: 0; + padding: 0; + height: 7rem; + line-height: 7rem; + color: black; +} +.wy-body-for-nav .wy-side-nav-search a:before { + display: none; +} +.wy-body-for-nav .wy-side-nav-search a img.logo { + width: 4.5rem; + height: auto; + margin: 0 0.3rem 0 -0.6rem; + padding: 0; +} +.wy-body-for-nav .wy-side-nav-search>div.version { + color: #555; + display: inline-block; + margin-left: 0.5em; + font-size: 1.4rem; +} + +.wy-body-for-nav .wy-menu-vertical { + width: 300px; + margin: 0; + padding: 0 20px; +} + +.wy-body-for-nav .wy-menu-vertical p.caption { + margin: 1.5em 0 0.2em; + padding: 0; + color: #666; + font-weight: bold; + font-size: 2rem; +} + +.wy-body-for-nav .wy-menu-vertical li.on a, .wy-body-for-nav .wy-menu-vertical li.current>a { + border-top: none; + border-bottom: none; +} + +.wy-body-for-nav .wy-menu-vertical li.current { + background: none; +} + +.wy-body-for-nav .wy-menu-vertical li { + margin: 0; + padding: 0 0 0 20px; +} + +.wy-body-for-nav .wy-menu-vertical li a { + display: block; + margin: 0; + padding: 1.2rem 0 !important; + font-size: 1.6rem; + line-height: 1; + color: #222; + background: none !important; +} + +.rst-content.style-external-links a.reference.external:after { + color: inherit; + opacity: 0.8; +} + +.wy-body-for-nav .wy-menu-vertical a:hover { + color: #000; +} + +.wy-body-for-nav li span.toctree-expand { + margin-left: -20px; +} + +.wy-body-for-nav .toctree-expand:before { + font-size: 15px; + margin-right: 5px; +} + +.wy-body-for-nav .wy-nav-content-wrap { + margin-left: 300px; + background: #fafafa; +} + +.wy-body-for-nav .wy-nav-content { + padding: 0; + max-width: initial; +} + +.wy-body-for-nav .wy-breadcrumbs { + display: -webkit-box; + display: -ms-flexbox; + display: flex; + -webkit-box-align: center; + -ms-flex-align: center; + align-items: center; + width: 100%; + height: 7rem; + padding-left: 3.2rem; + background: white; + background-image: initial; + background-position-x: initial; + background-position-y: initial; + background-size: initial; + background-repeat-x: initial; + background-repeat-y: initial; + background-attachment: initial; + background-origin: initial; + background-clip: initial; + background-color: white; + box-shadow: 0 2px 5px 0 rgba(0, 0, 0, .05); +} + +.wy-body-for-nav .wy-breadcrumbs + hr { + display: none; +} +.wy-body-for-nav .wy-breadcrumbs li { + font-size: 1.8rem; +} + +.wy-body-for-nav .wy-breadcrumbs li:not(:first-child) { + margin-left: 6px; +} + +.wy-body-for-nav .wy-nav-content .document { + padding: 3.2rem 3.2rem 2rem; + max-width: 100rem; +} + +.causalnex-logo { + margin-bottom: 3rem; +} + +.wy-body-for-nav footer { + padding: 32px; +} + +.wy-body-for-nav .wy-body-for-nav { + background: #fafafa !important; + background-image: none; + background-size: initial; +} + +.wy-body-for-nav .wy-menu-vertical li.current a:hover { + background: none; +} + +.wy-menu-vertical a span.toctree-expand { + color: #555 !important; +} + +.wy-menu-vertical a:hover span.toctree-expand { + color: #000 !important; +} + +.wy-body-for-nav .wy-menu-vertical li.current a { + border-right: none; +} + +.wy-body-for-nav .toctree-l2 { + padding-left: 15px !important; +} + +.wy-body-for-nav .toctree-l3 { + padding-left: 15px !important; +} + +.wy-body-for-nav .toctree-l4 { + padding-left: 15px !important; + } + +.wy-body-for-nav .toctree-l5 { + padding-left: 15px !important; +} + + +.wy-body-for-nav .toctree-l2.current > a, .wy-body-for-nav .toctree-l3.current > a,.wy-body-for-nav .toctree-l4.current > a { + font-weight: normal !important; +} +.wy-body-for-nav .toctree-l4 a { + word-break: break-word; +} + +.wy-body-for-nav b, .wy-body-for-nav strong { + font-weight: normal; +} + +.wy-plain-list-disc li, +.rst-content .section ul li, +.rst-content .toctree-wrapper ul li, +article ul li { + margin-top: 0.35em; + margin-bottom: 0.35em; +} + +h1, h2, .rst-content .toctree-wrapper p.caption, h3, h4, h5, h6, legend { + font-family: 'Titillium Web', sans-serif; + margin: 1em 0 0.3em; + line-height: 1.2em; +} + +.wy-body-for-nav .document > div > .section > *:not(:empty):first-of-type { + margin-top: 0; +} + +.wy-body-for-nav h1 { + font-size: 3.9rem; + letter-spacing: -0.3px; +} + +.wy-body-for-nav h2 { + font-size: 3.25rem; +} + +.wy-body-for-nav h3 { + font-size: 2.6rem; +} + +.wy-body-for-nav h4 { + font-size: 1.95rem; +} + +.wy-body-for-nav h5 { + font-size: 1.625rem; +} + +.wy-body-for-nav h6 { + font-size: 1.4625rem; +} + +.wy-body-for-nav .headerlink { + display: none !important; +} + +.wy-body-for-nav p { + font-size: 100%; + margin: 0 0 15px 0; + line-height: 1.5; +} + +.wy-body-for-nav blockquote { + margin: 1em 0; + padding-left: 1em; + border-left: 4px solid #ddd; + color: #6a6a6a; +} + +.wy-body-for-nav .rst-content a, .wy-body-for-nav footer a { + font-family: inherit; + font-size: inherit; + color: #006ea7; + text-decoration: none; +} + +.wy-body-for-nav .rst-content a:visited, .wy-body-for-nav footer a:visited { + color: #446b7f; +} + +.wy-body-for-nav .rst-content a:hover, .wy-body-for-nav footer a:hover { + text-decoration: underline; +} + +.wy-body-for-nav .rst-content .btn:hover, .wy-body-for-nav footer .btn:hover { + text-decoration: none; +} + +.wy-body-for-nav .wy-nav-top { + padding: 5px 20px; + background: white; + color: black; + box-shadow: 0 2px 5px 0 rgba(0, 0, 0, 0.05); +} + +.wy-body-for-nav .wy-nav-top i { + transform: translateY(6px); +} + +.wy-body-for-nav .wy-nav-top a { + font-size: 2.8rem; + color: black !important; +} + + +@media screen and (max-width: 768px) { + .wy-body-for-nav .wy-nav-side { + transform: translate(-300px, 0); + transition: all ease 0.3s; + } + .wy-body-for-nav .wy-nav-side.shift { + width: 85%; + transform: translate(0, 0); + } + .wy-body-for-nav .wy-nav-content-wrap { + margin-left: 0; + transform: translate(0, 0); + transition: all ease 0.3s; + } + .wy-body-for-nav .wy-nav-content-wrap.shift { + position: relative; + left: 0; + transform: translate(85%, 0); + } + .wy-body-for-nav .wy-breadcrumbs { + display: none; + } + .wy-body-for-nav .wy-nav-content .document { + padding: 20px; + } +} + +@media screen and (min-width: 1600px) { + .wy-body-for-nav .wy-nav-side { + width: 350px; + } + .wy-body-for-nav .wy-nav-content-wrap { + margin-left: 350px; + } + html { + font-size: 70%; + } +} + +/* Fix Read The Docs side-effects */ +.wy-body-for-nav .rst-versions { + font-size: 16px; + line-height: 1; +} diff --git a/docs/source/css/theme-overrides.css b/docs/source/css/theme-overrides.css new file mode 100644 index 0000000..c928bd0 --- /dev/null +++ b/docs/source/css/theme-overrides.css @@ -0,0 +1,11 @@ +/* override table width restrictions */ +@media screen and (min-width: 767px) { + + .wy-table-responsive table td { + white-space: normal; + } + + .wy-table-responsive { + overflow: visible; + } +} diff --git a/docs/source/examples/.gitkeep b/docs/source/examples/.gitkeep new file mode 100644 index 0000000..e69de29 diff --git a/legal_header.txt b/legal_header.txt new file mode 100644 index 0000000..5da8261 --- /dev/null +++ b/legal_header.txt @@ -0,0 +1,27 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0a0069c --- /dev/null +++ b/requirements.txt @@ -0,0 +1,9 @@ +matplotlib>=3.0.3, <4.0 +networkx==2.2 +numpy>=1.14.2, <2.0 +pandas==0.24.0 +pgmpy==0.1.6 +prettytable==0.7.2 +scikit-learn==0.20.2 +scipy>=1.2.0, <1.3 +wrapt>=1.11.0, <1.12 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..10fbf27 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[tool:pytest] +addopts=--cov-report term-missing + --cov causalnex + --cov ebaybbn + --cov tests + --no-cov-on-fail + -ra diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8e1462d --- /dev/null +++ b/setup.py @@ -0,0 +1,77 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from os import path + +from setuptools import find_packages, setup + +name = "causalnex" +here = path.abspath(path.dirname(__file__)) + +# get package version +with open(path.join(here, name, "__init__.py"), encoding="utf-8") as f: + result = re.search(r'__version__ = ["\']([^"\']+)', f.read()) + if not result: + raise ValueError("Can't find the version in causalnex/__init__.py") + version = result.group(1) + +# get the dependencies and installs +with open("requirements.txt", "r", encoding="utf-8") as f: + requires = [x.strip() for x in f if x.strip()] + +# get test dependencies and installs +with open("test_requirements.txt", "r", encoding="utf-8") as f: + test_requires = [x.strip() for x in f if x.strip() and not x.startswith("-r")] + +# Get the long description from the README file +with open(path.join(here, "README.md"), encoding="utf-8") as f: + readme = f.read() + +setup( + name=name, + version=version, + description="Toolkit for causal reasoning (Bayesian Networks / Inference)", + long_description=readme, + long_description_content_type="text/markdown", + url="https://github.com/quantumblacklabs/causalnex", + python_requires=">=3.5, <3.8", + author="QuantumBlack Labs", + author_email="causalnex@quantumblack.com", + packages=find_packages(exclude=["docs*", "tests*", "tools*"]), + include_package_data=True, + tests_require=test_requires, + install_requires=requires, + keywords="Causal Reasoning, Bayesian Network, Inference, Structure Learning, Do-Calculus", + classifiers=[ + "Development Status :: 4 - Beta", + "Programming Language :: Python :: 3.5", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + ], +) diff --git a/test_requirements.txt b/test_requirements.txt new file mode 100644 index 0000000..756b28a --- /dev/null +++ b/test_requirements.txt @@ -0,0 +1,9 @@ +-r requirements.txt +flake8>=3.5,<4.0 +isort>=4.3.16, <5.0 +mock>=2.0.0,<3.0 +pre-commit>=1.17.0, <2.0.0 +pylint>=2.3.1, <3.0 +pytest-cov>=2.5, <3.0 +pytest-mock>=1.7.1,<2.0 +pytest>=4.3.0,<5.0 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..6c320f5 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +""" +causalnex +Toolkit for causal reasoning (Bayesian Networks / Inference) +""" diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..0ca3604 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,489 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict + +import numpy as np +import pandas as pd +import pytest +from pgmpy.models import BayesianModel + +from causalnex.network import BayesianNetwork +from causalnex.structure import StructureModel +from causalnex.structure.notears import from_pandas + + +@pytest.fixture +def train_model() -> StructureModel: + """ + This Bayesian Model structure will be used in all tests, and all fixtures will adhere to this structure. + + Cause-only nodes: [d, e] + Effect-only nodes: [a, c] + Cause / Effect nodes: [b] + + d + ↙ ↓ ↘ + a ← b → c + ↑ ↗ + e + """ + model = StructureModel() + model.add_edges_from( + [ + ("b", "a"), + ("b", "c"), + ("d", "a"), + ("d", "c"), + ("d", "b"), + ("e", "c"), + ("e", "b"), + ] + ) + return model + + +@pytest.fixture +def train_model_idx(train_model) -> BayesianModel: + """ + This Bayesian model is identical to the train_model() fixture, with the exception that node names + are integers from zero to 1, mapped by: + + {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4} + """ + model = BayesianModel() + idx_map = {"a": 0, "b": 1, "c": 2, "d": 3, "e": 4} + model.add_edges_from([(idx_map[u], idx_map[v]) for u, v in train_model.edges]) + return model + + +@pytest.fixture +def train_data() -> pd.DataFrame: + """ + Training data for testing Bayesian Networks. There are 98 samples, with 5 columns: + + - a: {"a", "b", "c", "d"} + - b: {"x", "y", "z"} + - c: 0.0 - 100.0 + - d: Boolean + - e: Boolean + + This data was generated by constructing the Bayesian Model train_model(), and then sampling + from this structure. Since e and d are both independent of all other nodes, these were sampled first for + each row (form their respective pre-defined distributions). This then allows the sampling of all further + variables based on their conditional dependencies. + + The approximate distributions used to sample from can be viewed by inspecting train_data_cpds(). + + """ + + data_arr = [ + ["a", "x", 73.78658346945414, False, False], + ["d", "x", 12.765853213346603, False, False], + ["c", "y", 22.43657132589221, False, False], + ["a", "x", 4.267744937038964, False, False], + ["b", "x", 62.87087344904927, False, False], + ["c", "x", 31.55295196889971, False, False], + ["a", "x", 37.403388911083965, False, False], + ["b", "x", 63.171968604247155, False, False], + ["d", "x", 11.140539452118263, False, False], + ["d", "x", 0.1555338799942385, True, False], + ["c", "x", 9.269926225399187, False, True], + ["b", "z", 75.38846241765208, True, True], + ["c", "z", 33.10212378889936, False, True], + ["b", "z", 57.04657630213301, True, True], + ["b", "x", 72.03855905511072, True, False], + ["c", "x", 5.106018765399956, False, False], + ["c", "z", 5.802617702038839, False, True], + ["c", "x", 17.22538330530506, False, False], + ["a", "y", 87.05395007052729, False, False], + ["d", "y", 19.09989481093348, False, False], + ["c", "x", 4.313272835124353, True, False], + ["b", "x", 13.660704178900938, True, True], + ["b", "x", 7.693287813764131, False, False], + ["c", "y", 32.791770073523246, False, False], + ["c", "y", 12.039098492465282, False, False], + ["a", "x", 51.97718339128754, False, False], + ["d", "x", 8.393970656769238, False, False], + ["a", "x", 0.3610815726384886, False, False], + ["a", "y", 35.31788713900731, True, False], + ["b", "x", 35.84702992379284, False, True], + ["c", "y", 32.872350426703356, True, False], + ["a", "x", 21.218746335586868, False, True], + ["b", "y", 71.5495653029006, True, False], + ["c", "x", 15.393846082097575, False, False], + ["d", "y", 4.514559208625406, False, False], + ["d", "x", 0.704928173400301, False, False], + ["c", "y", 34.10829794112354, True, False], + ["d", "x", 6.84602512195673, False, False], + ["b", "y", 25.43743439885204, False, False], + ["d", "x", 7.544831467091971, False, False], + ["d", "x", 13.923699372025073, False, False], + ["b", "x", 21.493005760070915, False, False], + ["a", "x", 41.353977640369436, False, False], + ["c", "z", 10.015835005248583, True, True], + ["c", "z", 29.40115954319444, False, True], + ["c", "x", 17.305145945035388, False, False], + ["b", "x", 57.3687951851441, False, False], + ["a", "x", 59.31395756039643, False, False], + ["d", "x", 19.557939187075984, False, False], + ["d", "y", 15.739556224725082, False, False], + ["c", "x", 6.850626809845993, True, False], + ["c", "x", 7.774579861173826, False, False], + ["c", "x", 20.807136344297092, True, False], + ["b", "y", 29.406207780312343, False, False], + ["a", "x", 34.38851648220974, False, False], + ["d", "x", 1.0951104244381218, True, False], + ["c", "x", 37.27483338042188, False, False], + ["b", "x", 15.745994603442064, False, False], + ["c", "x", 17.78180189764816, False, True], + ["a", "x", 17.067548428231493, True, False], + ["c", "x", 26.857320012899727, False, False], + ["a", "x", 41.0038510689549, False, True], + ["d", "x", 0.2299684913699096, False, True], + ["a", "x", 57.35885570158893, True, False], + ["d", "x", 12.40118443712448, False, False], + ["c", "x", 22.624550487374112, False, False], + ["a", "x", 93.08587619178269, False, False], + ["b", "y", 18.33030505634329, False, False], + ["a", "z", 64.29945681859853, False, True], + ["b", "x", 73.66024742961967, False, False], + ["b", "x", 16.717397443478287, False, True], + ["c", "y", 4.642615342125205, False, True], + ["c", "x", 9.431345661106931, False, False], + ["c", "y", 31.76238774237109, False, False], + ["c", "y", 3.6961806894707965, False, False], + ["d", "y", 2.298895066631253, True, False], + ["d", "y", 13.222298172220462, False, False], + ["c", "x", 28.301638775451153, False, False], + ["d", "x", 7.702270580869413, True, False], + ["a", "y", 41.38492280508702, True, False], + ["d", "x", 13.047815503255656, True, False], + ["c", "x", 22.14641490202623, False, False], + ["b", "z", 43.13007970158368, False, True], + ["b", "x", 60.09518672623882, True, False], + ["a", "x", 79.6370082234198, False, False], + ["d", "x", 16.60880504367762, False, False], + ["a", "z", 22.88783470451029, False, True], + ["a", "x", 33.66416643964188, False, False], + ["b", "y", 69.91787304290465, True, True], + ["c", "x", 31.941092922567663, True, False], + ["d", "x", 16.739638908154518, False, False], + ["a", "z", 11.129589373273108, False, True], + ["d", "y", 4.96943558614434, True, False], + ["d", "y", 6.585354730457387, False, False], + ["d", "x", 9.859942318446954, False, False], + ["b", "z", 18.541485302271496, False, True], + ["a", "x", 87.53473074574995, True, False], + ["a", "z", 59.61068083691302, False, True], + ] + + data = pd.DataFrame(data_arr, columns=["a", "b", "c", "d", "e"]) + return data + + +@pytest.fixture +def train_data_discrete(train_data) -> pd.DataFrame: + """ + train_data in discretised form. This maps "c" into 5 buckets: + - 0: x < 20 + - 1: 20 <= x < 40 + - 2: 40 <= x < 60 + - 3: 60 <= x < 80 + - 4: 80 <= x + """ + df = train_data.copy(deep=True) # type: pd.DataFrame + df["c"] = df["c"].apply( + lambda c: 0 if c < 20 else 1 if c < 40 else 2 if c < 60 else 3 if c < 80 else 4 + ) + return df + + +@pytest.fixture +def train_data_idx(train_data) -> pd.DataFrame: + """ + train_data in integer index form. This maps each column into values from 0..n + """ + + df = train_data.copy(deep=True) # type: pd.DataFrame + + df["a"] = df["a"].map({"a": 0, "b": 1, "c": 2, "d": 3}) + df["b"] = df["b"].map({"x": 0, "y": 1, "z": 2}) + df["c"] = df["c"].apply( + lambda c: 0 if c < 20 else 1 if c < 40 else 2 if c < 60 else 3 if c < 80 else 4 + ) + df["d"] = df["d"].map({True: 1, False: 0}) + df["e"] = df["e"].map({True: 1, False: 0}) + return df + + +@pytest.fixture +def train_data_idx_cpds(train_data_idx) -> Dict[str, np.ndarray]: + """Conditional probability distributions of train_data in the train_model""" + + return create_cpds(train_data_idx) + + +@pytest.fixture +def train_data_discrete_cpds(train_data_discrete) -> Dict[str, np.ndarray]: + """Conditional probability distributions of train_data in the train_model""" + + return create_cpds(train_data_discrete) + + +@pytest.fixture +def train_data_discrete_cpds_k2(train_data_discrete) -> Dict[str, np.ndarray]: + """Conditional probability distributions of train_data in the train_model""" + + return create_cpds(train_data_discrete, pc=1) + + +def create_cpds(data, pc=0): + + df = data.copy(deep=True) # type: pd.DataFrame + + df_vals = {col: list(df[col].unique()) for col in df.columns} + for _, vals in df_vals.items(): + vals.sort() + + cpd_a = np.array( + [ + [ + (len(df[(df["a"] == a) & (df["b"] == b) & (df["d"] == d)]) + pc) + / (len(df[(df["b"] == b) & (df["d"] == d)]) + (pc * len(df_vals["a"]))) + for b in df_vals["b"] + for d in df_vals["d"] + ] + for a in df_vals["a"] + ] + ) + + cpd_b = np.array( + [ + [ + (len(df[(df["b"] == b) & (df["d"] == d) & (df["e"] == e)]) + pc) + / (len(df[(df["d"] == d) & (df["e"] == e)]) + (pc * len(df_vals["b"]))) + for d in df_vals["d"] + for e in df_vals["e"] + ] + for b in df_vals["b"] + ] + ) + + cpd_c = np.array( + [ + [ + ( + ( + len( + df[ + (df["c"] == c) + & (df["b"] == b) + & (df["d"] == d) + & (df["e"] == e) + ] + ) + + pc + ) + / ( + len(df[(df["b"] == b) & (df["d"] == d) & (df["e"] == e)]) + + (pc * len(df_vals["c"])) + ) + ) + if not df[(df["b"] == b) & (df["d"] == d) & (df["e"] == e)].empty + else (1 / len(df_vals["c"])) + for b in df_vals["b"] + for d in df_vals["d"] + for e in df_vals["e"] + ] + for c in df_vals["c"] + ] + ) + + cpd_d = np.array( + [ + [(len(df[df["d"] == d]) + pc) / (len(df) + (pc * len(df_vals["d"])))] + for d in df_vals["d"] + ] + ) + + cpd_e = np.array( + [ + [(len(df[df["e"] == e]) + pc) / (len(df) + (pc * len(df_vals["e"])))] + for e in df_vals["e"] + ] + ) + + return {"a": cpd_a, "b": cpd_b, "c": cpd_c, "d": cpd_d, "e": cpd_e} + + +@pytest.fixture +def train_data_idx_marginals(train_data_idx_cpds): + + return create_marginals( + train_data_idx_cpds, + { + "a": list(range(4)), + "b": list(range(3)), + "c": list(range(5)), + "d": list(range(2)), + "e": list(range(2)), + }, + ) + + +@pytest.fixture +def train_data_discrete_marginals(train_data_discrete_cpds): + + return create_marginals( + train_data_discrete_cpds, + { + "a": ["a", "b", "c", "d"], + "b": ["x", "y", "z"], + "c": [0, 1, 2, 3, 4], + "d": [False, True], + "e": [False, True], + }, + ) + + +def create_marginals(cpds, data_vals): + cpd_d = cpds["d"] + p_d = {i: cpd_d[i, 0] for i in range(len(cpd_d))} + + cpd_e = cpds["e"] + p_e = {i: cpd_e[i, 0] for i in range(len(cpd_e))} + + cpd_b = cpds["b"] + c_b = np.array( + [ + [p_d[d] * p_e[e] for d in range(len(cpd_d)) for e in range(len(cpd_e))] + for _ in range(len(cpd_b)) + ] + ) + p_b = dict(enumerate((c_b * cpd_b).sum(axis=1))) + + cpd_a = cpds["a"] + c_a = np.array( + [ + [p_b[b] * p_d[d] for b in range(len(cpd_b)) for d in range(len(cpd_d))] + for _ in range(len(cpd_a)) + ] + ) + p_a = dict(enumerate((c_a * cpd_a).sum(axis=1))) + + cpd_c = cpds["c"] + c_c = np.array( + [ + [ + p_b[b] * p_d[d] * p_e[e] + for b in range(len(cpd_b)) + for d in range(len(cpd_d)) + for e in range(len(cpd_e)) + ] + for _ in range(len(cpd_c)) + ] + ) + p_c = dict(enumerate((c_c * cpd_c).sum(axis=1))) + + marginals = { + "a": {data_vals["a"][k]: v for k, v in p_a.items()}, + "b": {data_vals["b"][k]: v for k, v in p_b.items()}, + "c": {data_vals["c"][k]: v for k, v in p_c.items()}, + "d": {data_vals["d"][k]: v for k, v in p_d.items()}, + "e": {data_vals["e"][k]: v for k, v in p_e.items()}, + } + + return marginals + + +@pytest.fixture +def test_data_c() -> pd.DataFrame: + """Test data created so that C should be perfectly predicted based on train_data_cpds. + + Given the two independent variables are set randomly (d, e), all other variables are set to be + from the category with maximum likelihood in train_data_cpds""" + + data_arr = [ + ["a", "x", 1, False, False], + ["b", "x", 2, False, True], + ["c", "x", 3, True, False], + ["d", "x", 4, True, True], + ["d", "y", 1, False, False], + ["c", "y", 2, False, True], + ["b", "y", 23, True, False], + ["a", "y", 64, True, True], + ["c", "z", 1, False, False], + ["a", "z", 2, False, True], + ["d", "z", 3, True, False], + ["b", "z", 0, True, True], + ] + + data = pd.DataFrame(data_arr, columns=["a", "b", "c", "d", "e"]) + return data + + +@pytest.fixture +def test_data_c_discrete(test_data_c) -> pd.DataFrame: + """Test data C that has been discretised (see train_data_discrete)""" + df = test_data_c.copy(deep=True) # type: pd.DataFrame + df["c"] = df["c"].apply( + lambda c: 0 if c < 20 else 1 if c < 40 else 2 if c < 60 else 3 if c < 80 else 4 + ) + return df + + +@pytest.fixture +def test_data_c_likelihood(train_data_discrete_cpds) -> pd.DataFrame: + """Marginal likelihoods for train_data in train_model""" + + # Known bug in pylint with generated Dict: https://github.com/PyCQA/pylint/issues/1498 + data_arr = [ + [ + (train_data_discrete_cpds["c"])[ # pylint: disable=unsubscriptable-object + y, x + ] + for y in range( + len( + # pylint: disable=unsubscriptable-object + train_data_discrete_cpds["c"] + ) + ) + ] + for x in range(len(train_data_discrete_cpds["c"][0])) + ] + + likelihood = pd.DataFrame(data_arr, columns=["c_0", "c_1", "c_2", "c_3", "c_4"]) + return likelihood + + +@pytest.fixture +def bn(train_data_idx, train_data_discrete) -> BayesianNetwork: + return BayesianNetwork( + from_pandas(train_data_idx, w_threshold=0.3) + ).fit_node_states_and_cpds(train_data_discrete) diff --git a/tests/contrib/__init__.py b/tests/contrib/__init__.py new file mode 100644 index 0000000..5da8261 --- /dev/null +++ b/tests/contrib/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/ebaybbn/__init__.py b/tests/ebaybbn/__init__.py new file mode 100644 index 0000000..5da8261 --- /dev/null +++ b/tests/ebaybbn/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/ebaybbn/conftest.py b/tests/ebaybbn/conftest.py new file mode 100644 index 0000000..b106bdc --- /dev/null +++ b/tests/ebaybbn/conftest.py @@ -0,0 +1,205 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import pytest + +from causalnex.ebaybbn import ( + BBN, + Node, + build_bbn, + build_join_tree, + combinations, + make_moralized_copy, + make_undirected_copy, +) +from causalnex.ebaybbn.utils import make_key + + +@pytest.fixture +def sprinkler_graph(): + """The Sprinkler Example as a BBN + to be used in tests. + """ + cloudy = Node("Cloudy") + sprinkler = Node("Sprinkler") + rain = Node("Rain") + wet_grass = Node("WetGrass") + cloudy.children = [sprinkler, rain] + sprinkler.parents = [cloudy] + sprinkler.children = [wet_grass] + rain.parents = [cloudy] + rain.children = [wet_grass] + wet_grass.parents = [sprinkler, rain] + bbn = BBN(dict(cloudy=cloudy, sprinkler=sprinkler, rain=rain, wet_grass=wet_grass)) + return bbn + + +@pytest.fixture +def sprinkler_bbn(): + """Sprinkler BBN built with build_bbn.""" + + def f_rain(rain): + if rain is True: + return 0.2 + return 0.8 + + def f_sprinkler(rain, sprinkler): + sprinkler_dict = { + (False, True): 0.4, + (False, False): 0.6, + (True, True): 0.01, + (True, False): 0.99, + } + return sprinkler_dict[(rain, sprinkler)] + + def f_grass_wet(sprinkler, rain, grass_wet): + table = dict() + table["fft"] = 0.0 + table["fff"] = 1.0 + table["ftt"] = 0.8 + table["ftf"] = 0.2 + table["tft"] = 0.9 + table["tff"] = 0.1 + table["ttt"] = 0.99 + table["ttf"] = 0.01 + return table[make_key(sprinkler, rain, grass_wet)] + + return build_bbn(f_rain, f_sprinkler, f_grass_wet) + + +@pytest.fixture +def huang_darwiche_nodes(): + """The nodes for the Huang Darwich example""" + + def f_a(a): + if a: + return 1 / 2 + return 1 / 2 + + def f_b(a, b): + tt = dict(tt=0.5, ft=0.4, tf=0.5, ff=0.6) + return tt[make_key(a, b)] + + def f_c(a, c): + tt = dict(tt=0.7, ft=0.2, tf=0.3, ff=0.8) + return tt[make_key(a, c)] + + def f_d(b, d): + tt = dict(tt=0.9, ft=0.5, tf=0.1, ff=0.5) + return tt[make_key(b, d)] + + def f_e(c, e): + tt = dict(tt=0.3, ft=0.6, tf=0.7, ff=0.4) + return tt[make_key(c, e)] + + def f_f(d, e, f): + tt = dict( + ttt=0.01, + ttf=0.99, + tft=0.01, + tff=0.99, + ftt=0.01, + ftf=0.99, + fft=0.99, + fff=0.01, + ) + return tt[make_key(d, e, f)] + + def f_g(c, g): + tt = dict(tt=0.8, tf=0.2, ft=0.1, ff=0.9) + return tt[make_key(c, g)] + + def f_h(e, g, h): + tt = dict( + ttt=0.05, + ttf=0.95, + tft=0.95, + tff=0.05, + ftt=0.95, + ftf=0.05, + fft=0.95, + fff=0.05, + ) + return tt[make_key(e, g, h)] + + return [f_a, f_b, f_c, f_d, f_e, f_f, f_g, f_h] + + +@pytest.fixture +def huang_darwiche_dag(huang_darwiche_nodes): + + nodes = huang_darwiche_nodes + return build_bbn(nodes) + + +@pytest.fixture +def huang_darwiche_moralized(huang_darwiche_dag): + + dag = huang_darwiche_dag + gu = make_undirected_copy(dag) + gm = make_moralized_copy(gu, dag) + + return gm + + +@pytest.fixture +def huang_darwiche_jt(huang_darwiche_dag): + def priority_func_override(node): + introduced_arcs = 0 + cluster = [node] + node.neighbours + for node_a, node_b in combinations(cluster, 2): + if node_a not in node_b.neighbours: + assert node_b not in node_a.neighbours + introduced_arcs += 1 + introduced_arcs_dict = { + "f_h": [introduced_arcs, 0], + "f_g": [introduced_arcs, 1], + "f_c": [introduced_arcs, 2], + "f_b": [introduced_arcs, 3], + "f_d": [introduced_arcs, 4], + "f_e": [introduced_arcs, 5], + "others": [introduced_arcs, 10], + } + if node.name in introduced_arcs_dict: + return introduced_arcs_dict[node.name] + + return introduced_arcs_dict["others"] + + dag = huang_darwiche_dag + jt = build_join_tree(dag, priority_func_override) + return jt diff --git a/tests/ebaybbn/test_ebaybbn.py b/tests/ebaybbn/test_ebaybbn.py new file mode 100644 index 0000000..96e3937 --- /dev/null +++ b/tests/ebaybbn/test_ebaybbn.py @@ -0,0 +1,619 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# The methods found in this file are adapted from a repository under Apache 2.0: +# eBay's Pythonic Bayesian Belief Network Framework. +# @online{ +# author = {Neville Newey,Anzar Afaq}, +# title = {bayesian-belief-networks}, +# organisation = {eBay}, +# codebase = {https://github.com/eBay/bayesian-belief-networks}, +# } +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from __future__ import division + +import copy +from collections import Counter + +import pytest + +from causalnex.ebaybbn import ( + BBNNode, + JoinTree, + JoinTreeCliqueNode, + SepSet, + build_bbn, + build_bbn_from_conditionals, + build_join_tree, + combinations, + make_moralized_copy, + make_node_func, + make_undirected_copy, + priority_func, + triangulate, +) +from causalnex.ebaybbn.exceptions import ( + VariableNotInGraphError, + VariableValueNotInDomainError, +) +from causalnex.ebaybbn.graph import Node, UndirectedNode +from causalnex.ebaybbn.utils import get_args, get_original_factors, make_key + + +def r3(x): + return round(x, 3) + + +def r5(x): + return round(x, 5) + + +class TestBBN: + def test_get_graphviz_source(self, sprinkler_graph): + gv_src = """digraph G { + graph [ dpi = 300 bgcolor="transparent" rankdir="LR"]; + Cloudy [ shape="ellipse" color="blue"]; + Rain [ shape="ellipse" color="blue"]; + Sprinkler [ shape="ellipse" color="blue"]; + WetGrass [ shape="ellipse" color="blue"]; + Cloudy -> Rain; + Cloudy -> Sprinkler; + Rain -> WetGrass; + Sprinkler -> WetGrass; +} +""" + assert sprinkler_graph.get_graphviz_source() == gv_src + + def test_get_original_factors(self, huang_darwiche_nodes): + + original_factors = get_original_factors(huang_darwiche_nodes) + assert original_factors["a"] == huang_darwiche_nodes[0] + assert original_factors["b"] == huang_darwiche_nodes[1] + assert original_factors["c"] == huang_darwiche_nodes[2] + assert original_factors["d"] == huang_darwiche_nodes[3] + assert original_factors["e"] == huang_darwiche_nodes[4] + assert original_factors["f"] == huang_darwiche_nodes[5] + assert original_factors["g"] == huang_darwiche_nodes[6] + assert original_factors["h"] == huang_darwiche_nodes[7] + + def test_build_graph(self, huang_darwiche_nodes): + bbn = build_bbn(huang_darwiche_nodes) + nodes = {node.name: node for node in bbn.nodes} + assert nodes["f_a"].parents == [] + assert nodes["f_b"].parents == [nodes["f_a"]] + assert nodes["f_c"].parents == [nodes["f_a"]] + assert nodes["f_d"].parents == [nodes["f_b"]] + assert nodes["f_e"].parents == [nodes["f_c"]] + assert nodes["f_f"].parents == [nodes["f_d"], nodes["f_e"]] + assert nodes["f_g"].parents == [nodes["f_c"]] + assert nodes["f_h"].parents == [nodes["f_e"], nodes["f_g"]] + + def test_make_undirecred_copy(self, huang_darwiche_dag): + ug = make_undirected_copy(huang_darwiche_dag) + nodes = {node.name: node for node in ug.nodes} + assert set(nodes["f_a"].neighbours) == set([nodes["f_b"], nodes["f_c"]]) + assert set(nodes["f_b"].neighbours) == set([nodes["f_a"], nodes["f_d"]]) + assert set(nodes["f_c"].neighbours) == set( + [nodes["f_a"], nodes["f_e"], nodes["f_g"]] + ) + assert set(nodes["f_d"].neighbours) == set([nodes["f_b"], nodes["f_f"]]) + assert set(nodes["f_e"].neighbours) == set( + [nodes["f_c"], nodes["f_f"], nodes["f_h"]] + ) + assert set(nodes["f_f"].neighbours) == set([nodes["f_d"], nodes["f_e"]]) + assert set(nodes["f_g"].neighbours) == set([nodes["f_c"], nodes["f_h"]]) + assert set(nodes["f_h"].neighbours) == set([nodes["f_e"], nodes["f_g"]]) + + def test_make_moralized_copy(self, huang_darwiche_dag): + gu = make_undirected_copy(huang_darwiche_dag) + gm = make_moralized_copy(gu, huang_darwiche_dag) + nodes = {node.name: node for node in gm.nodes} + assert set(nodes["f_a"].neighbours) == set([nodes["f_b"], nodes["f_c"]]) + assert set(nodes["f_b"].neighbours) == set([nodes["f_a"], nodes["f_d"]]) + assert set(nodes["f_c"].neighbours) == set( + [nodes["f_a"], nodes["f_e"], nodes["f_g"]] + ) + assert set(nodes["f_d"].neighbours) == set( + [nodes["f_b"], nodes["f_f"], nodes["f_e"]] + ) + assert set(nodes["f_e"].neighbours) == set( + [nodes["f_c"], nodes["f_f"], nodes["f_h"], nodes["f_d"], nodes["f_g"]] + ) + assert set(nodes["f_f"].neighbours) == set([nodes["f_d"], nodes["f_e"]]) + assert set(nodes["f_g"].neighbours) == set( + [nodes["f_c"], nodes["f_h"], nodes["f_e"]] + ) + assert set(nodes["f_h"].neighbours) == set([nodes["f_e"], nodes["f_g"]]) + + def test_triangulate(self, huang_darwiche_moralized): + + # Because of ties in the priority q we will + # override the priority function here to + # insert tie breakers to ensure the same + # elimination ordering as Darwich Huang. + def priority_func_override(node): + introduced_arcs = 0 + cluster = [node] + node.neighbours + for node_a, node_b in combinations(cluster, 2): + if node_a not in node_b.neighbours: + assert node_b not in node_a.neighbours + introduced_arcs += 1 + introduced_arcs_dict = { + "f_h": [introduced_arcs, 0], + "f_g": [introduced_arcs, 1], + "f_c": [introduced_arcs, 2], + "f_b": [introduced_arcs, 3], + "f_d": [introduced_arcs, 4], + "f_e": [introduced_arcs, 5], + "others": [introduced_arcs, 10], + } + if node.name in introduced_arcs_dict: + return introduced_arcs_dict[node.name] + + return introduced_arcs_dict["others"] + + cliques, elimination_ordering = triangulate( + huang_darwiche_moralized, priority_func_override + ) + nodes = {node.name: node for node in huang_darwiche_moralized.nodes} + assert len(cliques) == 6 + assert cliques[0].nodes == set([nodes["f_e"], nodes["f_g"], nodes["f_h"]]) + assert cliques[1].nodes == set([nodes["f_c"], nodes["f_e"], nodes["f_g"]]) + assert cliques[2].nodes == set([nodes["f_d"], nodes["f_e"], nodes["f_f"]]) + assert cliques[3].nodes == set([nodes["f_a"], nodes["f_c"], nodes["f_e"]]) + assert cliques[4].nodes == set([nodes["f_a"], nodes["f_b"], nodes["f_d"]]) + assert cliques[5].nodes == set([nodes["f_a"], nodes["f_d"], nodes["f_e"]]) + + assert elimination_ordering == [ + "f_h", + "f_g", + "f_f", + "f_c", + "f_b", + "f_d", + "f_e", + "f_a", + ] + # Now lets ensure the triangulated graph is + # the same as Darwiche Huang fig. 2 pg. 13 + nodes = {node.name: node for node in huang_darwiche_moralized.nodes} + assert set(nodes["f_a"].neighbours) == set( + [nodes["f_b"], nodes["f_c"], nodes["f_d"], nodes["f_e"]] + ) + assert set(nodes["f_b"].neighbours) == set([nodes["f_a"], nodes["f_d"]]) + assert set(nodes["f_c"].neighbours) == set( + [nodes["f_a"], nodes["f_e"], nodes["f_g"]] + ) + assert set(nodes["f_d"].neighbours) == set( + [nodes["f_b"], nodes["f_f"], nodes["f_e"], nodes["f_a"]] + ) + assert set(nodes["f_e"].neighbours) == set( + [ + nodes["f_c"], + nodes["f_f"], + nodes["f_h"], + nodes["f_d"], + nodes["f_g"], + nodes["f_a"], + ] + ) + assert set(nodes["f_f"].neighbours) == set([nodes["f_d"], nodes["f_e"]]) + assert set(nodes["f_g"].neighbours) == set( + [nodes["f_c"], nodes["f_h"], nodes["f_e"]] + ) + assert set(nodes["f_h"].neighbours) == set([nodes["f_e"], nodes["f_g"]]) + + def test_triangulate_no_tie_break(self, huang_darwiche_moralized): + # Now lets see what happens if + # we dont enforce the tie-breakers... + # It seems the triangulated graph is + # different adding edges from d to c + # and b to c + # Will be interesting to see whether + # inference will still be correct. + triangulate(huang_darwiche_moralized) + nodes = {node.name: node for node in huang_darwiche_moralized.nodes} + assert set(nodes["f_a"].neighbours) == set([nodes["f_b"], nodes["f_c"]]) + assert set(nodes["f_b"].neighbours) == set( + [nodes["f_a"], nodes["f_d"], nodes["f_c"]] + ) + assert set(nodes["f_c"].neighbours) == set( + [nodes["f_a"], nodes["f_e"], nodes["f_g"], nodes["f_b"], nodes["f_d"]] + ) + assert set(nodes["f_d"].neighbours) == set( + [nodes["f_b"], nodes["f_f"], nodes["f_e"], nodes["f_c"]] + ) + assert set(nodes["f_e"].neighbours) == set( + [nodes["f_c"], nodes["f_f"], nodes["f_h"], nodes["f_d"], nodes["f_g"]] + ) + assert set(nodes["f_f"].neighbours) == set([nodes["f_d"], nodes["f_e"]]) + assert set(nodes["f_g"].neighbours) == set( + [nodes["f_c"], nodes["f_h"], nodes["f_e"]] + ) + assert set(nodes["f_h"].neighbours) == set([nodes["f_e"], nodes["f_g"]]) + + def test_build_join_tree(self, huang_darwiche_dag): + def priority_func_override(node): + introduced_arcs = 0 + cluster = [node] + node.neighbours + for node_a, node_b in combinations(cluster, 2): + if node_a not in node_b.neighbours: + assert node_b not in node_a.neighbours + introduced_arcs += 1 + introduced_arcs_dict = { + "f_h": [introduced_arcs, 0], + "f_g": [introduced_arcs, 1], + "f_c": [introduced_arcs, 2], + "f_b": [introduced_arcs, 3], + "f_d": [introduced_arcs, 4], + "f_e": [introduced_arcs, 5], + "others": [introduced_arcs, 10], + } + if node.name in introduced_arcs_dict: + return introduced_arcs_dict[node.name] + + return introduced_arcs_dict["others"] + + jt = build_join_tree(huang_darwiche_dag, priority_func_override) + for node in jt.sepset_nodes: + assert {n.clique for n in node.neighbours} == {node.sepset.X, node.sepset.Y} + # clique nodes. + + def test_initialize_potentials(self, huang_darwiche_jt, huang_darwiche_dag): + # Seems like there can be multiple assignments so + # for this test we will set the assignments explicitely + cliques = {node.name: node for node in huang_darwiche_jt.nodes} + bbn_nodes = {node.name: node for node in huang_darwiche_dag.nodes} + assignments = { + cliques["Clique_ACE"]: [bbn_nodes["f_c"], bbn_nodes["f_e"]], + cliques["Clique_ABD"]: [ + bbn_nodes["f_a"], + bbn_nodes["f_b"], + bbn_nodes["f_d"], + ], + } + huang_darwiche_jt.initialize_potentials(assignments, huang_darwiche_dag) + for node in huang_darwiche_jt.sepset_nodes: + for v in node.potential_tt.values(): + assert v == 1 + + # Note that in H&D there are two places that show + # initial potentials, one is for ABD and AD + # and the second is for ACE and CE + # We should test both here but we must enforce + # the assignments above because alternate and + # equally correct Junction Trees will give + # different potentials. + def r(x): + return round(x, 3) + + tt = cliques["Clique_ACE"].potential_tt + assert r(tt[("a", True), ("c", True), ("e", True)]) == 0.21 + assert r(tt[("a", True), ("c", True), ("e", False)]) == 0.49 + assert r(tt[("a", True), ("c", False), ("e", True)]) == 0.18 + assert r(tt[("a", True), ("c", False), ("e", False)]) == 0.12 + assert r(tt[("a", False), ("c", True), ("e", True)]) == 0.06 + assert r(tt[("a", False), ("c", True), ("e", False)]) == 0.14 + assert r(tt[("a", False), ("c", False), ("e", True)]) == 0.48 + assert r(tt[("a", False), ("c", False), ("e", False)]) == 0.32 + + tt = cliques["Clique_ABD"].potential_tt + assert r(tt[("a", True), ("b", True), ("d", True)]) == 0.225 + assert r(tt[("a", True), ("b", True), ("d", False)]) == 0.025 + assert r(tt[("a", True), ("b", False), ("d", True)]) == 0.125 + assert r(tt[("a", True), ("b", False), ("d", False)]) == 0.125 + assert r(tt[("a", False), ("b", True), ("d", True)]) == 0.180 + assert r(tt[("a", False), ("b", True), ("d", False)]) == 0.020 + assert r(tt[("a", False), ("b", False), ("d", True)]) == 0.150 + assert r(tt[("a", False), ("b", False), ("d", False)]) == 0.150 + + def test_jtclique_node_variable_names(self, huang_darwiche_jt): + for node in huang_darwiche_jt.clique_nodes: + if "ADE" in node.name: + assert set(node.variable_names) == set(["a", "d", "e"]) + + def test_propagate(self, huang_darwiche_jt, huang_darwiche_dag): + jt_cliques = {node.name: node for node in huang_darwiche_jt.clique_nodes} + assignments = huang_darwiche_jt.assign_clusters(huang_darwiche_dag) + huang_darwiche_jt.initialize_potentials(assignments, huang_darwiche_dag) + + huang_darwiche_jt.propagate(starting_clique=jt_cliques["Clique_ACE"]) + tt = jt_cliques["Clique_DEF"].potential_tt + assert r5(tt[(("d", False), ("e", True), ("f", True))]) == 0.00150 + assert r5(tt[(("d", True), ("e", False), ("f", True))]) == 0.00365 + assert r5(tt[(("d", False), ("e", False), ("f", True))]) == 0.16800 + assert r5(tt[(("d", True), ("e", True), ("f", True))]) == 0.00315 + assert r5(tt[(("d", False), ("e", False), ("f", False))]) == 0.00170 + assert r5(tt[(("d", True), ("e", True), ("f", False))]) == 0.31155 + assert r5(tt[(("d", False), ("e", True), ("f", False))]) == 0.14880 + assert r5(tt[(("d", True), ("e", False), ("f", False))]) == 0.36165 + + def test_marginal(self, huang_darwiche_jt, huang_darwiche_dag): + # The remaining marginals here come + # from the module itself, however they + # have been corrobarted by running + # inference using the sampling inference + # engine and the same results are + # achieved. + """ + +------+-------+----------+ + | Node | Value | Marginal | + +------+-------+----------+ + | a | False | 0.500000 | + | a | True | 0.500000 | + | b | False | 0.550000 | + | b | True | 0.450000 | + | c | False | 0.550000 | + | c | True | 0.450000 | + | d | False | 0.320000 | + | d | True | 0.680000 | + | e | False | 0.535000 | + | e | True | 0.465000 | + | f | False | 0.823694 | + | f | True | 0.176306 | + | g | False | 0.585000 | + | g | True | 0.415000 | + | h | False | 0.176900 | + | h | True | 0.823100 | + +------+-------+----------+ + """ + bbn_nodes = {node.name: node for node in huang_darwiche_dag.nodes} + assignments = huang_darwiche_jt.assign_clusters(huang_darwiche_dag) + huang_darwiche_jt.initialize_potentials(assignments, huang_darwiche_dag) + huang_darwiche_jt.propagate() + + # These test values come directly from + # pg. 22 of H & D + p_A = huang_darwiche_jt.marginal(bbn_nodes["f_a"]) + assert r3(p_A[(("a", True),)]) == 0.5 + assert r3(p_A[(("a", False),)]) == 0.5 + + p_D = huang_darwiche_jt.marginal(bbn_nodes["f_d"]) + assert r3(p_D[(("d", True),)]) == 0.68 + assert r3(p_D[(("d", False),)]) == 0.32 + + p_B = huang_darwiche_jt.marginal(bbn_nodes["f_b"]) + assert r3(p_B[(("b", True),)]) == 0.45 + assert r3(p_B[(("b", False),)]) == 0.55 + + p_C = huang_darwiche_jt.marginal(bbn_nodes["f_c"]) + assert r3(p_C[(("c", True),)]) == 0.45 + assert r3(p_C[(("c", False),)]) == 0.55 + + p_E = huang_darwiche_jt.marginal(bbn_nodes["f_e"]) + assert r3(p_E[(("e", True),)]) == 0.465 + assert r3(p_E[(("e", False),)]) == 0.535 + + p_F = huang_darwiche_jt.marginal(bbn_nodes["f_f"]) + assert r3(p_F[(("f", True),)]) == 0.176 + assert r3(p_F[(("f", False),)]) == 0.824 + + p_G = huang_darwiche_jt.marginal(bbn_nodes["f_g"]) + assert r3(p_G[(("g", True),)]) == 0.415 + assert r3(p_G[(("g", False),)]) == 0.585 + + p_H = huang_darwiche_jt.marginal(bbn_nodes["f_h"]) + assert r3(p_H[(("h", True),)]) == 0.823 + assert r3(p_H[(("h", False),)]) == 0.177 + + +def test_make_node_func(): + UPDATE = { + "prize_door": [ + # For nodes that have no parents + # use the empty list to specify + # the conditioned upon variables + # ie conditioned on the empty set + [[], {"A": 1 / 3, "B": 1 / 3, "C": 1 / 3}] + ], + "guest_door": [[[], {"A": 1 / 3, "B": 1 / 3, "C": 1 / 3}]], + "monty_door": [ + [[["prize_door", "A"], ["guest_door", "A"]], {"A": 0, "B": 0.5, "C": 0.5}], + [[["prize_door", "A"], ["guest_door", "B"]], {"A": 0, "B": 0, "C": 1}], + [[["prize_door", "A"], ["guest_door", "C"]], {"A": 0, "B": 1, "C": 0}], + [[["prize_door", "B"], ["guest_door", "A"]], {"A": 0, "B": 0, "C": 1}], + [[["prize_door", "B"], ["guest_door", "B"]], {"A": 0.5, "B": 0, "C": 0.5}], + [[["prize_door", "B"], ["guest_door", "C"]], {"A": 1, "B": 0, "C": 0}], + [[["prize_door", "C"], ["guest_door", "A"]], {"A": 0, "B": 1, "C": 0}], + [[["prize_door", "C"], ["guest_door", "B"]], {"A": 1, "B": 0, "C": 0}], + [[["prize_door", "C"], ["guest_door", "C"]], {"A": 0.5, "B": 0.5, "C": 0}], + ], + } + + node_func = make_node_func("prize_door", UPDATE["prize_door"]) + assert get_args(node_func) == ["prize_door"] + assert node_func("A") == 1 / 3 + assert node_func("B") == 1 / 3 + assert node_func("C") == 1 / 3 + + node_func = make_node_func("guest_door", UPDATE["guest_door"]) + assert get_args(node_func) == ["guest_door"] + assert node_func("A") == 1 / 3 + assert node_func("B") == 1 / 3 + assert node_func("C") == 1 / 3 + + node_func = make_node_func("monty_door", UPDATE["monty_door"]) + assert get_args(node_func) == ["guest_door", "prize_door", "monty_door"] + assert node_func("A", "A", "A") == 0 + assert node_func("A", "A", "B") == 0.5 + assert node_func("A", "A", "C") == 0.5 + assert node_func("A", "B", "A") == 0 + assert node_func("A", "B", "B") == 0 + assert node_func("A", "B", "C") == 1 + assert node_func("A", "C", "A") == 0 + assert node_func("A", "C", "B") == 1 + assert node_func("A", "C", "C") == 0 + assert node_func("B", "A", "A") == 0 + assert node_func("B", "A", "B") == 0 + assert node_func("B", "A", "C") == 1 + assert node_func("B", "B", "A") == 0.5 + assert node_func("B", "B", "B") == 0 + assert node_func("B", "B", "C") == 0.5 + assert node_func("B", "C", "A") == 1 + assert node_func("B", "C", "B") == 0 + assert node_func("B", "C", "C") == 0 + assert node_func("C", "A", "A") == 0 + assert node_func("C", "A", "B") == 1 + assert node_func("C", "A", "C") == 0 + assert node_func("C", "B", "A") == 1 + assert node_func("C", "B", "B") == 0 + assert node_func("C", "B", "C") == 0 + assert node_func("C", "C", "A") == 0.5 + assert node_func("C", "C", "B") == 0.5 + assert node_func("C", "C", "C") == 0 + + +def close_enough(x, y, r=3): + return round(x, r) == round(y, r) + + +def test_build_bbn_from_conditionals(): + UPDATE = { + "prize_door": [ + # For nodes that have no parents + # use the empty list to specify + # the conditioned upon variables + # ie conditioned on the empty set + [[], {"A": 1 / 3, "B": 1 / 3, "C": 1 / 3}] + ], + "guest_door": [[[], {"A": 1 / 3, "B": 1 / 3, "C": 1 / 3}]], + "monty_door": [ + [[["prize_door", "A"], ["guest_door", "A"]], {"A": 0, "B": 0.5, "C": 0.5}], + [[["prize_door", "A"], ["guest_door", "B"]], {"A": 0, "B": 0, "C": 1}], + [[["prize_door", "A"], ["guest_door", "C"]], {"A": 0, "B": 1, "C": 0}], + [[["prize_door", "B"], ["guest_door", "A"]], {"A": 0, "B": 0, "C": 1}], + [[["prize_door", "B"], ["guest_door", "B"]], {"A": 0.5, "B": 0, "C": 0.5}], + [[["prize_door", "B"], ["guest_door", "C"]], {"A": 1, "B": 0, "C": 0}], + [[["prize_door", "C"], ["guest_door", "A"]], {"A": 0, "B": 1, "C": 0}], + [[["prize_door", "C"], ["guest_door", "B"]], {"A": 1, "B": 0, "C": 0}], + [[["prize_door", "C"], ["guest_door", "C"]], {"A": 0.5, "B": 0.5, "C": 0}], + ], + } + g = build_bbn_from_conditionals(UPDATE) + result = g.query() + assert close_enough(result[("guest_door", "A")], 0.333) + assert close_enough(result[("guest_door", "B")], 0.333) + assert close_enough(result[("guest_door", "C")], 0.333) + assert close_enough(result[("monty_door", "A")], 0.333) + assert close_enough(result[("monty_door", "B")], 0.333) + assert close_enough(result[("monty_door", "C")], 0.333) + assert close_enough(result[("prize_door", "A")], 0.333) + assert close_enough(result[("prize_door", "B")], 0.333) + assert close_enough(result[("prize_door", "C")], 0.333) + + result = g.query(guest_door="A", monty_door="B") + assert close_enough(result[("guest_door", "A")], 1) + assert close_enough(result[("guest_door", "B")], 0) + assert close_enough(result[("guest_door", "C")], 0) + assert close_enough(result[("monty_door", "A")], 0) + assert close_enough(result[("monty_door", "B")], 1) + assert close_enough(result[("monty_door", "C")], 0) + assert close_enough(result[("prize_door", "A")], 0.333) + assert close_enough(result[("prize_door", "B")], 0) + assert close_enough(result[("prize_door", "C")], 0.667) + + +def valid_sample(samples, query_result): + """For a group of samples from + a query result ensure that + the sample is approximately equivalent + to the query_result which is the + true distribution.""" + counts = Counter() + for sample in samples: + for var, val in sample.items(): + counts[(var, val)] += 1 + # Now lets normalize for each count... + differences = [] + for k, v in counts.items(): + counts[k] = v / len(samples) + difference = abs(counts.get(k, 0) - query_result[k]) + differences.append(difference) + return all([not round(difference, 2) > 0.01 for difference in differences]) + + +def test_draw_sample_sprinkler(sprinkler_bbn): + + query_result = sprinkler_bbn.query() + samples = sprinkler_bbn.draw_samples({}, 10000) + assert valid_sample(samples, query_result) + + +def test_repr(): + + assert repr(Node("test")) == "<Node test>" + assert repr(UndirectedNode("test")) == "<UndirectedNode test>" + assert ( + repr(BBNNode(get_original_factors)) + == "<BBNNode get_original_factors (['factors'])>" + ) + assert ( + repr(JoinTreeCliqueNode(UndirectedNode("test"))) + == "<JoinTreeCliqueNode: <UndirectedNode test>>" + ) + + +def test_exception(sprinkler_bbn): + with pytest.raises(VariableValueNotInDomainError): + sprinkler_bbn.query(rain="No") + with pytest.raises(VariableNotInGraphError): + sprinkler_bbn.query(sunny="True") + + +def test_make_key(): + class DummyTest: + def __init__(self, value): + + self.value = value + + def dummy_method(self, value): # Add this method to by pass linting + self.value = value + + test = DummyTest(8) + test.dummy_method(10) + with pytest.raises(ValueError, match=r"Unexpected type"): + make_key(test) + + +def test_insert_duplicate_clique(huang_darwiche_moralized): + + cliques, _ = triangulate(huang_darwiche_moralized, priority_func) + + forest = set() + for clique in cliques: + jt_node = JoinTreeCliqueNode(clique) + clique.node = jt_node + tree = JoinTree([jt_node]) + forest.add(tree) + + s = SepSet(cliques[0], cliques[0]) + assert s.insertable(forest) is False + s_copy = copy.deepcopy(s) + s.insert(forest) + assert len(s.X.node.neighbours) > len(s_copy.X.node.neighbours) diff --git a/tests/structure/__init__.py b/tests/structure/__init__.py new file mode 100644 index 0000000..5da8261 --- /dev/null +++ b/tests/structure/__init__.py @@ -0,0 +1,27 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/tests/structure/test_notears.py b/tests/structure/test_notears.py new file mode 100644 index 0000000..baa2b8b --- /dev/null +++ b/tests/structure/test_notears.py @@ -0,0 +1,575 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import networkx as nx +import numpy as np +import pandas as pd +import pytest + +from causalnex.structure.notears import ( + from_numpy, + from_numpy_lasso, + from_pandas, + from_pandas_lasso, +) + + +class TestFromPandas: + """Test behaviour of the from_pandas method""" + + def test_all_columns_in_structure(self, train_data_idx): + """Every columns that is in the data should become a node in the learned structure""" + + g = from_pandas(train_data_idx) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_isolated_nodes_exist(self, train_data_idx): + """Isolated nodes should still be in the learned structure""" + + g = from_pandas(train_data_idx, w_threshold=1.0) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_expected_structure_learned(self, train_data_idx, train_model): + """Given a small data set that can be examined by hand, the structure should be deterministic""" + + g = from_pandas(train_data_idx, w_threshold=0.3) + assert set(g.edges) == set(train_model.edges) + + def test_empty_data_raises_error(self): + """ + Providing an empty data set should result in a Value Error explaining that data must not be empty. + This error is useful to catch and handle gracefully, because otherwise the user would experience + misleading division by zero, or unpacking errors. + """ + + with pytest.raises(ValueError): + from_pandas(pd.DataFrame(data=[], columns=["a"])) + + def test_non_numeric_data_raises_error(self): + """Only numeric data frames should be supported""" + + with pytest.raises(ValueError, match="All columns must have numeric data.*"): + from_pandas(pd.DataFrame(data=["x"], columns=["a"])) + + def test_single_iter_gets_converged_fail_warnings(self, train_data_idx): + """ + With a single iteration on this dataset, learn_structure fails to converge and should give warnings. + """ + + with pytest.warns( + UserWarning, match="Failed to converge. Consider increasing max_iter." + ): + from_pandas(train_data_idx, max_iter=1) + + def test_certain_relationships_get_near_certain_weight(self): + """If observations reliably show a==b and !a==!b then the relationship from a->b should be certain""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + g = from_pandas(data) + assert all( + [ + 0.99 <= weight <= 1 + for u, v, weight in g.edges(data="weight") + if u == 0 and v == 1 + ] + ) + + def test_inverse_relationships_get_negative_weight(self): + """If observations indicate a==!b and b==!a then the weight of the relationship from a-> should be negative""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + data.append(pd.DataFrame([[1, 0] for _ in range(10)], columns=["a", "b"])) + g = from_pandas(data) + assert all( + [weight < 0 for u, v, weight in g.edges(data="weight") if u == 0 and v == 1] + ) + + def test_no_cycles(self, train_data_idx): + """ + The learned structure should be acyclic + """ + + g = from_pandas(train_data_idx, w_threshold=0.3) + assert nx.algorithms.is_directed_acyclic_graph(g) + + def test_tabu_edges_on_non_existing_edges_do_nothing(self, train_data_idx): + """If tabu edges do not exist in the original unconstrained network then nothing changes""" + + g1 = from_pandas(train_data_idx, w_threshold=0.3) + g2 = from_pandas( + train_data_idx, w_threshold=0.3, tabu_edges=[("a", "d"), ("e", "a")] + ) + assert set(g1.edges) == set(g2.edges) + + def test_tabu_expected_edges(self, train_data_idx): + """Tabu edges should not exist in the network""" + + tabu_e = [("d", "a"), ("b", "c")] + g = from_pandas(train_data_idx, tabu_edges=tabu_e) + assert [e not in g.edges for e in tabu_e] + + def test_tabu_expected_parent_nodes(self, train_data_idx): + """Tabu parent nodes should not have any outgoing edges""" + + tabu_p = ["a", "d", "b"] + g = from_pandas(train_data_idx, tabu_parent_nodes=tabu_p) + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + + def test_tabu_expected_child_nodes(self, train_data_idx): + """Tabu child nodes should not have any ingoing edges""" + + tabu_c = ["a", "d", "b"] + g = from_pandas(train_data_idx, tabu_child_nodes=tabu_c) + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + def test_multiple_tabu(self, train_data_idx): + """Any edge related to tabu edges/parent nodes/child nodes should not exist in the network""" + + tabu_e = [("d", "a"), ("b", "c")] + tabu_p = ["b"] + tabu_c = ["a", "d"] + g = from_pandas( + train_data_idx, + tabu_edges=tabu_e, + tabu_parent_nodes=tabu_p, + tabu_child_nodes=tabu_c, + ) + assert [e not in g.edges for e in tabu_e] + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + +class TestFromPandasLasso: + """Test behaviour of the from_pandas_lasso method""" + + def test_all_columns_in_structure(self, train_data_idx): + """Every columns that is in the data should become a node in the learned structure""" + + g = from_pandas_lasso(train_data_idx, 0.1) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_isolated_nodes_exist(self, train_data_idx): + """Isolated nodes should still be in the learned structure""" + + g = from_pandas_lasso(train_data_idx, 0.1, w_threshold=1.0) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_expected_structure_learned(self, train_data_idx, train_model): + """Given an extremely small alpha and small data set that can be examined by hand, + the structure should be deterministic""" + + g = from_pandas_lasso(train_data_idx, 1e-8, w_threshold=0.3) + assert set(g.edges) == set(train_model.edges) + + def test_empty_data_raises_error(self): + """ + Providing an empty data set should result in a Value Error explaining that data must not be empty. + This error is useful to catch and handle gracefully, because otherwise the user would experience + misleading division by zero, or unpacking errors. + """ + + with pytest.raises(ValueError): + from_pandas_lasso(pd.DataFrame(data=[], columns=["a"]), 0.1) + + def test_non_numeric_data_raises_error(self): + """Only numeric data frames should be supported""" + + with pytest.raises(ValueError, match="All columns must have numeric data.*"): + from_pandas_lasso(pd.DataFrame(data=["x"], columns=["a"]), 0.1) + + def test_single_iter_gets_converged_fail_warnings(self, train_data_idx): + """ + With a single iteration on this dataset, learn_structure fails to converge and should give warnings. + """ + + with pytest.warns( + UserWarning, match="Failed to converge. Consider increasing max_iter." + ): + from_pandas_lasso(train_data_idx, 0.1, max_iter=1) + + def test_certain_relationships_get_near_certain_weight(self): + """If observations reliably show a==b and !a==!b then the relationship from a->b should be certain""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + g = from_pandas_lasso(data, 0.1) + assert all( + [ + 0.99 <= weight <= 1 + for u, v, weight in g.edges(data="weight") + if u == 0 and v == 1 + ] + ) + + def test_inverse_relationships_get_negative_weight(self): + """If observations indicate a==!b and b==!a then the weight of the relationship from a-> should be negative""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + data.append(pd.DataFrame([[1, 0] for _ in range(10)], columns=["a", "b"])) + g = from_pandas_lasso(data, 0.1) + assert all( + [weight < 0 for u, v, weight in g.edges(data="weight") if u == 0 and v == 1] + ) + + def test_no_cycles(self, train_data_idx): + """ + The learned structure should be acyclic + """ + + g = from_pandas_lasso(train_data_idx, 0.1, w_threshold=0.3) + assert nx.algorithms.is_directed_acyclic_graph(g) + + def test_tabu_expected_edges(self, train_data_idx): + """Tabu edges should not exist in the network""" + + tabu_e = [("d", "a"), ("b", "c")] + g = from_pandas_lasso(train_data_idx, 0.1, tabu_edges=tabu_e) + assert [e not in g.edges for e in tabu_e] + + def test_tabu_expected_parent_nodes(self, train_data_idx): + """Tabu parent nodes should not have any outgoing edges""" + + tabu_p = ["a", "d", "b"] + g = from_pandas_lasso(train_data_idx, 0.1, tabu_parent_nodes=tabu_p) + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + + def test_tabu_expected_child_nodes(self, train_data_idx): + """Tabu child nodes should not have any ingoing edges""" + + tabu_c = ["a", "d", "b"] + g = from_pandas_lasso(train_data_idx, 0.1, tabu_child_nodes=tabu_c) + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + def test_multiple_tabu(self, train_data_idx): + """Any edge related to tabu edges/parent nodes/child nodes should not exist in the network""" + + tabu_e = [("d", "a"), ("b", "c")] + tabu_p = ["b"] + tabu_c = ["a", "d"] + g = from_pandas_lasso( + train_data_idx, + 0.1, + tabu_edges=tabu_e, + tabu_parent_nodes=tabu_p, + tabu_child_nodes=tabu_c, + ) + assert [e not in g.edges for e in tabu_e] + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + def test_sparsity(self, train_data_idx): + """Structure learnt from larger lambda should be sparser than smaller lambda""" + + g1 = from_pandas_lasso(train_data_idx, 0.1, w_threshold=0.3) + g2 = from_pandas_lasso(train_data_idx, 1e-6, w_threshold=0.3) + assert len(g1.edges) > len(g2.edges) + + def test_sparsity_against_without_reg(self, train_data_idx): + """Structure learnt from regularisation should be sparser than the one without""" + + g1 = from_pandas_lasso(train_data_idx, 0.1, w_threshold=0.3) + g2 = from_pandas(train_data_idx, w_threshold=0.3) + assert len(g1.edges) > len(g2.edges) + + def test_f1_score(self, train_data_idx, train_model): + """Structure learnt from regularisation should have very high f1 score relative to the ground truth""" + g = from_pandas_lasso(train_data_idx, 0.1, w_threshold=0.3) + print(sorted(list(g.edges))) + print(train_model.edges) + + n_predictions_made = len(g.edges) + n_correct_predictions = len(set(g.edges).intersection(set(train_model.edges))) + n_relevant_predictions = len(train_model.edges) + + precision = n_correct_predictions / n_predictions_made + recall = n_correct_predictions / n_relevant_predictions + f1_score = 2 * (precision * recall) / (precision + recall) + + assert f1_score > 0.8 + + +class TestFromNumpy: + """Test behaviour of the from_numpy_lasso method""" + + def test_all_columns_in_structure(self, train_data_idx): + """Every columns that is in the data should become a node in the learned structure""" + + g = from_numpy(train_data_idx.values) + assert (len(g.nodes)) == len(train_data_idx.columns) + + def test_isolated_nodes_exist(self, train_data_idx): + """Isolated nodes should still be in the learned structure""" + + g = from_numpy(train_data_idx.values, w_threshold=1.0) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_expected_structure_learned(self, train_data_idx, train_model_idx): + """Given a small data set that can be examined by hand, the structure should be deterministic""" + + g = from_numpy(train_data_idx.values, w_threshold=0.3) + assert set(g.edges) == set(train_model_idx.edges) + + def test_empty_data_raises_error(self): + """ + Providing an empty data set should result in a Value Error explaining that data must not be empty. + This error is useful to catch and handle gracefully, because otherwise the user would experience + misleading division by zero, or unpacking errors. + """ + + with pytest.raises(ValueError): + from_numpy(np.empty([0, 5])) + + def test_single_iter_gets_converged_fail_warnings(self, train_data_idx): + """ + With a single iteration on this dataset, learn_structure fails to converge and should give warnings. + """ + + with pytest.warns( + UserWarning, match="Failed to converge. Consider increasing max_iter." + ): + from_numpy(train_data_idx.values, max_iter=1) + + def test_certain_relationships_get_near_certain_weight(self): + """If observations reliably show a==b and !a==!b then the relationship from a->b should be certain""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + g = from_numpy(data.values) + assert all( + [ + 0.99 <= weight <= 1 + for u, v, weight in g.edges(data="weight") + if u == 0 and v == 1 + ] + ) + + def test_inverse_relationships_get_negative_weight(self): + """If observations indicate a==!b and b==!a then the weight of the relationship from a-> should be negative""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + data.append(pd.DataFrame([[1, 0] for _ in range(10)], columns=["a", "b"])) + g = from_numpy(data.values) + assert all( + [weight < 0 for u, v, weight in g.edges(data="weight") if u == 0 and v == 1] + ) + + def test_no_cycles(self, train_data_idx): + """ + The learned structure should be acyclic + """ + + g = from_numpy(train_data_idx.values, w_threshold=0.3) + assert nx.algorithms.is_directed_acyclic_graph(g) + + def test_tabu_edges_on_non_existing_edges_do_nothing(self, train_data_idx): + """If tabu edges do not exist in the original unconstrained network then nothing changes""" + + g1 = from_numpy(train_data_idx.values, w_threshold=0.3) + g2 = from_numpy( + train_data_idx.values, w_threshold=0.3, tabu_edges=[(0, 3), (4, 0), (1, 6)] + ) + assert set(g1.edges) == set(g2.edges) + + def test_tabu_expected_edges(self, train_data_idx): + """Tabu edges should not exist in the network""" + + tabu_e = [(3, 0), (1, 2)] + g = from_numpy(train_data_idx.values, tabu_edges=tabu_e) + assert [e not in g.edges for e in tabu_e] + + def test_tabu_expected_parent_nodes(self, train_data_idx): + """Tabu parent nodes should not have any outgoing edges""" + + tabu_p = [0, 3, 1] + g = from_numpy(train_data_idx.values, tabu_parent_nodes=tabu_p) + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + + def test_tabu_expected_child_nodes(self, train_data_idx): + """Tabu child nodes should not have any ingoing edges""" + + tabu_c = [0, 3, 1] + g = from_numpy(train_data_idx.values, tabu_child_nodes=tabu_c) + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + def test_multiple_tabu(self, train_data_idx): + """Any edge related to tabu edges/parent nodes/child nodes should not exist in the network""" + + tabu_e = [(3, 0), (1, 2)] + tabu_p = [1] + tabu_c = [0, 3] + g = from_numpy( + train_data_idx.values, + tabu_edges=tabu_e, + tabu_parent_nodes=tabu_p, + tabu_child_nodes=tabu_c, + ) + assert [e not in g.edges for e in tabu_e] + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + +class TestFromNumpyLasso: + """Test behaviour of the from_numpy_lasso method""" + + def test_all_columns_in_structure(self, train_data_idx): + """Every columns that is in the data should become a node in the learned structure""" + + g = from_numpy_lasso(train_data_idx.values, 0.1) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_isolated_nodes_exist(self, train_data_idx): + """Isolated nodes should still be in the learned structure""" + + g = from_numpy_lasso(train_data_idx.values, 0.1, w_threshold=1.0) + assert len(g.nodes) == len(train_data_idx.columns) + + def test_expected_structure_learned(self, train_data_idx, train_model_idx): + """Given an extremely small lambda_lasso and small data set that can be examined by hand, + the structure should be deterministic""" + + g = from_numpy_lasso(train_data_idx.values, 1e-8, w_threshold=0.3) + assert set(g.edges) == set(train_model_idx.edges) + + def test_empty_data_raises_error(self): + """ + Providing an empty data set should result in a Value Error explaining that data must not be empty. + This error is useful to catch and handle gracefully, because otherwise the user would experience + misleading division by zero, or unpacking errors. + """ + + with pytest.raises(ValueError): + from_numpy_lasso(np.empty([0, 5]), 0.1) + + def test_single_iter_gets_converged_fail_warnings(self, train_data_idx): + """ + With a single iteration on this dataset, learn_structure fails to converge and should give warnings. + """ + + with pytest.warns( + UserWarning, match="Failed to converge. Consider increasing max_iter." + ): + from_numpy_lasso(train_data_idx.values, 0.1, max_iter=1) + + def test_certain_relationships_get_near_certain_weight(self): + """If observations reliably show a==b and !a==!b then the relationship from a->b should be certain""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + g = from_numpy_lasso(data.values, 0.1) + assert all( + [ + 0.99 <= weight <= 1 + for u, v, weight in g.edges(data="weight") + if u == 0 and v == 1 + ] + ) + + def test_inverse_relationships_get_negative_weight(self): + """If observations indicate a==!b and b==!a then the weight of the relationship from a-> should be negative""" + + data = pd.DataFrame([[0, 1] for _ in range(10)], columns=["a", "b"]) + data.append(pd.DataFrame([[1, 0] for _ in range(10)], columns=["a", "b"])) + g = from_numpy_lasso(data.values, 0.1) + assert all( + [weight < 0 for u, v, weight in g.edges(data="weight") if u == 0 and v == 1] + ) + + def test_no_cycles(self, train_data_idx): + """ + The learned structure should be acyclic + """ + + g = from_numpy_lasso(train_data_idx.values, 0.1, w_threshold=0.3) + assert nx.algorithms.is_directed_acyclic_graph(g) + + def test_tabu_expected_edges(self, train_data_idx): + """Tabu edges should not exist in the network""" + + tabu_e = [("d", "a"), ("b", "c")] + g = from_numpy_lasso(train_data_idx.values, 0.1, tabu_edges=tabu_e) + assert [e not in g.edges for e in tabu_e] + + def test_tabu_expected_parent_nodes(self, train_data_idx): + """Tabu parent nodes should not have any outgoing edges""" + + tabu_p = ["a", "d", "b"] + g = from_numpy_lasso(train_data_idx.values, 0.1, tabu_parent_nodes=tabu_p) + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + + def test_tabu_expected_child_nodes(self, train_data_idx): + """Tabu child nodes should not have any ingoing edges""" + + tabu_c = ["a", "d", "b"] + g = from_numpy_lasso(train_data_idx.values, 0.1, tabu_child_nodes=tabu_c) + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + def test_multiple_tabu(self, train_data_idx): + """Any edge related to tabu edges/parent nodes/child nodes should not exist in the network""" + + tabu_e = [("d", "a"), ("b", "c")] + tabu_p = ["b"] + tabu_c = ["a", "d"] + g = from_numpy_lasso( + train_data_idx.values, + 0.1, + tabu_edges=tabu_e, + tabu_parent_nodes=tabu_p, + tabu_child_nodes=tabu_c, + ) + assert [e not in g.edges for e in tabu_e] + assert [p not in [e[0] for e in g.edges] for p in tabu_p] + assert [c not in [e[1] for e in g.edges] for c in tabu_c] + + def test_sparsity(self, train_data_idx): + """Structure learnt from larger lambda should be sparser than smaller lambda""" + + g1 = from_numpy_lasso(train_data_idx.values, 0.1, w_threshold=0.3) + g2 = from_numpy_lasso(train_data_idx.values, 1e-6, w_threshold=0.3) + assert len(g1.edges) > len(g2.edges) + + def test_sparsity_against_without_reg(self, train_data_idx): + """Structure learnt from regularisation should be sparser than the one without""" + + g1 = from_numpy_lasso(train_data_idx.values, 0.1, w_threshold=0.3) + g2 = from_numpy(train_data_idx.values, w_threshold=0.3) + assert len(g1.edges) > len(g2.edges) + + def test_f1_score(self, train_data_idx, train_model_idx): + """Structure learnt from regularisation should have very high f1 score relative to the ground truth""" + g = from_numpy_lasso(train_data_idx.values, 0.1, w_threshold=0.3) + + print(g.edges) + print(train_model_idx.edges) + n_predictions_made = len(g.edges) + n_correct_predictions = len( + set(g.edges).intersection(set(train_model_idx.edges)) + ) + n_relevant_predictions = len(train_model_idx.edges) + + precision = n_correct_predictions / n_predictions_made + recall = n_correct_predictions / n_relevant_predictions + f1_score = 2 * (precision * recall) / (precision + recall) + + assert f1_score > 0.8 diff --git a/tests/structure/test_structuremodel.py b/tests/structure/test_structuremodel.py new file mode 100644 index 0000000..149bc8e --- /dev/null +++ b/tests/structure/test_structuremodel.py @@ -0,0 +1,479 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from causalnex.structure import StructureModel + + +class TestStructureModel: + def test_init_has_origin(self): + """Creating a StructureModel using constructor should give all edges unknown origin""" + + sm = StructureModel([(1, 2)]) + assert (1, 2) in sm.edges + assert (1, 2, "unknown") in sm.edges.data("origin") + + def test_init_with_origin(self): + """should be possible to specify origin during init""" + + sm = StructureModel([(1, 2)], origin="learned") + assert (1, 2, "learned") in sm.edges.data("origin") + + def test_edge_unknown_property(self): + """should return only edges whose origin is unknown""" + + sm = StructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert sm.edges_with_origin("unknown") == [(1, 2)] + + def test_edge_learned_property(self): + """should return only edges whose origin is unknown""" + + sm = StructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert sm.edges_with_origin("learned") == [(1, 3)] + + def test_edge_expert_property(self): + """should return only edges whose origin is unknown""" + + sm = StructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert sm.edges_with_origin("expert") == [(1, 4)] + + def test_to_directed(self): + """should create a structure model""" + + sm = StructureModel() + edges = [(1, 2), (2, 1), (2, 3), (3, 4)] + sm.add_edges_from(edges) + + dag = sm.to_directed() + assert isinstance(dag, StructureModel) + assert all(edge in dag.edges for edge in edges) + + def test_to_undirected(self): + """should create an undirected Graph""" + + sm = StructureModel() + sm.add_edges_from([(1, 2), (2, 1), (2, 3), (3, 4)]) + + udg = sm.to_undirected() + assert all(edge in udg.edges for edge in [(2, 3), (3, 4)]) + assert (1, 2) in udg.edges or (2, 1) in udg.edges + assert len(udg.edges) == 3 + + +class TestStructureModelAddEdge: + def test_add_edge_default(self): + """edges added with default origin should be identified as unknown origin""" + + sm = StructureModel() + sm.add_edge(1, 2) + + assert (1, 2) in sm.edges + assert (1, 2, "unknown") in sm.edges.data("origin") + + def test_add_edge_unknown(self): + """edges added with unknown origin should be labelled as unknown origin""" + + sm = StructureModel() + sm.add_edge(1, 2, "unknown") + + assert (1, 2) in sm.edges + assert (1, 2, "unknown") in sm.edges.data("origin") + + def test_add_edge_learned(self): + """edges added with learned origin should be labelled as learned origin""" + + sm = StructureModel() + sm.add_edge(1, 2, "learned") + + assert (1, 2) in sm.edges + assert (1, 2, "learned") in sm.edges.data("origin") + + def test_add_edge_expert(self): + """edges added with expert origin should be labelled as expert origin""" + + sm = StructureModel() + sm.add_edge(1, 2, "expert") + + assert (1, 2) in sm.edges + assert (1, 2, "expert") in sm.edges.data("origin") + + def test_add_edge_other(self): + """edges added with other origin should throw an error""" + + sm = StructureModel() + + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): + sm.add_edge(1, 2, "other") + + def test_add_edge_custom_attr(self): + """it should be possible to add an edge with custom attributes""" + + sm = StructureModel() + sm.add_edge(1, 2, x="Y") + + assert (1, 2) in sm.edges + assert (1, 2, "Y") in sm.edges.data("x") + + def test_add_edge_multiple_times(self): + """adding an edge again should update the edges origin attr""" + + sm = StructureModel() + sm.add_edge(1, 2, origin="unknown") + assert (1, 2, "unknown") in sm.edges.data("origin") + sm.add_edge(1, 2, origin="learned") + assert (1, 2, "learned") in sm.edges.data("origin") + + def test_add_multiple_edges(self): + """it should be possible to add multiple edges with different origins""" + + sm = StructureModel() + sm.add_edge(1, 2, origin="unknown") + sm.add_edge(1, 3, origin="learned") + sm.add_edge(1, 4, origin="expert") + + assert (1, 2, "unknown") in sm.edges.data("origin") + assert (1, 3, "learned") in sm.edges.data("origin") + assert (1, 4, "expert") in sm.edges.data("origin") + + +class TestStructureModelAddEdgesFrom: + def test_add_edges_from_default(self): + """edges added with default origin should be identified as unknown origin""" + + sm = StructureModel() + edges = [(1, 2), (2, 3)] + sm.add_edges_from(edges) + + assert all(edge in sm.edges for edge in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_unknown(self): + """edges added with unknown origin should be labelled as unknown origin""" + + sm = StructureModel() + edges = [(1, 2), (2, 3)] + sm.add_edges_from(edges, "unknown") + + assert all(edge in sm.edges for edge in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_learned(self): + """edges added with learned origin should be labelled as learned origin""" + + sm = StructureModel() + edges = [(1, 2), (2, 3)] + sm.add_edges_from(edges, "learned") + + assert all(edge in sm.edges for edge in edges) + assert all((u, v, "learned") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_expert(self): + """edges added with expert origin should be labelled as expert origin""" + + sm = StructureModel() + edges = [(1, 2), (2, 3)] + sm.add_edges_from(edges, "expert") + + assert all(edge in sm.edges for edge in edges) + assert all((u, v, "expert") in sm.edges.data("origin") for u, v in edges) + + def test_add_edges_from_other(self): + """edges added with other origin should throw an error""" + + sm = StructureModel() + + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): + sm.add_edges_from([(1, 2)], "other") + + def test_add_edges_from_custom_attr(self): + """it should be possible to add edges with custom attributes""" + + sm = StructureModel() + edges = [(1, 2), (2, 3)] + sm.add_edges_from(edges, x="Y") + + assert all(edge in sm.edges for edge in edges) + assert all((u, v, "Y") in sm.edges.data("x") for u, v in edges) + + def test_add_edges_from_multiple_times(self): + """adding edges again should update the edges origin attr""" + + sm = StructureModel() + edges = [(1, 2), (2, 3)] + sm.add_edges_from(edges, "unknown") + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v in edges) + sm.add_edges_from(edges, "learned") + assert all((u, v, "learned") in sm.edges.data("origin") for u, v in edges) + + def test_add_multiple_edges(self): + """it should be possible to add multiple edges with different origins""" + + sm = StructureModel() + sm.add_edges_from([(1, 2)], origin="unknown") + sm.add_edges_from([(1, 3)], origin="learned") + sm.add_edges_from([(1, 4)], origin="expert") + + assert (1, 2, "unknown") in sm.edges.data("origin") + assert (1, 3, "learned") in sm.edges.data("origin") + assert (1, 4, "expert") in sm.edges.data("origin") + + +class TestStructureModelAddWeightedEdgesFrom: + def test_add_weighted_edges_from_default(self): + """edges added with default origin should be identified as unknown origin""" + + sm = StructureModel() + edges = [(1, 2, 0.5), (2, 3, 0.5)] + sm.add_weighted_edges_from(edges) + + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_unknown(self): + """edges added with unknown origin should be labelled as unknown origin""" + + sm = StructureModel() + edges = [(1, 2, 0.5), (2, 3, 0.5)] + sm.add_weighted_edges_from(edges, origin="unknown") + + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_learned(self): + """edges added with learned origin should be labelled as learned origin""" + + sm = StructureModel() + edges = [(1, 2, 0.5), (2, 3, 0.5)] + sm.add_weighted_edges_from(edges, origin="learned") + + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "learned") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_expert(self): + """edges added with expert origin should be labelled as expert origin""" + + sm = StructureModel() + edges = [(1, 2, 0.5), (2, 3, 0.5)] + sm.add_weighted_edges_from(edges, origin="expert") + + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "expert") in sm.edges.data("origin") for u, v, w in edges) + + def test_add_weighted_edges_from_other(self): + """edges added with other origin should throw an error""" + + sm = StructureModel() + + with pytest.raises(ValueError, match="^Unknown origin: must be one of.*$"): + sm.add_weighted_edges_from([(1, 2, 0.5)], origin="other") + + def test_add_weighted_edges_from_custom_attr(self): + """it should be possible to add edges with custom attributes""" + + sm = StructureModel() + edges = [(1, 2, 0.5), (2, 3, 0.5)] + sm.add_weighted_edges_from(edges, x="Y") + + assert all((u, v, w) in sm.edges.data("weight") for u, v, w in edges) + assert all((u, v, "Y") in sm.edges.data("x") for u, v, _ in edges) + + def test_add_weighted_edges_from_multiple_times(self): + """adding edges again should update the edges origin attr""" + + sm = StructureModel() + edges = [(1, 2, 0.5), (2, 3, 0.5)] + sm.add_weighted_edges_from(edges, origin="unknown") + assert all((u, v, "unknown") in sm.edges.data("origin") for u, v, _ in edges) + sm.add_weighted_edges_from(edges, origin="learned") + assert all((u, v, "learned") in sm.edges.data("origin") for u, v, _ in edges) + + def test_add_multiple_weighted_edges(self): + """it should be possible to add multiple edges with different origins""" + + sm = StructureModel() + sm.add_weighted_edges_from([(1, 2, 0.5)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 0.5)], origin="learned") + sm.add_weighted_edges_from([(1, 4, 0.5)], origin="expert") + + assert (1, 2, "unknown") in sm.edges.data("origin") + assert (1, 3, "learned") in sm.edges.data("origin") + assert (1, 4, "expert") in sm.edges.data("origin") + + +class TestStructureModelRemoveEdgesBelowThreshold: + def test_remove_edges_below_threshold(self): + """Edges whose weight is less than a defined threshold should be removed""" + + sm = StructureModel() + strong_edges = [(1, 2, 1.0), (1, 3, 0.8), (1, 5, 2.0)] + weak_edges = [(1, 4, 0.4), (2, 3, 0.6), (3, 5, 0.5)] + sm.add_weighted_edges_from(strong_edges) + sm.add_weighted_edges_from(weak_edges) + + sm.remove_edges_below_threshold(0.7) + + assert set(sm.edges(data="weight")) == set(strong_edges) + + def test_negative_weights(self): + """Negative edges whose absolute value is greater than the defined threshold should not be removed""" + + sm = StructureModel() + strong_edges = [(1, 2, -3.0), (3, 1, 0.7), (1, 5, -2.0)] + weak_edges = [(1, 4, 0.4), (2, 3, -0.6), (3, 5, -0.5)] + sm.add_weighted_edges_from(strong_edges) + sm.add_weighted_edges_from(weak_edges) + + sm.remove_edges_below_threshold(0.7) + + assert set(sm.edges(data="weight")) == set(strong_edges) + + def test_equal_weights(self): + """Edges whose absolute value is equal to the defined threshold should not be removed""" + + sm = StructureModel() + strong_edges = [(1, 2, 1.0), (1, 5, 2.0)] + equal_edges = [(1, 3, 0.6), (2, 3, 0.6)] + weak_edges = [(1, 4, 0.4), (3, 5, 0.5)] + sm.add_weighted_edges_from(strong_edges) + sm.add_weighted_edges_from(equal_edges) + sm.add_weighted_edges_from(weak_edges) + + sm.remove_edges_below_threshold(0.6) + + assert set(sm.edges(data="weight")) == set.union( + set(strong_edges), set(equal_edges) + ) + + def test_graph_with_no_edges(self): + """Can still run even if the graph is without edges""" + + sm = StructureModel() + nodes = [1, 2, 3] + sm.add_nodes_from(nodes) + sm.remove_edges_below_threshold(0.6) + + assert set(sm.nodes) == set(nodes) + assert set(sm.edges) == set() + + +class TestStructureModelGetLargestSubgraph: + @pytest.mark.parametrize( + "test_input, expected", + [ + ([(0, 1), (1, 2), (1, 3), (4, 6)], [(0, 1), (1, 2), (1, 3)]), + ([(3, 4), (3, 5), (7, 6)], [(3, 4), (3, 5)]), + ], + ) + def test_get_largest_subgraph(self, test_input, expected): + """Should be able to return the largest subgraph""" + sm = StructureModel() + sm.add_edges_from(test_input) + largest_subgraph = sm.get_largest_subgraph() + + expected_graph = StructureModel() + expected_graph.add_edges_from(expected) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_more_than_one_largest(self): + """Return the first largest when there are more than one largest subgraph""" + + edges = [(0, 1), (1, 2), (3, 4), (3, 5)] + sm = StructureModel() + sm.add_edges_from(edges) + largest_subgraph = sm.get_largest_subgraph() + + expected_edges = [(0, 1), (1, 2)] + expected_graph = StructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_empty(self): + """Should return None if the structure model is empty""" + + sm = StructureModel() + assert sm.get_largest_subgraph() is None + + def test_isolates(self): + """Should return None if the structure model only contains isolates""" + nodes = [1, 3, 5, 2, 7] + + sm = StructureModel() + sm.add_nodes_from(nodes) + + assert sm.get_largest_subgraph() is None + + def test_isolates_nodes_and_edges(self): + """Should be able to return the largest subgraph""" + + edges = [(0, 1), (1, 2), (1, 3), (5, 6)] + isolated_nodes = [7, 8, 9] + sm = StructureModel() + sm.add_edges_from(edges) + sm.add_nodes_from(isolated_nodes) + largest_subgraph = sm.get_largest_subgraph() + + expected_edges = [(0, 1), (1, 2), (1, 3)] + expected_graph = StructureModel() + expected_graph.add_edges_from(expected_edges) + + assert set(largest_subgraph.nodes) == set(expected_graph.nodes) + assert set(largest_subgraph.edges) == set(expected_graph.edges) + + def test_different_origins_and_weights(self): + """The largest subgraph returned should still have the edge data preserved from the original graph""" + + sm = StructureModel() + sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") + sm.add_weighted_edges_from([(5, 6, 0.7)], origin="expert") + + largest_subgraph = sm.get_largest_subgraph() + + assert set(largest_subgraph.edges.data("origin")) == set( + [(1, 2, "unknown"), (1, 3, "learned")] + ) + assert set(largest_subgraph.edges.data("weight")) == set( + [(1, 2, 2.0), (1, 3, 1.0)] + ) diff --git a/tests/test_bayesiannetwork.py b/tests/test_bayesiannetwork.py new file mode 100644 index 0000000..a54ae22 --- /dev/null +++ b/tests/test_bayesiannetwork.py @@ -0,0 +1,612 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import pandas as pd +import pytest + +from causalnex.network import BayesianNetwork +from causalnex.structure import StructureModel +from causalnex.structure.notears import from_pandas + + +class TestFitNodeStates: + """Test behaviour of fit node states method""" + + @pytest.mark.parametrize( + "weighted_edges, data", + [ + ([("a", "b", 1)], pd.DataFrame([[1, 1]], columns=["a", "b"])), + ( + [("a", "b", 1)], + pd.DataFrame([[1, 1, 1, 1]], columns=["a", "b", "c", "d"]), + ), + # c and d are isolated nodes in the data + ], + ) + def test_all_nodes_included(self, weighted_edges, data): + """No errors if all the nodes can be found in the columns of training data""" + cg = StructureModel() + cg.add_weighted_edges_from(weighted_edges) + bn = BayesianNetwork(cg).fit_node_states(data) + assert all(node in data.columns for node in bn.node_states.keys()) + + def test_all_states_included(self): + """All states in a node should be included""" + cg = StructureModel() + cg.add_weighted_edges_from([("a", "b", 1)]) + bn = BayesianNetwork(cg).fit_node_states( + pd.DataFrame([[i, i] for i in range(10)], columns=["a", "b"]) + ) + assert all(v in bn.node_states["a"] for v in range(10)) + + def test_fit_with_null_states_raises_error(self): + """An error should be raised if fit is called with null data""" + cg = StructureModel() + cg.add_weighted_edges_from([("a", "b", 1)]) + with pytest.raises(ValueError, match="node '.*' contains None state"): + BayesianNetwork(cg).fit_node_states( + pd.DataFrame([[None, 1]], columns=["a", "b"]) + ) + + def test_fit_with_missing_feature_in_data(self): + """An error should be raised if fit is called with missing feature in data""" + cg = StructureModel() + + cg.add_weighted_edges_from([("a", "e", 1)]) + with pytest.raises( + KeyError, + match="The data does not cover all the features found in the Bayesian Network. " + "Please check the following features: {'e'}", + ): + BayesianNetwork(cg).fit_node_states( + pd.DataFrame([[1, 1, 1, 1]], columns=["a", "b", "c", "d"]) + ) + + +class TestFitCPDSErrors: + """Test errors for fit CPDs method""" + + def test_invalid_method(self, bn, train_data_discrete): + """a value error should be raised in an invalid method is provided""" + + with pytest.raises(ValueError, match=r"unrecognised method.*"): + bn.fit_cpds(train_data_discrete, method="INVALID") + + def test_invalid_prior(self, bn, train_data_discrete): + """a value error should be raised in an invalid prior is provided""" + + with pytest.raises(ValueError, match=r"unrecognised bayes_prior.*"): + bn.fit_cpds( + train_data_discrete, method="BayesianEstimator", bayes_prior="INVALID" + ) + + +class TestFitCPDsMaximumLikelihoodEstimator: + """Test behaviour of fit_cpds using MLE""" + + def test_cause_only_node(self, bn, train_data_discrete, train_data_discrete_cpds): + """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" + + bn.fit_cpds(train_data_discrete) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["d"].values.reshape(2) + - train_data_discrete_cpds["d"].reshape(2) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["e"].values.reshape(2) + - train_data_discrete_cpds["e"].reshape(2) + ) + ) + < 1e-7 + ) + + def test_dependent_node(self, bn, train_data_discrete, train_data_discrete_cpds): + """Test that probabilities are fit correctly to nodes that are caused by other nodes""" + + bn.fit_cpds(train_data_discrete) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["a"].values.reshape(24) + - train_data_discrete_cpds["a"].reshape(24) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["b"].values.reshape(12) + - train_data_discrete_cpds["b"].reshape(12) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["c"].values.reshape(60) + - train_data_discrete_cpds["c"].reshape(60) + ) + ) + < 1e-7 + ) + + +class TestFitBayesianEstimator: + """Test behaviour of fit_cpds using BE""" + + def test_cause_only_node_bdeu( + self, bn, train_data_discrete, train_data_discrete_cpds + ): + """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=5, + ) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["d"].values.reshape(2) + - train_data_discrete_cpds["d"].reshape(2) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["e"].values.reshape(2) + - train_data_discrete_cpds["e"].reshape(2) + ) + ) + < 0.02 + ) + + def test_cause_only_node_k2( + self, bn, train_data_discrete, train_data_discrete_cpds + ): + """Test that probabilities are fit correctly to nodes which are not caused by other nodes""" + + bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["d"].values.reshape(2) + - train_data_discrete_cpds["d"].reshape(2) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["e"].values.reshape(2) + - train_data_discrete_cpds["e"].reshape(2) + ) + ) + < 0.02 + ) + + def test_dependent_node_bdeu( + self, bn, train_data_discrete, train_data_discrete_cpds + ): + """Test that probabilities are fit correctly to nodes that are caused by other nodes""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=1, + ) + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["a"].values.reshape(24) + - train_data_discrete_cpds["a"].reshape(24) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["b"].values.reshape(12) + - train_data_discrete_cpds["b"].reshape(12) + ) + ) + < 0.02 + ) + assert ( + np.mean( + np.abs( + cpds["c"].values.reshape(60) + - train_data_discrete_cpds["c"].reshape(60) + ) + ) + < 0.02 + ) + + def test_dependent_node_k2( + self, bn, train_data_discrete, train_data_discrete_cpds_k2 + ): + """Test that probabilities are fit correctly to nodes that are caused by other nodes""" + + bn.fit_cpds(train_data_discrete, method="BayesianEstimator", bayes_prior="K2") + cpds = bn.cpds + + assert ( + np.mean( + np.abs( + cpds["a"].values.reshape(24) + - train_data_discrete_cpds_k2["a"].reshape(24) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["b"].values.reshape(12) + - train_data_discrete_cpds_k2["b"].reshape(12) + ) + ) + < 1e-7 + ) + assert ( + np.mean( + np.abs( + cpds["c"].values.reshape(60) + - train_data_discrete_cpds_k2["c"].reshape(60) + ) + ) + < 1e-7 + ) + + +class TestPredictMaximumLikelihoodEstimator: + """Test behaviour of predict using MLE""" + + def test_predictions_are_based_on_probabilities( + self, bn, train_data_discrete, test_data_c_discrete + ): + """Predictions made using the model should be based on the probabilities that are in the model""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete, "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + def test_prediction_node_suffixed_as_prediction( + self, bn, train_data_discrete, test_data_c_discrete + ): + """The column that contains the values of the predicted node should be named node_prediction""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete, "c") + assert "c_prediction" in predictions.columns + + def test_only_predicted_column_returned( + self, bn, train_data_discrete, test_data_c_discrete + ): + """The returned df should not contain any of the input data columns""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete, "c") + assert len(predictions.columns) == 1 + + def test_predictions_are_not_appended_to_input_df( + self, bn, train_data_discrete, test_data_c_discrete + ): + """The predictions should not be appended to the input df""" + + expected_cols = test_data_c_discrete.columns + bn.fit_cpds(train_data_discrete) + bn.predict(test_data_c_discrete, "c") + assert np.array_equal(test_data_c_discrete.columns, expected_cols) + + def test_missing_parent(self, bn, train_data_discrete, test_data_c_discrete): + """Predictions made when parents are missing should still be reasonably accurate""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete[["a", "b", "c", "d"]], "c") + + n = len(test_data_c_discrete) + + accuracy = ( + 1 + - np.count_nonzero( + predictions.values.reshape(len(predictions.values)) + - test_data_c_discrete["c"].values + ) + / n + ) + + assert accuracy > 0.9 + + def test_missing_non_parent(self, bn, train_data_discrete, test_data_c_discrete): + """It should be possible to make predictions with non-parent nodes missing""" + + bn.fit_cpds(train_data_discrete) + predictions = bn.predict(test_data_c_discrete[["b", "c", "d", "e"]], "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + +class TestPredictBayesianEstimator: + """Test behaviour of predict using BE""" + + def test_predictions_are_based_on_probabilities_dbeu( + self, bn, train_data_discrete, test_data_c_discrete + ): + """Predictions made using the model should be based on the probabilities that are in the model""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=5, + ) + predictions = bn.predict(test_data_c_discrete, "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + def test_predictions_are_based_on_probabilities_k2( + self, bn, train_data_discrete, test_data_c_discrete + ): + """Predictions made using the model should be based on the probabilities that are in the model""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="K2", + equivalent_sample_size=5, + ) + predictions = bn.predict(test_data_c_discrete, "c") + assert np.all( + predictions.values.reshape(len(predictions.values)) + == test_data_c_discrete["c"].values + ) + + +class TestPredictProbabilityMaximumLikelihoodEstimator: + """Test behaviour of predict_probability using MLE""" + + def test_expected_probabilities_are_predicted( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """Probabilities should return exactly correct on a hand computable scenario""" + bn.fit_cpds(train_data_discrete) + probability = bn.predict_probability(test_data_c_discrete, "c") + + assert all( + np.isclose( + probability.values.flatten(), test_data_c_likelihood.values.flatten() + ) + ) + + def test_missing_parent( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """Probabilities made when parents are missing should still be reasonably accurate""" + + bn.fit_cpds(train_data_discrete) + probability = bn.predict_probability( + test_data_c_discrete[["a", "b", "c", "d"]], "c" + ) + + n = len(probability.values.flatten()) + + accuracy = ( + np.count_nonzero( + [ + 1 if math.isclose(a, b, abs_tol=0.15) else 0 + for a, b in zip( + probability.values.flatten(), + test_data_c_likelihood.values.flatten(), + ) + ] + ) + / n + ) + + assert accuracy > 0.8 + + def test_missing_non_parent( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """It should be possible to make predictions with non-parent nodes missing""" + + bn.fit_cpds(train_data_discrete) + probability = bn.predict_probability( + test_data_c_discrete[["b", "c", "d", "e"]], "c" + ) + assert all( + np.isclose( + probability.values.flatten(), test_data_c_likelihood.values.flatten() + ) + ) + + +class TestPredictProbabilityBayesianEstimator: + """Test behaviour of predict_probability using BayesianEstimator""" + + def test_expected_probabilities_are_predicted( + self, bn, train_data_discrete, test_data_c_discrete, test_data_c_likelihood + ): + """Probabilities should return exactly correct on a hand computable scenario""" + + bn.fit_cpds( + train_data_discrete, + method="BayesianEstimator", + bayes_prior="BDeu", + equivalent_sample_size=1, + ) + probability = bn.predict_probability(test_data_c_discrete, "c") + assert all( + np.isclose( + probability.values.flatten(), + test_data_c_likelihood.values.flatten(), + atol=0.1, + ) + ) + + +class TestFitNodesStatesAndCPDs: + """Test behaviour of helper function""" + + def test_behaves_same_as_seperate_calls(self, train_data_idx, train_data_discrete): + bn1 = BayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3)) + bn2 = BayesianNetwork(from_pandas(train_data_idx, w_threshold=0.3)) + + bn1.fit_node_states(train_data_discrete).fit_cpds(train_data_discrete) + bn2.fit_node_states_and_cpds(train_data_discrete) + + assert bn1.edges == bn2.edges + assert bn1.node_states == bn2.node_states + + cpds1 = bn1.cpds + cpds2 = bn2.cpds + + assert cpds1.keys() == cpds2.keys() + + for k in cpds1: + assert cpds1[k].equals(cpds2[k]) + + +class TestCPDsProperty: + """Test behaviour of the CPDs property""" + + def test_row_index_of_state_values(self, bn): + """CPDs should have row index set to values of all possible states of the node""" + + assert bn.cpds["a"].index.tolist() == sorted(list(bn.node_states["a"])) + + def test_col_index_of_parent_state_combinations(self, bn): + """CPDs should have a column multi-index of parent state permutations""" + + assert bn.cpds["a"].columns.names == ["b", "d"] + + +class TestInit: + """Test behaviour when constructing a BayesianNetwork""" + + def test_cycles_in_structure(self): + """An error should be raised if cycles are present""" + + with pytest.raises( + ValueError, + match=r"The given structure is not acyclic\. " + r"Please review the following cycle\.*", + ): + BayesianNetwork(StructureModel([(0, 1), (1, 2), (2, 0)])) + + @pytest.mark.parametrize( + "test_input,n_components", + [([(0, 1), (1, 2), (3, 4), (4, 6)], 2), ([(0, 1), (1, 2), (3, 4), (5, 6)], 3)], + ) + def test_disconnected_components(self, test_input, n_components): + """An error should be raised if there is more than one graph component""" + + with pytest.raises( + ValueError, + match=r"The given structure has " + + str(n_components) + + r" separated graph components\. " + r"Please make sure it has only one\.", + ): + BayesianNetwork(StructureModel(test_input)) + + +class TestStructure: + """Test behaviour of the property structure""" + + def test_get_structure(self): + """The structure retrieved should be the same""" + + sm = StructureModel() + + sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") + sm.add_weighted_edges_from([(3, 5, 0.7)], origin="expert") + + bn = BayesianNetwork(sm) + + sm_from_bn = bn.structure + + assert set(sm.edges.data("origin")) == set(sm_from_bn.edges.data("origin")) + assert set(sm.edges.data("weight")) == set(sm_from_bn.edges.data("weight")) + + assert set(sm.nodes) == set(sm_from_bn.nodes) + + def test_set_structure(self): + """An error should be raised if setting the structure""" + + sm = StructureModel() + sm.add_weighted_edges_from([(1, 2, 2.0)], origin="unknown") + sm.add_weighted_edges_from([(1, 3, 1.0)], origin="learned") + sm.add_weighted_edges_from([(3, 5, 0.7)], origin="expert") + + bn = BayesianNetwork(sm) + + new_sm = StructureModel() + sm.add_weighted_edges_from([(2, 5, 3.0)], origin="unknown") + sm.add_weighted_edges_from([(2, 3, 2.0)], origin="learned") + sm.add_weighted_edges_from([(3, 4, 1.7)], origin="expert") + + with pytest.raises(AttributeError, match=r"can't set attribute"): + bn.structure = new_sm diff --git a/tests/test_inference.py b/tests/test_inference.py new file mode 100644 index 0000000..99bd049 --- /dev/null +++ b/tests/test_inference.py @@ -0,0 +1,371 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import pytest + +from causalnex.inference import InferenceEngine +from causalnex.network import BayesianNetwork +from causalnex.structure import StructureModel +from causalnex.structure.notears import from_pandas + + +class TestInferenceEngineIdx: + def test_create_inference_from_bn(self, train_model, train_data_idx): + """It should be possible to create a new Inference object from an existing pgmpy model""" + + bn = BayesianNetwork(train_model).fit_node_states(train_data_idx) + bn.fit_cpds(train_data_idx) + InferenceEngine(bn) + + def test_create_inference_with_bad_variable_names_fails( + self, train_model, train_data_idx + ): + + model = StructureModel() + model.add_edges_from( + [ + (str(u).replace("a", "$a"), str(v).replace("a", "$a")) + for u, v in train_model.edges + ] + ) + + train_data_idx.rename(columns={"a": "$a"}, inplace=True) + + bn = BayesianNetwork(model).fit_node_states(train_data_idx) + bn.fit_cpds(train_data_idx) + + with pytest.raises(ValueError, match="Variable names must match.*"): + InferenceEngine(bn) + + def test_empty_query_returns_marginals( + self, train_model, train_data_idx, train_data_idx_marginals + ): + """An empty query should return all the marginal probabilities of the model's distribution""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + dist = ie.query({}) + + for node, states in dist.items(): + for state, p in states.items(): + assert math.isclose( + train_data_idx_marginals[node][state], p, abs_tol=0.05 + ) + + def test_observations_affect_marginals(self, train_model, train_data_idx): + """Observing the state of a node should affect the marginals of dependent nodes""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + m1 = ie.query({}) + m2 = ie.query({"d": 1}) + + assert m2["d"][0] == 0 + assert m2["d"][1] == 1 + assert not math.isclose(m2["b"][1], m1["b"][1], abs_tol=0.01) + + def test_observations_does_not_affect_marginals_of_independent_nodes( + self, train_model, train_data_idx + ): + """Observing the state of a node should not affect the marginal probability of an independent node""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + m1 = ie.query({}) + m2 = ie.query({"d": 1}) + + assert m2["d"][0] == 0 + assert m2["d"][1] == 1 + assert math.isclose(m2["e"][1], m1["e"][1], abs_tol=0.05) + + def test_do_sets_state_probability_to_one(self, train_model, train_data_idx): + """Do should update the probability of the given observation=state to 1""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + ie.do_intervention("d", 1) + assert math.isclose(ie.query()["d"][1], 1) + + def test_do_on_node_with_no_effects_not_allowed(self, train_model, train_data_idx): + """It should not be possible to create an isolated node in the network""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, + match="Do calculus cannot be applied because it would result in an isolate", + ): + ie.do_intervention("a", 1) + + def test_do_sets_other_state_probabilitys_to_zero( + self, train_model, train_data_idx + ): + """Do should update the probability of every other state for the observation to zero""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + ie.do_intervention("d", 1) + assert ie.query()["d"][0] == 0 + + def test_do_accepts_all_state_probabilities(self, train_model, train_data_idx): + """Do should accept a map of state->p and update p accordingly""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + ie.do_intervention("d", {0: 0.7, 1: 0.3}) + assert math.isclose(ie.query()["d"][0], 0.7) + assert math.isclose(ie.query()["d"][1], 0.3) + + def test_do_expects_all_state_probabilities_sum_to_one( + self, train_model, train_data_idx + ): + """Do should accept only state probabilities where the full distribution is provided""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, match="The cpd for the provided observation must sum to 1" + ): + ie.do_intervention("d", {0: 0.7, 1: 0.4}) + + def test_do_expects_all_states_have_a_probability( + self, train_model, train_data_idx + ): + """Do should accept only state probabilities where all states in the original cpds are present""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, match="The cpd states do not match expected states*" + ): + ie.do_intervention("d", {1: 1}) + + def test_do_prevents_new_states_being_added(self, train_model, train_data_idx): + """Do should not allow the introduction of new states""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, match="The cpd states do not match expected states*" + ): + ie.do_intervention("d", {0: 0.7, 1: 0.3, 2: 0.0}) + + def test_do_reflected_in_query(self, train_model, train_data_idx): + """Do should adjust marginals returned by query when given a different observation""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + + assert ie.query({"a": 1})["d"][1] != 1 + ie.do_intervention("d", 1) + assert ie.query({"a": 1})["d"][1] == 1 + + def test_reset_do_sets_probabilities_back_to_initial_state( + self, train_model, train_data_idx, train_data_idx_marginals + ): + """Resetting Do operator should re-introduce the original conditional dependencies""" + + bn = BayesianNetwork(train_model) + bn.fit_node_states(train_data_idx).fit_cpds(train_data_idx) + + ie = InferenceEngine(bn) + ie.do_intervention("d", {0: 0.7, 1: 0.3}) + ie.reset_do("d") + + assert math.isclose(ie.query()["d"][0], train_data_idx_marginals["d"][0]) + assert math.isclose(ie.query()["d"][1], train_data_idx_marginals["d"][1]) + + +class TestInferenceEngineDiscrete: + """Test behaviour of query and interventions""" + + def test_query_when_cpds_not_fit(self, train_data_idx, train_data_discrete): + """An error should be raised if query before CPDs are fit""" + + bn = BayesianNetwork( + from_pandas(train_data_idx, w_threshold=0.3) + ).fit_node_states(train_data_discrete) + + with pytest.raises( + ValueError, match=r"Bayesian Network does not contain any CPDs.*" + ): + InferenceEngine(bn) + + def test_empty_query_returns_marginals(self, bn, train_data_discrete_marginals): + """An empty query should return all the marginal probabilities of the model's distribution""" + + ie = InferenceEngine(bn) + dist = ie.query({}) + + for node, states in dist.items(): + for state, p in states.items(): + assert math.isclose( + train_data_discrete_marginals[node][state], p, abs_tol=0.05 + ) + + def test_observations_affect_marginals(self, bn): + """Observing the state of a node should affect the marginals of dependent nodes""" + + ie = InferenceEngine(bn) + + m1 = ie.query({}) + m2 = ie.query({"d": True}) + + assert m2["d"][False] == 0 + assert m2["d"][True] == 1 + assert not math.isclose(m2["b"]["x"], m1["b"]["x"], abs_tol=0.05) + + def test_observations_does_not_affect_marginals_of_independent_nodes(self, bn): + """Observing the state of a node should not affect the marginal probability of an independent node""" + + ie = InferenceEngine(bn) + + m1 = ie.query({}) + m2 = ie.query({"d": True}) + + assert m2["d"][False] == 0 + assert m2["d"][True] == 1 + assert math.isclose(m2["e"][True], m1["e"][True], abs_tol=0.05) + + def test_do_sets_state_probability_to_one(self, bn): + """Do should update the probability of the given observation=state to 1""" + + ie = InferenceEngine(bn) + ie.do_intervention("d", True) + assert math.isclose(ie.query()["d"][True], 1) + + def test_do_on_node_with_no_effects_not_allowed(self, bn): + """It should not be possible to create an isolated node in the network""" + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, + match="Do calculus cannot be applied because it would result in an isolate", + ): + ie.do_intervention("a", "b") + + def test_do_sets_other_state_probabilitys_to_zero(self, bn): + """Do should update the probability of every other state for the observation to zero""" + + ie = InferenceEngine(bn) + ie.do_intervention("d", True) + assert ie.query()["d"][False] == 0 + + def test_do_accepts_all_state_probabilities(self, bn): + """Do should accept a map of state->p and update p accordingly""" + + ie = InferenceEngine(bn) + ie.do_intervention("d", {False: 0.7, True: 0.3}) + assert math.isclose(ie.query()["d"][False], 0.7) + assert math.isclose(ie.query()["d"][True], 0.3) + + def test_do_expects_all_state_probabilities_sum_to_one(self, bn): + """Do should accept only state probabilities where the full distribution is provided""" + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, match="The cpd for the provided observation must sum to 1" + ): + ie.do_intervention("d", {False: 0.7, True: 0.4}) + + def test_do_expects_all_states_have_a_probability(self, bn): + """Do should accept only state probabilities where all states in the original cpds are present""" + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, match="The cpd states do not match expected states*" + ): + ie.do_intervention("d", {False: 1}) + + def test_do_prevents_new_states_being_added(self, bn): + """Do should not allow the introduction of new states""" + + ie = InferenceEngine(bn) + + with pytest.raises( + ValueError, match="The cpd states do not match expected states*" + ): + ie.do_intervention("d", {False: 0.7, True: 0.3, "other": 0.0}) + + def test_do_reflected_in_query(self, bn): + """Do should adjust marginals returned by query when given a different observation""" + + ie = InferenceEngine(bn) + + assert ie.query({"a": "b"})["d"][True] != 1 + ie.do_intervention("d", True) + assert ie.query({"a": "b"})["d"][True] == 1 + + def test_reset_do_sets_probabilities_back_to_initial_state( + self, bn, train_data_discrete_marginals + ): + """Resetting Do operator should re-introduce the original conditional dependencies""" + + ie = InferenceEngine(bn) + ie.do_intervention("d", {False: 0.7, True: 0.3}) + ie.reset_do("d") + + assert math.isclose( + ie.query()["d"][False], train_data_discrete_marginals["d"][False] + ) + assert math.isclose( + ie.query()["d"][False], train_data_discrete_marginals["d"][False] + ) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..9fee6b3 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,400 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import random + +import numpy as np +import pandas as pd + +from causalnex.evaluation import classification_report, roc_auc +from causalnex.network import BayesianNetwork +from causalnex.structure.notears import from_pandas +from causalnex.structure.structuremodel import StructureModel + + +class TestROCAUCStates: + """Test behaviour of the roc_auc_states metric""" + + def test_roc_of_incorrect_has_fpr_lt_tpr(self): + """The ROC of incorrect predictions should have FPR < TPR""" + + # regardless of a or b, c=1 is always more likely to varying amounts (to create multiple threshold + # points in roc curve) + train = pd.DataFrame( + [[a, b, 0] for a in range(3) for b in range(3) for _ in range(1)] + + [ + [a, b, 1] + for a in range(3) + for b in range(3) + for _ in range(a * 1000 + b * 1000 + 1000) + ], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + assert np.allclose(bn.cpds["c"].loc[1].values, 1, atol=0.02) + + # in test, c=0 is always more likely (opposite of train) + test = pd.DataFrame( + [[a, b, 0] for a in range(3) for b in range(3) for _ in range(1000)] + + [[a, b, 1] for a in range(3) for b in range(3) for _ in range(1)], + columns=["a", "b", "c"], + ) + + roc, _ = roc_auc(bn, test, "c") + + assert len(roc) > 3 + assert all(fpr > tpr for fpr, tpr in roc if tpr not in [0.0, 1.0]) + + def test_auc_of_incorrect_close_to_zero(self): + """The AUC of incorrect predictions should be close to zero""" + + # regardless of a or b, c=1 is always more likely to varying amounts (to create multiple threshold + # points in roc curve) + train = pd.DataFrame( + [[a, b, 0] for a in range(3) for b in range(3) for _ in range(1)] + + [ + [a, b, 1] + for a in range(3) + for b in range(3) + for _ in range(a * 1000 + b * 1000 + 1000) + ], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + assert np.allclose(bn.cpds["c"].loc[1].values, 1, atol=0.02) + + # in test, c=0 is always more likely (opposite of train) + test = pd.DataFrame( + [[a, b, 0] for a in range(3) for b in range(3) for _ in range(1000)] + + [[a, b, 1] for a in range(3) for b in range(3) for _ in range(1)], + columns=["a", "b", "c"], + ) + + _, auc = roc_auc(bn, test, "c") + + assert math.isclose(auc, 0, abs_tol=0.001) + + def test_roc_of_random_has_unit_gradient(self): + """The ROC curve for random predictions should be a line from (0,0) to (1,1)""" + + # regardless of a or b, c=1 is always more likely to varying amounts (to create multiple threshold + # points in roc curve) + train = pd.DataFrame( + [[a, b, 0] for a in range(3) for b in range(3) for _ in range(1)] + + [ + [a, b, 1] + for a in range(3) + for b in range(3) + for _ in range(a * 1000 + b * 1000 + 1000) + ], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + assert np.allclose(bn.cpds["c"].loc[1].values, 1, atol=0.02) + + test = pd.DataFrame( + [ + [a, b, random.randint(0, 1)] + for a in range(3) + for b in range(3) + for _ in range(1000) + ], + columns=["a", "b", "c"], + ) + + roc, _ = roc_auc(bn, test, "c") + + assert len(roc) > 3 + assert all(math.isclose(a, b, abs_tol=0.03) for a, b in roc) + + def test_auc_of_random_is_half(self): + """The AUC of random predictions should be 0.5""" + + # regardless of a or b, c=1 is always more likely to varying amounts (to create multiple threshold + # points in roc curve) + train = pd.DataFrame( + [[a, b, 0] for _ in range(10) for a in range(3) for b in range(3)] + + [ + [a, b, 1] + for a in range(3) + for b in range(3) + for _ in range(a * 1000 + b * 1000 + 1000) + ], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + assert np.allclose(bn.cpds["c"].loc[1].values, 1, atol=0.02) + + test = pd.DataFrame( + [ + [a, b, random.randint(0, 1)] + for a in range(3) + for b in range(3) + for _ in range(1000) + ], + columns=["a", "b", "c"], + ) + + _, auc = roc_auc(bn, test, "c") + + assert math.isclose(auc, 0.5, abs_tol=0.03) + + def test_roc_of_accurate_predictions(self): + """TPR should always be better than FPR for accurate predictions""" + + # equal class (c) weighting to guarantee high ROC expected + train = pd.DataFrame( + [[a, b, 0] for a in range(0, 2) for b in range(0, 2) for _ in range(10)] + + [ + [a, b, 1] + for a in range(0, 2) + for b in range(0, 2) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [ + [a, b, 0] + for a in range(2, 4) + for b in range(2, 4) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [[a, b, 1] for a in range(2, 4) for b in range(2, 4) for _ in range(10)], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + roc, _ = roc_auc(bn, train, "c") + assert all(tpr > fpr for fpr, tpr in roc if tpr not in [0.0, 1.0]) + + def test_auc_of_accurate_predictions(self): + """AUC of accurate predictions should be 1""" + + # equal class (c) weighting to guarantee high ROC expected + train = pd.DataFrame( + [[a, b, 0] for a in range(0, 2) for b in range(0, 2) for _ in range(1)] + + [ + [a, b, 1] + for a in range(0, 2) + for b in range(0, 2) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [ + [a, b, 0] + for a in range(2, 4) + for b in range(2, 4) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [[a, b, 1] for a in range(2, 4) for b in range(2, 4) for _ in range(1)], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + _, auc = roc_auc(bn, train, "c") + assert math.isclose(auc, 1, abs_tol=0.001) + + def test_auc_with_missing_state_in_test(self): + """AUC should still be calculated correctly with states missing in test set""" + + # equal class (c) weighting to guarantee high ROC expected + train = pd.DataFrame( + [[a, b, 0] for a in range(0, 2) for b in range(0, 2) for _ in range(1)] + + [ + [a, b, 1] + for a in range(0, 2) + for b in range(0, 2) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [ + [a, b, 0] + for a in range(2, 4) + for b in range(2, 4) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [[a, b, 1] for a in range(2, 4) for b in range(2, 4) for _ in range(1)], + columns=["a", "b", "c"], + ) + + test = train[train["c"] == 1] + assert len(test["c"].unique()) == 1 + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + _, auc = roc_auc(bn, test, "c") + assert math.isclose(auc, 1, abs_tol=0.01) + + def test_auc_node_with_no_parents(self): + """Should be possible to compute auc for state with no parent nodes""" + + train = pd.DataFrame( + [[a, b, 0] for a in range(0, 2) for b in range(0, 2) for _ in range(1)] + + [ + [a, b, 1] + for a in range(0, 2) + for b in range(0, 2) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [ + [a, b, 0] + for a in range(2, 4) + for b in range(2, 4) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [[a, b, 1] for a in range(2, 4) for b in range(2, 4) for _ in range(1)], + columns=["a", "b", "c"], + ) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + _, auc = roc_auc(bn, train, "a") + assert math.isclose(auc, 0.5, abs_tol=0.01) + + def test_auc_for_nonnumeric_features(self): + """AUC of accurate predictions should be 1 even after remapping numbers to strings""" + + # equal class (c) weighting to guarantee high ROC expected + train = pd.DataFrame( + [[a, b, 0] for a in range(0, 2) for b in range(0, 2) for _ in range(1)] + + [ + [a, b, 1] + for a in range(0, 2) + for b in range(0, 2) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [ + [a, b, 0] + for a in range(2, 4) + for b in range(2, 4) + for _ in range(a * 10 + b * 10 + 1000) + ] + + [[a, b, 1] for a in range(2, 4) for b in range(2, 4) for _ in range(1)], + columns=["a", "b", "c"], + ) + + # remap values in column c + train["c"] = train["c"].map({0: "f", 1: "g"}) + + cg = StructureModel() + cg.add_weighted_edges_from([("a", "c", 1), ("b", "c", 1)]) + + bn = BayesianNetwork(cg) + bn.fit_node_states(train) + bn.fit_cpds(train) + + _, auc = roc_auc(bn, train, "c") + assert math.isclose(auc, 1, abs_tol=0.001) + + +class TestClassificationReport: + """Test behaviour of classification_report""" + + def test_contains_expected_columns(self, test_data_c_discrete, bn): + """Check that the report contains all of the required data""" + + report = classification_report(bn, test_data_c_discrete, "c") + + assert set(report.columns) == {"recall", "precision", "support", "f1-score"} + + def test_contains_all_class_data( + self, test_data_c_discrete, bn, test_data_c_likelihood + ): + """Check that the report contains data on each possible class""" + + report = classification_report(bn, test_data_c_discrete, "c") + + assert (label in report.index for label in test_data_c_likelihood.columns) + + def test_report_ignores_unrequired_columns_in_data( + self, train_data_idx, train_data_discrete, test_data_c_discrete + ): + """Classification report should ignore any columns that are no needed by predict""" + + bn = BayesianNetwork( + from_pandas(train_data_idx, w_threshold=0.3) + ).fit_node_states(train_data_discrete) + train_data_discrete["NEW_COL"] = [1] * len(train_data_discrete) + bn.fit_cpds(train_data_discrete) + classification_report(bn, test_data_c_discrete, "c") + + def test_report_on_node_with_no_parents_based_on_modal_state( + self, bn, train_data_discrete + ): + """Classification Report on a node with no parents should reflect that predictions are on modal state""" + + report = classification_report(bn, train_data_discrete, "d") + assert report.loc["d_False", "recall"] == 1 # always predicts most likely class + assert report.loc["d_True", "recall"] == 0 diff --git a/tests/test_plotting.py b/tests/test_plotting.py new file mode 100644 index 0000000..5e28884 --- /dev/null +++ b/tests/test_plotting.py @@ -0,0 +1,136 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +from string import ascii_lowercase + +import matplotlib as plt +import pytest +from matplotlib.colors import to_rgba + +from causalnex.plots import plot_structure +from causalnex.structure import StructureModel + + +class TestPlotStructure: + """Test behaviour of plot structure method""" + + @pytest.mark.parametrize( + "test_input,expected", [(None, ""), ("", ""), ("TEST", "TEST")] + ) + def test_title(self, test_input, expected): + """Title should be set correctly""" + sm = StructureModel([("a", "b")]) + _, ax, _ = plot_structure(sm, title=test_input) + assert ax.get_title() == expected + + def test_edges_exist(self): + """All edges should exist""" + + for num_nodes in range(2, 10): + nodes = [c for i, c in enumerate(ascii_lowercase) if i < num_nodes] + sm = StructureModel(list(zip(nodes[:-1], nodes[1:]))) + _, ax, _ = plot_structure(sm) + ax_edges = [ + patch + for patch in ax.patches + if isinstance(patch, plt.patches.FancyArrowPatch) + ] + assert len(ax_edges) == num_nodes - 1 + + @pytest.mark.parametrize( + "test_input,expected", + [("#123456", to_rgba("#123456")), ("blue", to_rgba("blue"))], + ) + def test_edge_color(self, test_input, expected): + """Edge color should be set if given""" + sm = StructureModel([("a", "b")]) + _, ax, _ = plot_structure(sm, edge_color=test_input) + ax_edges = [ + patch + for patch in ax.patches + if isinstance(patch, plt.patches.FancyArrowPatch) + ] + assert ax_edges[0].get_edgecolor() == expected + + def test_nodes_exist(self): + """All nodes should exist""" + + for num_nodes in range(2, 10): + nodes = [c for i, c in enumerate(ascii_lowercase) if i < num_nodes] + sm = StructureModel(list(zip(nodes[:-1], nodes[1:]))) + _, ax, _ = plot_structure(sm) + ax_nodes = ax.collections[0].get_offsets() + assert len(ax_nodes) == num_nodes + + @pytest.mark.parametrize( + "input_positions,expected_positions", + [({"a": [1, 1], "b": [2, 2]}, [[1.0, 1.0], [2.0, 2.0]])], + ) + def test_node_positions_respected(self, input_positions, expected_positions): + """Nodes should be at the positions provided""" + sm = StructureModel([("a", "b")]) + _, ax, _ = plot_structure(sm, node_positions=input_positions) + node_coords = [list(coord) for coord in ax.collections[0].get_offsets()] + assert all( + [ + node_x == exp_x and node_y == exp_y + for ((exp_x, exp_y), (node_x, node_y)) in zip( + expected_positions, sorted(node_coords) + ) + ] + ) + + @pytest.mark.parametrize( + "test_input,expected", + [("#123456", to_rgba("#123456")), ("blue", to_rgba("blue"))], + ) + def test_node_color(self, test_input, expected): + """Node color should be set if given""" + sm = StructureModel([("a", "b")]) + _, ax, _ = plot_structure(sm, node_color=test_input) + assert all( + all(face_color == expected) + for face_color in ax.collections[0].get_facecolors() + ) + + @pytest.mark.parametrize("test_input,expected", [(False, False), (True, True)]) + def test_show_labels(self, test_input, expected): + """Labels should be hidden when show_labels set to False""" + sm = StructureModel([("a", "b")]) + _, ax, _ = plot_structure(sm, show_labels=test_input) + + assert bool(ax.texts) == expected + + @pytest.mark.parametrize( + "test_input,expected", [("r", "r"), ("#123456", "#123456")] + ) + def test_label_colors(self, test_input, expected): + """Labels should have color provided to them""" + sm = StructureModel([("a", "b")]) + _, ax, _ = plot_structure(sm, show_labels=True, label_color=test_input) + assert all(text.get_color() == expected for text in ax.texts) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py new file mode 100644 index 0000000..78f2d37 --- /dev/null +++ b/tests/test_preprocessing.py @@ -0,0 +1,454 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import math + +import numpy as np +import pytest + +from causalnex.discretiser import Discretiser + + +class TestUniform: + def test_fit_creates_exactly_uniform_splits_when_possible(self): + """splits should be exactly uniform if possible""" + + arr = np.array(range(20)) + np.random.shuffle(arr) + d = Discretiser(method="uniform", num_buckets=4) + d.fit(arr) + for n in range(2): + assert 4 < (d.numeric_split_points[n + 1] - d.numeric_split_points[n]) <= 5 + + def test_fit_creates_close_to_uniform_splits_when_uniform_not_possible(self): + """splits should be close to uniform if uniform is not possible""" + + arr = np.array(range(9)) + np.random.shuffle(arr) + d = Discretiser(method="uniform", num_buckets=4) + d.fit(arr) + + assert len(d.numeric_split_points) == 3 + for n in range(2): + assert 2 <= (d.numeric_split_points[n + 1] - d.numeric_split_points[n]) <= 3 + + def test_fit_does_not_attempt_to_deal_with_identical_split_points(self): + """if all data is identical, and num_buckets>1, then this is not possible. + In this case the standard behaviour of numpy is followed, and many identical + splits will be created. See transform for how these are applied""" + + arr = np.array([1 for _ in range(20)]) + d = Discretiser(method="uniform", num_buckets=4) + d.fit(arr) + assert np.array_equal( + np.array([d.numeric_split_points[0] for _ in range(3)]), + d.numeric_split_points, + ) + + def test_transform_uneven_split(self): + """Data that cannot be split evenly between buckets should be transformed + into near-even buckets""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + d = Discretiser(method="uniform", num_buckets=4) + d.fit(arr) + unique, counts = np.unique(d.transform(arr), return_counts=True) + # check all 4 buckets are used + assert np.array_equal([0, 1, 2, 3], unique) + # check largest difference in distribution is 1 item + assert (np.max(counts) - np.min(counts)) <= 1 + + def test_transform_larger_than_fit_range_goes_into_last_bucket(self): + """If a value larger than the input is transformed, then it + should go into the maximum bucket""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + d = Discretiser(method="uniform", num_buckets=4) + d.fit(arr) + assert np.array_equal([3], d.transform(np.array([101]))) + + def test_transform_smaller_than_fit_range_goes_into_first_bucket(self): + """If a value smaller than the input is transformed, then it + should go into the minimum bucket""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + d = Discretiser(method="uniform", num_buckets=4) + d.fit(arr) + assert np.array_equal([0], d.transform(np.array([-101]))) + + def test_fit_transform(self): + """fit transform should give the same result as calling fit and + transform separately""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + + d1 = Discretiser(method="uniform", num_buckets=4) + d1.fit(arr) + r1 = d1.transform(arr) + + d2 = Discretiser(method="uniform", num_buckets=4) + r2 = d2.fit_transform(arr) + + assert np.array_equal(r1, r2) + + +class TestQuantile: + def test_fit_uniform_data(self): + """Fitting uniform data should produce uniform splits""" + + arr = np.array(range(100001)) + np.random.shuffle(arr) + d = Discretiser(method="quantile", num_buckets=4) + d.fit(arr) + assert np.array_equal([25000, 50000, 75000], d.numeric_split_points) + + def test_fit_gauss_data(self): + """Fitting gauss data should produce standard percentiles splits""" + + arr = np.random.normal(loc=0, scale=1, size=100001) + np.random.shuffle(arr) + d = Discretiser(method="quantile", num_buckets=4) + d.fit(arr) + assert math.isclose(-0.675, d.numeric_split_points[0], abs_tol=0.025) + assert math.isclose(0, d.numeric_split_points[1], abs_tol=0.025) + assert math.isclose(0.675, d.numeric_split_points[2], abs_tol=0.025) + + def test_transform_gauss(self): + """Fitting gauss data should transform to predictable buckets""" + + arr = np.random.normal(loc=0, scale=1, size=1000000) + np.random.shuffle(arr) + d = Discretiser(method="quantile", num_buckets=4) + d.fit(arr) + unique, counts = np.unique(d.transform(arr), return_counts=True) + # check all 4 buckets are used + assert np.array_equal([0, 1, 2, 3], unique) + assert np.array_equal([250000 for n in range(4)], counts) + + def test_fit_transform(self): + """fit transform should give the same result as calling fit and + transform separately""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + + d1 = Discretiser(method="quantile", num_buckets=4) + d1.fit(arr) + r1 = d1.transform(arr) + + d2 = Discretiser(method="quantile", num_buckets=4) + r2 = d2.fit_transform(arr) + + assert np.array_equal(r1, r2) + + +class TestOutlier: + def test_outlier_percentile_lower_boundary(self): + """Discretiser should accept lower boundary down to zero""" + + Discretiser(method="outlier", outlier_percentile=0.0) + Discretiser(method="outlier", outlier_percentile=-0.0) + with pytest.raises(ValueError): + Discretiser(method="outlier", outlier_percentile=-0.1) + + def test_outlier_percentile_upper_boundary(self): + """Discretiser should accept upper boundary up to half""" + + Discretiser(method="outlier", outlier_percentile=0.49) + with pytest.raises(ValueError): + Discretiser(method="outlier", outlier_percentile=0.5) + + def test_outlier_lower_percentile(self): + """the split point for lower outliers should be at provided percentile""" + + arr = np.array(range(100001)) + np.random.shuffle(arr) + d = Discretiser(method="outlier", outlier_percentile=0.2) + d.fit(arr) + assert d.numeric_split_points[0] == 20000 + + def test_outlier_upper_percentile(self): + """the split point for upper outliers should be at range - provided percentile""" + + arr = np.array(range(100001)) + np.random.shuffle(arr) + d = Discretiser(method="outlier", outlier_percentile=0.2) + d.fit(arr) + assert d.numeric_split_points[1] == 80000 + + def test_transform_outlier(self): + """transforming outliers should put the expected amount of data in each bucket""" + + arr = np.array(range(100001)) + np.random.shuffle(arr) + d = Discretiser(method="outlier", outlier_percentile=0.2) + d.fit(arr) + unique, counts = np.unique(d.transform(arr), return_counts=True) + # check all 3 buckets are used + assert np.array_equal([0, 1, 2], unique) + # check largest difference in outliers is 1 + print(counts) + assert np.abs(counts[0] - counts[2]) <= 1 + + def test_fit_transform(self): + """fit transform should give the same result as calling fit and + transform separately""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + + d1 = Discretiser(method="outlier", outlier_percentile=0.2) + d1.fit(arr) + r1 = d1.transform(arr) + + d2 = Discretiser(method="outlier", outlier_percentile=0.2) + r2 = d2.fit_transform(arr) + + assert np.array_equal(r1, r2) + + +class TestFixed: + def test_fit_raises_error(self): + """since numeric split points are provided, fit will not do anything""" + + d = Discretiser(method="fixed", numeric_split_points=[1]) + with pytest.raises(RuntimeError): + d.fit(np.array([1])) + + def test_fit_transform_raises_error(self): + """since numeric split points are provided, fit will not do anything""" + + d = Discretiser(method="fixed", numeric_split_points=[1]) + with pytest.raises(RuntimeError): + d.fit_transform(np.array([1])) + + def test_transform_splits_using_defined_split_points(self): + """transforming should be done using the provided numeric split points""" + + d = Discretiser(method="fixed", numeric_split_points=[10, 20, 30]) + transformed = d.transform(np.array([9, 10, 11, 19, 20, 21, 29, 30, 31])) + assert np.array_equal(transformed, [0, 1, 1, 1, 2, 2, 2, 3, 3]) + + +class TestErrorHandling: + def test_invalid_method(self): + """a value error should be raised if an invalid method is given""" + + allowed_methods = ["uniform", "quantile", "outlier", "fixed", "percentiles"] + selected_method = "INVALID" + with pytest.raises( + ValueError, + match="{0} is not a recognised method. Use one of: {1}".format( + selected_method, " ".join(allowed_methods) + ), + ): + Discretiser(method=selected_method) + + def test_uniform_requires_num_buckets(self): + """a value error should be raised if method=uniform and num_buckets is not provided""" + + selected_method = "uniform" + with pytest.raises( + ValueError, + match="{0} method expects {1}".format(selected_method, "num_buckets"), + ): + Discretiser(method=selected_method) + + def test_quantile_requires_num_buckets(self): + """a value error should be raised if method=quantile and num_buckets is not provided""" + + selected_method = "quantile" + with pytest.raises( + ValueError, + match="{0} method expects {1}".format(selected_method, "num_buckets"), + ): + Discretiser(method=selected_method) + + def test_outlier_requires_outlier_percentile(self): + """a value error should be raised if method=outlier and outlier_percentile is not provided""" + + selected_method = "outlier" + with pytest.raises( + ValueError, + match="{0} method expects {1}".format( + selected_method, "outlier_percentile" + ), + ): + Discretiser(method=selected_method) + + def test_outlier_geq_zero(self): + """a value error should be raised if outlier is not >= 0""" + + Discretiser(method="outlier", outlier_percentile=0.0) + Discretiser(method="outlier", outlier_percentile=-0.0) + Discretiser(method="outlier", outlier_percentile=0.1) + with pytest.raises( + ValueError, + match="{0} must be between 0 and 0.5".format("outlier_percentile"), + ): + Discretiser(method="outlier", outlier_percentile=-0.0000001) + + def test_outlier_lt_half(self): + """a value error should be raised if outlier is not < 0.5""" + + Discretiser(method="outlier", outlier_percentile=0.49) + with pytest.raises( + ValueError, + match="{0} must be between 0 and 0.5".format("outlier_percentile"), + ): + Discretiser(method="outlier", outlier_percentile=0.5) + + def test_fixed_split_points(self): + """a value error should be raised if method=fixed and no numeric split points are provided""" + + selected_method = "fixed" + with pytest.raises( + ValueError, + match="{0} method expects {1}".format( + selected_method, "numeric_split_points" + ), + ): + Discretiser(method=selected_method) + + def test_fixed_split_points_monotonic(self): + """a value error should be raised if numeric split points are not monotonically increasing""" + + Discretiser(method="fixed", numeric_split_points=[-1, -0, 0, 1]) + with pytest.raises( + ValueError, + match="{0} must be monotonically increasing".format("numeric_split_points"), + ): + Discretiser(method="fixed", numeric_split_points=[1, -1]) + + def test_percentile_requires_percentile_split_points(self): + """a value error should be raised if method=percentiles and no percentile split points are provided""" + + selected_method = "percentiles" + with pytest.raises( + ValueError, + match="{0} method expects {1}".format( + selected_method, "percentile_split_points" + ), + ): + Discretiser(method=selected_method) + + def test_percentile_geq_zero(self): + """a value error should be raised if not all percentiles split points >= 0""" + + Discretiser(method="percentiles", percentile_split_points=[-0.0, 0.0, 0.0001]) + with pytest.raises( + ValueError, + match="{0} must be between 0 and 1".format("percentile_split_points"), + ): + Discretiser( + method="percentiles", percentile_split_points=[-0.0000001, 0.0001] + ) + + def test_percentile_leq_1(self): + """a value error should be raised if not all percentile split points <= 1""" + + Discretiser(method="percentiles", percentile_split_points=[0.0001, 1]) + with pytest.raises( + ValueError, + match="{0} must be between 0 and 1".format("percentile_split_points"), + ): + Discretiser( + method="percentiles", percentile_split_points=[0.0001, 1.0000001] + ) + + def test_percentile_split_points_monotonic(self): + """a value error should be raised if percentile split points are not monotonically increasing""" + + Discretiser(method="percentiles", percentile_split_points=[0, -0, 0.1, 1]) + with pytest.raises( + ValueError, + match="{0} must be monotonically increasing".format( + "percentile_split_points" + ), + ): + Discretiser(method="percentiles", percentile_split_points=[1, 0.1]) + + +class TestPercentile: + def test_fit_uniform_data(self): + """Fitting uniform data should produce expected percentile splits of uniform distribution""" + + arr = np.array(range(100001)) + np.random.shuffle(arr) + d = Discretiser(method="percentiles", percentile_split_points=[0.1, 0.4, 0.85]) + d.fit(arr) + assert np.array_equal([10000, 40000, 85000], d.numeric_split_points) + + def test_fit_gauss_data(self): + """Fitting gauss data should produce percentile splits of standard normal distribution""" + + arr = np.random.normal(loc=0, scale=1, size=100001) + np.random.shuffle(arr) + d = Discretiser(method="percentiles", percentile_split_points=[0.1, 0.4, 0.85]) + d.fit(arr) + assert math.isclose(-1.2815, d.numeric_split_points[0], abs_tol=0.025) + assert math.isclose(-0.253, d.numeric_split_points[1], abs_tol=0.025) + assert math.isclose(1.036, d.numeric_split_points[2], abs_tol=0.025) + + def test_transform_uniform(self): + """Fitting uniform data should transform to predictable buckets""" + + arr = np.array(range(100001)) + np.random.shuffle(arr) + d = Discretiser( + method="percentiles", percentile_split_points=[0.10, 0.40, 0.85] + ) + d.fit(arr) + unique, counts = np.unique(d.transform(arr), return_counts=True) + # check all 4 buckets are used + assert np.array_equal([0, 1, 2, 3], unique) + assert np.array_equal([10000, 30000, 45000, 15001], counts) + + def test_fit_transform(self): + """fit transform should give the same result as calling fit and + transform separately""" + + arr = np.array([n + 1 for n in range(10)]) + np.random.shuffle(arr) + + d1 = Discretiser( + method="percentiles", percentile_split_points=[0.10, 0.40, 0.85] + ) + d1.fit(arr) + r1 = d1.transform(arr) + + d2 = Discretiser( + method="percentiles", percentile_split_points=[0.10, 0.40, 0.85] + ) + r2 = d2.fit_transform(arr) + + assert np.array_equal(r1, r2) diff --git a/tools/github_release.sh b/tools/github_release.sh new file mode 100755 index 0000000..cd44e1b --- /dev/null +++ b/tools/github_release.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +GITHUB_USER=$1 +GITHUB_REPO=$2 +GITHUB_TOKEN=$3 +VERSION=$4 + +GITHUB_ENDPOINT="https://github.com/gitapi/repos/${GITHUB_USER}/${GITHUB_REPO}/releases" + +PAYLOAD=$(cat <<-END +{ + "tag_name": "${VERSION}", + "target_commitish": "master", + "name": "${VERSION}", + "body": "Release ${VERSION}", + "draft": false, + "prerelease": false +} +END +) + +STATUS=$(curl -o /dev/null -L -s -w "%{http_code}\n" -X POST -H "Authorization: token ${GITHUB_TOKEN}" \ + -H "Content-Type: application/json" ${GITHUB_ENDPOINT} -d "${PAYLOAD}") + +[ "${STATUS}" == "201" ] || [ "${STATUS}" == "422" ] diff --git a/tools/license_and_headers.py b/tools/license_and_headers.py new file mode 100644 index 0000000..3842454 --- /dev/null +++ b/tools/license_and_headers.py @@ -0,0 +1,133 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import glob + +PATHS_REQUIRING_HEADER = ["causalnex", "tests"] +LEGAL_HEADER_FILE = "legal_header.txt" +LICENSE_MD = "LICENSE.md" + +RED_COLOR = "\033[0;31m" +NO_COLOR = "\033[0m" + +LICENSE = """Copyright 2019-2020 QuantumBlack Visual Analytics Limited + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +(either separately or in combination, "QuantumBlack Trademarks") are +trademarks of QuantumBlack. The License does not grant you any right or +license to the QuantumBlack Trademarks. You may not use the QuantumBlack +Trademarks or any confusingly similar mark as a trademark for your product, +or use the QuantumBlack Trademarks in any other manner that might cause +confusion in the marketplace, including but not limited to in advertising, +on websites, or on software. + +See the License for the specific language governing permissions and +limitations under the License. +""" + + +def files_at_path(path: str): + return [fn for fn in glob.glob(path + '/**/*.py', recursive=True) + if not ('ebaybbn' in fn or 'structure/notears.py' in fn)] + + +def files_missing_substring(file_names, substring): + for file_name in file_names: + with open(file_name, "r", encoding="utf-8") as current_file: + content = current_file.read() + + if content.strip() and substring not in content: + yield file_name + + +def main(): + exit_code = 0 + + with open(LEGAL_HEADER_FILE) as header_f: + header = header_f.read() + + # find all .py files recursively + files = [ + new_file for path in PATHS_REQUIRING_HEADER for new_file in files_at_path(path) + ] + + # find all files which do not contain the header and are non-empty + files_with_missing_header = list(files_missing_substring(files, header)) + + # exit with an error and print all files without header in read, if any + if files_with_missing_header: + print( + RED_COLOR + + "The legal header is missing from the following files:\n- " + + "\n- ".join(files_with_missing_header) + + NO_COLOR + + "\nPlease add it by copy-pasting the below:\n\n" + + header + + "\n" + ) + exit_code = 1 + + # check the LICENSE.md exists and has the right contents + try: + files = list(files_missing_substring([LICENSE_MD], LICENSE)) + if files: + print( + RED_COLOR + + "Please make sure the LICENSE.md file " + + "at the root of the project " + + "has the right contents." + + NO_COLOR + ) + exit(1) + except IOError: + print( + RED_COLOR + "Please add the LICENSE.md file at the root of the project " + "with the appropriate contents." + NO_COLOR + ) + exit(1) + + # if it doesn't exist, send a notice + exit(exit_code) + + +if __name__ == "__main__": + main() diff --git a/tools/min_version.py b/tools/min_version.py new file mode 100644 index 0000000..3398d19 --- /dev/null +++ b/tools/min_version.py @@ -0,0 +1,49 @@ +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform +import shlex +import subprocess +import sys + +if __name__ == "__main__": + required_version = tuple(int(x) for x in sys.argv[1].strip().split(".")) + install_cmd = shlex.split(sys.argv[2]) + run_cmd = shlex.split(sys.argv[3]) + + current_version = tuple(map(int, platform.python_version_tuple()[:2])) + + if current_version < required_version: + print("Python version is too low, exiting") + sys.exit(0) + + try: + subprocess.run(run_cmd, check=True) + except FileNotFoundError: + subprocess.run(install_cmd, check=True) + subprocess.run(run_cmd, check=True) diff --git a/tools/python_version.sh b/tools/python_version.sh new file mode 100755 index 0000000..f4d929e --- /dev/null +++ b/tools/python_version.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash + +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +PACKAGE_DIR=$1 + +LINE=$(perl -ne "print if /^__version__\s+=\s+\"(\d+\.\d+(\.\d+|(rc\d+)*))\"$/" \ + ${PACKAGE_DIR}/__init__.py | (head -n1 && tail -n1)) + +if [ -z "${LINE}" ]; then + exit 1 +else + VERSION=$(echo ${LINE} | perl -p -e "s/__version__\s+=\s+\"(\d+\.\d+(\.\d+|(rc\d+)*))\"/\1/g") + echo ${VERSION} +fi diff --git a/tools/python_version_dev_bump.sh b/tools/python_version_dev_bump.sh new file mode 100755 index 0000000..c514335 --- /dev/null +++ b/tools/python_version_dev_bump.sh @@ -0,0 +1,43 @@ +#!/usr/bin/env bash + +# Copyright 2019-2020 QuantumBlack Visual Analytics Limited +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES +# OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, AND +# NONINFRINGEMENT. IN NO EVENT WILL THE LICENSOR OR OTHER CONTRIBUTORS +# BE LIABLE FOR ANY CLAIM, DAMAGES, OR OTHER LIABILITY, WHETHER IN AN +# ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF, OR IN +# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. +# +# The QuantumBlack Visual Analytics Limited ("QuantumBlack") name and logo +# (either separately or in combination, "QuantumBlack Trademarks") are +# trademarks of QuantumBlack. The License does not grant you any right or +# license to the QuantumBlack Trademarks. You may not use the QuantumBlack +# Trademarks or any confusingly similar mark as a trademark for your product, +# or use the QuantumBlack Trademarks in any other manner that might cause +# confusion in the marketplace, including but not limited to in advertising, +# on websites, or on software. +# +# See the License for the specific language governing permissions and +# limitations under the License. + +PACKAGE_DIR=$1 + +LCA=$(git merge-base origin/develop origin/master) +CNT_LCA=$(git rev-list --count ${LCA}..HEAD) + +LINE=$(perl -ne "print if /^__version__\s+=\s+\"(\d+\.\d+(\.\d+|(rc\d+)*))\"$/" \ + ${PACKAGE_DIR}/__init__.py | (head -n1 && tail -n1)) + +if [ ! -z "${LINE}" ] && [ ! -z "${CNT_LCA}" ]; then + perl -pi -e 's/(__version__.*(\.|rc))(\d+)(.+)/$1.($3 + 1)."'".dev${CNT_LCA}"'".$4/ge' ${PACKAGE_DIR}/__init__.py +else + exit 1 +fi