Skip to content

Commit

Permalink
Fix issue in AIS precomputation and extend training functionality (#673)
Browse files Browse the repository at this point in the history
  • Loading branch information
constantinpape committed Jul 30, 2024
1 parent 1d33862 commit 1869115
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 4 deletions.
1 change: 0 additions & 1 deletion micro_sam/instance_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -956,7 +956,6 @@ def to_bbox_3d(bbox):
]
return masks

# TODO find good default values (empirically)
def generate(
self,
center_distance_threshold: float = 0.5,
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def initialize_predictor(
prefer_decoder=True,
pbar_init=None,
pbar_update=None,
skip_load=True,
):
assert ndim in (2, 3)

Expand Down Expand Up @@ -127,7 +128,7 @@ def initialize_predictor(
raise RuntimeError("Require a save path to precompute the amg state")

cache_state = cache_amg_state if self.decoder is None else partial(
cache_is_state, decoder=self.decoder, skip_load=True,
cache_is_state, decoder=self.decoder, skip_load=skip_load,
)

if ndim == 2:
Expand Down
1 change: 1 addition & 0 deletions micro_sam/sam_annotator/annotator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def annotator_2d(
image, model_type=model_type, save_path=embedding_path,
halo=halo, tile_shape=tile_shape, precompute_amg_state=precompute_amg_state,
ndim=2, checkpoint_path=checkpoint_path, device=device, prefer_decoder=prefer_decoder,
skip_load=False,
)

if viewer is None:
Expand Down
3 changes: 2 additions & 1 deletion micro_sam/sam_annotator/image_series_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def image_series_annotator(
image, model_type=model_type, save_path=image_embedding_path, halo=halo, tile_shape=tile_shape,
predictor=predictor, decoder=decoder,
ndim=3 if is_volumetric else 2, precompute_amg_state=precompute_amg_state,
checkpoint_path=checkpoint_path, device=device,
checkpoint_path=checkpoint_path, device=device, skip_load=False,
)
state.image_shape = _get_input_shape(image, is_volumetric)

Expand Down Expand Up @@ -237,6 +237,7 @@ def next_image(*args):
tile_shape=tile_shape, halo=halo,
predictor=predictor, decoder=decoder,
precompute_amg_state=precompute_amg_state, device=device,
skip_load=False,
)
state.image_shape = _get_input_shape(image, is_volumetric)

Expand Down
14 changes: 13 additions & 1 deletion micro_sam/training/training.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import time
import warnings
from glob import glob
from typing import Any, Dict, List, Optional, Tuple, Union

Expand Down Expand Up @@ -419,6 +420,7 @@ def train_sam_for_configuration(
val_loader: DataLoader,
checkpoint_path: Optional[Union[str, os.PathLike]] = None,
with_segmentation_decoder: bool = True,
model_type: Optional[str] = None,
**kwargs,
) -> None:
"""Run training for a SAM model with the configuration for a given hardware resource.
Expand All @@ -435,16 +437,26 @@ def train_sam_for_configuration(
checkpoint_path: Path to checkpoint for initializing the SAM model.
with_segmentation_decoder: Whether to train additional UNETR decoder
for automatic instance segmentation.
model_type: Over-ride the default model type.
This can be used to use one of the micro_sam models as starting point
instead of a default sam model.
kwargs: Additional keyword parameterts that will be passed to `train_sam`.
"""
if configuration in CONFIGURATIONS:
train_kwargs = CONFIGURATIONS[configuration]
else:
raise ValueError(f"Invalid configuration {configuration} expect one of {list(CONFIGURATIONS.keys())}")

if model_type is None:
model_type = train_kwargs.pop("model_type")
else:
expected_model_type = train_kwargs.pop("model_type")
if model_type[:5] != expected_model_type:
warnings.warn("You have specified a different model type.")

train_kwargs.update(**kwargs)
train_sam(
name=name, train_loader=train_loader, val_loader=val_loader,
checkpoint_path=checkpoint_path, with_segmentation_decoder=with_segmentation_decoder,
**train_kwargs
model_type=model_type, **train_kwargs
)

0 comments on commit 1869115

Please sign in to comment.