Skip to content
This repository has been archived by the owner on Jul 1, 2024. It is now read-only.

Unable to export decoder in onnx format for GPU context #747

Open
Greg7000 opened this issue May 7, 2024 · 0 comments
Open

Unable to export decoder in onnx format for GPU context #747

Greg7000 opened this issue May 7, 2024 · 0 comments

Comments

@Greg7000
Copy link

Greg7000 commented May 7, 2024

I am currently trying to execute notebook block 10 of this link https://github.com/AndreyGermanov/sam_onnx_full_export/blob/main/sam_onnx_export.ipynb}

Which is:

# Export mask decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)
embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float),
    "has_mask_input": torch.tensor([1], dtype=torch.float),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
    f="vit_b_decoder.onnx",
    model=onnx_model,
    args=tuple(dummy_inputs.values()),
    input_names=list(dummy_inputs.keys()),
    output_names=output_names,
    dynamic_axes={
        "point_coords": {1: "num_points"},
        "point_labels": {1: "num_points"}
    },
    export_params=True,
    opset_version=17,
    do_constant_folding=True
)

This works perfectly fine for cpu context. However when trying to do it for a GPU context using:

from segment_anything import sam_model_registry
from segment_anything.utils.onnx import SamOnnxModel

# Load SAM model
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")

# sam = sam.cuda()
sam.to(device="cuda")


# Export mask decoder from SAM model to ONNX
onnx_model = SamOnnxModel(sam, return_single_mask=True)


embed_dim = sam.prompt_encoder.embed_dim
embed_size = sam.prompt_encoder.image_embedding_size
mask_input_size = [4 * x for x in embed_size]
dummy_inputs = {
    "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float).cuda(),
    "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float).cuda(),
    "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float).cuda(),
    "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float).cuda(),
    "has_mask_input": torch.tensor([1], dtype=torch.float).cuda(),
    "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float).cuda(),
}
output_names = ["masks", "iou_predictions", "low_res_masks"]
torch.onnx.export(
    f="bob/vit_h_decoder.onnx",
    model=onnx_model,
    args=tuple(dummy_inputs.values()),
    input_names=list(dummy_inputs.keys()),
    output_names=output_names,
    dynamic_axes={"point_coords": {1: "num_points"}, "point_labels": {1: "num_points"}},
    export_params=True,
    opset_version=17,
    do_constant_folding=True,
)

I get:

Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument index in method wrapper_CUDA__index_select)

Raised from segment_anything.modeling.mask_decoder.MaskDecoder.predict_mask line 126
(raised from torch.repeat_interleave(image_embeddings, tokens.shape[0], dim=0))

I tried many different things but the only way I managed to get it to work is by modifying a bit mask_decoder.py which is undesirable.

Anybody got a suggestion that could avoid any modifications to mask_decoder.py

I have a cpu remnant somewhere, maybe I need to convert embed_size and mask_input_size to torch.Size() but it does note seem to be enough

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant