Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refining masks with micro-sam. #383

Open
constantinpape opened this issue Feb 6, 2024 · 14 comments
Open

Refining masks with micro-sam. #383

constantinpape opened this issue Feb 6, 2024 · 14 comments

Comments

@constantinpape
Copy link
Contributor

constantinpape commented Feb 6, 2024

There are different ways for refining existing masks with micro_sam.

The easiest option would be to derive point prompts from the centers of the masks and then prompt the model with these points.

The function batched_inference can be used for this.

Here is some (non-tested!) code for this, using skimage to derive the point prompts.

import numpy as np
from micro_sam.util import get_sam_model
from micro_sam.inference import batched_inference
from skimage.measure import regionprops

image = ...  # <- this is the 2D! image
initial_segmentation = ...   # <- this is the segmentation you want to refine

props = regionprops(initial_segmentation)
points = np.array([prop.centroid for prop in props])[:, ::-1]   # The coordinates of the centroids need to be reversed to match the convention of SAM.
point_labels = np.ones(len(points), dtype="int")  # <- All prompts are positive.

# We need to add an extra dimension to provide the correct input for batched_prediction.
points = np.expand_dims(points, 1)
point_labels = np.expand_dims(point_labels, 1)

predictor = get_sam_model(model_type="vit_b_lm")  # <- you can control which model is used with the model type argument. 
# See the function signature of get_sam_model for details.

refined_segmentation = batched_inference(
  predictor, image,
  batch_size=32,  # This controls how many points are processed at once, lower it if you get memory issues 
  points=points,
  point_labels=point_labels,
  return_instance_segmentation=True
)

Another possible strategies is to derive bounding boxes from the segmented objects and use these for prompts instead.
This could be done by passing the boxes argument.

Note that this code will only for 2D. It is possible to extend this to 3D, but I would suggest to start in 2D first and once this is working well I can give hints for how to extend it to 3D.

cc @Nal44

@Nal44
Copy link

Nal44 commented Feb 6, 2024

Great !!! thanks a lot :)

I will try to make it "napari compatible" using viewer.layers as well :)

Thanks !!!

@Nal44
Copy link

Nal44 commented Feb 8, 2024

my attempt on a 2d image :

import numpy as np
from micro_sam.util import get_sam_model
from micro_sam.inference import batched_inference
from skimage.measure import regionprops

image = viewer.layers[0]
initial_segmentation = viewer.layers[1]

Get the labels from the napari Labels layer

labels = initial_segmentation.data

Convert labels to a numpy array

labeled_image = np.array(labels)

Get the image data as a NumPy array

image_data = viewer.layers[1].data

Convert to a NumPy array

image = np.array(image_data)

props = regionprops(labeled_image)
points = np.array([prop.centroid for prop in props])[:, ::-1] # The coordinates of the centroids need to be reversed to match the convention of SAM.
point_labels = np.ones(len(points), dtype="int") # <- All prompts are positive.

predictor = get_sam_model(model_type="vit_b_lm") # <- you can control which model is used with the model type argument.

See the function signature of get_sam_model for details.

refined_segmentation = batched_inference(
predictor, image,
batch_size=32, # This controls how many points are processed at once, lower it if you get memory issues
points=points,
point_labels=point_labels,
return_instance_segmentation=True
)

BUT :

RuntimeError                              Traceback (most recent call last)
Cell In[3], line 29
     26 predictor = get_sam_model(model_type="vit_b_lm")  # <- you can control which model is used with the model type argument. 
     27 # See the function signature of get_sam_model for details.
---> 29 refined_segmentation = batched_inference(
     30   predictor, image,
     31   batch_size=32,  # This controls how many points are processed at once, lower it if you get memory issues 
     32   points=points,
     33   point_labels=point_labels,
     34   return_instance_segmentation=True
     35 )

File ~\AppData\Local\micro_sam\Lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~\AppData\Local\micro_sam\Lib\site-packages\micro_sam\inference.py:115, in batched_inference(predictor, image, batch_size, boxes, points, point_labels, multimasking, embedding_path, return_instance_segmentation, segmentation_ids, reduce_multimasking)
    112 batch_points = points[batch_start:batch_stop] if have_points else None
    113 batch_labels = point_labels[batch_start:batch_stop] if have_points else None
