Skip to content

Commit

Permalink
Add PyTorch op: tupleindex (#1772)
Browse files Browse the repository at this point in the history
  • Loading branch information
TobyRoseman authored Feb 20, 2023
1 parent a0effdf commit 51b0003
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 5 deletions.
4 changes: 2 additions & 2 deletions coremltools/converters/mil/frontend/torch/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,12 +241,12 @@ def convert(self):
# This will hold the converted model.
prog = self._prog

# Construct placeholder for input to ssa function
# Construct placeholder for input to SSA function
# This is where input renaming occurs
ssa_func_inputs = OrderedDict()
for index, (name, spec) in enumerate(self.graph.inputs.items()):
placeholder = self._create_placeholder(spec)
# Set ssa function input name to user defined name if provided.
# Set SSA function input name to user defined name if provided.
if spec.name is not None:
name = spec.name
self.inputs[index].name = name
Expand Down
6 changes: 6 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5632,3 +5632,9 @@ def torchvision_nms(context, node):
valid_indices = mb.gather(x=indices, indices=range, axis=0)
valid_indices = mb.cast(x=valid_indices, dtype="int32", name=node.name)
context.add(valid_indices)


@register_torch_op
def tupleindex(context, node):
tuple_input, index_input = _get_inputs(context, node, expected=2)
context.add(tuple_input[index_input.val], node.name)
25 changes: 25 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8172,3 +8172,28 @@ def forward(self, x):
self.run_compare_torch(
input_shape, UnfoldModel(), backend=backend, compute_unit=compute_unit
)


class TestTupleIndex(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend",
itertools.product(compute_units, backends,),
)
def test_tuple_index(self, compute_unit, backend):
class InnerModel(nn.Module):
def forward(self,x):
return (torch.tensor([0]), torch.tensor([1]))

class OuterModel(nn.Module):
def __init__(self):
super().__init__()
self.innermodel = torch.jit.trace(InnerModel().eval(), x)

def forward(self, x):
inner = self.innermodel(x)
return inner[0]

x = torch.rand(1, 3, 640, 640)
self.run_compare_torch(x, OuterModel(),
input_as_shape=False, use_scripting=True,
backend=backend, compute_unit=compute_unit)
3 changes: 0 additions & 3 deletions coremltools/converters/mil/mil/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,9 +830,6 @@ def get_dot_string(


class Function(Block):
"""
"""

def __init__(self, inputs, opset_version=None):
"""
inputs: str -> placeholder
Expand Down

0 comments on commit 51b0003

Please sign in to comment.