-
Notifications
You must be signed in to change notification settings - Fork 43
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
642 train 3d model with lucchi data #650
Draft
lufre1
wants to merge
26
commits into
dev
Choose a base branch
from
642-train-3d-model-with-lucchi-data
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
9d75668
implemented 3dsam train routine with lucchi data. still shape mismatch
lufre1 42f9f36
implemented training routine for 3d sam
lufre1 9be15d5
tidied up code
lufre1 b0fc01a
changed dataset esp. label shape not depending on num_classes
lufre1 8ca1326
added check_loader
lufre1 ca864ed
Add mentions for annotating 3D RGB volumes (#629)
anwai98 a66c09f
tidied up code
lufre1 a5e937a
Add SemanticSam3dLogger (#643)
anwai98 1592988
added new training and predict scripts
lufre1 b61ee04
Add simple 3d wrapper and enable freezing the encoder in sam 3d wrapp…
constantinpape c64944d
Minor fix to trainable sam model functionality (#646)
anwai98 70cf9b7
Fix dimension order in 3d sam wrappers
constantinpape 09af0a7
Api cleanup (#648)
constantinpape 3d8d879
Fix bug in precompute for 3d data (#649)
constantinpape 9bf0d45
Merge branch 'dev' into 642-train-3d-model-with-lucchi-data
lufre1 b4f7865
merges...
lufre1 63b4654
added support for vitl and vith
lufre1 eaacf7a
changed training for n iterations to n epochs
lufre1 e3b2dbb
debug train sam without encoder on mitottomo
lufre1 a19f73d
added parameter for raw transform and min_size for label_transform to…
lufre1 a90ca2e
added checkpoint to train_with_lucchi
lufre1 ad76f2e
Add min-size to training and fix other issues
constantinpape a550893
removed unused code
lufre1 908e1c1
merged train routine updated
lufre1 b6a7ce9
updates on train 3d without decoer
lufre1 3422041
bash script for sbatch
lufre1 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,193 @@ | ||
import os | ||
import argparse | ||
from tqdm import tqdm | ||
import numpy as np | ||
import imageio.v3 as imageio | ||
from elf.io import open_file | ||
from skimage.measure import label as connected_components | ||
|
||
import torch | ||
from glob import glob | ||
|
||
from torch_em.util.segmentation import size_filter | ||
from torch_em.util import load_model | ||
from torch_em.transform.raw import normalize | ||
from torch_em.util.prediction import predict_with_halo | ||
|
||
from micro_sam import util | ||
from micro_sam.evaluation.inference import _run_inference_with_iterative_prompting_for_image | ||
|
||
from segment_anything import SamPredictor | ||
|
||
from micro_sam.models.sam_3d_wrapper import get_sam_3d_model | ||
from typing import List, Union, Dict, Optional, Tuple | ||
|
||
|
||
class RawTrafoFor3dInputs: | ||
def _normalize_inputs(self, raw): | ||
raw = normalize(raw) | ||
raw = raw * 255 | ||
return raw | ||
|
||
def _set_channels_for_inputs(self, raw): | ||
raw = np.stack([raw] * 3, axis=0) | ||
return raw | ||
|
||
def __call__(self, raw): | ||
raw = self._normalize_inputs(raw) | ||
raw = self._set_channels_for_inputs(raw) | ||
return raw | ||
|
||
|
||
def _run_semantic_segmentation_for_image_3d( | ||
model: torch.nn.Module, | ||
image: np.ndarray, | ||
prediction_path: Union[os.PathLike, str], | ||
patch_shape: Tuple[int, int, int], | ||
halo: Tuple[int, int, int], | ||
): | ||
device = next(model.parameters()).device | ||
block_shape = tuple(bs - 2 * ha for bs, ha in zip(patch_shape, halo)) | ||
|
||
def preprocess(x): | ||
x = 255 * normalize(x) | ||
x = np.stack([x] * 3) | ||
return x | ||
|
||
def prediction_function(net, inp): | ||
# Note: we have two singleton axis in front here, I am not quite sure why. | ||
# Both need to be removed to be compatible with the SAM network. | ||
batched_input = [{ | ||
"image": inp[0, 0], "original_size": inp.shape[-2:] | ||
}] | ||
masks = net(batched_input, multimask_output=True)[0]["masks"] | ||
masks = torch.argmax(masks, dim=1) | ||
return masks | ||
|
||
# num_classes = model.sam_model.mask_decoder.num_multimask_outputs | ||
image_size = patch_shape[-1] | ||
output = np.zeros(image.shape, dtype="float32") | ||
predict_with_halo( | ||
image, model, gpu_ids=[device], | ||
block_shape=block_shape, halo=halo, | ||
preprocess=preprocess, output=output, | ||
prediction_function=prediction_function | ||
) | ||
|
||
# save the segmentations | ||
imageio.imwrite(prediction_path, output, compression="zlib") | ||
|
||
|
||
def run_semantic_segmentation_3d( | ||
model: torch.nn.Module, | ||
image_paths: List[Union[str, os.PathLike]], | ||
prediction_dir: Union[str, os.PathLike], | ||
semantic_class_map: Dict[str, int], | ||
patch_shape: Tuple[int, int, int] = (32, 512, 512), | ||
halo: Tuple[int, int, int] = (6, 64, 64), | ||
image_key: Optional[str] = None, | ||
is_multiclass: bool = False, | ||
): | ||
""" | ||
""" | ||
for image_path in tqdm(image_paths, desc="Run inference for semantic segmentation with all images"): | ||
image_name = os.path.basename(image_path) | ||
|
||
assert os.path.exists(image_path), image_path | ||
|
||
# Perform segmentation only on the semantic class | ||
# for i, (semantic_class_name, _) in enumerate(semantic_class_map.items()): | ||
# if is_multiclass: | ||
# semantic_class_name = "all" | ||
# if i > 0: # We only perform segmentation for multiclass once. | ||
# continue | ||
|
||
semantic_class_name = "all" #since we only perform segmentation for multiclass | ||
# We skip the images that already have been segmented | ||
image_name = os.path.splitext(image_name)[0] + ".tif" | ||
prediction_path = os.path.join(prediction_dir, "all", image_name) | ||
if os.path.exists(prediction_path): | ||
continue | ||
|
||
if image_key is None: | ||
image = imageio.imread(image_path) | ||
else: | ||
with open_file(image_path, "r") as f: | ||
image = f[image_key][:] | ||
|
||
# create the prediction folder | ||
os.makedirs(os.path.join(prediction_dir, semantic_class_name), exist_ok=True) | ||
|
||
_run_semantic_segmentation_for_image_3d( | ||
model=model, image=image, prediction_path=prediction_path, | ||
patch_shape=patch_shape, halo=halo, | ||
) | ||
|
||
|
||
def transform_labels(y): | ||
return (y > 0).astype("float32") | ||
|
||
|
||
def predict(args): | ||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
if args.checkpoint_path is not None: | ||
if os.path.exists(args.checkpoint_path): | ||
# model = load_model(checkpoint=args.checkpoint_path, device=device) # does not work | ||
|
||
cp_path = os.path.join(args.checkpoint_path, "", "best.pt") | ||
print(cp_path) | ||
model = get_sam_3d_model(device, n_classes=args.n_classes, image_size=args.patch_shape[1], | ||
lora_rank=4, | ||
model_type=args.model_type, | ||
# checkpoint_path=args.checkpoint_path | ||
) | ||
|
||
checkpoint = torch.load(cp_path, map_location=device) | ||
# # Load the state dictionary from the checkpoint | ||
for k, v in checkpoint.items(): | ||
print("keys", k) | ||
model.load_state_dict(checkpoint['model_state']) #.state_dict() | ||
model.eval() | ||
|
||
data_paths = glob(os.path.join(args.input_path, "**/*test.h5"), recursive=True) | ||
pred_path = args.save_root | ||
semantic_class_map = {"all": 0} | ||
|
||
run_semantic_segmentation_3d( | ||
model=model, image_paths=data_paths, prediction_dir=pred_path, semantic_class_map=semantic_class_map, | ||
patch_shape=args.patch_shape, image_key="raw", is_multiclass=True | ||
) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/", | ||
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." | ||
) | ||
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") | ||
parser.add_argument("--n_iterations", type=int, default=10, help="Number of training iterations") | ||
parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") | ||
parser.add_argument("--batch_size", type=int, default=1, help="Batch size") | ||
parser.add_argument("--num_workers", type=int, default=4, help="num_workers") | ||
parser.add_argument( | ||
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
parser.add_argument( | ||
"--checkpoint_path", "-c", default="/scratch-grete/usr/nimlufre/micro-sam3d/checkpoints/3d-sam-vitb-masamhyp-lucchi", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
|
||
args = parser.parse_args() | ||
|
||
predict(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,201 @@ | ||
import os | ||
import argparse | ||
import numpy as np | ||
from math import ceil, floor | ||
import torch | ||
|
||
from torch_em.data.datasets import get_lucchi_loader, get_lucchi_dataset | ||
from torch_em.segmentation import SegmentationDataset | ||
import torch_em | ||
from torch_em.util.debug import check_loader | ||
from torch_em.transform.raw import normalize | ||
|
||
from micro_sam.models.sam_3d_wrapper import get_sam_3d_model | ||
|
||
from micro_sam.training.semantic_sam_trainer import SemanticSamTrainer | ||
|
||
import micro_sam.training as sam_training | ||
|
||
|
||
class RawTrafoFor3dInputs: | ||
def _normalize_inputs(self, raw): | ||
raw = normalize(raw) | ||
raw = raw * 255 | ||
return raw | ||
|
||
def _set_channels_for_inputs(self, raw): | ||
raw = np.stack([raw] * 3, axis=0) | ||
return raw | ||
|
||
def __call__(self, raw): | ||
raw = self._normalize_inputs(raw) | ||
raw = self._set_channels_for_inputs(raw) | ||
return raw | ||
|
||
|
||
# for sega | ||
class RawResizeTrafoFor3dInputs(RawTrafoFor3dInputs): | ||
def __init__(self, desired_shape, padding="constant"): | ||
super().__init__() | ||
self.desired_shape = desired_shape | ||
self.padding = padding | ||
|
||
def __call__(self, raw): | ||
raw = self._normalize_inputs(raw) | ||
|
||
# let's pad the inputs | ||
tmp_ddim = ( | ||
self.desired_shape[0] - raw.shape[0], | ||
self.desired_shape[1] - raw.shape[1], | ||
self.desired_shape[2] - raw.shape[2] | ||
) | ||
ddim = (tmp_ddim[0] / 2, tmp_ddim[1] / 2, tmp_ddim[2] / 2) | ||
raw = np.pad( | ||
raw, | ||
pad_width=( | ||
(ceil(ddim[0]), floor(ddim[0])), (ceil(ddim[1]), floor(ddim[1])), (ceil(ddim[2]), floor(ddim[2])) | ||
), | ||
mode=self.padding | ||
) | ||
|
||
raw = self._set_channels_for_inputs(raw) | ||
|
||
return raw | ||
|
||
|
||
class LucchiSegmentationDataset(SegmentationDataset): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can now be removed. |
||
def __init__(self, patch_shape, label_transform=None, **kwargs): | ||
super().__init__(patch_shape=patch_shape, label_transform=label_transform, **kwargs) # Call parent class constructor | ||
|
||
def __getitem__(self, index): | ||
raw, label = super().__getitem__(index) | ||
# raw shape: (z, color channels, y, x) channels is fixed to 3 | ||
image_shape = (self.patch_shape[0], 1) + self.patch_shape[1:] | ||
raw = raw.unsqueeze(2) | ||
raw = raw.view(image_shape) | ||
raw = raw.squeeze(0) | ||
raw = raw.repeat(1, 3, 1, 1) | ||
# wanted label shape: (1, z, y, x) | ||
label = (label != 0).to(torch.float) | ||
return raw, label | ||
|
||
|
||
def transform_labels(y): | ||
#return (y > 0).astype("float32") | ||
# use torch_em to get foreground and boundary channels | ||
transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True) | ||
one_hot_channels = transform(y) | ||
# Combine foreground and background using element-wise maximum | ||
foreground = np.where(one_hot_channels[0] > 0, 1, 0) | ||
|
||
# Combine foreground and boundaries with priority to boundaries (ensures boundaries are 2) | ||
combined = np.where(one_hot_channels[1] > 0, 2, foreground) | ||
|
||
# Set background to 0 | ||
combined[combined == 0] = 0 | ||
|
||
return combined | ||
|
||
|
||
def get_loaders(input_path, patch_shape): | ||
train_loader = get_lucchi_loader( | ||
input_path, split="train", patch_shape=patch_shape, batch_size=1, download=True, | ||
raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels, | ||
n_samples=100 | ||
) | ||
val_loader = get_lucchi_loader( | ||
input_path, split="test", patch_shape=patch_shape, batch_size=1, | ||
raw_transform=RawTrafoFor3dInputs(), label_transform=transform_labels | ||
) | ||
return train_loader, val_loader | ||
|
||
|
||
def train_on_lucchi(args): | ||
from micro_sam.training.util import ConvertToSemanticSamInputs | ||
input_path = args.input_path | ||
patch_shape = args.patch_shape | ||
batch_size = args.batch_size | ||
num_workers = args.num_workers | ||
n_classes = args.n_classes | ||
model_type = args.model_type | ||
n_epochs = args.n_epochs | ||
save_root = args.save_root | ||
cp_path = args.checkpoint_path | ||
|
||
|
||
device = "cuda" if torch.cuda.is_available() else "cpu" | ||
if args.without_lora: | ||
sam_3d = get_sam_3d_model( | ||
device, n_classes=n_classes, image_size=patch_shape[1], | ||
model_type=model_type, lora_rank=None) # freeze encoder | ||
else: | ||
sam_3d = get_sam_3d_model( | ||
device, n_classes=n_classes, image_size=patch_shape[1], | ||
model_type=model_type, lora_rank=4) | ||
if cp_path is not None: | ||
if os.path.exists(cp_path): | ||
checkpoint = torch.load(cp_path, map_location=device) | ||
# # Load the state dictionary from the checkpoint | ||
sam_3d.load_state_dict(checkpoint['model_state']) #.state_dict() | ||
train_loader, val_loader = get_loaders(input_path=input_path, patch_shape=patch_shape) | ||
#optimizer = torch.optim.AdamW(sam_3d.parameters(), lr=args.learning_rate, betas=(0.9, 0.999), weight_decay=0.1) | ||
optimizer = torch.optim.Adam(sam_3d.parameters(), lr=1e-5) | ||
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="min", factor=0.9, patience=15, verbose=True) | ||
#masam no scheduler | ||
|
||
|
||
trainer = SemanticSamTrainer( | ||
name=args.exp_name, | ||
model=sam_3d, | ||
convert_inputs=ConvertToSemanticSamInputs(), | ||
num_classes=n_classes, | ||
train_loader=train_loader, | ||
val_loader=val_loader, | ||
optimizer=optimizer, | ||
lr_scheduler=scheduler, | ||
device=device, | ||
compile_model=False, | ||
save_root=save_root, | ||
#logger=None | ||
) | ||
# check_loader(train_loader, n_samples=10) | ||
trainer.fit(epochs=n_epochs) | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Finetune Segment Anything for the LiveCELL dataset.") | ||
parser.add_argument( | ||
"--input_path", "-i", default="/scratch/projects/nim00007/sam/data/lucchi/", | ||
help="The filepath to the LiveCELL data. If the data does not exist yet it will be downloaded." | ||
) | ||
parser.add_argument( | ||
"--model_type", "-m", default="vit_b", | ||
help="The model type to use for fine-tuning. Either vit_t, vit_b, vit_l or vit_h." | ||
) | ||
parser.add_argument("--without_lora", action="store_true", help="Whether to use LoRA for finetuning SAM for semantic segmentation.") | ||
parser.add_argument("--patch_shape", type=int, nargs=3, default=(32, 512, 512), help="Patch shape for data loading (3D tuple)") | ||
|
||
parser.add_argument("--n_epochs", type=int, default=400, help="Number of training epochs") | ||
parser.add_argument("--n_classes", type=int, default=3, help="Number of classes to predict") | ||
parser.add_argument("--batch_size", "-bs", type=int, default=1, help="Batch size") # masam 3 | ||
parser.add_argument("--num_workers", type=int, default=4, help="num_workers") | ||
parser.add_argument("--learning_rate", type=float, default=1e-5, help="base learning rate") # MASAM 0.0008 | ||
parser.add_argument( | ||
"--save_root", "-s", default="/scratch-grete/usr/nimlufre/micro-sam3d", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
parser.add_argument( | ||
"--checkpoint_path", default=None, | ||
help="The filepath to where the checkpoints are loaded from." | ||
) | ||
parser.add_argument( | ||
"--exp_name", default="vitb_3d_lora4-microsam-hypam-lucchi", | ||
help="The filepath to where the logs and the checkpoints will be saved." | ||
) | ||
|
||
args = parser.parse_args() | ||
train_on_lucchi(args) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure why you would ever run prediction without a checkpoint. I would not make this optional.