--> 115 batch_masks, batch_ious, _ = predictor.predict_torch(
    116     point_coords=batch_points, point_labels=batch_labels,
    117     boxes=batch_boxes, multimask_output=multimasking
    118 )
    120 # If we expect to reduce the masks from multimasking and use multi-masking,
    121 # then we need to select the most likely mask (according to the predicted IOU) here.
    122 if reduce_multimasking and multimasking:

File ~\AppData\Local\micro_sam\Lib\site-packages\torch\utils\_contextlib.py:115, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    112 @functools.wraps(func)
    113 def decorate_context(*args, **kwargs):
    114     with ctx_factory():
--> 115         return func(*args, **kwargs)

File ~\AppData\Local\micro_sam\Lib\site-packages\segment_anything\predictor.py:222, in SamPredictor.predict_torch(self, point_coords, point_labels, boxes, mask_input, multimask_output, return_logits)
    219     points = None
    221 # Embed prompts
--> 222 sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
    223     points=points,
    224     boxes=boxes,
    225     masks=mask_input,
    226 )
    228 # Predict masks
    229 low_res_masks, iou_predictions = self.model.mask_decoder(
    230     image_embeddings=self.features,
    231     image_pe=self.model.prompt_encoder.get_dense_pe(),
   (...)
    234     multimask_output=multimask_output,
    235 )

File ~\AppData\Local\micro_sam\Lib\site-packages\torch\nn\modules\module.py:1518, in Module._wrapped_call_impl(self, *args, **kwargs)
   1516     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1517 else:
-> 1518     return self._call_impl(*args, **kwargs)

File ~\AppData\Local\micro_sam\Lib\site-packages\torch\nn\modules\module.py:1527, in Module._call_impl(self, *args, **kwargs)
   1522 # If we don't have any hooks, we want to skip the rest of the logic in
   1523 # this function, and just call forward.
   1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1525         or _global_backward_pre_hooks or _global_backward_hooks
   1526         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1527     return forward_call(*args, **kwargs)
   1529 try:
   1530     result = None

File ~\AppData\Local\micro_sam\Lib\site-packages\segment_anything\modeling\prompt_encoder.py:155, in PromptEncoder.forward(self, points, boxes, masks)
    153 if points is not None:
    154     coords, labels = points
--> 155     point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
    156     sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
    157 if boxes is not None:

File ~\AppData\Local\micro_sam\Lib\site-packages\segment_anything\modeling\prompt_encoder.py:84, in PromptEncoder._embed_points(self, points, labels, pad)
     82     padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
     83     padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
---> 84     points = torch.cat([points, padding_point], dim=1)
     85     labels = torch.cat([labels, padding_label], dim=1)
     86 point_embedding = self.pe_layer.forward_with_coords(points, self.input_image_size)

RuntimeError: Tensors must have same number of dimensions: got 2 and 3

not sure how to fix the dimension problem ...?

thanks
antho

@constantinpape
Copy link
Contributor Author

It turns out an extra dimension needs to be added to the inputs of batched_prediction:

# We need to add an extra dimension to provide the correct input for batched_prediction.
points = np.expand_dims(points, 1)
point_labels = np.expand_dims(point_labels, 1)

(I have updated the pseudo-code on top too, so you can see the full example there.)

@Nal44
Copy link

Nal44 commented Feb 8, 2024

Thanks a lot,

there is a shape argument missing in the function ( mask_data_to_segmentation) , will work to fix it.
Almost there ...!

thanks a lot,
antho

@Nal44
Copy link

Nal44 commented Feb 8, 2024

Ok that seems to work : (tested roughly!)

I aslo added in the inference.py file (shape attribute was missing) :

added shape=image_shape

if return_instance_segmentation:
    masks = mask_data_to_segmentation(masks, with_background=False, min_object_size=0, shape=image_shape)"

import numpy as np
import napari
from micro_sam.util import get_sam_model
from micro_sam.inference import batched_inference
from skimage.measure import regionprops

image = viewer.layers[0]
initial_segmentation = viewer.layers[1]

Get the labels from the napari Labels layer

labels = initial_segmentation.data

Convert labels to a numpy array

labeled_image = np.array(labels)

Get the image data as a NumPy array

image_data = viewer.layers[1].data

Convert to a NumPy array

image = np.array(image_data)

props = regionprops(labeled_image)
points = np.array([prop.centroid for prop in props])[:, ::-1] # The coordinates of the centroids need to be reversed to match the convention of SAM.
point_labels = np.ones(len(points), dtype="int") # <- All prompts are positive.

We need to add an extra dimension to provide the correct input for batched_prediction.

