Skip to content

Commit

Permalink
Support flexible shape fill_ for torch. (#1897)
Browse files Browse the repository at this point in the history
Support flexible shape fill_ for torch.
  • Loading branch information
fukatani authored Jul 20, 2023
1 parent ce4f85b commit 57270f3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 10 deletions.
10 changes: 10 additions & 0 deletions coremltools/converters/mil/frontend/torch/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3376,6 +3376,16 @@ def _internal_op_tensor_inplace_fill(context, node):
data = context[node.inputs[0]]
fill_scalar = context[node.inputs[1]]

if len(node.inputs) == 2 and fill_scalar.val is not None:
shape = mb.shape(x=data)
if isinstance(fill_scalar.val, _np.ndarray):
fill = mb.fill(shape=shape, value=fill_scalar.val.item())
else:
fill = mb.fill(shape=shape, value=fill_scalar)
casted = mb.cast(x=fill, dtype=TYPE_TO_DTYPE_STRING[data.dtype], name=node.name)
context.add(casted)
return

begin, end, stride, begin_mask, end_mask, squeeze_mask = _get_slice_params(
context, data, node.inputs[2:]
)
Expand Down
90 changes: 80 additions & 10 deletions coremltools/converters/mil/frontend/torch/test/test_torch_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6005,46 +6005,116 @@ def forward(self, x):

class TestFill(TorchBaseTest):
@pytest.mark.parametrize(
"compute_unit, backend, rank",
"compute_unit, backend, rank, dynamic, fill_scalar, src_dtype",
itertools.product(
compute_units,
backends,
[1, 3],
[False, True],
[0.2, torch.tensor(float("-inf")), torch.tensor(2)],
[torch.int32, torch.float32],
),
)
def test_fill_(self, compute_unit, backend, rank):
def test_fill_(self, compute_unit, backend, rank, dynamic, fill_scalar, src_dtype):
if src_dtype == torch.int32 and fill_scalar == torch.tensor(float("-inf")):
pytest.skip("float(-inf) cannot be casted to int.")

input_shape = np.random.randint(low=2, high=6, size=rank)
input_shape = tuple(input_shape)

class FillModel(nn.Module):
def forward(self, x):
y = torch.empty(x.shape)
y.fill_(0.2)
y = torch.empty(x.shape, dtype=src_dtype)
y.fill_(fill_scalar)
return y

model = FillModel()
self.run_compare_torch(input_shape, model, backend=backend, compute_unit=compute_unit)
if dynamic:
upper_bound = 10 if backend[0] == "mlprogram" else -1
if rank == 1:
converter_input_type = [
ct.TensorType(
shape=(
ct.RangeDim(upper_bound=upper_bound),
)
),
]
else:
converter_input_type = [
ct.TensorType(
shape=(
ct.RangeDim(upper_bound=upper_bound),
ct.RangeDim(upper_bound=upper_bound),
ct.RangeDim(upper_bound=upper_bound),
)
),
]
else:
converter_input_type = None

self.run_compare_torch(
input_shape,
model,
converter_input_type=converter_input_type,
backend=backend,
compute_unit=compute_unit
)

@pytest.mark.parametrize(
"compute_unit, backend, rank",
"compute_unit, backend, rank, dynamic, fill_scalar, src_dtype",
itertools.product(
compute_units,
backends,
[1, 3],
[False, True],
[0.2, torch.tensor(float("-inf")), torch.tensor(2)],
[torch.int32, torch.float32],
),
)
def test_fill__2(self, compute_unit, backend, rank):
def test_fill__2(self, compute_unit, backend, rank, dynamic, fill_scalar, src_dtype):
if src_dtype == torch.int32 and fill_scalar == torch.tensor(float("-inf")):
pytest.skip("float(-inf) cannot be casted to int.")

input_shape = np.random.randint(low=2, high=6, size=rank)
input_shape = tuple(input_shape)

class FillModel(nn.Module):
def forward(self, x):
y = torch.empty(x.shape)
y.fill_(0.2)
y = torch.empty(x.shape, dtype=src_dtype)
y.fill_(fill_scalar)
return y + 1

model = FillModel()
self.run_compare_torch(input_shape, model, backend=backend, compute_unit=compute_unit)
if dynamic:
upper_bound = 10 if backend[0] == "mlprogram" else -1
if rank == 1:
converter_input_type = [
ct.TensorType(
shape=(
ct.RangeDim(upper_bound=upper_bound),
)
),
]
else:
converter_input_type = [
ct.TensorType(
shape=(
ct.RangeDim(upper_bound=upper_bound),
ct.RangeDim(upper_bound=upper_bound),
ct.RangeDim(upper_bound=upper_bound),
)
),
]
else:
converter_input_type = None

self.run_compare_torch(
input_shape,
model,
converter_input_type=converter_input_type,
backend=backend,
compute_unit=compute_unit
)


class TestCopy(TorchBaseTest):
Expand Down

0 comments on commit 57270f3

Please sign in to comment.