Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Protein classes #137

Merged
merged 2 commits into from
Feb 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
proteinflow pickle file or a PDB file.

"""
import itertools
import os
import pickle
import string
Expand All @@ -22,6 +23,7 @@
import pandas as pd
from Bio import pairwise2
from biopandas.pdb import PandasPdb
from editdistance import eval as edit_distance
from torch import Tensor, from_numpy

try:
Expand Down Expand Up @@ -2080,6 +2082,25 @@ def apply_mask(self, mask):
out_dict["protein_id"] = self.id
return ProteinEntry.from_dict(out_dict)

def get_protein_class(self):
"""Get the protein class.

Returns
-------
protein_class : str
The protein class ("single_chain", "heteromer", "homomer")

"""
if len(self.get_chains()) == 1:
return "single_chain"
else:
for chain1, chain2 in itertools.combinations(self.get_chains(), 2):
if len(chain1) > 0.9 * len(chain2) or len(chain2) > 0.9 * len(chain1):
return "heteromer"
if edit_distance(chain1, chain2) / max(len(chain1), len(chain2)) > 0.1:
return "heteromer"
return "homomer"


class PDBEntry:
"""A class for parsing PDB entries."""
Expand Down
55 changes: 41 additions & 14 deletions proteinflow/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def from_args(
entry_type : {"biounit", "chain", "pair"}
the type of entries to generate (`"biounit"` for biounit-level, `"chain"` for chain-level, `"pair"` for chain-chain pairs)
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `"single_chains"`, `"heteromers"`, `"homomers"`)
a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`)
lower_limit : int, default 15
the minimum number of residues to mask
upper_limit : int, default 100
Expand Down Expand Up @@ -360,7 +360,7 @@ def __init__(
the type of entries to generate (`"biounit"` for biounit-level complexes, `"chain"` for chain-level, `"pair"`
for chain-chain pairs (all pairs that are seen in the same biounit and have intersecting coordinate clouds))
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `"single_chains"`, `"heteromers"`, `"homomers"`)
a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`)
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
min_cdr_length : int, optional
Expand Down Expand Up @@ -456,13 +456,10 @@ def __init__(
}
self.feature_functions.update(feature_functions or {})
if classes_to_exclude is not None and not all(
[
x in ["single_chains", "heteromers", "homomers"]
for x in classes_to_exclude
]
[x in ["single_chain", "heteromer", "homomer"] for x in classes_to_exclude]
):
raise ValueError(
"Invalid class to exclude, choose from 'single_chains', 'heteromers', 'homomers'"
"Invalid class to exclude, choose from 'single_chain', 'heteromer', 'homomer'"
)

if debug_file_path is not None:
Expand Down Expand Up @@ -507,7 +504,11 @@ def __init__(
)
output_tuples_list = p_map(
lambda x: self._process(
x, rewrite=rewrite, max_length=max_length, min_cdr_length=min_cdr_length
x,
rewrite=rewrite,
max_length=max_length,
min_cdr_length=min_cdr_length,
classes_to_exclude=classes_to_exclude,
),
to_process,
)
Expand Down Expand Up @@ -537,11 +538,12 @@ def __init__(
"Classes to exclude are given but no classes dictionary is found, please set classes_dict_path to the path of the classes dictionary"
)
to_exclude = set()
if classes is not None:
for c in classes_to_exclude:
for key, id_arr in classes.get(c, {}).items():
for id, _ in id_arr:
to_exclude.add(id)

# if classes is not None:
# for c in classes_to_exclude:
# for key, id_arr in classes.get(c, {}).items():
# for id, _ in id_arr:
# to_exclude.add(id)
if require_antigen or require_light_chain:
to_exclude.update(
self._exclude_by_chains(
Expand Down Expand Up @@ -613,6 +615,21 @@ def _check_chain_types(self, file):
chain_types.add("light")
return chain_types

def _exclude_by_class(
self,
classes_to_exclude,
):
"""Exclude entries that are in the classes to exclude."""
to_exclude = set()
for id in self.files:
for chain in self.files[id]:
filename = self.files[id][chain][0]
with open(filename, "rb") as f:
data = pickle.load(f)
if classes_to_exclude in data["classes"]:
to_exclude.add(id)
return to_exclude

def _exclude_by_chains(
self,
require_antigen,
Expand Down Expand Up @@ -771,13 +788,23 @@ def _sidechain_coords(self, data_entry, chains):
"""Return idechain coordinates."""
return data_entry.sidechain_coordinates(chains)

def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None):
def _process(
self,
filename,
rewrite=False,
max_length=None,
min_cdr_length=None,
classes_to_exclude=None,
):
"""Process a proteinflow file and save it as ProteinMPNN features."""
input_file = os.path.join(self.dataset_folder, filename)
no_extension_name = filename.split(".")[0]
data_entry = ProteinEntry.from_pickle(input_file)
if self.load_ligands:
ligands = ProteinEntry.retrieve_ligands_from_pickle(input_file)
if classes_to_exclude is not None:
if data_entry.get_protein_class() in classes_to_exclude:
return []
chains = data_entry.get_chains()
if self.entry_type == "biounit":
chain_sets = [chains]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_download(tag):
)
subprocess.run(["proteinflow", "split", "--tag", tag], check=True)
for cluster_dict_path, entry_type, classes_to_exclude in [
(os.path.join(folder, "splits_dict/valid.pickle"), "chain", ["homomers"]),
(os.path.join(folder, "splits_dict/valid.pickle"), "chain", ["homomer"]),
(None, "pair", None),
]:
valid_loader = ProteinLoader.from_args(
Expand Down
Loading