-
Notifications
You must be signed in to change notification settings - Fork 5.5k
/
onnx.py
144 lines (119 loc) · 5.68 KB
/
onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Tuple
from ..modeling import Sam
from .amg import calculate_stability_score
class SamOnnxModel(nn.Module):
"""
This model should not be called directly, but is used in ONNX export.
It combines the prompt encoder, mask decoder, and mask postprocessing of Sam,
with some functions modified to enable model tracing. Also supports extra
options controlling what information. See the ONNX export script for details.
"""
def __init__(
self,
model: Sam,
return_single_mask: bool,
use_stability_score: bool = False,
return_extra_metrics: bool = False,
) -> None:
super().__init__()
self.mask_decoder = model.mask_decoder
self.model = model
self.img_size = model.image_encoder.img_size
self.return_single_mask = return_single_mask
self.use_stability_score = use_stability_score
self.stability_score_offset = 1.0
self.return_extra_metrics = return_extra_metrics
@staticmethod
def resize_longest_image_size(
input_image_size: torch.Tensor, longest_side: int
) -> torch.Tensor:
input_image_size = input_image_size.to(torch.float32)
scale = longest_side / torch.max(input_image_size)
transformed_size = scale * input_image_size
transformed_size = torch.floor(transformed_size + 0.5).to(torch.int64)
return transformed_size
def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
point_coords = point_coords + 0.5
point_coords = point_coords / self.img_size
point_embedding = self.model.prompt_encoder.pe_layer._pe_encoding(point_coords)
point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
point_embedding = point_embedding * (point_labels != -1)
point_embedding = point_embedding + self.model.prompt_encoder.not_a_point_embed.weight * (
point_labels == -1
)
for i in range(self.model.prompt_encoder.num_point_embeddings):
point_embedding = point_embedding + self.model.prompt_encoder.point_embeddings[
i
].weight * (point_labels == i)
return point_embedding
def _embed_masks(self, input_mask: torch.Tensor, has_mask_input: torch.Tensor) -> torch.Tensor:
mask_embedding = has_mask_input * self.model.prompt_encoder.mask_downscaling(input_mask)
mask_embedding = mask_embedding + (
1 - has_mask_input
) * self.model.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
return mask_embedding
def mask_postprocessing(self, masks: torch.Tensor, orig_im_size: torch.Tensor) -> torch.Tensor:
masks = F.interpolate(
masks,
size=(self.img_size, self.img_size),
mode="bilinear",
align_corners=False,
)
prepadded_size = self.resize_longest_image_size(orig_im_size, self.img_size).to(torch.int64)
masks = masks[..., : prepadded_size[0], : prepadded_size[1]] # type: ignore
orig_im_size = orig_im_size.to(torch.int64)
h, w = orig_im_size[0], orig_im_size[1]
masks = F.interpolate(masks, size=(h, w), mode="bilinear", align_corners=False)
return masks
def select_masks(
self, masks: torch.Tensor, iou_preds: torch.Tensor, num_points: int
) -> Tuple[torch.Tensor, torch.Tensor]:
# Determine if we should return the multiclick mask or not from the number of points.
# The reweighting is used to avoid control flow.
score_reweight = torch.tensor(
[[1000] + [0] * (self.model.mask_decoder.num_mask_tokens - 1)]
).to(iou_preds.device)
score = iou_preds + (num_points - 2.5) * score_reweight
best_idx = torch.argmax(score, dim=1)
masks = masks[torch.arange(masks.shape[0]), best_idx, :, :].unsqueeze(1)
iou_preds = iou_preds[torch.arange(masks.shape[0]), best_idx].unsqueeze(1)
return masks, iou_preds
@torch.no_grad()
def forward(
self,
image_embeddings: torch.Tensor,
point_coords: torch.Tensor,
point_labels: torch.Tensor,
mask_input: torch.Tensor,
has_mask_input: torch.Tensor,
orig_im_size: torch.Tensor,
):
sparse_embedding = self._embed_points(point_coords, point_labels)
dense_embedding = self._embed_masks(mask_input, has_mask_input)
masks, scores = self.model.mask_decoder.predict_masks(
image_embeddings=image_embeddings,
image_pe=self.model.prompt_encoder.get_dense_pe(),
sparse_prompt_embeddings=sparse_embedding,
dense_prompt_embeddings=dense_embedding,
)
if self.use_stability_score:
scores = calculate_stability_score(
masks, self.model.mask_threshold, self.stability_score_offset
)
if self.return_single_mask:
masks, scores = self.select_masks(masks, scores, point_coords.shape[1])
upscaled_masks = self.mask_postprocessing(masks, orig_im_size)
if self.return_extra_metrics:
stability_scores = calculate_stability_score(
upscaled_masks, self.model.mask_threshold, self.stability_score_offset
)
areas = (upscaled_masks > self.model.mask_threshold).sum(-1).sum(-1)
return upscaled_masks, scores, stability_scores, areas, masks
return upscaled_masks, scores, masks