Skip to content

Commit

Permalink
Local Area graphs (ecmwf#30)
Browse files Browse the repository at this point in the history
Limited Area graphs

Co-authored-by: Helen Theissen <helen.theissen@ecmwf.int>
Co-authored-by: Jesper Dramsch <jesper.dramsch@ecmwf.int>
  • Loading branch information
3 people authored Oct 3, 2024
1 parent 7a35e10 commit 986ee9e
Show file tree
Hide file tree
Showing 25 changed files with 1,005 additions and 585 deletions.
6 changes: 5 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ Keep it human-readable, your future self will thank you!

### Added
- ci: hpc-config, CODEOWNERS (#49)
- feat: New node builder class, CutOutZarrDatasetNodes, to create nodes from 2 datasets. (#30)
- feat: New class, KNNAreaMaskBuilder, to specify Area of Interest (AOI) based on a set of nodes. (#30)
- feat: New node builder classes, LimitedAreaXXXXXNodes, to create nodes within an Area of Interest (AOI). (#30)
- feat: Expanded MultiScaleEdges to support multi-scale connections in limited area graphs. (#30)

### Changed
- ci: small fixes and updates pre-commit, downsteam-ci (#49)
Expand All @@ -20,7 +24,7 @@ Keep it human-readable, your future self will thank you!

### Added

- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere.
- HEALPixNodes - nodebuilder based on Hierarchical Equal Area isoLatitude Pixelation of a sphere

- Inspection tools: interactive plots, and distribution plots of edge & node attributes.

Expand Down
2 changes: 1 addition & 1 deletion docs/usage/getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ following command:

.. code:: console
$ anemoi-graphs inspect graph.pt
$ anemoi-graphs inspect graph.pt output_plots
This will generate the following graph:

Expand Down
12 changes: 8 additions & 4 deletions src/anemoi/graphs/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,14 @@ def generate_graph(self) -> HeteroData:
graph, nodes_cfg.get("attributes", {})
)

for edges_cfg in self.config.edges:
graph = instantiate(edges_cfg.edge_builder, edges_cfg.source_name, edges_cfg.target_name).update_graph(
graph, edges_cfg.get("attributes", {})
)
for edges_cfg in self.config.get("edges", {}):
graph = instantiate(
edges_cfg.edge_builder,
edges_cfg.source_name,
edges_cfg.target_name,
source_mask_attr_name=edges_cfg.get("source_mask_attr_name", None),
target_mask_attr_name=edges_cfg.get("target_mask_attr_name", None),
).update_graph(graph, edges_cfg.get("attributes", {}))

return graph

Expand Down
177 changes: 122 additions & 55 deletions src/anemoi/graphs/edges/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,18 @@
import torch
from anemoi.utils.config import DotDict
from hydra.utils import instantiate
from scipy.sparse import coo_matrix
from sklearn.neighbors import NearestNeighbors
from torch_geometric.data import HeteroData
from torch_geometric.data.storage import NodeStorage

from anemoi.graphs import EARTH_RADIUS
from anemoi.graphs.generate import hexagonal
from anemoi.graphs.generate import icosahedral
from anemoi.graphs.nodes.builder import HexNodes
from anemoi.graphs.nodes.builder import TriNodes
from anemoi.graphs.generate import hex_icosahedron
from anemoi.graphs.generate import tri_icosahedron
from anemoi.graphs.nodes.builders.from_refined_icosahedron import HexNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaHexNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import LimitedAreaTriNodes
from anemoi.graphs.nodes.builders.from_refined_icosahedron import TriNodes
from anemoi.graphs.utils import get_grid_reference_distance

LOGGER = logging.getLogger(__name__)
Expand All @@ -26,9 +29,17 @@
class BaseEdgeBuilder(ABC):
"""Base class for edge builders."""

def __init__(self, source_name: str, target_name: str):
def __init__(
self,
source_name: str,
target_name: str,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
self.source_name = source_name
self.target_name = target_name
self.source_mask_attr_name = source_mask_attr_name
self.target_mask_attr_name = target_mask_attr_name

@property
def name(self) -> tuple[str, str, str]:
Expand Down Expand Up @@ -117,15 +128,48 @@ def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -
"""
graph = self.register_edges(graph)

if attrs_config is None:
return graph

graph = self.register_attributes(graph, attrs_config)
if attrs_config is not None:
graph = self.register_attributes(graph, attrs_config)

return graph


class KNNEdges(BaseEdgeBuilder):
class NodeMaskingMixin:
"""Mixin class for masking source/target nodes when building edges."""

def get_node_coordinates(
self, source_nodes: NodeStorage, target_nodes: NodeStorage
) -> tuple[np.ndarray, np.ndarray]:
"""Get the node coordinates."""
source_coords, target_coords = source_nodes.x.numpy(), target_nodes.x.numpy()

if self.source_mask_attr_name is not None:
source_coords = source_coords[source_nodes[self.source_mask_attr_name].squeeze()]

if self.target_mask_attr_name is not None:
target_coords = target_coords[target_nodes[self.target_mask_attr_name].squeeze()]

return source_coords, target_coords

def undo_masking(self, adj_matrix, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.target_mask_attr_name is not None:
target_mask = target_nodes[self.target_mask_attr_name].squeeze()
target_mapper = dict(zip(list(range(len(adj_matrix.row))), np.where(target_mask)[0]))
adj_matrix.row = np.vectorize(target_mapper.get)(adj_matrix.row)

if self.source_mask_attr_name is not None:
source_mask = source_nodes[self.source_mask_attr_name].squeeze()
source_mapper = dict(zip(list(range(len(adj_matrix.col))), np.where(source_mask)[0]))
adj_matrix.col = np.vectorize(source_mapper.get)(adj_matrix.col)

if self.source_mask_attr_name is not None or self.target_mask_attr_name is not None:
true_shape = target_nodes.x.shape[0], source_nodes.x.shape[0]
adj_matrix = coo_matrix((adj_matrix.data, (adj_matrix.row, adj_matrix.col)), shape=true_shape)

return adj_matrix


class KNNEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes KNN based edges and adds them to the graph.
Attributes
Expand All @@ -136,6 +180,10 @@ class KNNEdges(BaseEdgeBuilder):
The name of the target nodes.
num_nearest_neighbours : int
Number of nearest neighbours.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.
Methods
-------
Expand All @@ -147,27 +195,35 @@ class KNNEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, num_nearest_neighbours: int):
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
num_nearest_neighbours: int,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
):
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(num_nearest_neighbours, int), "Number of nearest neighbours must be an integer"
assert num_nearest_neighbours > 0, "Number of nearest neighbours must be positive"
self.num_nearest_neighbours = num_nearest_neighbours

def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarray) -> np.ndarray:
def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage) -> np.ndarray:
"""Compute the adjacency matrix for the KNN method.
Parameters
----------
source_nodes : np.ndarray
source_nodes : NodeStorage
The source nodes.
target_nodes : np.ndarray
target_nodes : NodeStorage
The target nodes.
Returns
-------
np.ndarray
The adjacency matrix.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
assert self.num_nearest_neighbours is not None, "number of neighbors required for knn encoder"
LOGGER.info(
"Using KNN-Edges (with %d nearest neighbours) between %s and %s.",
Expand All @@ -177,16 +233,20 @@ def get_adjacency_matrix(self, source_nodes: np.ndarray, target_nodes: np.ndarra
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x.numpy())
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.kneighbors_graph(
target_nodes.x.numpy(),
target_coords,
n_neighbors=self.num_nearest_neighbours,
mode="distance",
).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


class CutOffEdges(BaseEdgeBuilder):
class CutOffEdges(BaseEdgeBuilder, NodeMaskingMixin):
"""Computes cut-off based edges and adds them to the graph.
Attributes
Expand All @@ -197,6 +257,10 @@ class CutOffEdges(BaseEdgeBuilder):
The name of the target nodes.
cutoff_factor : float
Factor to multiply the grid reference distance to get the cut-off radius.
source_mask_attr_name : str | None
The name of the source mask attribute to filter edge connections.
target_mask_attr_name : str | None
The name of the target mask attribute to filter edge connections.
Methods
-------
Expand All @@ -208,8 +272,15 @@ class CutOffEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, cutoff_factor: float) -> None:
super().__init__(source_name, target_name)
def __init__(
self,
source_name: str,
target_name: str,
cutoff_factor: float,
source_mask_attr_name: str | None = None,
target_mask_attr_name: str | None = None,
) -> None:
super().__init__(source_name, target_name, source_mask_attr_name, target_mask_attr_name)
assert isinstance(cutoff_factor, (int, float)), "Cutoff factor must be a float"
assert cutoff_factor > 0, "Cutoff factor must be positive"
self.cutoff_factor = cutoff_factor
Expand Down Expand Up @@ -258,6 +329,7 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
np.ndarray
The adjacency matrix.
"""
source_coords, target_coords = self.get_node_coordinates(source_nodes, target_nodes)
LOGGER.info(
"Using CutOff-Edges (with radius = %.1f km) between %s and %s.",
self.radius * EARTH_RADIUS,
Expand All @@ -266,8 +338,12 @@ def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStor
)

nearest_neighbour = NearestNeighbors(metric="haversine", n_jobs=4)
nearest_neighbour.fit(source_nodes.x)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_nodes.x, radius=self.radius).tocoo()
nearest_neighbour.fit(source_coords)
adj_matrix = nearest_neighbour.radius_neighbors_graph(target_coords, radius=self.radius).tocoo()

# Post-process the adjacency matrix. Add masked nodes.
adj_matrix = self.undo_masking(adj_matrix, source_nodes, target_nodes)

return adj_matrix


Expand All @@ -294,61 +370,52 @@ class MultiScaleEdges(BaseEdgeBuilder):
Update the graph with the edges.
"""

def __init__(self, source_name: str, target_name: str, x_hops: int):
VALID_NODES = [TriNodes, HexNodes, LimitedAreaTriNodes, LimitedAreaHexNodes]

def __init__(self, source_name: str, target_name: str, x_hops: int, **kwargs):
super().__init__(source_name, target_name)
assert source_name == target_name, f"{self.__class__.__name__} requires source and target nodes to be the same."
assert isinstance(x_hops, int), "Number of x_hops must be an integer"
assert x_hops > 0, "Number of x_hops must be positive"
self.x_hops = x_hops
self.node_type = None

def adjacency_from_tri_nodes(self, source_nodes: NodeStorage):
source_nodes["_nx_graph"] = icosahedral.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
def add_edges_from_tri_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"] = tri_icosahedron.add_edges_to_nx_graph(
nodes["_nx_graph"],
resolutions=nodes["_resolutions"],
x_hops=self.x_hops,
) # HeteroData refuses to accept None

adjmat = nx.to_scipy_sparse_array(
source_nodes["_nx_graph"], nodelist=list(range(len(source_nodes["_nx_graph"]))), format="coo"
area_mask_builder=nodes.get("_area_mask_builder", None),
)
return adjmat

def adjacency_from_hex_nodes(self, source_nodes: NodeStorage):
return nodes

source_nodes["_nx_graph"] = hexagonal.add_edges_to_nx_graph(
source_nodes["_nx_graph"],
resolutions=source_nodes["_resolutions"],
def add_edges_from_hex_nodes(self, nodes: NodeStorage) -> NodeStorage:
nodes["_nx_graph"] = hex_icosahedron.add_edges_to_nx_graph(
nodes["_nx_graph"],
resolutions=nodes["_resolutions"],
x_hops=self.x_hops,
)

adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")
return adjmat
return nodes

def get_adjacency_matrix(self, source_nodes: NodeStorage, target_nodes: NodeStorage):
if self.node_type == TriNodes.__name__:
adjmat = self.adjacency_from_tri_nodes(source_nodes)
elif self.node_type == HexNodes.__name__:
adjmat = self.adjacency_from_hex_nodes(source_nodes)
if self.node_type in [TriNodes.__name__, LimitedAreaTriNodes.__name__]:
source_nodes = self.add_edges_from_tri_nodes(source_nodes)
elif self.node_type in [HexNodes.__name__, LimitedAreaHexNodes.__name__]:
source_nodes = self.add_edges_from_hex_nodes(source_nodes)
else:
raise ValueError(f"Invalid node type {self.node_type}")

adjmat = self.post_process_adjmat(source_nodes, adjmat)

return adjmat
adjmat = nx.to_scipy_sparse_array(source_nodes["_nx_graph"], format="coo")

def post_process_adjmat(self, nodes: NodeStorage, adjmat):
graph_sorted = {node_pos: i for i, node_pos in enumerate(nodes["_node_ordering"])}
sort_func = np.vectorize(graph_sorted.get)
adjmat.row = sort_func(adjmat.row)
adjmat.col = sort_func(adjmat.col)
return adjmat

def update_graph(self, graph: HeteroData, attrs_config: DotDict | None = None) -> HeteroData:
assert (
graph[self.source_name].node_type == TriNodes.__name__
or graph[self.source_name].node_type == HexNodes.__name__
), f"{self.__class__.__name__} requires {TriNodes.__name__} or {HexNodes.__name__}."

self.node_type = graph[self.source_name].node_type
valid_node_names = [n.__name__ for n in self.VALID_NODES]
assert (
self.node_type in valid_node_names
), f"{self.__class__.__name__} requires {','.join(valid_node_names)} nodes."

return super().update_graph(graph, attrs_config)
Loading

0 comments on commit 986ee9e

Please sign in to comment.