Skip to content

Commit

Permalink
Merge pull request #137 from adaptyvbio/protein_classes
Browse files Browse the repository at this point in the history
Protein classes
  • Loading branch information
elkoz committed Feb 6, 2024
2 parents 683adce + cd2ba00 commit eb4f0fa
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 15 deletions.
21 changes: 21 additions & 0 deletions proteinflow/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
proteinflow pickle file or a PDB file.
"""
import itertools
import os
import pickle
import string
Expand All @@ -22,6 +23,7 @@
import pandas as pd
from Bio import pairwise2
from biopandas.pdb import PandasPdb
from editdistance import eval as edit_distance
from torch import Tensor, from_numpy

try:
Expand Down Expand Up @@ -2080,6 +2082,25 @@ def apply_mask(self, mask):
out_dict["protein_id"] = self.id
return ProteinEntry.from_dict(out_dict)

def get_protein_class(self):
"""Get the protein class.
Returns
-------
protein_class : str
The protein class ("single_chain", "heteromer", "homomer")
"""
if len(self.get_chains()) == 1:
return "single_chain"
else:
for chain1, chain2 in itertools.combinations(self.get_chains(), 2):
if len(chain1) > 0.9 * len(chain2) or len(chain2) > 0.9 * len(chain1):
return "heteromer"
if edit_distance(chain1, chain2) / max(len(chain1), len(chain2)) > 0.1:
return "heteromer"
return "homomer"


class PDBEntry:
"""A class for parsing PDB entries."""
Expand Down
55 changes: 41 additions & 14 deletions proteinflow/data/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def from_args(
entry_type : {"biounit", "chain", "pair"}
the type of entries to generate (`"biounit"` for biounit-level, `"chain"` for chain-level, `"pair"` for chain-chain pairs)
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `"single_chains"`, `"heteromers"`, `"homomers"`)
a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`)
lower_limit : int, default 15
the minimum number of residues to mask
upper_limit : int, default 100
Expand Down Expand Up @@ -360,7 +360,7 @@ def __init__(
the type of entries to generate (`"biounit"` for biounit-level complexes, `"chain"` for chain-level, `"pair"`
for chain-chain pairs (all pairs that are seen in the same biounit and have intersecting coordinate clouds))
classes_to_exclude : list of str, optional
a list of classes to exclude from the dataset (select from `"single_chains"`, `"heteromers"`, `"homomers"`)
a list of classes to exclude from the dataset (select from `"single_chain"`, `"heteromer"`, `"homomer"`)
shuffle_clusters : bool, default True
if `True`, a new representative is randomly selected for each cluster at each epoch (if `clustering_dict_path` is given)
min_cdr_length : int, optional
Expand Down Expand Up @@ -456,13 +456,10 @@ def __init__(
}
self.feature_functions.update(feature_functions or {})
if classes_to_exclude is not None and not all(
[
x in ["single_chains", "heteromers", "homomers"]
for x in classes_to_exclude
]
[x in ["single_chain", "heteromer", "homomer"] for x in classes_to_exclude]
):
raise ValueError(
"Invalid class to exclude, choose from 'single_chains', 'heteromers', 'homomers'"
"Invalid class to exclude, choose from 'single_chain', 'heteromer', 'homomer'"
)

if debug_file_path is not None:
Expand Down Expand Up @@ -507,7 +504,11 @@ def __init__(
)
output_tuples_list = p_map(
lambda x: self._process(
x, rewrite=rewrite, max_length=max_length, min_cdr_length=min_cdr_length
x,
rewrite=rewrite,
max_length=max_length,
min_cdr_length=min_cdr_length,
classes_to_exclude=classes_to_exclude,
),
to_process,
)
Expand Down Expand Up @@ -537,11 +538,12 @@ def __init__(
"Classes to exclude are given but no classes dictionary is found, please set classes_dict_path to the path of the classes dictionary"
)
to_exclude = set()
if classes is not None:
for c in classes_to_exclude:
for key, id_arr in classes.get(c, {}).items():
for id, _ in id_arr:
to_exclude.add(id)

# if classes is not None:
# for c in classes_to_exclude:
# for key, id_arr in classes.get(c, {}).items():
# for id, _ in id_arr:
# to_exclude.add(id)
if require_antigen or require_light_chain:
to_exclude.update(
self._exclude_by_chains(
Expand Down Expand Up @@ -613,6 +615,21 @@ def _check_chain_types(self, file):
chain_types.add("light")
return chain_types

def _exclude_by_class(
self,
classes_to_exclude,
):
"""Exclude entries that are in the classes to exclude."""
to_exclude = set()
for id in self.files:
for chain in self.files[id]:
filename = self.files[id][chain][0]
with open(filename, "rb") as f:
data = pickle.load(f)
if classes_to_exclude in data["classes"]:
to_exclude.add(id)
return to_exclude

def _exclude_by_chains(
self,
require_antigen,
Expand Down Expand Up @@ -771,13 +788,23 @@ def _sidechain_coords(self, data_entry, chains):
"""Return idechain coordinates."""
return data_entry.sidechain_coordinates(chains)

def _process(self, filename, rewrite=False, max_length=None, min_cdr_length=None):
def _process(
self,
filename,
rewrite=False,
max_length=None,
min_cdr_length=None,
classes_to_exclude=None,
):
"""Process a proteinflow file and save it as ProteinMPNN features."""
input_file = os.path.join(self.dataset_folder, filename)
no_extension_name = filename.split(".")[0]
data_entry = ProteinEntry.from_pickle(input_file)
if self.load_ligands:
ligands = ProteinEntry.retrieve_ligands_from_pickle(input_file)
if classes_to_exclude is not None:
if data_entry.get_protein_class() in classes_to_exclude:
return []
chains = data_entry.get_chains()
if self.entry_type == "biounit":
chain_sets = [chains]
Expand Down
2 changes: 1 addition & 1 deletion tests/test_download.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_download(tag):
)
subprocess.run(["proteinflow", "split", "--tag", tag], check=True)
for cluster_dict_path, entry_type, classes_to_exclude in [
(os.path.join(folder, "splits_dict/valid.pickle"), "chain", ["homomers"]),
(os.path.join(folder, "splits_dict/valid.pickle"), "chain", ["homomer"]),
(None, "pair", None),
]:
valid_loader = ProteinLoader.from_args(
Expand Down

0 comments on commit eb4f0fa

Please sign in to comment.