Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

StructureData types and docs #17

Merged
merged 7 commits into from
Mar 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 3 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -145,16 +145,15 @@ trainer.train(train_loader, val_loader, test_loader)
1. The energy used for training should be energy/atom if you're fine-tuning the pretrained `CHGNet`.
2. The pretrained dataset of `CHGNet` comes from GGA+U DFT with [`MaterialsProject2020Compatibility`](https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/compatibility.py#L826-L1102).
The parameter for VASP is described in [`MPRelaxSet`](https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/io/vasp/sets.py#L862-L879).
If you're fine-tuning with `MPRelaxSet``](https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/io/vasp/sets.py#L862-L879), it is recommended that apply the [`MP2020`](https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/compatibility.py#L826-L1102)
If you're fine-tuning with [`MPRelaxSet`](https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/io/vasp/sets.py#L862-L879), it is recommended to apply the [`MP2020`](https://github.com/materialsproject/pymatgen/blob/v2023.2.28/pymatgen/entries/compatibility.py#L826-L1102)
compatibility to your energy labels so that they're consistent with the pretrained dataset.
3. If you're fine-tuning to functionals other than GGA, we recommend you to refit the [`AtomRef`](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/model/composition_model.py).
3. If you're fine-tuning to functionals other than GGA, we recommend you refit the [`AtomRef`](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/model/composition_model.py).
4. `CHGNet` stress is in unit GPa, and the unit conversion has already been included in
[`dataset.py`](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/data/dataset.py). So `VASP` stress can be directly fed to `StructureData`
5. To save time from graph conversion step for each training, we recommend you use [`GraphData`](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/data/dataset.py) defined in
[`dataset.py`](https://github.com/CederGroupHub/chgnet/blob/main/chgnet/data/dataset.py), which reads graphs directly from saved directory. To create saved graphs,
see [`examples/make_graphs.py`](https://github.com/CederGroupHub/chgnet/blob/main/examples/make_graphs.py).
6. Apple’s Metal Performance Shaders `MPS` is currently disabled before stable version of `pytorch` for
`MPS` is released.
6. Apple’s Metal Performance Shaders `MPS` is currently disabled until a stable version of `pytorch` for `MPS` is released.

## Reference

Expand Down
82 changes: 41 additions & 41 deletions chgnet/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os
import random
import warnings
from typing import Sequence

import numpy as np
import torch
Expand All @@ -24,48 +25,47 @@ class StructureData(Dataset):

def __init__(
self,
structures: list,
energies: list,
forces: list,
stresses: list = None,
magmoms: list = None,
structures: list[Structure],
energies: list[float],
forces: list[Sequence[Sequence[float]]],
stresses: list[Sequence[Sequence[float]]] = None,
magmoms: list[Sequence[Sequence[float]]] = None,
graph_converter: CrystalGraphConverter = None,
) -> None:
"""Initialize the dataset.

Args:
structures (list): a list of structures
energies (list): a list of energies
forces (list): a list of forces
stresses (List, optional): a list of stresses
magmoms (List, optional): a list of magmoms
graph_converter (CrystalGraphConverter, optional):
a CrystalGraphConverter to convert the structures,
if None, it will be set to CHGNet default converter
structures (list[dict]): pymatgen Structure objects.
energies (list[float]): [data_size, 1]
forces (list[list[float]]): [data_size, n_atoms, 3]
stresses (list[list[float]], optional): [data_size, 3, 3]
magmoms (list[list[float]], optional): [data_size, n_atoms, 1]
graph_converter (CrystalGraphConverter, optional): Converts the structures to
graphs. If None, it will be set to CHGNet default converter.
"""
for idx, struct in enumerate(structures):
if not isinstance(struct, Structure):
raise ValueError(f"{idx} is not a pymatgen Structure object: {struct}")
self.structures = structures
self.energies = energies
self.forces = forces
self.stresses = stresses
self.magmoms = magmoms
self.keys = np.arange(len(structures))
random.shuffle(self.keys)
print(f"{len(self.structures)} structures imported")
if graph_converter is not None:
self.graph_converter = graph_converter
else:
self.graph_converter = CrystalGraphConverter(
atom_graph_cutoff=5, bond_graph_cutoff=3
)
self.failed_idx = []
self.failed_graph_id = {}
print(f"{len(structures)} structures imported")
self.graph_converter = graph_converter or CrystalGraphConverter(
atom_graph_cutoff=5, bond_graph_cutoff=3
)
self.failed_idx: list[int] = []
self.failed_graph_id: dict[str, str] = {}

def __len__(self) -> int:
"""Get the number of structures in this dataset."""
return len(self.keys)

@functools.lru_cache(maxsize=None) # Cache loaded structures
def __getitem__(self, idx) -> tuple[CrystalGraph, dict]:
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict]:
"""Get one graph for a structure in this dataset.

Args:
Expand All @@ -78,7 +78,7 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict]:
if idx not in self.failed_idx:
graph_id = self.keys[idx]
try:
struct = Structure.from_dict(self.structures[graph_id])
struct = self.structures[graph_id]
crystal_graph = self.graph_converter(
struct, graph_id=graph_id, mp_id=graph_id
)
Expand All @@ -101,9 +101,9 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict]:

return crystal_graph, targets

# Omit structures with isolated atoms. Return another random selected structure
# Omit structures with isolated atoms. Return another randomly selected structure
except Exception:
struct = Structure.from_dict(self.structures[graph_id])
struct = self.structures[graph_id]
self.failed_graph_id[graph_id] = struct.composition.formula
self.failed_idx.append(idx)
idx = random.randint(0, len(self) - 1)
Expand All @@ -114,7 +114,7 @@ def __getitem__(self, idx) -> tuple[CrystalGraph, dict]:


class CIFData(Dataset):
"""A dataset from cifs."""
"""A dataset from CIFs."""

def __init__(
self,
Expand All @@ -124,7 +124,7 @@ def __init__(
graph_converter: CrystalGraphConverter = None,
**kwargs,
) -> None:
"""Initialize the dataset from a directory containing cifs.
"""Initialize the dataset from a directory containing CIFs.

Args:
cif_path (str): path that contain all the graphs, labels.json
Expand All @@ -149,20 +149,20 @@ def __init__(

self.energy_str = kwargs.pop("energy_str", "energy_per_atom")
self.targets = targets
self.failed_idx: list[str] = []
self.failed_idx: list[int] = []
self.failed_graph_id: dict[str, str] = {}

def __len__(self) -> int:
"""Get the number of structures in this dataset."""
return len(self.cif_ids)

@functools.lru_cache(maxsize=None) # Cache loaded structures
def __getitem__(self, idx):
def __getitem__(self, idx: int) -> tuple[CrystalGraph, dict[str, Tensor]]:
"""Get one item in the dataset.

Returns:
crystal_graph (CrystalGraph): graph of the crystal structure
targets (dict): list of targets. i.e. energy, force, stress
tuple[CrystalGraph, dict[str, Tensor]]: graph of the crystal structure
and dict of targets i.e. energy, force, stress
"""
if idx not in self.failed_idx:
try:
Expand Down Expand Up @@ -192,7 +192,7 @@ def __getitem__(self, idx):
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
return crystal_graph, targets

# Omit structures with isolated atoms. Return another random selected structure
# Omit structures with isolated atoms. Return another randomly selected structure
except Exception:
try:
graph_id = self.cif_ids[idx]
Expand Down Expand Up @@ -258,7 +258,7 @@ def __init__(

self.energy_str = energy_str
self.targets = targets
self.failed_idx: list[str] = []
self.failed_idx: list[int] = []
self.failed_graph_id: dict[str, str] = {}

def __len__(self) -> int:
Expand Down Expand Up @@ -436,16 +436,16 @@ def __init__(
data: str | dict,
graph_converter: CrystalGraphConverter,
targets: TrainTask = "efsm",
**kwargs,
energy_str: str = "energy_per_atom",
) -> None:
"""Initialize the dataset by reading Json files.

Args:
data (str | dict): json path or dir name that contain all the jsons
data (str | dict): file path or dir name that contain all the JSONs
graph_converter (CrystalGraphConverter): Converts pymatgen.core.Structure to graph
targets ('ef' | 'efs' | 'efsm'): the training targets e=energy, f=forces, s=stress,
m=magmons. Default = "efsm".
**kwargs: other arguments
energy_str (str): key to get energy from the JSON file. Default = "energy_per_atom"
"""
if isinstance(data, str):
self.data = {}
Expand All @@ -464,14 +464,14 @@ def __init__(

self.keys = []
for mp_id, dic in self.data.items():
for graph_id, _ in dic.items():
for graph_id in dic:
self.keys.append((mp_id, graph_id))
random.shuffle(self.keys)
print(f"{len(self.data)} mp_ids, {len(self)} structures imported")
self.graph_converter = graph_converter
self.energy_str = kwargs.pop("energy_str", "energy_per_atom")
self.energy_str = energy_str
self.targets = targets
self.failed_idx: list[str] = []
self.failed_idx: list[int] = []
self.failed_graph_id: dict[str, str] = {}

def __len__(self) -> int:
Expand Down Expand Up @@ -515,7 +515,7 @@ def __getitem__(self, idx):
targets["m"] = torch.abs(torch.tensor(mag, dtype=datatype))
return crystal_graph, targets

# Omit structures with isolated atoms. Return another random selected structure
# Omit structures with isolated atoms. Return another randomly selected structure
except Exception:
structure = Structure.from_dict(self.data[mp_id][graph_id]["structure"])
self.failed_graph_id[graph_id] = structure.composition.formula
Expand Down
31 changes: 13 additions & 18 deletions chgnet/model/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def calculate(
atoms: Atoms | None = None,
desired_properties: list | None = None,
changed_properties: list | None = None,
):
) -> None:
"""Calculate various properties of the atoms using CHGNet.

Args:
Expand All @@ -94,9 +94,6 @@ def calculate(
Default is all properties.
changed_properties (list | None): The changes made to the system.
Default is all changes.

Returns:
None
"""
desired_properties = desired_properties or all_properties
changed_properties = changed_properties or all_changes
Expand Down Expand Up @@ -259,25 +256,23 @@ def compute_energy(self) -> float:
energy = self.atoms.get_potential_energy()
return energy

def save(self, filename: str):
def save(self, filename: str) -> None:
"""Save the trajectory to file.

Args:
filename (str): filename to save the trajectory
"""
with open(filename, "wb") as f:
pickle.dump(
{
"energy": self.energies,
"forces": self.forces,
"stresses": self.stresses,
"magmoms": self.magmoms,
"atom_positions": self.atom_positions,
"cell": self.cells,
"atomic_number": self.atoms.get_atomic_numbers(),
},
f,
)
out_pkl = {
"energy": self.energies,
"forces": self.forces,
"stresses": self.stresses,
"magmoms": self.magmoms,
"atom_positions": self.atom_positions,
"cell": self.cells,
"atomic_number": self.atoms.get_atomic_numbers(),
}
with open(filename, "wb") as file:
pickle.dump(out_pkl, file)


class MolecularDynamics:
Expand Down
12 changes: 6 additions & 6 deletions chgnet/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ def mae(prediction: Tensor, target: Tensor) -> Tensor:
"""Computes the mean absolute error between prediction and target
Parameters
----------
prediction: torch.Tensor (N, 1)
target: torch.Tensor (N, 1).
prediction: Tensor (N, 1)
target: Tensor (N, 1).
"""
return torch.mean(torch.abs(target - prediction))

Expand All @@ -52,8 +52,8 @@ def read_json(fjson):
Returns:
dictionary stored in fjson
"""
with open(fjson) as f:
return json.load(f)
with open(fjson) as file:
return json.load(file)


def write_json(d, fjson):
Expand All @@ -64,8 +64,8 @@ def write_json(d, fjson):
Returns:
written dictionary
"""
with open(fjson, "w") as f:
json.dump(d, f)
with open(fjson, "w") as file:
json.dump(d, file)


def solve_charge_by_mag(
Expand Down
19 changes: 9 additions & 10 deletions examples/crystaltoolkit_relax_viewer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,7 @@
"source": [
"e_col = \"Energy (eV)\"\n",
"force_col = \"Force (eV/Å)\"\n",
"df_traj = pd.DataFrame()\n",
"df_traj[e_col] = trajectory.energies\n",
"df_traj = pd.DataFrame(trajectory.energies, columns=[e_col])\n",
"df_traj[force_col] = [\n",
" np.linalg.norm(force, axis=1).mean() # mean of norm of force on each atom\n",
" for force in trajectory.forces\n",
Expand Down Expand Up @@ -301,20 +300,19 @@
" go.Scatter(x=df.index, y=df[force_col], mode=\"lines\", name=\"Forces\", yaxis=\"y2\")\n",
" )\n",
"\n",
" # Set up the layout\n",
" fig.update_layout(\n",
" template=\"plotly_white\",\n",
" title=title,\n",
" xaxis=dict(title=\"Relaxation Step\"),\n",
" yaxis=dict(title=e_col),\n",
" yaxis2=dict(title=force_col, overlaying=\"y\", side=\"right\"),\n",
" legend=dict(orientation=\"h\", yanchor=\"bottom\", y=1.02, xanchor=\"right\", x=1),\n",
" legend=dict(yanchor=\"top\", y=1, xanchor=\"right\", x=1),\n",
" )\n",
"\n",
" # Add vertical line at the specified step\n",
" # vertical line at the specified step\n",
" fig.add_vline(x=step, line=dict(dash=\"dash\", width=1))\n",
"\n",
" # Add horizontal line for MP final energy\n",
" # horizontal line for DFT final energy\n",
" anno = dict(text=\"DFT final energy\", yanchor=\"top\")\n",
" fig.add_hline(\n",
" y=dft_energy,\n",
Expand Down Expand Up @@ -348,9 +346,7 @@
" slider,\n",
" html.Div([struct_comp.layout(), graph], style=dict(display=\"flex\", gap=\"2em\")),\n",
" ],\n",
" style=dict(\n",
" margin=\"2em auto\", placeItems=\"center\", textAlign=\"center\", maxWidth=\"1200px\"\n",
" ),\n",
" style=dict(margin=\"auto\", textAlign=\"center\", maxWidth=\"1200px\", padding=\"2em\"),\n",
")\n",
"\n",
"ctc.register_crystal_toolkit(app=app, layout=app.layout)\n",
Expand All @@ -369,8 +365,11 @@
" assert len(structure) == len(coords)\n",
" for site, coord in zip(structure, coords):\n",
" site.coords = coord\n",
"\n",
" title = make_title(structure.get_space_group_info())\n",
" return structure, plot_energy_and_forces(df_traj, step, e_col, force_col, title)\n",
" fig = plot_energy_and_forces(df_traj, step, e_col, force_col, title)\n",
"\n",
" return structure, fig\n",
"\n",
"\n",
"app.run_server(mode=\"inline\", height=800, use_reloader=False)"
Expand Down