Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow customizing style of model_graph nodes #7302

Merged
merged 8 commits into from
May 8, 2024

Conversation

wd60622
Copy link
Contributor

@wd60622 wd60622 commented May 7, 2024

Description

Allow users to override the default graphviz by passing mapping from node type to function that creates the node kwargs.

The default behavior is the same but now the user can override based on the node type defined:

  • Data
  • Free Random Variable
  • Observed Random Variable
  • Deterministic
  • Potential

User needs to define function from node variable to kwargs passed to graphviz / networkx

def node_formatter(var: TensorVariable) -> dict[str, Any]:
    return {"label": var.name}

then pass a mapping from node type to node_formatter.

node_mapping = {"Data": node_formatter}
pm.model_to_graphviz(model, node_formatters=node_mapping)

Example

import pymc as pm 

with pm.Model() as model: 
    a = pm.Normal("a")
    b = pm.Normal("b", mu=a)
    c = pm.Normal("c", mu=a)
    d = pm.Normal("d", mu=b + c, observed=0)

default = pm.model_to_graphviz(model)

simple_random_variable_formatter = {
    "Free Random Variable": lambda var: {"shape": "circle", "label": var.name}, 
}
pm.model_to_graphviz(model, node_formatters=simple_random_variable_formatter)

fancy_formatter = {
    "Free Random Variable": lambda var: {"shape": "polygon", "sides": "7", "label": var.name, "style": "dashed"}, 
    "Observed Random Variable": lambda var: {"shape": "circle", "label": var.name, "style": "solid"},
}
pm.model_to_graphviz(model, node_formatters=fancy_formatter)

Default:
default

Simple:
simple

Fancy:

polygon

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7302.org.readthedocs.build/en/7302/

@wd60622
Copy link
Contributor Author

wd60622 commented May 7, 2024

@drbenvincent might be interested in this example

with pm.Model() as model: 
    c = pm.Normal("c")
    z = pm.Normal("z", mu=c)
    y = pm.Normal("y", mu=c + z)

node_formatters = {
    "Basic Random Variable": lambda var: {"shape": "circle", "label": var.name},
}
model_graph.model_to_graphviz(model, node_formatters=node_formatters)

simple

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks sweet. Left some minor suggestions.

pymc/model_graph.py Outdated Show resolved Hide resolved
pymc/model_graph.py Outdated Show resolved Hide resolved
def get_node_type(var_name: VarName, model) -> NodeType:
v = model[var_name]

if v in model.potentials:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small optimization, most models don't have potentials, and most have many free_RVs, and deterministics, so I would suggest changing the order of this to be something like if deterministics elif free rvs elif observed rvs elif data elif potentials instead

@ricardoV94 ricardoV94 changed the title Customize Graph Nodes Allow customizing style of model_graph nodes May 7, 2024
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
@wd60622
Copy link
Contributor Author

wd60622 commented May 7, 2024

Thank you for the review @ricardoV94
Will add these changes and add some tests

Comment on lines 456 to 466
@pytest.mark.xfail(reason="Graphviz is not deterministic")
def test_custom_node_formatting_graphviz(simple_model):
node_formatters = {
"Free Random Variable": lambda var: {
"label": var.name,
},
}

G = model_to_graphviz(simple_model, node_formatters=node_formatters)
assert G.source == (
"digraph {\n\ta [label=a]\n\tb [label=b]" "\n\tc [label=c]\n\ta -> b\n\tb -> c\n}\n"
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any thoughts on to tests while accounting for graphviz putting the nodes in random order in the string? @ricardoV94

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think it might have something to do with our side actual and the set ordering. graphviz python source code for reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but not 100% sure about that

@wd60622
Copy link
Contributor Author

wd60622 commented May 8, 2024

Some random / unrelated tests seem to be randomly failing. Latest run passed

@ricardoV94 ricardoV94 merged commit 82eae9a into pymc-devs:main May 8, 2024
20 checks passed
@wd60622 wd60622 deleted the format-graph-nodes branch June 26, 2024 08:34
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants