Skip to content

Commit

Permalink
Merge pull request #144 from computational-cell-analytics/generalist-…
Browse files Browse the repository at this point in the history
…experiments

Implement generalist model evaluation experiments
  • Loading branch information
constantinpape authored Aug 16, 2023
2 parents 0e19472 + f81b400 commit 4150fd9
Show file tree
Hide file tree
Showing 23 changed files with 759 additions and 399 deletions.
124 changes: 124 additions & 0 deletions finetuning/generalists/cellpose_baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
import argparse
import os
from glob import glob
from subprocess import run

import imageio.v3 as imageio

from tqdm import tqdm

DATA_ROOT = "/scratch/projects/nim00007/sam/datasets"
EXP_ROOT = "/scratch/projects/nim00007/sam/experiments/cellpose"

DATASETS = (
"covid-if",
"deepbacs",
"hpa",
"livecell",
"lizard",
"mouse-embryo",
"plantseg-ovules",
"plantseg-root",
"tissuenet",
)


def load_cellpose_model():
from cellpose import models

device, gpu = models.assign_device(True, True)
model = models.Cellpose(gpu=gpu, model_type="cyto", device=device)
return model


def run_cellpose_segmentation(datasets, job_id):
dataset = datasets[job_id]
experiment_folder = os.path.join(EXP_ROOT, dataset)

prediction_folder = os.path.join(experiment_folder, "predictions")
os.makedirs(prediction_folder, exist_ok=True)

image_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "image*.tif")))
model = load_cellpose_model()

for path in tqdm(image_paths, desc=f"Segmenting {dataset} with cellpose"):
fname = os.path.basename(path)
out_path = os.path.join(prediction_folder, fname)
if os.path.exists(out_path):
continue
image = imageio.imread(path)
if image.ndim == 3:
assert image.shape[-1] == 3
image = image.mean(axis=-1)
assert image.ndim == 2
seg = model.eval(image, diameter=None, flow_threshold=None, channels=[0, 0])[0]
assert seg.shape == image.shape
imageio.imwrite(out_path, seg, compression=5)


def submit_array_job(datasets):
n_datasets = len(datasets)
cmd = ["sbatch", "-a", f"0-{n_datasets-1}", "cellpose_baseline.sbatch"]
run(cmd)


def evaluate_dataset(dataset):
from micro_sam.evaluation.evaluation import run_evaluation

gt_paths = sorted(glob(os.path.join(DATA_ROOT, dataset, "test", "label*.tif")))
experiment_folder = os.path.join(EXP_ROOT, dataset)
pred_paths = sorted(glob(os.path.join(experiment_folder, "predictions", "*.tif")))
assert len(gt_paths) == len(pred_paths), f"{len(gt_paths)}, {len(pred_paths)}"
result_path = os.path.join(experiment_folder, "cellpose.csv")
run_evaluation(gt_paths, pred_paths, result_path)


def evaluate_segmentations(datasets):
for dataset in datasets:
# we skip livecell, which has already been processed by cellpose
if dataset == "livecell":
continue
evaluate_dataset(dataset)


def check_results(datasets):
for ds in datasets:
# we skip livecell, which has already been processed by cellpose
if ds == "livecell":
continue
result_path = os.path.join(EXP_ROOT, ds, "cellpose.csv")
if not os.path.exists(result_path):
print("Cellpose results missing for", ds)
print("All checks passed")


def main():
parser = argparse.ArgumentParser()
parser.add_argument("--segment", "-s", action="store_true")
parser.add_argument("--evaluate", "-e", action="store_true")
parser.add_argument("--check", "-c", action="store_true")
parser.add_argument("--datasets", nargs="+")
args = parser.parse_args()

job_id = os.environ.get("SLURM_ARRAY_TASK_ID", None)

if args.datasets is None:
datasets = DATASETS
else:
datasets = args.datasets
assert all(ds in DATASETS for ds in datasets)

if job_id is not None:
run_cellpose_segmentation(datasets, int(job_id))
elif args.segment:
submit_array_job(datasets)
elif args.evaluate:
evaluate_segmentations(datasets)
elif args.check:
check_results(datasets)
else:
raise ValueError("Doing nothing")


if __name__ == "__main__":
main()
10 changes: 10 additions & 0 deletions finetuning/generalists/cellpose_baseline.sbatch
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#! /bin/bash
#SBATCH -c 4
#SBATCH --mem 48G
#SBATCH -t 300
#SBATCH -p grete:shared
#SBATCH -G A100:1
#SBATCH -A nim00007

source activate cellpose
python cellpose_baseline.py $@
104 changes: 104 additions & 0 deletions finetuning/generalists/compile_results.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import os
from glob import glob

import pandas as pd

from evaluate_generalist import EXPERIMENT_ROOT
from util import EM_DATASETS, LM_DATASETS


def get_results(model, ds):
res_folder = os.path.join(EXPERIMENT_ROOT, model, ds, "results")
res_paths = sorted(glob(os.path.join(res_folder, "box", "*.csv"))) +\
sorted(glob(os.path.join(res_folder, "points", "*.csv")))

