Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor model graph and allow suppressing dim lengths #7392

Merged
merged 17 commits into from
Jul 3, 2024

Conversation

wd60622
Copy link
Contributor

@wd60622 wd60622 commented Jun 26, 2024

Description

Pulled any graph related information into two methods:

  1. get_plates: Get plate meta information and the nodes that are associated with each plate
  2. edges: Edges between nodes as a list[tuple[VarName, VarName]]

The get_plates methods returns a list of Plate objects which store all the variable information. That data include:

  1. DimInfo with stores the dim names and lengths
  2. NodeInfo which stores the model variable and it's NodeType in the graph (introduced in Allow customizing style of model_graph nodes #7302)
  3. Plate which is a collection of the DimInfo and list[NodeInfo]

With list[tuple[VarName, VarName]] and list[Plate], a user can now make use of the exposed make_graph and make_networkx functions to create customized graphviz or networkx graphs.

The previous behavior of model_to_graphviz and model_to_networkx is still maintained. However, there is a new include_dim_lengths parameter that can be used to include the dim lengths in the plate labels.

The previous issue #6335 behavior has changed to now include all the variables on a plate with dlen instead of var_name_dim{d}. (See examples below)

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7392.org.readthedocs.build/en/7392/

pymc/model_graph.py Outdated Show resolved Hide resolved
# parents is a set of rv names that precede child rv nodes
for parent in parents:
yield child.replace(":", "&"), parent.replace(":", "&")

