Skip to content

Commit

Permalink
Merge pull request #423 from OmicsML/celltype_annotation_automl
Browse files Browse the repository at this point in the history
Celltype annotation automl
  • Loading branch information
JiayuanDing100 authored Jun 17, 2024
2 parents c69d017 + d4f410f commit 3620ce1
Show file tree
Hide file tree
Showing 106 changed files with 5,815 additions and 805 deletions.
10 changes: 7 additions & 3 deletions dance/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@
import anndata
import mudata
import numpy as np
import omegaconf
import pandas as pd
import scipy.sparse as sp
import torch

from dance import logger
from dance.typing import Any, Dict, FeatType, Iterator, List, Literal, Optional, Sequence, Tuple, Union
from dance.typing import Any, Dict, FeatType, Iterator, List, ListConfig, Literal, Optional, Sequence, Tuple, Union


def _ensure_iter(val: Optional[Union[List[str], str]]) -> Iterator[Optional[str]]:
Expand All @@ -34,7 +35,7 @@ def _check_types_and_sizes(types, sizes):
raise TypeError(f"Found mixed types: {types}. Input configs must be either all str or all lists.")
elif ((type_ := types.pop()) == list) and (len(sizes) > 1):
raise ValueError(f"Found mixed sizes lists: {sizes}. Input configs must be of same length.")
elif type_ not in (list, str):
elif type_ not in (list, str, ListConfig):
raise TypeError(f"Unknownn type {type_} found in config.")


