Skip to content

Commit

Permalink
enable onnx inference
Browse files Browse the repository at this point in the history
  • Loading branch information
Y-T-G committed Aug 11, 2023
1 parent d8305f8 commit d685958
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 17 deletions.
2 changes: 1 addition & 1 deletion app.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_prompt(click_state, click_input):
"prompt_type": ["click"],
"input_point": click_state[0],
"input_label": click_state[1],
"multimask_output": "True",
"multimask_output": "False",
}
return prompt

Expand Down
13 changes: 10 additions & 3 deletions utils/base_segmenter.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,10 @@ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device="c
from mobile_sam import sam_model_registry, SamPredictor
from onnxruntime import InferenceSession
self.ort_session = InferenceSession(sam_onnx_checkpoint)
self.predict = self.predict_onnx
else:
from segment_anything import sam_model_registry, SamPredictor
self.predict = self.predict_pt

self.model = sam_model_registry[model_type](checkpoint=sam_pt_checkpoint)
self.model.to(device=self.device)
Expand All @@ -51,7 +53,7 @@ def reset_image(self):
self.predictor.reset_image()
self.embedded = False

def predict(self, prompts, mode, multimask=True):
def predict_pt(self, prompts, mode, multimask=True):
"""
image: numpy array, h, w, 3
prompts: dictionary, 3 keys: 'point_coords', 'point_labels', 'mask_input'
Expand Down Expand Up @@ -115,17 +117,20 @@ def predict_onnx(self, prompts, mode, multimask=True):
"orig_im_size": prompts["orig_im_size"],
}
masks, scores, logits = self.ort_session.run(None, ort_inputs)
masks = masks > self.predictor.model.mask_threshold

elif mode == "mask":
ort_inputs = {
"image_embeddings": self.image_embedding,
"point_coords": prompts["point_coords"],
"point_coords": np.zeros((len(prompts["point_labels"]), 2), dtype=np.float32),
"point_labels": prompts["point_labels"],
"mask_input": prompts["mask_input"],
"has_mask_input": np.ones(1, dtype=np.float32),
"orig_im_size": prompts["orig_im_size"],
}
masks, scores, logits = self.ort_session.run(None, ort_inputs)
masks = masks > self.predictor.model.mask_threshold

elif mode == "both": # both
ort_inputs = {
"image_embeddings": self.image_embedding,
Expand All @@ -136,7 +141,9 @@ def predict_onnx(self, prompts, mode, multimask=True):
"orig_im_size": prompts["orig_im_size"],
}
masks, scores, logits = self.ort_session.run(None, ort_inputs)
masks = masks > self.predictor.model.mask_threshold

else:
raise ("Not implement now!")
# masks (n, h, w), scores (n,), logits (n, 256, 256)
return masks, scores, logits
return masks[0], scores[0], logits[0]
33 changes: 20 additions & 13 deletions utils/interact_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def __init__(self, sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device):
"""

self.sam_controler = BaseSegmenter(sam_pt_checkpoint, sam_onnx_checkpoint, model_type, device)
self.onnx = model_type == "vit_t"

def first_frame_click(
self,
Expand All @@ -38,32 +39,38 @@ def first_frame_click(
"""
# self.sam_controler.set_image(image)
neg_flag = labels[-1]
if neg_flag == 1:
# find neg

if self.onnx:
onnx_coord = np.concatenate([points, np.array([[0.0, 0.0]])], axis=0)[None, :, :]
onnx_label = np.concatenate([labels, np.array([-1])], axis=0)[None, :].astype(np.float32)
onnx_coord = self.sam_controler.predictor.transform.apply_coords(onnx_coord, image.shape[:2]).astype(np.float32)
prompts = {
"point_coords": onnx_coord,
"point_labels": onnx_label,
"orig_im_size": np.array(image.shape[:2], dtype=np.float32),
}

else:
prompts = {
"point_coords": points,
"point_labels": labels,
"orig_im_size": image.shape[:2],
}

if neg_flag == 1:
# find positive
masks, scores, logits = self.sam_controler.predict(
prompts, "point", multimask
)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]
prompts = {
"point_coords": points,
"point_labels": labels,
"mask_input": logit[None, :, :],
}

prompts["mask_input"] = np.expand_dims(logit[None, :, :], 0)
masks, scores, logits = self.sam_controler.predict(
prompts, "both", multimask
)
mask, logit = masks[np.argmax(scores)], logits[np.argmax(scores), :, :]

else:
# find positive
prompts = {
"point_coords": points,
"point_labels": labels,
}
# find neg
masks, scores, logits = self.sam_controler.predict(
prompts, "point", multimask
)
Expand Down

0 comments on commit d685958

Please sign in to comment.