-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow customizing style of model_graph nodes #7302
Conversation
@drbenvincent might be interested in this example with pm.Model() as model:
c = pm.Normal("c")
z = pm.Normal("z", mu=c)
y = pm.Normal("y", mu=c + z)
node_formatters = {
"Basic Random Variable": lambda var: {"shape": "circle", "label": var.name},
}
model_graph.model_to_graphviz(model, node_formatters=node_formatters) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks sweet. Left some minor suggestions.
pymc/model_graph.py
Outdated
def get_node_type(var_name: VarName, model) -> NodeType: | ||
v = model[var_name] | ||
|
||
if v in model.potentials: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Small optimization, most models don't have potentials, and most have many free_RVs, and deterministics, so I would suggest changing the order of this to be something like if deterministics elif free rvs elif observed rvs elif data elif potentials
instead
Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com>
Thank you for the review @ricardoV94 |
tests/test_model_graph.py
Outdated
@pytest.mark.xfail(reason="Graphviz is not deterministic") | ||
def test_custom_node_formatting_graphviz(simple_model): | ||
node_formatters = { | ||
"Free Random Variable": lambda var: { | ||
"label": var.name, | ||
}, | ||
} | ||
|
||
G = model_to_graphviz(simple_model, node_formatters=node_formatters) | ||
assert G.source == ( | ||
"digraph {\n\ta [label=a]\n\tb [label=b]" "\n\tc [label=c]\n\ta -> b\n\tb -> c\n}\n" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Any thoughts on to tests while accounting for graphviz putting the nodes in random order in the string? @ricardoV94
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Think it might have something to do with our side actual and the set ordering. graphviz python source code for reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but not 100% sure about that
Some random / unrelated tests seem to be randomly failing. Latest run passed |
Description
Allow users to override the default graphviz by passing mapping from node type to function that creates the node kwargs.
The default behavior is the same but now the user can override based on the node type defined:
User needs to define function from node variable to kwargs passed to graphviz / networkx
then pass a mapping from node type to node_formatter.
Example
Default:
Simple:
Fancy:
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7302.org.readthedocs.build/en/7302/