Expand Down Expand Up @@ -240,7 +241,7 @@ def set_config_from_dict(self, config_dict: Dict[str, Any], *, overwrite: bool =
label_configs = [j for i, j in config_dict.items() if i in self._LABEL_CONFIGS and j is not None]

# Check type and length consistencies for feature and label configs
for i in (feature_configs, label_configs):
for i in [feature_configs, label_configs]:
types = set(map(type, i))
sizes = set(map(len, i))
_check_types_and_sizes(types, sizes)
Expand All @@ -249,6 +250,9 @@ def set_config_from_dict(self, config_dict: Dict[str, Any], *, overwrite: bool =
for config_key, config_val in config_dict.items():
# New config
if config_key not in self.config:
if isinstance(config_val, ListConfig):
config_val = omegaconf.OmegaConf.to_object(config_val)
logger.warning(f"transform ListConfig {config_val} to List")
self.config[config_key] = config_val
logger.info(f"Setting config {config_key!r} to {config_val!r}")
continue
Expand Down
190 changes: 129 additions & 61 deletions dance/datasets/singlemodality.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import pandas as pd
import scanpy as sc
from scipy.sparse import csr_matrix
from sklearn.model_selection import train_test_split

from dance import logger
from dance.data import Data
Expand Down Expand Up @@ -52,19 +53,42 @@ class CellTypeAnnotationDataset(BaseDataset):

def __init__(self, full_download=False, train_dataset=None, test_dataset=None, species=None, tissue=None,
valid_dataset=None, train_dir="train", test_dir="test", valid_dir="valid", map_path="map",
data_dir="./"):
data_dir="./", train_as_valid=False, val_size=0.2):
super().__init__(data_dir, full_download)

self.data_dir = data_dir
self.train_dataset = train_dataset
self.test_dataset = test_dataset
self.valid_dataset = train_dataset if valid_dataset is None else valid_dataset
self.species = species
self.tissue = tissue
self.train_dir = train_dir
self.test_dir = test_dir
self.valid_dir = valid_dir
self.map_path = map_path
self.train_as_valid = train_as_valid
self.bench_url_dict = self.BENCH_URL_DICT.copy()
self.available_data = self.AVAILABLE_DATA.copy()
self.valid_dataset = valid_dataset
if valid_dataset is None and self.train_as_valid:
self.valid_dataset = train_dataset
self.train2valid()
self.val_size = val_size

def train2valid(self):
logger.info("Copy train_dataset and use it as valid_dataset")
temp_ava_data = self.available_data.copy()
temp_ben_url_dict = self.bench_url_dict.copy()
for data in self.available_data:
if data["split"] == "train":
end_data = data.copy()
end_data['split'] = 'valid'
temp_ava_data.append(end_data)

for k, v in self.bench_url_dict.items():
if k.startswith("train"):
temp_ben_url_dict[k.replace("train", "valid", 1)] = v
self.available_data = temp_ava_data
self.bench_url_dict = temp_ben_url_dict

def download_all(self):
if self.is_complete():
Expand All @@ -87,7 +111,8 @@ def download_all(self):

def get_all_filenames(self, filetype: str = "csv", feat_suffix: str = "data", label_suffix: str = "celltype"):
filenames = []
for id in self.train_dataset + self.test_dataset + self.valid_dataset:
for id in self.train_dataset + self.test_dataset + (self.valid_dataset
if self.valid_dataset is not None else []):
filenames.append(f"{self.species}_{self.tissue}{id}_{feat_suffix}.{filetype}")
filenames.append(f"{self.species}_{self.tissue}{id}_{label_suffix}.{filetype}")
return filenames
Expand All @@ -98,7 +123,7 @@ def download(self, download_map=True):

filenames = self.get_all_filenames()
# Download training and testing data
for name, url in self.BENCH_URL_DICT.items():
for name, url in self.bench_url_dict.items():
parts = name.split("_") # [train|test]_{species}_{tissue}{id}_[celltype|data].csv
filename = "_".join(parts[1:])
if filename in filenames:
Expand All @@ -115,7 +140,6 @@ def is_complete_all(self):
check = [
osp.join(self.data_dir, "train"),
osp.join(self.data_dir, "test"),
osp.join(self.data_dir, "valid"),
osp.join(self.data_dir, "pretrained")
]
for i in check:
Expand All @@ -126,7 +150,7 @@ def is_complete_all(self):

def is_complete(self):
"""Check if benchmarking data is complete."""
for name in self.BENCH_URL_DICT:
for name in self.bench_url_dict:
if any(i not in name for i in (self.species, self.tissue)):
continue
filename = name[name.find(self.species):]
Expand All @@ -150,58 +174,101 @@ def is_complete(self):
def _load_raw_data(self, ct_col: str = "Cell_type") -> Tuple[ad.AnnData, List[Set[str]], List[str], int]:
species = self.species
tissue = self.tissue
train_dataset_ids = self.train_dataset
test_dataset_ids = self.test_dataset
valid_dataset_ids = self.valid_dataset
data_dir = self.data_dir
train_dir = osp.join(data_dir, self.train_dir)
test_dir = osp.join(data_dir, self.test_dir)
valid_dir = osp.join(data_dir, self.valid_dir)
map_path = osp.join(data_dir, self.map_path, self.species)

# Load raw data
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
valid_feat_paths, valid_label_paths = self._get_data_paths(valid_dir, species, tissue, valid_dataset_ids)
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
train_feat, valid_feat, test_feat = (self._load_dfs(paths, transpose=True)
for paths in (train_feat_paths, valid_feat_paths, test_feat_paths))
train_label, valid_label, test_label = (self._load_dfs(paths)
for paths in (train_label_paths, valid_label_paths, test_label_paths))

# Combine features (only use features that are present in the training data)
train_size = train_feat.shape[0]
valid_size = valid_feat.shape[0]
feat_df = pd.concat(
train_feat.align(valid_feat, axis=1, join="left", fill_value=0) +
train_feat.align(test_feat, axis=1, join="left", fill_value=0)[1:]).fillna(0)
adata = ad.AnnData(feat_df, dtype=np.float32)

# Convert cell type labels and map test cell type names to train
cell_types = set(train_label[ct_col].unique())
idx_to_label = sorted(cell_types)
cell_type_mappings: Dict[str, Set[str]] = self.get_map_dict(map_path, tissue)
train_labels, valid_labels, test_labels = train_label[ct_col].tolist(), [], []
for i in valid_label[ct_col]:
valid_labels.append(i if i in cell_types else cell_type_mappings.get(i))
for i in test_label[ct_col]:
test_labels.append(i if i in cell_types else cell_type_mappings.get(i))
labels: List[Set[str]] = train_labels + valid_labels + test_labels

logger.debug("Mapped valid cell-types:")
for i, j, k in zip(valid_label.index, valid_label[ct_col], valid_labels):
logger.debug(f"{i}:{j}\t-> {k}")

logger.debug("Mapped test cell-types:")
for i, j, k in zip(test_label.index, test_label[ct_col], test_labels):
logger.debug(f"{i}:{j}\t-> {k}")

logger.info(f"Loaded expression data: {adata}")
logger.info(f"Number of training samples: {train_feat.shape[0]:,}")
logger.info(f"Number of valid samples: {valid_feat.shape[0]:,}")
logger.info(f"Number of testing samples: {test_feat.shape[0]:,}")
logger.info(f"Cell-types (n={len(idx_to_label)}):\n{pprint.pformat(idx_to_label)}")

return adata, labels, idx_to_label, train_size, valid_size
valid_feat = None
if self.valid_dataset is not None:
train_dataset_ids = self.train_dataset
test_dataset_ids = self.test_dataset
valid_dataset_ids = self.valid_dataset
data_dir = self.data_dir
train_dir = osp.join(data_dir, self.train_dir)
test_dir = osp.join(data_dir, self.test_dir)
valid_dir = osp.join(data_dir, self.valid_dir)
map_path = osp.join(data_dir, self.map_path, self.species)

# Load raw data
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
valid_feat_paths, valid_label_paths = self._get_data_paths(valid_dir, species, tissue, valid_dataset_ids)
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
train_feat, valid_feat, test_feat = (self._load_dfs(paths, transpose=True)
for paths in (train_feat_paths, valid_feat_paths, test_feat_paths))
train_label, valid_label, test_label = (self._load_dfs(paths)
for paths in (train_label_paths, valid_label_paths,
test_label_paths))
else:
train_dataset_ids = self.train_dataset
test_dataset_ids = self.test_dataset
data_dir = self.data_dir
train_dir = osp.join(data_dir, self.train_dir)
test_dir = osp.join(data_dir, self.test_dir)
map_path = osp.join(data_dir, self.map_path, self.species)
train_feat_paths, train_label_paths = self._get_data_paths(train_dir, species, tissue, train_dataset_ids)
test_feat_paths, test_label_paths = self._get_data_paths(test_dir, species, tissue, test_dataset_ids)
train_feat, test_feat = (self._load_dfs(paths, transpose=True)
for paths in (train_feat_paths, test_feat_paths))
train_label, test_label = (self._load_dfs(paths) for paths in (train_label_paths, test_label_paths))
if self.val_size > 0:
train_feat, valid_feat, train_label, valid_label = train_test_split(train_feat, train_label,
test_size=self.val_size)
if valid_feat is not None:
# Combine features (only use features that are present in the training data)
train_size = train_feat.shape[0]
valid_size = valid_feat.shape[0]
feat_df = pd.concat(
train_feat.align(valid_feat, axis=1, join="left", fill_value=0) +
train_feat.align(test_feat, axis=1, join="left", fill_value=0)[1:]).fillna(0)
adata = ad.AnnData(feat_df, dtype=np.float32)

# Convert cell type labels and map test cell type names to train
cell_types = set(train_label[ct_col].unique())
idx_to_label = sorted(cell_types)
cell_type_mappings: Dict[str, Set[str]] = self.get_map_dict(map_path, tissue)
train_labels, valid_labels, test_labels = train_label[ct_col].tolist(), [], []
for i in valid_label[ct_col]:
valid_labels.append(i if i in cell_types else cell_type_mappings.get(i))
for i in test_label[ct_col]:
test_labels.append(i if i in cell_types else cell_type_mappings.get(i))
labels: List[Set[str]] = train_labels + valid_labels + test_labels

logger.debug("Mapped valid cell-types:")
for i, j, k in zip(valid_label.index, valid_label[ct_col], valid_labels):
logger.debug(f"{i}:{j}\t-> {k}")

logger.debug("Mapped test cell-types:")
for i, j, k in zip(test_label.index, test_label[ct_col], test_labels):
logger.debug(f"{i}:{j}\t-> {k}")

logger.info(f"Loaded expression data: {adata}")
logger.info(f"Number of training samples: {train_feat.shape[0]:,}")
logger.info(f"Number of valid samples: {valid_feat.shape[0]:,}")
logger.info(f"Number of testing samples: {test_feat.shape[0]:,}")
logger.info(f"Cell-types (n={len(idx_to_label)}):\n{pprint.pformat(idx_to_label)}")

return adata, labels, idx_to_label, train_size, valid_size
else:
# Combine features (only use features that are present in the training data)
train_size = train_feat.shape[0]
feat_df = pd.concat(train_feat.align(test_feat, axis=1, join="left", fill_value=0)).fillna(0)
adata = ad.AnnData(feat_df, dtype=np.float32)

# Convert cell type labels and map test cell type names to train
cell_types = set(train_label[ct_col].unique())
idx_to_label = sorted(cell_types)
cell_type_mappings: Dict[str, Set[str]] = self.get_map_dict(map_path, tissue)
train_labels, test_labels = train_label[ct_col].tolist(), []
for i in test_label[ct_col]:
test_labels.append(i if i in cell_types else cell_type_mappings.get(i))
labels: List[Set[str]] = train_labels + test_labels

logger.debug("Mapped test cell-types:")
for i, j, k in zip(test_label.index, test_label[ct_col], test_labels):
logger.debug(f"{i}:{j}\t-> {k}")

logger.info(f"Loaded expression data: {adata}")
logger.info(f"Number of training samples: {train_feat.shape[0]:,}")
logger.info(f"Number of testing samples: {test_feat.shape[0]:,}")
logger.info(f"Cell-types (n={len(idx_to_label)}):\n{pprint.pformat(idx_to_label)}")

return adata, labels, idx_to_label, train_size, 0

def _raw_to_dance(self, raw_data):
adata, cell_labels, idx_to_label, train_size, valid_size = raw_data
Expand Down Expand Up @@ -290,9 +357,10 @@ def is_complete(self):
return osp.exists(self.data_path)

def _load_raw_data(self) -> Tuple[ad.AnnData, np.ndarray]:
with h5py.File(self.data_path, "r") as f:
x = np.array(f["X"])
y = np.array(f["Y"])
with open(self.data_path, "rb") as f_o:
with h5py.File(f_o, "r") as f:
x = np.array(f["X"])
y = np.array(f["Y"])
adata = ad.AnnData(x, dtype=np.float32)
return adata, y

Expand Down
11 changes: 8 additions & 3 deletions dance/datasets/spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class SpatialLIBDDataset(BaseDataset):

_DISPLAY_ATTRS = ("data_id", )
URL_DICT = {
"151510": "https://www.dropbox.com/sh/41h9brsk6my546x/AADa18mkJge-KQRTndRelTpMa?dl=0",
"151510": "https://www.dropbox.com/sh/41h9brsk6my546x/AADa18mkJge-KQRTndRelTpMa?dl=1",
"151507": "https://www.dropbox.com/sh/m3554vfrdzbwv2c/AACGsFNVKx8rjBgvF7Pcm2L7a?dl=1",
"151508": "https://www.dropbox.com/sh/tm47u3fre8692zt/AAAJJf8-za_Lpw614ft096qqa?dl=1",
"151509": "https://www.dropbox.com/sh/hihr7906vyirjet/AACslV5mKIkF2CF5QqE1LE6ya?dl=1",
Expand All @@ -47,11 +47,12 @@ class SpatialLIBDDataset(BaseDataset):
}
AVAILABLE_DATA = sorted(URL_DICT)

def __init__(self, root=".", full_download=False, data_id="151673", data_dir="data/spatial"):
def __init__(self, root=".", full_download=False, data_id="151673", data_dir="data/spatial", sample_file=None):
super().__init__(root, full_download)

self.data_id = data_id
self.data_dir = data_dir + "/{}".format(data_id)
self.sample_file = sample_file

def download_all(self):
logger.info(f"All data includes {len(self.URL_DICT)} datasets: {list(self.URL_DICT)}")
Expand Down Expand Up @@ -147,7 +148,11 @@ def _raw_to_dance(self, raw_data):
adata.obsm["spatial"] = xy.set_index(adata.obs_names)
adata.obsm["spatial_pixel"] = xy_pixel.set_index(adata.obs_names)
adata.uns["image"] = img

if self.sample_file is not None:
sample_file = osp.join(self.data_dir, self.sample_file)
with open(sample_file) as file:
sample_index = [int(line.strip()) for line in file]
adata = adata[sample_index]
data = Data(adata, train_size="all")
return data

Expand Down
8 changes: 4 additions & 4 deletions dance/metadata/clustering.csv
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
10X_PBMC,https://www.dropbox.com/s/pfunm27qzgfpj3u/10X_PBMC.h5?dl=1
mouse_lung_cell,https://dl.dropboxusercontent.com/scl/fi/6h4ewvj1n64mrppz56v7s/mouse_lung_cell.h5?rlkey=6snhzzkv6f7vmshkvne9leimu&dl=0
human_pbmc2_cell,https://dl.dropboxusercontent.com/scl/fi/c69gv5btxfvpcmkcqj4zx/human_pbmc_cell.h5?rlkey=jzi6u9qs2qf4nixr6a0n48mc0&dl=0
human_pbmc_cell,https://dl.dropboxusercontent.com/scl/fi/2by36reg6wjq6hxlytljx/human_ILCS_cell.h5?rlkey=mu4pz7quxspf9qgzx5wc4aaet&dl=0
human_ILCS_cell,https://dl.dropboxusercontent.com/scl/fi/2by36reg6wjq6hxlytljx/human_ILCS_cell.h5?rlkey=mu4pz7quxspf9qgzx5wc4aaet&dl=0
human_skin_cell,https://dl.dropboxusercontent.com/scl/fi/5gd3kcz307r42s7u3di3q/human_skin_cell.h5?rlkey=2hat0jeze2cn2uqnu4p7g7yhw&dl=0
mouse_ES_cell,https://www.dropbox.com/s/zbuku7oznvji8jk/mouse_ES_cell.h5?dl=1
mouse_bladder_cell,https://www.dropbox.com/s/xxtnomx5zrifdwi/mouse_bladder_cell.h5?dl=1
mouse_kidney_10x,https://dl.dropboxusercontent.com/scl/fi/b9b4dr82hcdwxykv8e53f/mouse_kidney_10x.h5?rlkey=aniqqz731klpmekl82db7k2pu&dl=
mouse_kidney_cell,https://dl.dropboxusercontent.com/scl/fi/b9b4dr82hcdwxykv8e53f/mouse_kidney_10x.h5?rlkey=aniqqz731klpmekl82db7k2pu&dl=0
mouse_kidney_cl2,https://dl.dropboxusercontent.com/scl/fi/d0uh8qqw4q4f0748yq5db/mouse_kidney_drop.h5?rlkey=3onfglh6sv6q91c5e1ns5lc5h&dl=0
mouse_kidney_10x,https://dl.dropboxusercontent.com/scl/fi/b9b4dr82hcdwxykv8e53f/mouse_kidney_10x.h5?rlkey=aniqqz731klpmekl82db7k2pu&dl=1
mouse_kidney_cell,https://www.dropbox.com/scl/fi/qrkyu9qhfcj43smlfqygq/mouse_kidney_cell.h5?rlkey=a0uyhgxfty4iti0k83xx9gtsc&dl=1
mouse_kidney_cl2,https://www.dropbox.com/scl/fi/g60cr1t6dqvtv5zei4h3m/mouse_kidney_cl2.h5?rlkey=gth7bakq4tugztiv1r1akgy8l&dl=1
mouse_kidney_drop,https://dl.dropboxusercontent.com/scl/fi/d0uh8qqw4q4f0748yq5db/mouse_kidney_drop.h5?rlkey=3onfglh6sv6q91c5e1ns5lc5h&dl=0
worm_neuron_cell,https://www.dropbox.com/s/58fkgemi2gcnp2k/worm_neuron_cell.h5?dl=1
Loading

0 comments on commit 3620ce1

Please sign in to comment.