-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add torch ops for d2go models #1509
base: main
Are you sure you want to change the base?
Changes from 11 commits
c5fe58a
5542de8
fdd1590
7b4cbd9
b002207
6389427
cecae9c
15185d8
e1b7d0f
70f1954
9d2d092
b913630
bd08a2b
b0074cc
a9fb7ed
29217d5
fb0cd19
12662dd
b268f9b
c58abbd
df41d90
573f103
8834011
12b3cc1
bdcfe40
4508f19
7944178
3356450
203b555
01983e6
f181995
9715d07
d13735d
7ce9f6e
5d842ec
4353c4c
ed2f33e
bf5de6b
b2e8153
20da0e2
ca4cd92
c80a3a7
8631d1b
2f05538
ed02c4d
ec550ca
78ab5fd
f8e1776
d96b7d6
c082d4c
108f5da
47debd3
e1aaf57
1f29b6a
f25a684
f2f795b
9be029f
37eef0e
9e842a2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2473,6 +2473,10 @@ def upsample_nearest2d(context, node): | |
def tupleunpack(context, node): | ||
inputs = _get_inputs(context, node, expected=1) | ||
values = inputs[0] | ||
|
||
if len(node.outputs) == 1: | ||
values = [values] | ||
|
||
# Node input could have been turned into constant array in @tupleconstruct | ||
if not isinstance(values, tuple) and not isinstance(values, list): | ||
values = values.val | ||
|
@@ -2972,8 +2976,11 @@ def index(context, node): | |
# For multiple index axes case, we now assume that all the index have equal shape | ||
for index in valid_indices: | ||
if not is_compatible_symbolic_vector(index.shape, valid_indices[0].shape): | ||
raise NotImplementedError("Broadcasable tensor index not supported.") | ||
|
||
broadcast_inputs = _broadcast_tensors([valid_indices[0], index]) | ||
index = broadcast_inputs[1] | ||
valid_indices[0] = broadcast_inputs[0] | ||
valid_indices.append(index) | ||
|
||
# First stack the index together | ||
indices_rank = valid_indices[0].rank | ||
indices = mb.stack(values=valid_indices, axis=indices_rank) | ||
|
@@ -3273,6 +3280,18 @@ def _slice(context, node): | |
context.add(res) | ||
|
||
|
||
def _num_splits_and_sizes(split_sizes): | ||
if split_sizes.sym_val is not None: | ||
return len(split_sizes.sym_val), split_sizes.sym_val | ||
|
||
if any_symbolic(split_sizes.shape): | ||
raise ValueError("Unable to determine number of splits") | ||
|
||
num_splits = len(split_sizes.shape) | ||
sizes = [get_new_symbol() for _ in range(num_splits)] | ||
return num_splits, sizes | ||
|
||
|
||
@register_torch_op(torch_alias=["split_with_sizes"]) | ||
def split(context, node): | ||
inputs = _get_inputs(context, node, expected=3) | ||
|
@@ -3300,6 +3319,14 @@ def split(context, node): | |
else: | ||
partial_size = mb.mul(x=tmp, y=remainder) | ||
split_sizes = mb.concat(values=[whole_sizes, partial_size], axis=0) | ||
|
||
|
||
num_splits, sizes = _num_splits_and_sizes(split_sizes=split_sizes) | ||
if num_splits == 1: | ||
out = mb.identity(x=x, name=node.name) | ||
context.add(out, node.name) | ||
return | ||
|
||
res = mb.split(x=x, split_sizes=split_sizes, axis=dim, name=node.name) | ||
context.add(res, torch_name=node.name) | ||
|
||
|
@@ -3357,6 +3384,13 @@ def to(context, node): | |
"Received invalid arguments for PyTorch conversion of op {}".format(node) | ||
) | ||
|
||
# We have to handle the case where the dtype is not set, this should be inferred from the Tensor dtype | ||
# see, https://pytorch.org/docs/stable/generated/torch.Tensor.to.html?highlight=#torch.Tensor.to | ||
if dtype is None: | ||
out = mb.identity(x=_input, name=node.name) | ||
context.add(out, node.name) | ||
return = 0 # TODO: infer from Tensor (spoiler in this case we care about its f32 => 6) | ||
dncnbuck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
torch_dtype = NUM_TO_TORCH_DTYPE[dtype] | ||
if isinstance(_input, Var) and _input.val is not None: | ||
_input = _input.val | ||
|
@@ -3799,8 +3833,20 @@ def ceil(context, node): | |
@register_torch_op | ||
def clamp(context, node): | ||
inputs = _get_inputs(context, node, expected=3) | ||
min_val = inputs[1] if inputs[1] else _np.finfo(_np.float32).min | ||
max_val = inputs[2] if inputs[2] else _np.finfo(_np.float32).max | ||
if not inputs[1]: | ||
min_val = _np.finfo(_np.float32).min | ||
else: | ||
min_val = inputs[1] | ||
if types.builtin_to_string(min_val.dtype).startswith('int'): | ||
min_val = mb.cast(x=min_val, dtype='fp32') | ||
|
||
if not inputs[2]: | ||
max_val = _np.finfo(_np.float32).max | ||
else: | ||
max_val = inputs[2] | ||
if types.builtin_to_string(max_val.dtype).startswith('int'): | ||
max_val = mb.cast(x=max_val, dtype='fp32') | ||
|
||
context.add(mb.clip(x=inputs[0], alpha=min_val, beta=max_val, name=node.name)) | ||
|
||
@register_torch_op | ||
|
@@ -4128,6 +4174,11 @@ def _make_tensor(list_of_tensor, name, rank): | |
context.add(mb.identity(x=val, name=node.name)) | ||
return | ||
|
||
if inputs[2] is None: | ||
res = mb.const(val=[val.val], name=node.name) | ||
context.add(res, torch_name=node.name) | ||
return | ||
|
||
# Case 2: Create a tensor filled with a single value | ||
val = val.val # element val to fill | ||
msg_prefix = 'torch::tensor {} '.format(node.name) | ||
|
@@ -4357,3 +4408,162 @@ def scatter_add(context, node): | |
updates = inputs[3] | ||
result = mb.scatter_along_axis(data=data, indices=indices, updates=updates, axis=axis, mode="add", name=node.name) | ||
context.add(result) | ||
|
||
@register_torch_op() | ||
def roi_align(context, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are there unit tests for this method? |
||
""" | ||
https://github.com/apple/coremltools/blob/655b3be5cc0d42c3c4fa49f0f0e4a93a26b3e492/mlmodel/format/NeuralNetwork.proto#L2239 | ||
dncnbuck marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
inputs = _get_inputs(context, node) | ||
|
||
x = context[node.inputs[0]] | ||
input_shape = x.shape # (B, h_in, w_in, C) | ||
if len(input_shape) != 4: | ||
raise ValueError( | ||
'"CropResize" op: expected input rank 4, got {}'.format(x.rank) | ||
) | ||
Hin, Win = input_shape[1:3] | ||
|
||
const_box_info = True | ||
if context[node.inputs[1]].val is None or context[node.inputs[2]].val is None: | ||
const_box_info = False | ||
|
||
extrapolation_value = context[node.inputs[2]].val | ||
|
||
# CoreML index information along with boxes | ||
if const_box_info: | ||
boxes = context[node.inputs[1]].val | ||
# CoreML expects boxes/ROI in | ||
# [N, 1, 5, 1, 1] format | ||
boxes = boxes.reshape(boxes.shape[0], 1, boxes.shape[1], 1, 1) | ||
else: | ||
boxes = inputs[1] | ||
boxes = mb.reshape(x=boxes, shape=[boxes.shape[0], 1, boxes.shape[1], 1, 1]) | ||
# Get Height and Width of crop | ||
h_out = inputs[3] | ||
w_out = inputs[4] | ||
|
||
# Torch input format: [B, C, h_in, w_in] | ||
# CoreML input format: [B, C, h_in, w_in] | ||
|
||
# Crop Resize | ||
x = mb.crop_resize( | ||
x=x, | ||
roi=boxes, | ||
target_height=h_out.val, | ||
target_width=w_out.val, | ||
normalized_coordinates=True, | ||
spatial_scale=extrapolation_value, | ||
box_coordinate_mode="CORNERS_HEIGHT_FIRST", | ||
sampling_mode='OFFSET_CORNERS', | ||
) | ||
|
||
# CoreML output format: [N, 1, C, h_out, w_out] | ||
# Torch output format: [N, C, h_out, w_out] | ||
x = mb.squeeze(x=x, axes=[1]) | ||
|
||
context.add(x, torch_name=node.outputs[0]) | ||
|
||
@register_torch_op() | ||
def numel(context, node): | ||
inputs = _get_inputs(context, node, expected=1) | ||
context.add(mb.reduce_prod(x=inputs[0], name=node.name), torch_name=node.outputs[0]) | ||
|
||
@register_torch_op() | ||
def nms(context, node): | ||
inputs = _get_inputs(context, node) | ||
boxes = inputs[0] | ||
|
||
num_boxes = boxes.shape[1] | ||
max_boxes = num_boxes # we set the max_boxes just to be # input boxes | ||
|
||
scores = inputs[1] | ||
iou_threshold = inputs[2] | ||
boxes = mb.expand_dims(x=boxes, axes=[0]) | ||
scores = mb.expand_dims(x=scores, axes=[0, -1]) | ||
|
||
# Follow tensorflow op example: TensorFlow's default value for score_threshold, Core ML does not | ||
# have float('-inf') support, converted to minimum float32 instead | ||
score_threshold = -3.4e38 | ||
|
||
_, _, x, _ = mb.non_maximum_suppression( | ||
boxes=boxes, | ||
scores=scores, | ||
iou_threshold=iou_threshold, | ||
score_threshold=score_threshold, | ||
max_boxes=max_boxes | ||
) | ||
|
||
if not is_symbolic(num_boxes): | ||
x = mb.squeeze(x=x, axes=[0]) | ||
x = mb.slice_by_index(x=x, begin=[0], end=[max_boxes], name=node.name) | ||
else: | ||
x = mb.squeeze(x=x, axes=[0], name=node.name) | ||
context.add(x, torch_name=node.name) | ||
|
||
@register_torch_op | ||
def repeat_interleave(context, node): | ||
inputs = _get_inputs(context, node) | ||
|
||
x = inputs[0] | ||
reps = inputs[1] | ||
dim = inputs[2] if inputs[2] else 0 | ||
|
||
perm = [] + [axis for axis in range(x.rank) if axis not in []] | ||
|
||
x = mb.transpose(x=x, perm=perm) # torch.transpose(x, 0, 1) | ||
x = mb.tile(x=x, reps=reps.val[0], name=node.name) # torch.repeat(x, size) | ||
x = mb.reshape(x=x, shape=(-1, x.shape[0])) # x.view(-1, 2) | ||
x = mb.transpose(x=x, perm=(-1, 0)) # torch.transpose(x, 0, 1) | ||
dims = list(x.shape) | ||
|
||
# Implementation of flatten | ||
total = 1 | ||
start_val = dim | ||
end_val = -1 | ||
start = len(dims) + start_val if start_val < 0 else start_val | ||
end = len(dims) + end_val if end_val < 0 else end_val | ||
|
||
if start > len(dims) or end > len(dims) or start < 0 or end < 0: | ||
raise ValueError( | ||
"Invalid start and end. (start, end) == ({}, {})".format(start, end_val) | ||
) | ||
if start > end: | ||
raise ValueError( | ||
"Start must be before end. (start, end) == ({}, {})".format(start, end_val) | ||
) | ||
x_shape = mb.shape(x=x) | ||
|
||
shape1 = mb.slice_by_index(x=x_shape, begin=[0], end=[start]) | ||
shape2 = mb.slice_by_index(x=x_shape, begin=[end + 1], end=[len(dims)]) | ||
|
||
flatten_dim = -1 | ||
if not any_symbolic(x.shape): | ||
flatten_dim = 1 | ||
for dim in dims[start: end + 1]: | ||
flatten_dim *= dim | ||
|
||
shape = mb.concat(values=(shape1, [flatten_dim], shape2), axis=0) | ||
shape = mb.cast(x=shape, dtype="int32") | ||
reshape = mb.reshape(x=x, shape=shape, name=node.name) | ||
|
||
context.add(reshape, node.name) | ||
|
||
@register_torch_op(override=True) | ||
def narrow(context, node): | ||
data, dim, start, length = _get_inputs(context, node, expected=4) | ||
data_shape = mb.shape(x=data).val | ||
begin = [0]*len(data_shape) | ||
end = [x for x in data_shape] | ||
begin[dim.val] = start.val | ||
end[dim.val] = start.val+length.val | ||
out = mb.slice_by_index(x=data, begin=begin, end=end) | ||
context.add(out, torch_name=node.name) | ||
|
||
@register_torch_op(torch_alias=["__and_", '__and__']) | ||
def logicaland(context, node): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @fukatani Awesome! Thanks, I've removed the op and used the logical_and version and added the torch alias there. A note on the alias
However
Maybe I should just patch this issue instead and leave only the alias something like
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Making this change as a separate PR sounds good. Could someone provide a toy network with a |
||
inputs = _get_inputs(context, node, expected=2) | ||
x, y = inputs | ||
x = mb.cast(x=x, dtype="bool") | ||
y = mb.cast(x=y, dtype="bool") | ||
context.add(mb.logical_and(x=x, y=y, name=node.name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this just be an inner method of the
split
method?