Skip to content

Commit

Permalink
Merge pull request #199 from ebetica/upstream_updates
Browse files Browse the repository at this point in the history
Minor optimizations & fixes to support ESMFold
  • Loading branch information
gahdritz authored Aug 23, 2022
2 parents 349fdbd + 023596d commit 4b41059
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 32 deletions.
73 changes: 45 additions & 28 deletions openfold/model/structure_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,11 +573,11 @@ def __init__(
self.epsilon = epsilon
self.inf = inf

# To be lazily initialized later
self.default_frames = None
self.group_idx = None
self.atom_mask = None
self.lit_positions = None
# Buffers to be lazily initialized later
# self.default_frames
# self.group_idx
# self.atom_mask
# self.lit_positions

self.layer_norm_s = LayerNorm(self.c_s)
self.layer_norm_z = LayerNorm(self.c_z)
Expand Down Expand Up @@ -723,6 +723,7 @@ def forward(
"unnormalized_angles": unnormalized_angles,
"angles": angles,
"positions": pred_xyz,
"states": s,
}

outputs.append(preds)
Expand All @@ -742,32 +743,48 @@ def forward(
return outputs

def _init_residue_constants(self, float_dtype, device):
if self.default_frames is None:
self.default_frames = torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
if not hasattr(self, "default_frames"):
self.register_buffer(
"default_frames",
torch.tensor(
restype_rigid_group_default_frame,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if self.group_idx is None:
self.group_idx = torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
if not hasattr(self, "group_idx"):
self.register_buffer(
"group_idx",
torch.tensor(
restype_atom14_to_rigid_group,
device=device,
requires_grad=False,
),
persistent=False,
)
if self.atom_mask is None:
self.atom_mask = torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
if not hasattr(self, "atom_mask"):
self.register_buffer(
"atom_mask",
torch.tensor(
restype_atom14_mask,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)
if self.lit_positions is None:
self.lit_positions = torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
if not hasattr(self, "lit_positions"):
self.register_buffer(
"lit_positions",
torch.tensor(
restype_atom14_rigid_group_positions,
dtype=float_dtype,
device=device,
requires_grad=False,
),
persistent=False,
)

def torsion_angles_to_frames(self, r, alpha, f):
Expand Down
21 changes: 17 additions & 4 deletions openfold/utils/rigid_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.

from __future__ import annotations
from functools import lru_cache
from typing import Tuple, Any, Sequence, Callable, Optional

import numpy as np
Expand Down Expand Up @@ -84,7 +85,7 @@ def rot_vec_mul(
dim=-1,
)

@lru_cache(maxsize=None)
def identity_rot_mats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
Expand All @@ -101,6 +102,7 @@ def identity_rot_mats(
return rots


@lru_cache(maxsize=None)
def identity_trans(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
Expand All @@ -116,6 +118,7 @@ def identity_trans(
return trans


@lru_cache(maxsize=None)
def identity_quats(
batch_dims: Tuple[int],
dtype: Optional[torch.dtype] = None,
Expand Down Expand Up @@ -175,7 +178,7 @@ def quat_to_rot(quat: torch.Tensor) -> torch.Tensor:
quat = quat[..., None] * quat[..., None, :]

# [4, 4, 3, 3]
mat = quat.new_tensor(_QTR_MAT, requires_grad=False)
mat = _get_quat("_QTR_MAT", dtype=quat.dtype, device=quat.device)

# [*, 4, 4, 3, 3]
shaped_qtr_mat = mat.view((1,) * len(quat.shape[:-2]) + mat.shape)
Expand Down Expand Up @@ -230,10 +233,20 @@ def rot_to_quat(

_QUAT_MULTIPLY_BY_VEC = _QUAT_MULTIPLY[:, 1:, :]

_CACHED_QUATS = {
"_QTR_MAT": _QTR_MAT,
"_QUAT_MULTIPLY": _QUAT_MULTIPLY,
"_QUAT_MULTIPLY_BY_VEC": _QUAT_MULTIPLY_BY_VEC
}

@lru_cache(maxsize=None)
def _get_quat(quat_key, dtype, device):
return torch.tensor(_CACHED_QUATS[quat_key], dtype=dtype, device=device)


def quat_multiply(quat1, quat2):
"""Multiply a quaternion by another quaternion."""
mat = quat1.new_tensor(_QUAT_MULTIPLY)
mat = _get_quat("_QUAT_MULTIPLY", dtype=quat1.dtype, device=quat1.device)
reshaped_mat = mat.view((1,) * len(quat1.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
Expand All @@ -245,7 +258,7 @@ def quat_multiply(quat1, quat2):

def quat_multiply_by_vec(quat, vec):
"""Multiply a quaternion by a pure-vector quaternion."""
mat = quat.new_tensor(_QUAT_MULTIPLY_BY_VEC)
mat = _get_quat("_QUAT_MULTIPLY_BY_VEC", dtype=quat.dtype, device=quat.device)
reshaped_mat = mat.view((1,) * len(quat.shape[:-1]) + mat.shape)
return torch.sum(
reshaped_mat *
Expand Down

0 comments on commit 4b41059

Please sign in to comment.