diff --git a/genslm/__init__.py b/genslm/__init__.py index 75beef31..6731c875 100644 --- a/genslm/__init__.py +++ b/genslm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.0.2a1" +__version__ = "0.0.3a1" # Public imports from genslm.dataset import SequenceDataset # noqa diff --git a/genslm/cmdline/gather_embeddings.py b/genslm/cmdline/gather_embeddings.py deleted file mode 100644 index a70f231f..00000000 --- a/genslm/cmdline/gather_embeddings.py +++ /dev/null @@ -1,40 +0,0 @@ -from argparse import ArgumentParser -from pathlib import Path -from typing import Optional - -import numpy as np -import numpy.typing as npt - - -def gather_embeddings( - input_dir: Path, output_path: Optional[Path] = None -) -> npt.ArrayLike: - """Gather embeddings produced via DDP into a single sorted numpy array.""" - - # Glob embedding and index files written by each rank - # (need to sort by uuid's to match the rank-label between indices and - # embeddings files) - index_files = sorted(input_dir.glob("indices-*.npy")) - embedding_files = sorted(input_dir.glob("embeddings-*.npy")) - - # Load all index and embedding files into memory (fp16 means they are not large)) - indices = np.concatenate([np.load(f) for f in index_files]) - embeddings = np.concatenate([np.load(f) for f in embedding_files]) - - # Sort scattered indices - sort_inds = np.argsort(indices) - embeddings = embeddings[sort_inds] - - if output_path is not None: - np.save(output_path, embeddings) - - return embeddings - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("-i", "--input_dir", type=Path, required=True) - parser.add_argument("-o", "--output_path", type=Path, required=True) - args = parser.parse_args() - - gather_embeddings(args.input_dir, args.output_path) diff --git a/genslm/cmdline/gather_inference_h5.py b/genslm/cmdline/gather_inference_h5.py new file mode 100644 index 00000000..9d6d8d14 --- /dev/null +++ b/genslm/cmdline/gather_inference_h5.py @@ -0,0 +1,124 @@ +""" +Gathers embeddings written by `run_inference.py`. Gathers +rank files into single h5py file with ExternalLinks to +the original files. This is necesary for matching new H5 files to original +fasta files, but makes the dataset brittle to being transferred to new locations. But if +we try and copy dataset to new file it becomes very very slow. + +Current implementation coupled to the output format of `run_inference.py`. +""" +import re +from argparse import ArgumentParser +from pathlib import Path +from typing import Optional + +import h5py + + +def gather_logits( + input_dir: Path, + output_path: Optional[Path] = None, + glob_pattern: str = "logits-*.h5", + verbose: bool = False, +): + + if output_path is None: + output_path = input_dir / "logits_gathered.h5" + + input_files = list(input_dir.glob(glob_pattern)) + # Glob embedding and index files written by each rank + with h5py.File(output_path, "w") as output_file: + output_file.create_group("logits") + output_file.create_group("na-hashes") + for i, h5_file in enumerate(input_files): + if verbose: + print("Loading", h5_file) + with h5py.File(h5_file, "r") as input_file: + resolved_path = h5_file.resolve() + + for seq_fasta_index in input_file["logits"].keys(): + output_file["logits"][str(seq_fasta_index)] = h5py.ExternalLink( + str(resolved_path), f"logits/{seq_fasta_index}" + ) + + hashes = input_file["na-hashes"] + indices = input_file["fasta-indices"] + for fasta_idx, na_hash in zip(indices, hashes): + output_file["na-hashes"].create_dataset( + f"{fasta_idx}", data=na_hash + ) + if verbose: + print("Wrote gathered output to", output_path, "\n") + + +def gather_embeddings( + input_dir: Path, + output_path: Optional[Path] = None, + glob_pattern: Optional[str] = None, + verbose: bool = False, +) -> None: + """Gather embeddings produced via DDP into a single h5 file.""" + + if glob_pattern is None: + glob_pattern = "*.h5" + + if output_path is None: + output_path = input_dir / "embeddings_gathered.h5" + + input_files = list(input_dir.glob(glob_pattern)) + # Glob embedding and index files written by each rank + with h5py.File(output_path, "w") as output_file: + output_file.create_group("embeddings") + output_file.create_group("na-hashes") + for i, h5_file in enumerate(input_files): + if verbose: + print("Loading", h5_file) + with h5py.File(h5_file, "r") as input_file: + resolved_path = h5_file.resolve() + + for seq_fasta_index in input_file["embeddings"].keys(): + output_file["embeddings"][seq_fasta_index] = h5py.ExternalLink( + str(resolved_path), f"embeddings/{seq_fasta_index}" + ) + + hashes = input_file["na-hashes"] + indices = input_file["fasta-indices"] + for fasta_idx, na_hash in zip(indices, hashes): + output_file["na-hashes"].create_dataset( + f"{fasta_idx}", data=na_hash + ) + if verbose: + print("Wrote gathered output to", output_path, "\n") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("-i", "--input_dir", type=Path, required=True) + parser.add_argument("-o", "--output_path", type=Path, required=True) + parser.add_argument( + "-g", "--embeddings_glob_pattern", type=str, default="embeddings-*.h5" + ) + parser.add_argument("-l", "--logits_glob_pattern", type=str, default="logits-*.h5") + parser.add_argument("--embeddings", action="store_true", help="Gather embeddings") + parser.add_argument("--logits", action="store_true", help="Gather logits.") + parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output.") + args = parser.parse_args() + + if args.embeddings: + files = list(args.input_dir.glob(args.embeddings_glob_pattern)) + layers = set() + layer_pattern = re.compile(r"layer-(\d+)") + for file in files: + if "layer" in file.name: + layer = layer_pattern.search(file.name).group(1) + layers.add(layer) + + for layer in layers: + glob_pattern = f"*layer-{layer}*.h5" + out_path = args.output_path / f"embeddings-gathered-layer-{layer}.h5" + + gather_embeddings(args.input_dir, out_path, glob_pattern, args.verbose) + + if args.logits: + out_path = args.output_path / "logits-gathered.h5" + gather_logits(args.input_dir, out_path, args.logits_glob_pattern, args.verbose) diff --git a/genslm/cmdline/inference_outputs.py b/genslm/cmdline/inference_outputs.py deleted file mode 100644 index e94e0cda..00000000 --- a/genslm/cmdline/inference_outputs.py +++ /dev/null @@ -1,183 +0,0 @@ -import os -from argparse import ArgumentParser -from pathlib import Path -from typing import Any, Dict, Optional - -import pytorch_lightning as pl -from pydantic import root_validator, validator -from torch.utils.data import DataLoader # Subset - -import genslm -from genslm.config import BaseSettings, WarmupLRSettings -from genslm.dataset import FastaDataset, FileBackedH5Dataset -from genslm.model import DNATransformer -from genslm.utils import ( - LoadDeepSpeedStrategy, - LoadPTCheckpointStrategy, - OutputsCallback, -) - - -class InferenceConfig(BaseSettings): - data_file: Path - """Data file to run inference on (HDF5).""" - embeddings_out_path: Path - """Directory to write embeddings to.""" - model_config_json: Path - """Huggingface json dict to load AutoConfig from.""" - load_pt_checkpoint: Optional[Path] = None - """Checkpoint pt file to initialze model weights.""" - load_ds_checkpoint: Optional[Path] = None - """DeepSpeed checkpoint file to initialze model weights.""" - precision: int = 16 - """Model precision.""" - num_nodes: int = 1 - """Number of nodes to use for inference.""" - batch_size: int = 32 - """Batch size to use for inference.""" - num_data_workers: int = 4 - """Number of subprocesses to use for data loading.""" - prefetch_factor: int = 2 - """Number of batches loaded in advance by each worker.""" - pin_memory: bool = True - """If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.""" - block_size: int = 2048 - """Only used when processing a directory of fasta files.""" - deepspeed_flops_profile: bool = False - """Always false when computing embeddings""" - - # Parameters needed to initialize DNATransformer (not used for inference) - tokenizer_file: Path = ( - Path(genslm.__file__).parent - / "tokenizer_files" - / "codon_wordlevel_100vocab.json" - ) - learning_rate: float = 5e-5 - warm_up_lr: Optional[WarmupLRSettings] = None - - @root_validator - def assert_checkpoint_file_specified(cls, values: Dict[str, Any]) -> Dict[str, Any]: - load_pt_checkpoint: Optional[Path] = values.get("load_pt_checkpoint") - load_ds_checkpoint: Optional[Path] = values.get("load_ds_checkpoint") - if load_pt_checkpoint is None and load_ds_checkpoint is None: - raise ValueError( - "At least one of load_pt_checkpoint or load_ds_checkpoint must be specified." - ) - return values - - @validator("data_file") - def data_file_exists(cls, v: Path) -> Path: - if not v.exists(): - raise FileNotFoundError(f"data_file path does not exist {v}.") - return v - - @validator("load_pt_checkpoint") - def load_pt_checkpoint_exists(cls, v: Optional[Path]) -> Optional[Path]: - if v is not None and not v.exists(): - raise FileNotFoundError(f"load_pt_checkpoint path does not exist {v}.") - return v - - @validator("load_ds_checkpoint") - def load_ds_checkpoint_exists(cls, v: Optional[Path]) -> Optional[Path]: - if v is not None and not v.exists(): - raise FileNotFoundError(f"load_ds_checkpoint path does not exist {v}.") - return v - - -# class DNATransformer(pl.LightningModule): -# def __init__(self, cfg: InferenceConfig) -> None: -# # Loads from a hugging face JSON file -# base_config = AutoConfig.from_pretrained(cfg.model_config_json) -# self.model = AutoModelForCausalLM.from_config(base_config) - -# def forward( -# self, batch: Dict[str, torch.Tensor], **kwargs: Dict[str, Any] -# ) -> ModelOutput: -# return self.model( -# batch["input_ids"], -# labels=batch["input_ids"], -# attention_mask=batch["attention_mask"], -# **kwargs, -# ) - -# def predict_step( -# self, batch: Dict[str, torch.Tensor], batch_idx: int -# ) -> ModelOutput: -# return self(batch, output_hidden_states=True) - - -def main(config: InferenceConfig) -> None: - # Setup torch environment - os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" - # torch.set_num_threads(config.num_data_workers) # type: ignore[attr-defined] - pl.seed_everything(0) - - if config.load_pt_checkpoint: - model_strategy = LoadPTCheckpointStrategy( - config.load_pt_checkpoint, cfg=config, generation_flag=True - ) - else: - model_strategy = LoadDeepSpeedStrategy( - config.load_ds_checkpoint, cfg=config, generation_flag=True - ) - - model: DNATransformer = model_strategy.get_model(DNATransformer) - - tmp_embeddings_dir = config.embeddings_out_path.with_suffix("") - - if args.attention: - print("Generating attention values...") - elif args.logits: - print("Generating logit values...") - else: - print("Generating embeddings values...") - - embedding_callback = OutputsCallback( - save_dir=tmp_embeddings_dir, - output_attentions=args.attention, - output_logits=args.logits, - ) - trainer = pl.Trainer( - gpus=-1, - precision=config.precision, - num_nodes=config.num_nodes, - callbacks=[embedding_callback], - strategy="ddp", - ) - - # Select datset type based on data_file type - if config.data_file.suffix == ".h5": - dataset = FileBackedH5Dataset(config.data_file) - elif config.data_file.is_dir(): - dataset = FastaDataset(config.data_file, config.block_size, model.tokenizer) - else: - raise ValueError(f"Couldn't process data_file: {config.data_file}") - - # dataset = Subset(dataset, np.arange(512)) # for testing - dataloader = DataLoader( - dataset, - batch_size=config.batch_size, - num_workers=config.num_data_workers, - prefetch_factor=config.prefetch_factor, - pin_memory=config.pin_memory, - ) - - print(f"Running inference with dataset length {len(dataloader)}") - trainer.predict(model, dataloaders=dataloader, return_predictions=False) - print("Done") - - # This approch has a bug since global_rank is not a single process - # if trainer.is_global_zero: - # gather_embeddings(tmp_embeddings_dir, config.embeddings_out_path) - # shutil.rmtree(tmp_embeddings_dir) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("-c", "--config", required=True) - parser.add_argument("-a", "--attention", action="store_true") - parser.add_argument("-l", "--logits", action="store_true") - args = parser.parse_args() - assert not (args.attention and args.logits) - config = InferenceConfig.from_yaml(args.config) - main(config) diff --git a/genslm/cmdline/run_inference.py b/genslm/cmdline/run_inference.py new file mode 100644 index 00000000..0c87e25f --- /dev/null +++ b/genslm/cmdline/run_inference.py @@ -0,0 +1,359 @@ +import functools +import hashlib +import os +import uuid +from argparse import ArgumentParser +from pathlib import Path +from typing import Any, Dict, List, Tuple, Union + +import h5py +import numpy as np +import pytorch_lightning as pl +import torch +import torch.multiprocessing as mp +from natsort import natsorted +from pytorch_lightning.callbacks import Callback +from torch.utils.data import DataLoader, Dataset # Subset +from tqdm import tqdm +from transformers import PreTrainedTokenizerFast + +from genslm.config import BaseSettings, path_validator +from genslm.inference import GenSLM +from genslm.utils import read_fasta_only_seq + + +class InferenceConfig(BaseSettings): + # Input files + model_id: str = "genslm_25M_patric" + """The genslm model to load.""" + model_cache_dir: Path + """The directory of the model weights.""" + data_file: Path + """Data file to run inference on (HDF5).""" + output_path: Path + """Directory to write embeddings, attentions, logits to.""" + + # Which outputs to generate + layer_bounds: Union[Tuple[int, int], List[int]] = (0, -1) + """Which layers to generate data for, all by default.""" + output_embeddings: bool = True + """Whether or not to generate and save embeddings.""" + output_attentions: bool = False + """Whether or not to generate and save attentions.""" + output_logits: bool = False + """Whether or not to generate and save logits.""" + + # Run time settings + num_nodes: int = 1 + """Number of nodes to use for inference.""" + precision: int = 16 + """Model precision.""" + batch_size: int = 32 + """Batch size to use for inference.""" + num_data_workers: int = 4 + """Number of subprocesses to use for data loading.""" + prefetch_factor: int = 2 + """Number of batches loaded in advance by each worker.""" + pin_memory: bool = True + """If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them.""" + + # validators + _data_file_exists = path_validator("data_file") + _model_cache_dir_exists = path_validator("model_cache_dir") + + +class InferenceSequenceDataset(Dataset): + """Dataset initialized from fasta files.""" + + def __init__( + self, + fasta_path: Path, + seq_length: int, + tokenizer: PreTrainedTokenizerFast, + kmer_size: int = 3, + ): + + # Read all fasta files into memory as strings + self.raw_sequences = self.read_sequences(fasta_path) + # Quick transformation to group sequences by kmers + self.sequences = [ + self.group_by_kmer(seq, kmer_size) for seq in self.raw_sequences + ] + + # Define tokenizer function, but wait to tokenize + # until a specific batch is requested + self.tokenizer_fn = functools.partial( + tokenizer, + max_length=seq_length, + padding="max_length", + truncation=True, + return_tensors="pt", + ) + + @staticmethod + def read_sequences(fasta_path: Path) -> List[str]: + sequences = [] + if fasta_path.is_dir(): + fasta_files = natsorted(fasta_path.glob("*.fasta")) + for fasta_file in tqdm(fasta_files, desc="Reading fasta files..."): + sequences.extend(read_fasta_only_seq(fasta_file)) + else: + sequences = read_fasta_only_seq(fasta_path) + return sequences + + @staticmethod + def group_by_kmer(seq: str, kmer: int) -> str: + return " ".join(seq[i : i + kmer] for i in range(0, len(seq), kmer)).upper() + + def __len__(self) -> int: + return len(self.sequences) + + def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: + raw_seq = self.raw_sequences[idx] + seq = self.sequences[idx] + batch_encoding = self.tokenizer_fn(seq) + # Squeeze so that batched tensors end up with (batch_size, seq_length) + # instead of (batch_size, 1, seq_length) + sample = { + "input_ids": batch_encoding["input_ids"].squeeze(), + "attention_mask": batch_encoding["attention_mask"], + "indices": torch.from_numpy(np.array([idx])), + "seq_lens": batch_encoding["attention_mask"].sum(1), + # Need raw string for hashing + "na_hash": hashlib.md5(raw_seq.encode("utf-8")).hexdigest(), + } + return sample + + +class OutputsCallback(Callback): + def __init__( + self, + save_dir: Path = Path("./outputs"), + layer_bounds: Tuple[int, int] = (0, -1), + output_embeddings: bool = True, + output_attentions: bool = False, + output_logits: bool = False, + ) -> None: + self.rank_label = uuid.uuid4() + + self.output_attentions = output_attentions + self.output_logits = output_logits + self.output_embeddings = output_embeddings + self.save_dir = save_dir + self.save_dir.mkdir(parents=True, exist_ok=True) + + if isinstance(layer_bounds, tuple): + self.layer_lb, self.layer_ub = layer_bounds + self.layers = None + elif isinstance(layer_bounds, list): + self.layer_lb, self.layer_ub = None, None + self.layers = layer_bounds + + # Embeddings: Key layer-id, value embedding array + self.attentions, self.indices, self.na_hashes = [], [], [] + + self.h5embeddings_open: Dict[int, h5py.File] = {} + self.h5logit_file = None + + self.h5_kwargs = { + # "compression": "gzip", + # "compression_opts": 4, Compression is too slow for current impl + "fletcher32": True, + } + + def on_predict_start( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + # Plus one for embedding layer + num_hidden_layers = pl_module.model.model.config.num_hidden_layers + 1 + + if self.layer_lb is not None and self.layer_lb < 0: + self.layer_lb = num_hidden_layers + self.layer_lb + if self.layer_ub is not None and self.layer_ub < 0: + self.layer_ub = num_hidden_layers + self.layer_ub + + if self.layers is None: + self.layers = list(range(self.layer_lb, self.layer_ub)) + + for ind in range(len(self.layers)): + layer_num = self.layers[ind] + if layer_num < 0: + self.layers[ind] = num_hidden_layers + layer_num + + if self.output_logits: + self.h5logit_file = h5py.File( + self.save_dir / f"logits-{self.rank_label}.h5", "w" + ) + self.h5logit_file.create_group("logits") + + def on_predict_batch_end( + self, + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + outputs: Any, + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + # outputs.hidden_states: (layer, batch_size, sequence_length, hidden_size) + seq_lens = batch["seq_lens"].detach().cpu().numpy().reshape(-1) + fasta_inds = batch["indices"].detach().cpu().numpy().reshape(-1) + + if self.output_attentions: + attend = torch.sum(outputs.attentions[0].detach().cpu().squeeze(), dim=0) + self.attentions.append(attend) + + if self.output_logits: + logits = outputs.logits.detach().cpu().numpy() + for logit, seq_len, fasta_ind in zip(logits, seq_lens, fasta_inds): + self.h5logit_file["logits"].create_dataset( + f"{fasta_ind}", + data=logit[:seq_len], + **self.h5_kwargs, + ) + + if self.output_embeddings: + for layer, embeddings in enumerate(outputs.hidden_states): + # User specified list of layers to take + if layer not in self.layers: + continue + + h5_file = self.h5embeddings_open.get(layer) + if h5_file is None: + name = ( + self.save_dir / f"embeddings-layer-{layer}-{self.rank_label}.h5" + ) + h5_file = h5py.File(name, "w") + h5_file.create_group("embeddings") + self.h5embeddings_open[layer] = h5_file + + embed = embeddings.detach().cpu().numpy() + for emb, seq_len, fasta_ind in zip(embed, seq_lens, fasta_inds): + h5_file["embeddings"].create_dataset( + f"{fasta_ind}", + data=emb[:seq_len], + **self.h5_kwargs, + ) + + h5_file.flush() + + self.na_hashes.extend(batch["na_hash"]) + self.indices.append(batch["indices"].detach().cpu()) + + def on_predict_end( + self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" + ) -> None: + + self.indices = torch.cat(self.indices).numpy().squeeze() + + if self.output_logits: + self.h5logit_file.create_dataset( + "fasta-indices", data=self.indices, **self.h5_kwargs + ) + self.h5logit_file.create_dataset( + "na-hashes", data=self.na_hashes, **self.h5_kwargs + ) + self.h5logit_file.close() + + if self.output_embeddings: + # Write indices to h5 files to map embeddings back to fasta file + for h5_file in self.h5embeddings_open.values(): + h5_file.create_dataset( + "fasta-indices", data=self.indices, **self.h5_kwargs + ) + h5_file.create_dataset( + "na-hashes", data=self.na_hashes, **self.h5_kwargs + ) + + # Close all h5 files + for h5_file in self.h5embeddings_open.values(): + h5_file.close() + + +class LightningGenSLM(pl.LightningModule): + """Lightning wrapper to facilitate distributed prediction.""" + + def __init__(self, model: GenSLM) -> None: + super().__init__() + self.model = model + + def forward(self, *args, **kwargs) -> Any: + return self.model(*args, **kwargs) + + def predict_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Any: + return self(batch["input_ids"], batch["attention_mask"]) + + +def main(config: InferenceConfig) -> None: + # Setup torch environment + os.environ["TOKENIZERS_PARALLELISM"] = "true" + os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE" + # Potential polaris fix for connection reset error + mp.set_start_method("spawn") + pl.seed_everything(42) + + # Load GenSLM model and inject into pytorch lightning + model = GenSLM(config.model_id, config.model_cache_dir) + # Set the default kwarg values once + model.forward = functools.partial( + model.forward, + output_hidden_states=config.output_embeddings, + output_attentions=config.output_attentions, + ) + ptl_model = LightningGenSLM(model) + + # Create callback to save model outputs to disk + outputs_callback = OutputsCallback( + save_dir=config.output_path, + layer_bounds=config.layer_bounds, + output_embeddings=config.output_embeddings, + output_attentions=config.output_attentions, + output_logits=config.output_logits, + ) + + # Use pytorch lightning trainer to take advantage of distribution strategies + trainer = pl.Trainer( + gpus=-1, + precision=config.precision, + num_nodes=config.num_nodes, + callbacks=[outputs_callback], + strategy="ddp", + logger=False, # Avoid lightning_logs dir + max_epochs=-1, # Avoid warning + ) + + # This dataset loads each sequence from each fasta file into memory + # as strings on each rank and then tokenizes on-the-fly. + dataset = InferenceSequenceDataset( + config.data_file, model.seq_length, model.tokenizer + ) + # dataset = Subset(dataset, np.arange(512)) # for testing + dataloader = DataLoader( + dataset, + batch_size=config.batch_size, + num_workers=config.num_data_workers, + prefetch_factor=config.prefetch_factor, + pin_memory=config.pin_memory, + ) + + if trainer.is_global_zero: + print(f"Running inference with dataset length {len(dataloader)}") + if config.output_embeddings: + print("Generating embeddings values...") + if config.output_attentions: + print("Generating attention values...") + if config.output_logits: + print("Generating logit values...") + + trainer.predict(ptl_model, dataloaders=dataloader, return_predictions=False) + + if trainer.is_global_zero: + print("Done") + + +if __name__ == "__main__": + parser = ArgumentParser() + parser.add_argument("-c", "--config", required=True) + args = parser.parse_args() + config = InferenceConfig.from_yaml(args.config) + main(config) diff --git a/genslm/cmdline/write_fasta_files.py b/genslm/cmdline/write_fasta_files.py deleted file mode 100644 index 5a99407a..00000000 --- a/genslm/cmdline/write_fasta_files.py +++ /dev/null @@ -1,13 +0,0 @@ -from argparse import ArgumentParser -from pathlib import Path - -from genslm.dataset import write_individual_fasta_files - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("-f", "--fasta", type=Path, required=True) - parser.add_argument("-o", "--output_dir", type=Path, required=True) - parser.add_argument("-n", "--num_workers", type=int, default=1) - args = parser.parse_args() - - write_individual_fasta_files(args.fasta, args.output_dir, args.num_workers) diff --git a/genslm/config.py b/genslm/config.py index 8f2648c3..9450cf3d 100644 --- a/genslm/config.py +++ b/genslm/config.py @@ -1,4 +1,4 @@ -"""Model configuration.""" +"""Configuration.""" import json import os import warnings @@ -16,6 +16,21 @@ PathLike = Union[str, Path] +def _resolve_path_exists(value: Optional[Path]) -> Optional[Path]: + if value is None: + return None + p = value.resolve() + if not p.exists(): + raise FileNotFoundError(p) + return p + + +def path_validator(field: str) -> classmethod: + decorator = validator(field, allow_reuse=True) + _validator = decorator(_resolve_path_exists) + return _validator + + class BaseSettings(_BaseSettings): """Base settings to provide an easier interface to read/write YAML files.""" diff --git a/genslm/dataset.py b/genslm/dataset.py index 5e9aa17e..e40c3d40 100644 --- a/genslm/dataset.py +++ b/genslm/dataset.py @@ -11,7 +11,6 @@ import numpy as np import torch from Bio import SeqIO # type: ignore[import] -from natsort import natsorted from torch.utils.data import Dataset from tqdm import tqdm from transformers import BatchEncoding, PreTrainedTokenizerFast @@ -19,78 +18,13 @@ from genslm.config import PathLike +# TODO: Remove dependecy for BioPython +# NOTE: Legacy H5 conversion code def group_by_kmer(s: SeqIO.SeqRecord, n: int) -> str: seq = str(s.seq).upper() # need to make sure it's in upper case return " ".join(seq[i : i + n] for i in range(0, len(seq), n)) -def _write_fasta_file(seq: SeqIO.SeqRecord, output_file: Path) -> None: - SeqIO.write(seq, str(output_file), "fasta") - - -def write_individual_fasta_files( - fasta_file: Path, output_dir: Path, num_workers: int = 1 -) -> None: - output_dir.mkdir(exist_ok=True) - seqs = list(SeqIO.parse(fasta_file, "fasta")) - output_files = [output_dir / f"sequence-{i}.fasta" for i in range(len(seqs))] - print(f"Number of sequences: {len(seqs)}") - chunksize = max(1, len(seqs) // num_workers) - with ProcessPoolExecutor(max_workers=num_workers) as executor: - for _ in executor.map( - _write_fasta_file, seqs, output_files, chunksize=chunksize - ): - pass - - -class FastaDataset(Dataset): - def __init__( - self, - fasta_dir: PathLike, - block_size: int, - tokenizer: PreTrainedTokenizerFast, - kmer_size: int = 3, - small_subset: int = 0, - ): - self.block_size = block_size - self.tokenizer = tokenizer - self.kmer_size = kmer_size - - self.files = natsorted(Path(fasta_dir).glob("*.fasta")) - - # default of zero will not call this logic - if small_subset: - self.files = self.files[:small_subset] - - # Cache the samples in memory - self.samples: Dict[int, Dict[str, torch.Tensor]] = {} - - def __len__(self) -> int: - return len(self.files) - - def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: - # tokenize on the fly - try: - return self.samples[idx] - except KeyError: - sequence = list(SeqIO.parse(self.files[idx], "fasta"))[0] - batch_encoding = self.tokenizer( - group_by_kmer(sequence, self.kmer_size), - max_length=self.block_size, - padding="max_length", - return_tensors="pt", - ) - # Squeeze so that batched tensors end up with (batch_size, seq_length) - # instead of (batch_size, 1, seq_length) - sample = { - "input_ids": batch_encoding["input_ids"].squeeze(), - "attention_mask": batch_encoding["attention_mask"], - "indices": torch.from_numpy(np.array([idx])), - } - self.samples[idx] = sample - return sample - - class H5PreprocessMixin: @staticmethod def train_val_test_split( diff --git a/genslm/hpc/templates/polaris.j2 b/genslm/hpc/templates/polaris.j2 index de051837..e6b7d1dd 100644 --- a/genslm/hpc/templates/polaris.j2 +++ b/genslm/hpc/templates/polaris.j2 @@ -53,7 +53,7 @@ conda activate genslm echo "$(df -h /local/scratch)" # NCCL settings -export NCCL_DEBUG=info +export NCCL_DEBUG=WARN export NCCL_NET_GDR_LEVEL=PHB # For applications that internally handle binding MPI/OpenMP processes to GPUs diff --git a/genslm/inference.py b/genslm/inference.py index 0b1747bb..51a8b6e4 100644 --- a/genslm/inference.py +++ b/genslm/inference.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Dict +from typing import Any, Dict, Union import torch import torch.nn as nn @@ -10,6 +10,8 @@ import genslm +PathLike = Union[str, Path] + class GenSLM(nn.Module): @@ -44,14 +46,14 @@ class GenSLM(nn.Module): }, } - def __init__(self, model_id: str, model_cache_dir: str = ".") -> None: + def __init__(self, model_id: str, model_cache_dir: PathLike = ".") -> None: """GenSLM inference module. Parameters ---------- model_id : str A model ID corresponding to a pre-trained model. (e.g., genslm_25M_patric) - model_cache_dir : str, optional + model_cache_dir : PathLike, optional Directory where model weights have been downloaded to (defaults to current working directory). If model weights are not found, then they will be downloaded, by default "." @@ -62,7 +64,7 @@ def __init__(self, model_id: str, model_cache_dir: str = ".") -> None: If model_id is invalid. """ super().__init__() - self.model_cache_dir = model_cache_dir + self.model_cache_dir = Path(model_cache_dir) self.model_info = self.MODELS.get(model_id) if self.model_info is None: valid_model_ids = list(self.MODELS.keys()) @@ -87,7 +89,7 @@ def configure_model(self) -> AutoModelForCausalLM: base_config = AutoConfig.from_pretrained(self.model_info["config"]) model = AutoModelForCausalLM.from_config(base_config) - weight_path = Path(self.model_cache_dir) / self.model_info["weights"] + weight_path = self.model_cache_dir / self.model_info["weights"] if not weight_path.exists(): # TODO: Implement model download raise NotImplementedError @@ -105,10 +107,7 @@ def configure_tokenizer(self) -> PreTrainedTokenizerFast: return tokenizer def forward( - self, - input_ids: torch.Tensor, - attention_mask: torch.Tensor, - **kwargs: Dict[str, Any], + self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs: Any ) -> ModelOutput: return self.model( input_ids, labels=input_ids, attention_mask=attention_mask, **kwargs diff --git a/genslm/model.py b/genslm/model.py index c6d0c11a..8a418f5e 100644 --- a/genslm/model.py +++ b/genslm/model.py @@ -2,11 +2,8 @@ import os import warnings from argparse import ArgumentParser -from pathlib import Path -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List -import numpy as np -import numpy.typing as npt import pytorch_lightning as pl import torch import torch.multiprocessing as mp @@ -23,7 +20,6 @@ from tokenizers import Tokenizer from torch.optim.lr_scheduler import ReduceLROnPlateau from torch.utils.data import DataLoader -from tqdm import tqdm from transformers import ( AutoConfig, AutoModelForCausalLM, @@ -34,12 +30,10 @@ from genslm.blast import BLASTCallback from genslm.config import ModelSettings, PathLike, throughput_config -from genslm.dataset import CachingH5Dataset, FileBackedH5Dataset +from genslm.dataset import CachingH5Dataset from genslm.utils import ( LoadDeepSpeedStrategy, LoadPTCheckpointStrategy, - ModelLoadStrategy, - OutputsCallback, PerplexityCallback, SequenceGenerationCallback, ThroughputMonitor, @@ -441,106 +435,9 @@ def train(cfg: ModelSettings) -> None: # noqa print("Completed training.") -def generate_embeddings( - model: DNATransformer, dataloader: DataLoader, compute_mean: bool = False -) -> np.ndarray: - """Output embedding array of shape (num_seqs, block_size, hidden_dim).""" - embeddings = [] - for batch in tqdm(dataloader): - for key in ["input_ids", "attention_mask"]: - batch[key] = batch[key].cuda() - outputs = model(batch, output_hidden_states=True) - # outputs.hidden_states: (batch_size, sequence_length, hidden_size) - emb = outputs.hidden_states[0].detach().cpu().numpy() - if compute_mean: - # Compute average over sequence length - emb = np.mean(emb, axis=1) - embeddings.append(emb) - - embeddings = np.concatenate(embeddings) # type: ignore - return embeddings - - -# TODO: Make separate files for training and inference -def inference( - cfg: ModelSettings, - model_load_strategy: ModelLoadStrategy, - fasta_file: str, - output_path: Optional[PathLike] = None, -) -> npt.ArrayLike: - """Output embedding array of shape (num_seqs, hidden_dim).""" - model: DNATransformer = model_load_strategy.get_model(DNATransformer) - - embedding_callback = OutputsCallback() - trainer = pl.Trainer( - gpus=-1, - # default_root_dir=str(cfg.checkpoint_dir), - # strategy=DeepSpeedStrategy(stage=3), - strategy="ddp", - # accumulate_grad_batches=cfg.accumulate_grad_batches, - # num_sanity_val_steps=2, - precision=cfg.precision, - num_nodes=cfg.num_nodes, - callbacks=[embedding_callback], - ) - - dataset = FileBackedH5Dataset(fasta_file) - dataloader = model.get_dataloader(dataset, shuffle=False, drop_last=False) - print(f"Running inference with dataset length {len(dataloader)}") - trainer.predict(model, dataloaders=dataloader, return_predictions=False) - - embeddings = embedding_callback.embeddings - print(f"Embeddings shape: {embeddings.shape}") - if output_path: - assert Path(output_path).suffix == ".npy" - np.save(output_path, embeddings) - return embeddings - - -def test(cfg: ModelSettings) -> None: - """Run test dataset after loading from checkpoint""" - if cfg.load_pt_checkpoint is not None: - load_strategy = LoadPTCheckpointStrategy(cfg.load_pt_checkpoint, cfg=cfg) - model = load_strategy.get_model(DNATransformer) - elif cfg.load_ds_checkpoint is not None: - # Check if loading from checkpoint - this assumes that you're - # loading from a sharded DeepSpeed checkpoint!!! - load_strategy = LoadDeepSpeedStrategy(cfg.load_ds_checkpoint, cfg=cfg) - model = load_strategy.get_model(DNATransformer) - print(f"Loaded existing model at checkpoint {cfg.load_ds_checkpoint}....") - else: - print("WARNING: running test on randomly initialized architecture") - model = DNATransformer(cfg) - - trainer = pl.Trainer( - gpus=-1, - default_root_dir=str(cfg.checkpoint_dir), - strategy=DeepSpeedStrategy(stage=3), - accumulate_grad_batches=cfg.accumulate_grad_batches, - num_sanity_val_steps=2, - precision=cfg.precision, - max_epochs=cfg.epochs, - num_nodes=cfg.num_nodes, - ) - - output = trainer.test(model) - print(output) - - if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("-c", "--config", required=True) - parser.add_argument("--mode", default="train") - parser.add_argument("--inference_fasta", default="") - parser.add_argument("--inference_model_load", default="pt", help="deepspeed or pt") - parser.add_argument( - "--inference_pt_file", - type=Path, - help="Path to pytorch model weights if inference_model_load==pt", - ) - parser.add_argument( - "--inference_output_path", default="./embeddings.npy", type=Path - ) args = parser.parse_args() config = ModelSettings.from_yaml(args.config) @@ -562,35 +459,4 @@ def test(cfg: ModelSettings) -> None: # new config definition config = throughput_config(config) - if args.mode == "train": - train(config) - elif args.mode == "test": - test(config) - elif args.mode == "inference" and not config.compute_throughput: - if not args.inference_fasta: - raise ValueError("Must provide a fasta file to run inference on.") - - if args.inference_output_path.exists(): - raise FileExistsError( - f"inference_output_path: {args.inference_output_path} already exists!" - ) - - if args.inference_model_load == "pt": - model_strategy = LoadPTCheckpointStrategy( - args.inference_pt_file, cfg=config - ) - elif args.inference_model_load == "deepspeed": - if config.load_ds_checkpoint is None: - raise ValueError( - "load_from_checkpoint_dir must be set in the config file" - ) - model_strategy = LoadDeepSpeedStrategy( - config.load_ds_checkpoint, cfg=config - ) - else: - raise ValueError( - f"Invalid inference_model_load {args.inference_model_load}" - ) - inference( - config, model_strategy, args.inference_fasta, args.inference_output_path - ) + train(config) diff --git a/genslm/utils.py b/genslm/utils.py index 37d83103..7d85db2d 100644 --- a/genslm/utils.py +++ b/genslm/utils.py @@ -1,9 +1,9 @@ +import re import time -import uuid from abc import ABC, abstractmethod from pathlib import Path from statistics import mean -from typing import Any, Dict, List, Optional, Set, Type +from typing import Any, Dict, List, Optional, Set, Type, Union import numpy as np import pytorch_lightning as pl @@ -11,6 +11,7 @@ from Bio import SeqIO # type: ignore[import] from Bio.Seq import Seq # type: ignore[import] from Bio.SeqRecord import SeqRecord # type: ignore[import] +from pydantic import BaseModel from pytorch_lightning.callbacks import Callback from pytorch_lightning.utilities.deepspeed import ( convert_zero_checkpoint_to_fp32_state_dict, @@ -20,9 +21,54 @@ from transformers import PreTrainedTokenizerFast # , StoppingCriteriaList from transformers import StoppingCriteria +PathLike = Union[str, Path] + STOP_CODONS = {"TAA", "TAG", "TGA"} +class Sequence(BaseModel): + sequence: str + """Biological sequence (Nucleotide sequence).""" + tag: str + """Sequence description tag.""" + + +def read_fasta(fasta_file: PathLike) -> List[Sequence]: + """Reads fasta file sequences and description tags into dataclass.""" + text = Path(fasta_file).read_text() + pattern = re.compile("^>", re.MULTILINE) + non_parsed_seqs = re.split(pattern, text)[1:] + lines = [ + line.replace("\n", "") for seq in non_parsed_seqs for line in seq.split("\n", 1) + ] + + return [ + Sequence(sequence=seq, tag=tag) for seq, tag in zip(lines[1::2], lines[::2]) + ] + + +def read_fasta_only_seq(fasta_file: PathLike) -> List[str]: + """Reads fasta file sequences without description tag.""" + text = Path(fasta_file).read_text() + pattern = re.compile("^>", re.MULTILINE) + non_parsed_seqs = re.split(pattern, text)[1:] + lines = [ + line.replace("\n", "") for seq in non_parsed_seqs for line in seq.split("\n", 1) + ] + + return lines[1::2] + + +def write_fasta( + sequences: Union[Sequence, List[Sequence]], fasta_file: PathLike, mode: str = "w" +) -> None: + """Write or append sequences to a fasta file.""" + seqs = [sequences] if isinstance(sequences, Sequence) else sequences + with open(fasta_file, mode) as f: + for seq in seqs: + f.write(f">{seq.tag}\n{seq.sequence}\n") + + class FoundStopCodonCriteria(StoppingCriteria): # type: ignore[misc] def __init__(self, tokenizer: PreTrainedTokenizerFast) -> None: self.tokenizer = tokenizer @@ -538,77 +584,3 @@ def on_validation_epoch_end( self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" ) -> None: self._log_perplexity(pl_module, self.val_name, train=False, on_epoch=True) - - -class OutputsCallback(Callback): - def __init__( - self, - compute_mean: bool = True, - save_dir: Path = Path("./outputs"), - output_attentions=False, - output_logits=False, - ) -> None: - self.compute_mean = compute_mean - self.output_attentions = output_attentions - self.output_logits = output_logits - self.save_dir = save_dir - self.embeddings, self.attentions, self.logits, self.indices = [], [], [], [] - save_dir.mkdir(exist_ok=True) - - def _gather_data(self) -> None: - if self.output_attentions: - self.attentions = torch.stack(self.attentions).numpy() - print(self.attentions.shape) - elif self.output_logits: - self.logits = torch.cat(self.logits).numpy() - else: - self.embeddings = torch.cat(self.embeddings).numpy() - self.indices = torch.cat(self.indices).numpy().squeeze() - - def _save_embeddings(self) -> None: - rank_label = uuid.uuid4() - if self.output_attentions: - np.save(self.save_dir / f"attentions-{rank_label}.npy", self.attentions) - elif self.output_logits: - np.save(self.save_dir / f"logits-{rank_label}.npy", self.logits) - else: - np.save(self.save_dir / f"embeddings-{rank_label}.npy", self.embeddings) - np.save(self.save_dir / f"indices-{rank_label}.npy", self.indices) - - def on_predict_start( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> None: - self.embeddings, self.attentions, self.indices = [], [], [] - - def on_predict_batch_end( - self, - trainer: "pl.Trainer", - pl_module: "pl.LightningModule", - outputs: Any, - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - # outputs.hidden_states: (batch_size, sequence_length, hidden_size) - if self.output_attentions: - attend = torch.sum(outputs.attentions[0].detach().cpu().squeeze(), dim=0) - self.attentions.append(attend) - elif self.output_logits: - logits = outputs.logits.detach().cpu() - self.logits.append(logits) - else: - if self.compute_mean: - # Compute average over sequence length - embed = outputs.hidden_states[0].detach().mean(dim=1).cpu() - else: - embed = outputs.hidden_states[0].detach().cpu() - - self.embeddings.append(embed) - - self.indices.append(batch["indices"].detach().cpu()) - - def on_predict_end( - self, trainer: "pl.Trainer", pl_module: "pl.LightningModule" - ) -> None: - self._gather_data() - self._save_embeddings() diff --git a/tests/test_dataset.py b/tests/test_dataset.py index cae88973..13b5d514 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -1,10 +1,12 @@ import itertools + import numpy as np from tokenizers import Tokenizer from torch.utils.data import DataLoader -from genslm import GenSLM, SequenceDataset from transformers import PreTrainedTokenizerFast +from genslm import GenSLM, SequenceDataset + def generate_random_sequence(min_length: int = 10, max_length: int = 2020) -> str: """Generate a sequence with random codons for testing.""" @@ -37,9 +39,11 @@ def test_dataset_length(): # Initialize dataset seq_length = 2048 dataset = SequenceDataset(sequences, seq_length, tokenizer, verbose=False) - dataloader = DataLoader(dataset) + dataloader = DataLoader(dataset, batch_size=1, shuffle=False) # Basic sanity check assert len(dataset) == num_seqs - for _ in dataloader: - break + for batch, seq in zip(dataloader, sequences): + batch_seq_len = batch["attention_mask"].sum().item() + # If exactly equal, no unknown tokens were added + assert batch_seq_len == len(seq) // 3