Skip to content

Commit

Permalink
Fix ONNX export for MobileViT for segmentation (#1128)
Browse files Browse the repository at this point in the history
  • Loading branch information
regisss authored Jul 3, 2023
1 parent 1939df0 commit bc5f825
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 13 deletions.
9 changes: 1 addition & 8 deletions optimum/exporters/onnx/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,7 @@ class OnnxConfig(ExportConfig, ABC):
"feature-extraction": OrderedDict({"last_hidden_state": {0: "batch_size", 1: "sequence_length"}}),
"fill-mask": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"image-classification": OrderedDict({"logits": {0: "batch_size"}}),
# TODO: Is this the same thing as semantic-segmentation?
"image-segmentation": OrderedDict(
{
"logits": {0: "batch_size", 1: "num_queries"},
"pred_boxes": {0: "batch_size", 1: "num_queries"},
"pred_masks": {0: "batch_size", 1: "num_queries"},
}
),
"image-segmentation": OrderedDict({"logits": {0: "batch_size", 1: "num_labels", 2: "height", 3: "width"}}),
"image-to-text": OrderedDict({"logits": {0: "batch_size", 1: "sequence_length"}}),
"mask-generation": OrderedDict({"logits": {0: "batch_size"}}),
"masked-im": OrderedDict(
Expand Down
13 changes: 9 additions & 4 deletions optimum/exporters/onnx/model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ class ConvNextOnnxConfig(ViTOnnxConfig):


class MobileViTOnnxConfig(ViTOnnxConfig):
pass
ATOL_FOR_VALIDATION = 1e-4


class RegNetOnnxConfig(ViTOnnxConfig):
Expand All @@ -588,9 +588,14 @@ class DetrOnnxConfig(ViTOnnxConfig):
DEFAULT_ONNX_OPSET = 12

@property
def inputs(self) -> Dict[str, Dict[int, str]]:
# TODO: is pixel mask needed?
return {**super().inputs, "pixel_mask": {0: "batch_size"}}
def outputs(self) -> Dict[str, Dict[int, str]]:
if self.task == "image-segmentation":
return {
"logits": {0: "batch_size", 1: "num_queries"},
"pred_masks": {0: "batch_size", 1: "num_queries"},
}
else:
return super().outputs


class YolosOnnxConfig(ViTOnnxConfig):
Expand Down
4 changes: 3 additions & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ class TasksManager:
"object-detection": "AutoModelForObjectDetection",
"question-answering": "AutoModelForQuestionAnswering",
"image-classification": "AutoModelForImageClassification",
"image-segmentation": "AutoModelForImageSegmentation",
"image-segmentation": ("AutoModelForImageSegmentation", "AutoModelForSemanticSegmentation"),
"mask-generation": "AutoModel",
"masked-im": "AutoModelForMaskedImageModeling",
"semantic-segmentation": "AutoModelForSemanticSegmentation",
Expand Down Expand Up @@ -632,6 +632,7 @@ class TasksManager:
"mobilevit": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-segmentation",
onnx="MobileViTOnnxConfig",
),
"mobilenet-v1": supported_tasks_mapping(
Expand Down Expand Up @@ -766,6 +767,7 @@ class TasksManager:
"segformer": supported_tasks_mapping(
"feature-extraction",
"image-classification",
"image-segmentation",
"semantic-segmentation",
onnx="SegformerOnnxConfig",
),
Expand Down

0 comments on commit bc5f825

Please sign in to comment.