def make_graph(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should make_graph and make_networkx now be functions that take plates and edges as inputs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would just remove calling get_plates and edges methods. Don't have much of a preference

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would make it more modular, in that if you find a way to create your own plates and edges, you can just pass it to the functions that then display it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, sure. I think that makes sense then.

The dictionary of {PlateMeta : set[NodeMeta]} is a bit weird and hard to work with. i.e. set is not subscritable and looking up by PlateMeta key is a bit tricky.

I was thinking of having another object, Plate which would be:

@dataclass 
class Plate: 
    plate_meta: PlateMeta
    nodes: list[NodeMeta]

and that would be in the input to make_graph and make_networkx instead. Making the signature: (plates: list[Plate], edges: list[tuple[str, str]], ...)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, does it make sense as a method still? Do you see model_to_graphviz taking this input as well?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lost track of the specific methods we're discussing. My low resolution guess was that once we have the plates / edges we can just pass them to a function that uses those to render graphviz or networkx graphs. Let me know if you were asking about something else or see a problem (or no point) with that

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Let me push something up and you can give feedback

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed.
If user has arbitrary list[Plate] and list[tuple[VarName, VarName]] then they can use make_graph or make_networkx in order to make the graphviz or networkx, respectively.
pm.model_to_graphviz and pm.model_to_networkx are still wrappers.
ModelGraph class can be used to create the plates and edges in the previous manner if desired with the get_plates and edges methods

pymc/model_graph.py Outdated Show resolved Hide resolved
# parents is a set of rv names that precede child rv nodes
for parent in parents:
yield child.replace(":", "&"), parent.replace(":", "&")

def make_graph(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would just remove calling get_plates and edges methods. Don't have much of a preference

# must be preceded by 'cluster' to get a box around it
plate_label = create_plate_label(plate_meta, include_size=include_shape_size)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticing that the plate_label actually depends on var_name in case of previous "{var_name}_dim{d}". However, the plate_label is required before the looping of all_var_names. i.e. all_var_names is assumed to be one element? Maybe that should be an explicit case?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Didn't manage to follow, can you explain again?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The graph.subgraph name of "cluster" + plate_label is dependent on the var_name which used to be constructed in the get_plates method (the previous keys of dictionary where the plate_label).

However, after the subgraph is constructed, the all_var_names is looped over. This is assuming that all_var_names is only one element since the plate_label is used in the subgraph name.

for plate in self.get_plates(var_names):
plate_meta = plate.meta
all_vars = plate.variables
if plate_meta.names or plate_meta.sizes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we simplify? Could plate_meta be None for the scalar variables?

Copy link
Contributor Author

@wd60622 wd60622 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic would still be needed somewhere. Likely in get_plates then.
How about having the __bool__ method for Plate class that does this logic.
Then would act like None and read like:

if plate_meta: # Truthy if sizes or names
    # plate_meta has sizes or names that are not empty tuples

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have that information when you defined the plate.meta no? Can't you do it immediately?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I changed to have it happen in the get_plates methods. Scalars will have Plate(meta=None, variables=[...])

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's enough to check for sizes? It is not possible for a plate to have names, but not sizes?

We should rename those to dim_names, and dim_lengths. And perhaps use None for dim_lengths for which we don't know the name?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So IIUC, scalars should belong to a "Plate" with dim_names = (), and dim_lengths = ()?

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And now I understand your approach and I think it was better like you did. The __bool__ sounds fine as well!

Sorry I got confused by the names of the things

@@ -49,6 +49,9 @@ class PlateMeta:
def __hash__(self):
return hash((self.names, self.sizes))

def __bool__(self) -> bool:
return len(self.sizes) > 0 or len(self.names) > 0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is a plate without names?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

with pm.Model(): 
    pm.Normal("y", shape=3)

y has sizes but no dim names.
Currently creates Plate(meta=PlateMeta(names=(), sizes=(5, )), variables=[NodeMeta(var=y, node_type=...)])

Think there should be some cases to test now that this logic is exposed. Will be much easier to confirm

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

names here are dim names

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens for a Deterministic with dims=("test_dim", None)? Apparently we still allow None dims for things that are not RVs

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That y should be names=(None,) ?

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm thinking of pm.Deterministic("x", np.zeros((3, 3)), dims=("hello", None)) and pm.Deterministic("y", np.zeros((3, 3)), dims=(None, "hello"). We don't want to put those in the same plate because dims can't be repeated, so they are definitely different things?

Can we add a test for that?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a test. I had to wrap the data in as_tensor_variable or I'd get an error saying the data needs name attribute

Comment on lines 369 to 372
plate_meta = PlateMeta(
names=tuple(names),
sizes=tuple(sizes),
)
Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand this tbh. Are we creating one plate per variable? But a plate can contain multiple variables?

Also names is ambiguous, it is dim_names? We should name it like that to distinguish from var_names?
Also sizes -> dim_lengths

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah plates are hashable... so you mutate the same thing...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just working with what was there. The historical { str: set[VarName] } is created with loop which I changed to { PlateMeta : set[NodeMeta] }
But switched to list[Plate] ultimately.
Ideally, there could be more straight-foward path to list[Plate]

Copy link
Member

@ricardoV94 ricardoV94 Jun 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic feels rather convoluted to be honest. Maybe we can take a step back and see what is actually needed.
Step1: Collect the dim names and dim lengths of every variable we want to plot. This seems simple enough, and we can do in a loop
Step2: Merge variables that have identical dim_names and dim_lengths into "plates". The hashable Plate thing may be a good trick to achieve that, or just a defaultdict with keys: tuple[dim_names, dim_lengths]

Would the code be more readable if we didn't try to do both things at once?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Updated comment above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

defaultdict with dims_names and dim_lengths is same as what is currently happening. But there is a wrapper class around it. Personally, I find the class helpful and more user friendly. But I could be wrong

For instance,

Plate(
    DimInfo(names=("obs", "covariate"), sizes=(10, 5)), 
    variables=[
        NodeInfo(X, node_type=DATA), 
        NodeInfo(X_transform, node_type=DETERMINISTIC), 
        NodeInfo(tvp, node_type=FREE_RV),
    ]
)

over

(("obs", "covariate"), (10, 5), (X, X_transform, tvp), (DATA, DETERMINSTIC, FREE_RV))

lines up a bit better in my mind that the first two are related objects and the last two are related objects as well

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, just thinking of how easy we make it for users to define their custom stuff. Either way seems manageable

@wd60622
Copy link
Contributor Author

wd60622 commented Jun 27, 2024

The name Plate and PlateMeta come from the historical get_plates method of ModelGraph. However, get_plates also get scalars which were "" before and now Plate(meta=None, variables=[...])

Is Plate still a good name? It is a collection of variables all with the same dims. Plate in my mind is Bayesian graphical model and might deviate with the scalars.
PlateMeta might be more suited as DimsMeta since the names and sizes are the dims of the variables

Any thoughts here on terminology?

@ricardoV94
Copy link
Member

I'm okay with Plate or Cluster. Why the Meta in it?

pymc/model_graph.py Outdated Show resolved Hide resolved
@wd60622
Copy link
Contributor Author

wd60622 commented Jun 27, 2024

I'm okay with Plate or Cluster. Why the Meta in it?

Meta would be information about the variables / plate to construct a plate label. Previously it was always " x ".join([f"{dname} ({dlen})" for ...]
Meta just provides the parts to construct based on components presented before

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 27, 2024

I don't love the word meta, it's too abstract. Plate.dim_names, Plate.dim_lengths, Plate.vars? or Plate.var_names if that's what we are storing

@wd60622
Copy link
Contributor Author

wd60622 commented Jun 27, 2024

I don't love the word meta, it's too abstract. Plate.dim_names, Plate.dim_lengths, Plate.vars? or Plate.var_names if that's what we are storing

I think itd be nice to keep the names and sizes together since they are related. How about DimInfo

@ricardoV94
Copy link
Member

Is the question whether we represent a data structure that looks like (in terms of access): ((dims_names, dim_lengths), var_names) vs (dim_names, dim_lengths, var_names)? Seems like a tiny detail. I have a slight preference for having it flat but up to you

@ricardoV94
Copy link
Member

This PR refreshed my mind that #6485 and #7048 exist.

To summarize: We can have variables that have entries in named_vars_to_dims of type tuple[str | None, ...]. We can also have variables that don't show up in named_vars_to_dims at all? Which is odd, since we already allow None to represent unknown dims, so all variables could conceivable have entries (or we would not allow None).

Then dims can have coords or not, but always have dim_lengths, which always work when we do the fast_eval for dim_lengths, so that's not a problem that shows up here. I think that doesn't matter here for us. Just mentioning in case I brought it up by mistake in my comments.

Comment on lines 447 to 451
plate_label = create_plate_label(
plate.variables[0].var.name,
plate.meta,
include_size=include_shape_size,
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should create_plate_label now take plate_formatters that among other things decides on whether to include_size?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I think that is fair. Where do you view that being exposed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I exposed the create_plate_label in both make_graph and make_networkx. However, left it out in the model_to_graphviz and model_to_networkx functions.

If a user defines Callable[[DimInfo], str] function, then that can be used in the first two, more general functions

pymc/model_graph.py Outdated Show resolved Hide resolved
@wd60622
Copy link
Contributor Author

wd60622 commented Jun 28, 2024

Is the question whether we represent a data structure that looks like (in terms of access): ((dims_names, dim_lengths), var_names) vs (dim_names, dim_lengths, var_names)? Seems like a tiny detail. I have a slight preference for having it flat but up to you

There is also the NodeType which is why I went for the small dataclass wrapper that contains TensorVariable and the preprocessed label. I think have a small data structure isn't the end of the world but also helps structure the problem a bit more. The user can clearly see what is part of the new data structures in my mind

pymc/model_graph.py Outdated Show resolved Hide resolved
@wd60622
Copy link
Contributor Author

wd60622 commented Jun 28, 2024

Need to

The 6335 comes up with this example:

# Current main branch
coords = {
    "obs": range(5),
}
with pm.Model(coords=coords) as model:
    data = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data",
    )
    pm.Deterministic("C", data, dims=("obs", None))
    pm.Deterministic("D", data, dims=("obs", None))
    pm.Deterministic("E", data, dims=("obs", None))

pm.model_to_graphviz(model)

Result:
previous-with-none

Which makes sense that they will not be on the same plate, right?

@wd60622
Copy link
Contributor Author

wd60622 commented Jun 28, 2024

I did just catch this bug: It comes from the make_compute_graph which causes a self loop

from pymc.model_graph import ModelGraph

coords = {
    "obs": range(5),
}
with pm.Model(coords=coords) as model:
    data = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="C",
    )
    pm.Deterministic("C", data, dims=("obs", None))

error_compute_graph = ModelGraph(model).make_compute_graph() # defaultdict(set, {"C": {"C"}})
# Visualize error:
pm.model_to_graphviz(model)

Result:

compute-graph-bug

Shall I make a separate issue?

@ricardoV94
Copy link
Member

I think they should be in the same plate, because in the absense of dims, the shape is used to cluster RVs?

@ricardoV94
Copy link
Member

Self loop is beautiful :)

@wd60622
Copy link
Contributor Author

wd60622 commented Jun 29, 2024

I think they should be in the same plate, because in the absense of dims, the shape is used to cluster RVs?

How should the {var_name}_dim{d} be handled then to put them on the same plate?

Just "dim{d} ({dlen})"?

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 29, 2024

Just the length? how does a plate without any dims look like?

I imagine the mix would be 50 x trial(30) or however the trial dim is usually displayed.

WDYT?

@wd60622
Copy link
Contributor Author

wd60622 commented Jun 30, 2024

Just the length? how does a plate without any dims look like?

I imagine the mix would be 50 x trial(30) or however the trial dim is usually displayed.

WDYT?

This mixing of dlen and "{dname} ({dlen})" is what I had in mind. That is the current behavior.

Here are some examples:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

coords = {
    "obs": range(5),
}
with pm.Model(coords=coords) as model:
    data = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data",
    )
    C = pm.Deterministic("C", data, dims=("obs", None))
    D = pm.Deterministic("D", data, dims=("obs", None))
    E = pm.Deterministic("E", data, dims=("obs", None))

pm.model_to_graphviz(model)

same-plate

# Same as above
pm.model_to_graphviz(model, include_dim_lengths=False)

same-plate-without

And larger example with various items:

import numpy as np
import pymc as pm
import pytensor.tensor as pt

coords = {
    "obs": range(5),
    "covariates": ["X1", "X2", "X3"],
}
with pm.Model(coords=coords) as model: 
    data1 = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data1",
    )
    data2 = pt.as_tensor_variable(
        np.ones((5, 3)),
        name="data2",
    )
    C = pm.Deterministic("C", data1, dims=("obs", None))
    CT = pm.Deterministic("CT", C.T, dims=(None, "obs"))
    D = pm.Deterministic("D", C @ CT, dims=("obs", "obs"))

    E = pm.Deterministic("E", data2, dims=("obs", None))
    beta = pm.Normal("beta", dims="covariates")
    pm.Deterministic("product", E[:, None, :] * beta[:, None], dims=("obs", None, "covariates"))

