Skip to content

Commit

Permalink
Support torch.amax and torch.amin (#1797)
Browse files Browse the repository at this point in the history
Support torch.amax and torch.amin
  • Loading branch information
nikalra authored Mar 10, 2023
1 parent 9065fdc commit fa190de
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 0 deletions.
20 changes: 20 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -4307,6 +4307,26 @@ def max(context, node):
context.add(values, torch_name=values_name)
context.add(indices, torch_name=indices_name)

def _add_amax_amin(context, node, reduce_op):
# mimic functionality from https://pytorch.org/docs/stable/generated/torch.amax.html
# mimic functionality from https://pytorch.org/docs/stable/generated/torch.amin.html
assert len(node.outputs) == 1

all_inputs = _get_inputs(context, node, expected=[2, 3])
_input = all_inputs[0]
dim = [all_inputs[1].val] if type(all_inputs[1].val) == int else [x for x in all_inputs[1].val]
keepdim = all_inputs[2] if len(all_inputs) == 3 else False

context.add(reduce_op(x=_input, axes=dim, keep_dims=keepdim), torch_name=node.outputs[0])

@register_torch_op
def amax(context, node):
_add_amax_amin(context, node, mb.reduce_max)

@register_torch_op
def amin(context, node):
_add_amax_amin(context, node, mb.reduce_min)


@register_torch_op
def argsort(context, node):
Expand Down
35 changes: 35 additions & 0 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2865,6 +2865,41 @@ def forward(self, x, y):
input_shapes, model, backend=backend, compute_unit=compute_unit
)

class TestAMaxAMin(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, input_shapes, mode, reduce_dim, keepdim",
itertools.product(
compute_units,
backends,
[
[(2, 5, 7, 3)],
[(3, 2, 9)],
[(1,)],
],
["minimum", "maximum"],
[0, 1, 2, 3, [0, 1], [0, 1, 2], [0, 1, 2, 3]],
[True, False],
),
)
def test_minimum_maximum(self, compute_unit, backend, input_shapes, mode, reduce_dim, keepdim):
class TestModel(torch.nn.Module):
def forward(self, input):
if type(reduce_dim) == int:
reduce_dim_clamped = min(input.dim() - 1, reduce_dim)
else:
reduce_dim_clamped = reduce_dim[:input.dim()]
if mode == "minimum":
return torch.amin(input, reduce_dim_clamped, keepdim)
elif mode == "maximum":
return torch.amax(input, reduce_dim_clamped, keepdim)
else:
raise ValueError("Unsupported mode: {mode}".format(mode=mode))

model = TestModel()
self.run_compare_torch(
input_shapes, model, backend=backend, compute_unit=compute_unit
)


class TestPoolSymbolicInput(TorchBaseTest):
def test_max_pool(self):
Expand Down

0 comments on commit fa190de

Please sign in to comment.