Skip to content

Commit

Permalink
Remove existing causal mechanisms when creating GCM
Browse files Browse the repository at this point in the history
Before, when a causal graph had causal mechanisms assigned, they were also used when creating a new GCM object based on it. Now, they are removed (from a copied version of the graph).

Signed-off-by: Patrick Bloebaum <bloebp@amazon.com>
  • Loading branch information
bloebp committed Jun 18, 2024
1 parent 512d1b0 commit 48dcaeb
Showing 1 changed file with 17 additions and 1 deletion.
18 changes: 17 additions & 1 deletion dowhy/gcm/causal_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""This module defines the fundamental classes for graphical causal models (GCMs)."""

from copy import deepcopy
from typing import Any, Callable, Optional, Union

import networkx as nx
Expand Down Expand Up @@ -32,12 +33,17 @@ class ProbabilisticCausalModel:
causal mechanisms can be any general stochastic models."""

def __init__(
self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph
self,
graph: Optional[DirectedGraph] = None,
graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph,
remove_existing_mechanisms: bool = False,
):
"""
:param graph: Optional graph object to be used as causal graph.
:param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph
constructor.
:param remove_existing_mechanisms: If True, removes existing causal mechanisms assigned to nodes if they exist.
Otherwise, does not modify graph.
"""
# Todo: Remove after https://github.com/py-why/dowhy/pull/943.
from dowhy.causal_graph import CausalGraph
Expand All @@ -50,6 +56,16 @@ def __init__(
elif isinstance(graph, CausalGraph):
graph = graph_copier(graph._graph)

if remove_existing_mechanisms:
for node in graph.nodes:
if CAUSAL_MECHANISM in graph.nodes[node]:
del graph.nodes[node][CAUSAL_MECHANISM]

# Create deep copies to avoid referencing the original one.
for node in graph.nodes:
if CAUSAL_MECHANISM in graph.nodes[node]:
graph.nodes[node][CAUSAL_MECHANISM] = deepcopy(graph.nodes[node][CAUSAL_MECHANISM])

self.graph = graph
self.graph_copier = graph_copier

Expand Down

0 comments on commit 48dcaeb

Please sign in to comment.