pm.model_to_graphviz(model)

larger-example

pymc/model_graph.py Outdated Show resolved Hide resolved
pymc/model_graph.py Outdated Show resolved Hide resolved
pymc/model_graph.py Outdated Show resolved Hide resolved
@ricardoV94 ricardoV94 marked this pull request as ready for review July 3, 2024 10:48
Copy link

codecov bot commented Jul 3, 2024

Codecov Report

Attention: Patch coverage is 76.66667% with 28 lines in your changes missing coverage. Please review.

Project coverage is 92.18%. Comparing base (7af0a87) to head (e30f6d9).
Report is 19 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7392      +/-   ##
==========================================
- Coverage   92.19%   92.18%   -0.01%     
==========================================
  Files         103      103              
  Lines       17214    17249      +35     
==========================================
+ Hits        15870    15901      +31     
- Misses       1344     1348       +4     
Files Coverage Δ
pymc/model_graph.py 87.25% <76.66%> (-0.13%) ⬇️

... and 5 files with indirect coverage changes

@ricardoV94 ricardoV94 merged commit f719796 into pymc-devs:main Jul 3, 2024
22 checks passed
@ricardoV94 ricardoV94 changed the title Abstract Graph Iteration Refactor model graph and allow suppressing dim lengths Jul 3, 2024
@ricardoV94
Copy link
Member

Thanks @wd60622

@wd60622 wd60622 deleted the abstract-graph-iteration branch July 3, 2024 12:29
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

ENH: Customizable plate labels
2 participants