Skip to content

Commit

Permalink
feat: update notebook, add charge check, fix sdpa (#61)
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinEloff authored Oct 11, 2024
1 parent d247f7f commit a99f2ae
Show file tree
Hide file tree
Showing 6 changed files with 252 additions and 98 deletions.
49 changes: 47 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ conda env create -f environment.yml
conda activate instanovo
```

Note: InstaNovo is built for Python >= 3.8, <3.12 and tested on Linux and Windows.
Note: InstaNovo is built for Python >= 3.10, <3.12 and tested on Linux.

### Training

Expand Down Expand Up @@ -115,6 +115,46 @@ The configuration file for inference may be found under

Note: the `denovo=True/False` flag controls whether metrics will be calculated.

### Models

InstaNovo 1.0.0 includes a new model `instanovo_extended.ckpt` trained on a larger dataset with more
PTMs

**Training Datasets**

- [ProteomeTools](https://www.proteometools.org/) Part
[I (PXD004732)](https://www.ebi.ac.uk/pride/archive/projects/PXD004732),
[II (PXD010595)](https://www.ebi.ac.uk/pride/archive/projects/PXD010595), and
[III (PXD021013)](https://www.ebi.ac.uk/pride/archive/projects/PXD021013) \
(referred to as the all-confidence ProteomeTools `AC-PT` dataset in our paper)
- Additional PRIDE dataset with more modifications: \
([PXD000666](https://www.ebi.ac.uk/pride/archive/projects/PXD000666), [PXD000867](https://www.ebi.ac.uk/pride/archive/projects/PXD000867),
[PXD001839](https://www.ebi.ac.uk/pride/archive/projects/PXD001839), [PXD003155](https://www.ebi.ac.uk/pride/archive/projects/PXD003155),
[PXD004364](https://www.ebi.ac.uk/pride/archive/projects/PXD004364), [PXD004612](https://www.ebi.ac.uk/pride/archive/projects/PXD004612),
[PXD005230](https://www.ebi.ac.uk/pride/archive/projects/PXD005230), [PXD006692](https://www.ebi.ac.uk/pride/archive/projects/PXD006692),
[PXD011360](https://www.ebi.ac.uk/pride/archive/projects/PXD011360), [PXD011536](https://www.ebi.ac.uk/pride/archive/projects/PXD011536),
[PXD013543](https://www.ebi.ac.uk/pride/archive/projects/PXD013543), [PXD015928](https://www.ebi.ac.uk/pride/archive/projects/PXD015928),
[PXD016793](https://www.ebi.ac.uk/pride/archive/projects/PXD016793), [PXD017671](https://www.ebi.ac.uk/pride/archive/projects/PXD017671),
[PXD019431](https://www.ebi.ac.uk/pride/archive/projects/PXD019431), [PXD019852](https://www.ebi.ac.uk/pride/archive/projects/PXD019852),
[PXD026910](https://www.ebi.ac.uk/pride/archive/projects/PXD026910), [PXD027772](https://www.ebi.ac.uk/pride/archive/projects/PXD027772))
- Additional phosphorylation dataset \
(not yet publicly released)

**Natively Supported Modifications**

- Oxidation of methionine
- Cysteine alkylation / Carboxyamidomethylation
- Asparagine and glutamine deamidation
- Serine, Threonine, and Tyrosine phosphorylation
- N-terminal ammonia loss
- N-terminal carbamylation
- N-terminal acetylation

See residue configuration under
[instanovo/configs/residues/extended.yaml](./instanovo/configs/residues/extended.yaml)

## Additional features

### Spectrum Data Class

InstaNovo introduces a Spectrum Data Class: [SpectrumDataFrame](./instanovo/utils/data_handler.py).
Expand Down Expand Up @@ -196,7 +236,7 @@ lazy_df = sdf.to_polars(return_lazy=True) # Returns a pl.LazyFrame
sdf.write_mgf("path/to/output.mgf")
```

**Additional Features:**
**SpectrumDataFrame Features:**

- The SpectrumDataFrame supports lazy loading with asynchronous prefetching, mitigating wait times
between files.
Expand Down Expand Up @@ -291,3 +331,8 @@ The model checkpoints are licensed under Creative Commons Non-Commercial
journal = {bioRxiv}
}
```

## Acknowledgements

Big thanks to Pathmanaban Ramasamy, Tine Claeys, and Lennart Martens for providing us with
additional phosphorylation training data.
1 change: 1 addition & 0 deletions instanovo/configs/inference/default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ data_type: # .csv, .mgf, .mzml, .mzxml
denovo: False
num_beams: 1 # 1 defaults to greedy search with basic filtering
max_length: 40
max_charge: 10 # Must be <= model max charge
isotope_error_range: [0, 1]
use_knapsack: False
save_beams: False
Expand Down
18 changes: 16 additions & 2 deletions instanovo/transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from omegaconf import DictConfig
from torch import nn
from torch import Tensor
from torch.nn.attention import sdpa_kernel
from torch.nn.attention import SDPBackend

from instanovo.constants import MAX_SEQUENCE_LENGTH
from instanovo.transformer.layers import ConvPeakEmbedding
Expand Down Expand Up @@ -354,6 +352,14 @@ def _flash_encoder(
latent_spectra = self.latent_spectrum.expand(x.shape[0], -1, -1)
x = torch.cat([latent_spectra, x], dim=1).contiguous()

try:
from torch.nn.attention import sdpa_kernel
from torch.nn.attention import SDPBackend
except ImportError:
raise ImportError(
"Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version"
)

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
x = self.encoder(x)

Expand Down Expand Up @@ -391,6 +397,14 @@ def _flash_decoder(

c_mask = self._get_causal_mask(y.shape[1]).to(y.device)

try:
from torch.nn.attention import sdpa_kernel
from torch.nn.attention import SDPBackend
except ImportError:
raise ImportError(
"Training InstaNovo with Flash attention enabled requires at least pytorch v2.3. Please upgrade your pytorch version"
)

with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
y_hat = self.decoder(y, x, tgt_mask=c_mask)

Expand Down
20 changes: 20 additions & 0 deletions instanovo/transformer/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,26 @@ def get_preds(
else:
raise

# Check max charge values:
original_size = len(sdf)
max_charge = config.get("max_charge", 10)
model_max_charge = model_config.get("max_charge", 10)
if max_charge > model_max_charge:
logger.warning(
f"Inference has been configured with max_charge={max_charge}, but model has max_charge={model_max_charge}."
)
logger.warning(f"Overwriting max_charge to Model value: {max_charge}.")
max_charge = model_max_charge

sdf.filter_rows(
lambda row: (row["precursor_charge"] <= max_charge)
and (row["precursor_charge"] > 0)
)
if len(sdf) < original_size:
logger.warning(
f"Found {original_size - len(sdf)} rows with charge > {max_charge}. These rows will be skipped."
)

sdf.sample_subset(fraction=config.get("subset", 1.0), seed=42)
logger.info(
f"Data loaded, evaluating {config.get('subset', 1.0)*100:.1f}%, {len(sdf):,} samples in total."
Expand Down
79 changes: 48 additions & 31 deletions instanovo/transformer/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,7 +420,39 @@ def train(
else:
raise

# TODO: Add automatic splitting if no validation set is specified.
if config.get("valid_path", None) is None:
logger.info("Validation path not specified, generating from training set.")
sequences = list(train_sdf.get_unique_sequences())
sequences = sorted(list(set([remove_modifications(x) for x in sequences])))
train_unique, valid_unique = train_test_split(
sequences,
test_size=config.get("valid_subset_of_train"),
random_state=42,
)
train_unique = set(train_unique)
valid_unique = set(valid_unique)

train_sdf.filter_rows(
lambda row: remove_modifications(row["sequence"]) in train_unique
)
valid_sdf.filter_rows(
lambda row: remove_modifications(row["sequence"]) in valid_unique
)
# Save splits
# TODO: Optionally load the data splits
# TODO: Allow loading of data splits in `predict.py`
# TODO: Upload to Aichor
split_path = os.path.join(
config.get("model_save_folder_path", "./checkpoints"), "splits.csv"
)
os.makedirs(os.path.dirname(split_path), exist_ok=True)
pd.DataFrame(
{
"modified_sequence": list(train_unique) + list(valid_unique),
"split": ["train"] * len(train_unique) + ["valid"] * len(valid_unique),
}
).to_csv(str(split_path), index=False)
logger.info(f"Data splits saved to {split_path}")

# Check residues
if config.get("perform_data_checks", True):
Expand Down Expand Up @@ -463,40 +495,25 @@ def train(
f"{original_size[1]-new_size[1]:,d} ({(original_size[1]-new_size[1])/original_size[1]*100:.2f}%) validation rows dropped."
)

# TODO Modify this code to work in the new SpectrumDataFrame
if config.get("valid_path", None) is None:
logger.info("Validation path not specified, generating from training set.")
sequences = list(train_sdf.get_unique_sequences())
sequences = sorted(list(set([remove_modifications(x) for x in sequences])))
train_unique, valid_unique = train_test_split(
sequences,
test_size=config.get("valid_subset_of_train"),
random_state=42,
)
train_unique = set(train_unique)
valid_unique = set(valid_unique)

# Check charge values:
original_size = (len(train_sdf), len(valid_sdf))
train_sdf.filter_rows(
lambda row: remove_modifications(row["sequence"]) in train_unique
lambda row: (row["precursor_charge"] <= config.get("max_charge", 10))
and (row["precursor_charge"] > 0)
)
if len(train_sdf) < original_size[0]:
logger.warning(
f"Found {original_size[0] - len(train_sdf)} rows in training set with charge > {config.get('max_charge', 10)} or <= 0. These rows will be skipped."
)

valid_sdf.filter_rows(
lambda row: remove_modifications(row["sequence"]) in valid_unique
lambda row: (row["precursor_charge"] <= config.get("max_charge", 10))
and (row["precursor_charge"] > 0)
)
# Save splits
# TODO: Optionally load the data splits
# TODO: Allow loading of data splits in `predict.py`
# TODO: Upload to Aichor
split_path = os.path.join(
config.get("model_save_folder_path", "./checkpoints"), "splits.csv"
)
os.makedirs(os.path.dirname(split_path), exist_ok=True)
pd.DataFrame(
{
"modified_sequence": list(train_unique) + list(valid_unique),
"split": ["train"] * len(train_unique) + ["valid"] * len(valid_unique),
}
).to_csv(str(split_path), index=False)
logger.info(f"Data splits saved to {split_path}")
if len(valid_sdf) < original_size[1]:
logger.warning(
f"Found {original_size[1] - len(valid_sdf)} rows in training set with charge > {config.get('max_charge', 10)}. These rows will be skipped."
)

train_sdf.sample_subset(fraction=config.get("train_subset", 1.0), seed=42)
valid_sdf.sample_subset(fraction=config.get("valid_subset", 1.0), seed=42)
Expand Down
Loading

0 comments on commit a99f2ae

Please sign in to comment.