points = np.expand_dims(points, 1)
point_labels = np.expand_dims(point_labels, 1)

predictor = get_sam_model(model_type="vit_b_lm") # <- you can control which model is used with the model type argument.

See the function signature of get_sam_model for details.

refined_segmentation = batched_inference(
predictor, image,
batch_size=32, # This controls how many points are processed at once, lower it if you get memory issues
points=points,
point_labels=point_labels,
return_instance_segmentation=True
)

Add the refined segmentation labels as a new layer to the viewer

viewer.add_labels(refined_segmentation, name='Refined Segmentation Labels')

Optionally, you can also set the colormap and opacity for the new layer

viewer.layers[-1].colormap = 'viridis'
viewer.layers[-1].opacity = 0.5

will test more later on,

Thanks a lot !! :)
antho

@constantinpape
Copy link
Contributor Author

Ok that seems to work : (tested roughly!)

Ok, great! Let me know how the quality looks. If there are any issues this can probably be improved by adjusting some parameters.

I aslo added in the inference.py file (shape attribute was missing) :

This should not be necessary if you're working of the dev branch:
https://github.com/computational-cell-analytics/micro-sam/blob/dev/micro_sam/instance_segmentation.py#L51
(dev contains the latest version, and we will merge it into the master branch soon).

But that is only a minor thing, just be aware that this might change soon on master too.

@Nal44
Copy link

Nal44 commented Feb 9, 2024

the results are slightly different , hence I think I have to adjust the parameters as you suggested, but it works in principle :D .

I think the point by side (default is 32) gives better results using 100 (more granular), but any customable parameters will be useful :).

I can share the image / pre-sam output if that helps?

The main idea, is to refine the segmentation for the elongated cells (often the sides are not well segmented), but also refine doublets and general fine segmentation. In addition it will be really cool to be able to add points (automatically) for any missing cells from the pre-segmentation. That way it will do 2 things : refine the existing segmentation and add the missing cells (1 stone , 2 birds..!) .

I am still on the master branch , but will switch to the dev one ,

what about for 3d (my main interest) ?
Thanks a lot :)
antho

@constantinpape
Copy link
Contributor Author

the results are slightly different , hence I think I have to adjust the parameters as you suggested, but it works in principle :D .

That's great!

I can share the image / pre-sam output if that helps?

Yes, that would be quite helpful!

In addition it will be really cool to be able to add points (automatically) for any missing cells from the pre-segmentation.

Do you have a good heuristic for how to adding points automatically?

what about for 3d (my main interest) ?

I will follow up on that next week. (I am on a retreat this week, so my answers are a bit slower, but I will be working on this next week anyways and share some code.)

@Nal44
Copy link

Nal44 commented Feb 13, 2024

Hi,

I sent you an invite to share the files to your email, I included the original image, my custom 2d model masks and the refinements from microsam, as mentioned the main improvement could be with elongated nucleus, that will be great to refine these :) .

For adding points , I was part of the last HTAN jamboree (https://github.com/NCI-HTAN-Jamborees/Improving-cell-segmentation-for-spatial-omics/tree/main) , we worked on similar approaches, and I know there are a few papers working on the idea , using the specialized models as promts (as we are doing now) , but adding the automatic grid points on top in case that the specialized model missed some nuclei (grid worked better with 100 points if i remember correctly). I will dig into finding these papers later on .

No probs for the delay, the 3d is the most time consuming, hence any help will be appreciated :)
Retreat , means holidays, hence no work :), enjoy and talk next week then !

thanks.
antho

@constantinpape
Copy link
Contributor Author

I sent you an invite to share the files to your email, I included the original image, my custom 2d model masks and the refinements from microsam, as mentioned the main improvement could be with elongated nucleus, that will be great to refine these :) .

Thanks for sending the data. Unfortunately the service you used for sharing seems to require a client for downloading that is not available for linux (and I use a linux machine). Could you share it with a different service that enables direct download via the browser?

@Nal44
Copy link

Nal44 commented Feb 13, 2024

I sent a google drive link, does this work ?

@constantinpape
Copy link
Contributor Author

Yes that worked! I have downloaded the data and will take a closer look next week.

@constantinpape
Copy link
Contributor Author

Hi @Nal44 ,
sorry to take a bit longer to follow up, had a busy few weeks. I am back working on the tool this week and will try to follow up here by the end of the week.

@Nal44
Copy link

Nal44 commented Mar 5, 2024

Hi ,
Sounds good , thanks for the update :)
antho

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants