-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #144 from computational-cell-analytics/generalist-…
…experiments Implement generalist model evaluation experiments
- Loading branch information
Showing
23 changed files
with
759 additions
and
399 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
import argparse | ||
import os | ||
from glob import glob | ||
from subprocess import run | ||
|
||
import imageio.v3 as imageio | ||
|
||
from tqdm import tqdm | ||
|
||
DATA_ROOT = "/scratch/projects/nim00007/sam/datasets" | ||
EXP_ROOT = "/scratch/projects/nim00007/sam/experiments/cellpose" | ||
|
||
DATASETS = ( | ||
"covid-if", | ||
"deepbacs", | ||
"hpa", | ||
"livecell", | ||
"lizard", | ||
"mouse-embryo", | ||
"plantseg-ovules", | ||
"plantseg-root", | ||
"tissuenet", | ||
) | ||
|
||
|
||
def load_cellpose_model(): | ||
from cellpose import models | ||
|
||
device, gpu = models.assign_device(True, True) | ||
model = models.Cellpose(gpu=gpu, model_type="cyto", device=device) | ||
return model | ||
|
||
|
||
def run_cellpose_segmentation(datasets, job_id): | ||
dataset = datasets[job_id] | ||
experiment_folder = os.path.join(EXP_ROOT, dataset) | ||
|
||
prediction_folder = os.path.join(experiment_folder, "predictions") | ||
os.makedirs(prediction_folder, exist_ok=True) | ||
|
||
image_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "image*.tif"))) | ||
model = load_cellpose_model() | ||
|
||
for path in tqdm(image_paths, desc=f"Segmenting {dataset} with cellpose"): | ||
fname = os.path.basename(path) | ||
out_path = os.path.join(prediction_folder, fname) | ||
if os.path.exists(out_path): | ||
continue | ||
image = imageio.imread(path) | ||
if image.ndim == 3: | ||
assert image.shape[-1] == 3 | ||
image = image.mean(axis=-1) | ||
assert image.ndim == 2 | ||
seg = model.eval(image, diameter=None, flow_threshold=None, channels=[0, 0])[0] | ||
assert seg.shape == image.shape | ||
imageio.imwrite(out_path, seg, compression=5) | ||
|
||
|
||
def submit_array_job(datasets): | ||
n_datasets = len(datasets) | ||
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "cellpose_baseline.sbatch"] | ||
run(cmd) | ||
|
||
|
||
def evaluate_dataset(dataset): | ||
from micro_sam.evaluation.evaluation import run_evaluation | ||
|
||
gt_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "label*.tif"))) | ||
experiment_folder = os.path.join(EXP_ROOT, dataset) | ||
pred_paths = sorted(glob(os.path.join(experiment_folder, "predictions", "*.tif"))) | ||
assert len(gt_paths) == len(pred_paths), f"{len(gt_paths)}, {len(pred_paths)}" | ||
result_path = os.path.join(experiment_folder, "cellpose.csv") | ||
run_evaluation(gt_paths, pred_paths, result_path) | ||
|
||
|
||
def evaluate_segmentations(datasets): | ||
for dataset in datasets: | ||
# we skip livecell, which has already been processed by cellpose | ||
if dataset == "livecell": | ||
continue | ||
evaluate_dataset(dataset) | ||
|
||
|
||
def check_results(datasets): | ||
for ds in datasets: | ||
# we skip livecell, which has already been processed by cellpose | ||
if ds == "livecell": | ||
continue | ||
result_path = os.path.join(EXP_ROOT, ds, "cellpose.csv") | ||
if not os.path.exists(result_path): | ||
print("Cellpose results missing for", ds) | ||
print("All checks passed") | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--segment", "-s", action="store_true") | ||
parser.add_argument("--evaluate", "-e", action="store_true") | ||
parser.add_argument("--check", "-c", action="store_true") | ||
parser.add_argument("--datasets", nargs="+") | ||
args = parser.parse_args() | ||
|
||
job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None) | ||
|
||
if args.datasets is None: | ||
datasets = DATASETS | ||
else: | ||
datasets = args.datasets | ||
assert all(ds in DATASETS for ds in datasets) | ||
|
||
if job_id is not None: | ||
run_cellpose_segmentation(datasets, int(job_id)) | ||
elif args.segment: | ||
submit_array_job(datasets) | ||
elif args.evaluate: | ||
evaluate_segmentations(datasets) | ||
elif args.check: | ||
check_results(datasets) | ||
else: | ||
raise ValueError("Doing nothing") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
#! /bin/bash | ||
#SBATCH -c 4 | ||
#SBATCH --mem 48G | ||
#SBATCH -t 300 | ||
#SBATCH -p grete:shared | ||
#SBATCH -G A100:1 | ||
#SBATCH -A nim00007 | ||
|
||
source activate cellpose | ||
python cellpose_baseline.py $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
import os | ||
from glob import glob | ||
|
||
import pandas as pd | ||
|
||
from evaluate_generalist import EXPERIMENT_ROOT | ||
from util import EM_DATASETS, LM_DATASETS | ||
|
||
|
||
def get_results(model, ds): | ||
res_folder = os.path.join(EXPERIMENT_ROOT, model, ds, "results") | ||
res_paths = sorted(glob(os.path.join(res_folder, "box", "*.csv"))) +\ | ||
sorted(glob(os.path.join(res_folder, "points", "*.csv"))) | ||
|
||
amg_res = os.path.join(res_folder, "amg.csv") | ||
if os.path.exists(amg_res): | ||
res_paths.append(amg_res) | ||
|
||
results = [] | ||
for path in res_paths: | ||
prompt_res = pd.read_csv(path) | ||
prompt_name = os.path.splitext(os.path.relpath(path, res_folder))[0] | ||
prompt_res.insert(0, "prompt", [prompt_name]) | ||
results.append(prompt_res) | ||
results = pd.concat(results) | ||
results.insert(0, "dataset", results.shape[0] * [ds]) | ||
|
||
return results | ||
|
||
|
||
def compile_results(models, datasets, out_path, load_results=False): | ||
results = [] | ||
|
||
for model in models: | ||
model_results = [] | ||
|
||
for ds in datasets: | ||
ds_results = get_results(model, ds) | ||
model_results.append(ds_results) | ||
|
||
model_results = pd.concat(model_results) | ||
model_results.insert(0, "model", [model] * model_results.shape[0]) | ||
results.append(model_results) | ||
|
||
results = pd.concat(results) | ||
if load_results: | ||
assert os.path.exists(out_path) | ||
all_results = pd.read_csv(out_path) | ||
results = pd.concat([all_results, results]) | ||
|
||
results.to_csv(out_path, index=False) | ||
|
||
|
||
def compile_em(): | ||
compile_results( | ||
["vit_h", "vit_h_em", "vit_b", "vit_b_em"], | ||
EM_DATASETS, | ||
os.path.join(EXPERIMENT_ROOT, "evaluation-em.csv") | ||
) | ||
|
||
|
||
def add_cellpose_results(datasets, out_path): | ||
cp_root = "/scratch/projects/nim00007/sam/experiments/cellpose" | ||
|
||
results = [] | ||
for dataset in datasets: | ||
if dataset == "livecell": | ||
continue | ||
res_path = os.path.join(cp_root, dataset, "cellpose.csv") | ||
ds_res = pd.read_csv(res_path) | ||
ds_res.insert(0, "prompt", ["cellpose"] * ds_res.shape[0]) | ||
ds_res.insert(0, "dataset", [dataset] * ds_res.shape[0]) | ||
results.append(ds_res) | ||
|
||
results = pd.concat(results) | ||
results.insert(0, "model", ["cellpose"] * results.shape[0]) | ||
|
||
all_results = pd.read_csv(out_path) | ||
results = pd.concat([all_results, results]) | ||
results.to_csv(out_path, index=False) | ||
|
||
|
||
def compile_lm(): | ||
res_path = os.path.join(EXPERIMENT_ROOT, "evaluation-lm.csv") | ||
compile_results( | ||
["vit_h", "vit_h_lm", "vit_b", "vit_b_lm"], LM_DATASETS, res_path | ||
) | ||
|
||
# add the deepbacs and tissuenet specialist results | ||
assert os.path.exists(res_path) | ||
compile_results(["vit_h_tissuenet", "vit_b_tissuenet"], ["tissuenet"], res_path, True) | ||
compile_results(["vit_h_deepbacs", "vit_b_deepbacs"], ["deepbacs"], res_path, True) | ||
|
||
# add the cellpose results | ||
add_cellpose_results(LM_DATASETS, res_path) | ||
|
||
|
||
def main(): | ||
# compile_em() | ||
compile_lm() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
|
||
import os | ||
from tqdm import tqdm | ||
import imageio.v2 as imageio | ||
import numpy as np | ||
|
||
from torch_em.data import MinInstanceSampler | ||
from torch_em.transform.label import label_consecutive | ||
from torch_em.data.datasets import get_tissuenet_loader | ||
from torch_em.transform.raw import standardize, normalize_percentile | ||
|
||
|
||
def rgb_to_gray_transform(raw): | ||
raw = normalize_percentile(raw, axis=(1, 2)) | ||
raw = np.mean(raw, axis=0) | ||
raw = standardize(raw) | ||
return raw | ||
|
||
|
||
def get_tissuenet_loaders(input_path): | ||
sampler = MinInstanceSampler() | ||
label_transform = label_consecutive | ||
raw_transform = rgb_to_gray_transform | ||
val_loader = get_tissuenet_loader(path=input_path, split="val", raw_channel="rgb", label_channel="cell", | ||
batch_size=1, patch_shape=(256, 256), num_workers=0, | ||
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform) | ||
test_loader = get_tissuenet_loader(path=input_path, split="test", raw_channel="rgb", label_channel="cell", | ||
batch_size=1, patch_shape=(256, 256), num_workers=0, | ||
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform) | ||
return val_loader, test_loader | ||
|
||
|
||
def extract_images(loader, out_folder): | ||
os.makedirs(out_folder, exist_ok=True) | ||
for i, (x, y) in tqdm(enumerate(loader), total=len(loader)): | ||
img_path = os.path.join(out_folder, "image_{:04d}.tif".format(i)) | ||
gt_path = os.path.join(out_folder, "label_{:04d}.tif".format(i)) | ||
|
||
img = x.squeeze().detach().cpu().numpy() | ||
gt = y.squeeze().detach().cpu().numpy() | ||
|
||
imageio.imwrite(img_path, img) | ||
imageio.imwrite(gt_path, gt) | ||
|
||
|
||
def main(): | ||
val_loader, test_loader = get_tissuenet_loaders("/scratch-grete/projects/nim00007/data/tissuenet") | ||
print("Length of val loader is:", len(val_loader)) | ||
print("Length of test loader is:", len(test_loader)) | ||
|
||
root_save_dir = "/scratch/projects/nim00007/sam/datasets/tissuenet" | ||
|
||
# we use the val set for test because there are some issues with the test set | ||
# out_folder = os.path.join(root_save_dir, "test") | ||
# extract_images(val_loader, out_folder) | ||
|
||
# we use the test folder for val and just use as many images as we can sample | ||
out_folder = os.path.join(root_save_dir, "val") | ||
extract_images(test_loader, out_folder) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.