amg_res = os.path.join(res_folder, "amg.csv")
if os.path.exists(amg_res):
res_paths.append(amg_res)

results = []
for path in res_paths:
prompt_res = pd.read_csv(path)
prompt_name = os.path.splitext(os.path.relpath(path, res_folder))[0]
prompt_res.insert(0, "prompt", [prompt_name])
results.append(prompt_res)
results = pd.concat(results)
results.insert(0, "dataset", results.shape[0] * [ds])

return results


def compile_results(models, datasets, out_path, load_results=False):
results = []

for model in models:
model_results = []

for ds in datasets:
ds_results = get_results(model, ds)
model_results.append(ds_results)

model_results = pd.concat(model_results)
model_results.insert(0, "model", [model] * model_results.shape[0])
results.append(model_results)

results = pd.concat(results)
if load_results:
assert os.path.exists(out_path)
all_results = pd.read_csv(out_path)
results = pd.concat([all_results, results])

results.to_csv(out_path, index=False)


def compile_em():
compile_results(
["vit_h", "vit_h_em", "vit_b", "vit_b_em"],
EM_DATASETS,
os.path.join(EXPERIMENT_ROOT, "evaluation-em.csv")
)


def add_cellpose_results(datasets, out_path):
cp_root = "/scratch/projects/nim00007/sam/experiments/cellpose"

results = []
for dataset in datasets:
if dataset == "livecell":
continue
res_path = os.path.join(cp_root, dataset, "cellpose.csv")
ds_res = pd.read_csv(res_path)
ds_res.insert(0, "prompt", ["cellpose"] * ds_res.shape[0])
ds_res.insert(0, "dataset", [dataset] * ds_res.shape[0])
results.append(ds_res)

results = pd.concat(results)
results.insert(0, "model", ["cellpose"] * results.shape[0])

all_results = pd.read_csv(out_path)
results = pd.concat([all_results, results])
results.to_csv(out_path, index=False)


def compile_lm():
res_path = os.path.join(EXPERIMENT_ROOT, "evaluation-lm.csv")
compile_results(
["vit_h", "vit_h_lm", "vit_b", "vit_b_lm"], LM_DATASETS, res_path
)

# add the deepbacs and tissuenet specialist results
assert os.path.exists(res_path)
compile_results(["vit_h_tissuenet", "vit_b_tissuenet"], ["tissuenet"], res_path, True)
compile_results(["vit_h_deepbacs", "vit_b_deepbacs"], ["deepbacs"], res_path, True)

# add the cellpose results
add_cellpose_results(LM_DATASETS, res_path)


def main():
# compile_em()
compile_lm()


if __name__ == "__main__":
main()
63 changes: 63 additions & 0 deletions finetuning/generalists/create_tissuenet_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

import os
from tqdm import tqdm
import imageio.v2 as imageio
import numpy as np

from torch_em.data import MinInstanceSampler
from torch_em.transform.label import label_consecutive
from torch_em.data.datasets import get_tissuenet_loader
from torch_em.transform.raw import standardize, normalize_percentile


def rgb_to_gray_transform(raw):
raw = normalize_percentile(raw, axis=(1, 2))
raw = np.mean(raw, axis=0)
raw = standardize(raw)
return raw


def get_tissuenet_loaders(input_path):
sampler = MinInstanceSampler()
label_transform = label_consecutive
raw_transform = rgb_to_gray_transform
val_loader = get_tissuenet_loader(path=input_path, split="val", raw_channel="rgb", label_channel="cell",
batch_size=1, patch_shape=(256, 256), num_workers=0,
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform)
test_loader = get_tissuenet_loader(path=input_path, split="test", raw_channel="rgb", label_channel="cell",
batch_size=1, patch_shape=(256, 256), num_workers=0,
sampler=sampler, label_transform=label_transform, raw_transform=raw_transform)
return val_loader, test_loader


def extract_images(loader, out_folder):
os.makedirs(out_folder, exist_ok=True)
for i, (x, y) in tqdm(enumerate(loader), total=len(loader)):
img_path = os.path.join(out_folder, "image_{:04d}.tif".format(i))
gt_path = os.path.join(out_folder, "label_{:04d}.tif".format(i))

img = x.squeeze().detach().cpu().numpy()
gt = y.squeeze().detach().cpu().numpy()

imageio.imwrite(img_path, img)
imageio.imwrite(gt_path, gt)


def main():
val_loader, test_loader = get_tissuenet_loaders("/scratch-grete/projects/nim00007/data/tissuenet")
print("Length of val loader is:", len(val_loader))
print("Length of test loader is:", len(test_loader))

root_save_dir = "/scratch/projects/nim00007/sam/datasets/tissuenet"

# we use the val set for test because there are some issues with the test set
# out_folder = os.path.join(root_save_dir, "test")
# extract_images(val_loader, out_folder)

# we use the test folder for val and just use as many images as we can sample
out_folder = os.path.join(root_save_dir, "val")
extract_images(test_loader, out_folder)


if __name__ == "__main__":
main()
84 changes: 0 additions & 84 deletions finetuning/generalists/data/precompute_prompts.py

This file was deleted.

Loading

0 comments on commit 4150fd9

Please sign in to comment.