Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[dicp]Support some ops for stable-diffusion. #467

Merged
merged 7 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion dicp/dicp/dynamo_bridge/op_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def placeholder(self, target: 'Target', args: Tuple[Argument, ...], kwargs: Dict
for idx, dim in enumerate(fake_tensor.shape):
if isinstance(dim, torch.SymInt):
st = dim.node.str()
if not st in self.sym_in_args:
if st not in self.sym_in_args:
self.sym_in_args[st] = (proxy, idx)
return proxy

Expand Down
61 changes: 37 additions & 24 deletions dicp/dicp/vendor/TopsGraph/codegen/enflame.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,8 +496,7 @@ def gen_main_func(self):
main_body.writeline("")
for i in range(0, len(self.input_args)):
itensor = self.input_args[i].meta['val']
main_body.writeline('arg' + str(i) + ' = ' +
self.gen_random_tensor(itensor))
main_body.writeline('arg' + str(i) + ' = ' + self.gen_random_tensor(itensor))

args = []
for i in range(len(self.input_args)):
Expand Down Expand Up @@ -584,16 +583,18 @@ def Abs(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Abs({x});"

@staticmethod
def make_const_if_scalar(op_var, value, dtype=torch.float32, count=0):
def make_const(op_var, value, dtype=torch.float32, count=0):
assert isinstance(value, (numbers.Number, list, tuple, str))
src_code = ""
if isinstance(value, numbers.Number):
src_code = f"{cxx_type_set[dtype]} {op_var}_const_value{count} = static_cast<{cxx_type_set[dtype]}>({value});\n"
value = f"{op_var}_const{count}"
const_type = dtype if dtype != torch.float16 else torch.float32
src_code += f"builder::Op {value} = builder::Const(hlir_builder, static_cast<void *>(&{op_var}_const_value{count}), builder::Type({{1}}, {type_set[const_type]}));\n"
if dtype == torch.float16:
src_code += f"{value} = builder::Convert({value}, builder::Type({{1}}, {type_set[dtype]}));\n"
return src_code, value
if isinstance(value, str):
return src_code, value
elif isinstance(value, numbers.Number):
src_code += f"builder::Op {op_var}_const{count} = builder::Const<{cxx_type_set[dtype]}>(hlir_builder, static_cast<{cxx_type_set[dtype]}>({value}), builder::Type({{1}}, {type_set[dtype]}));\n"
elif isinstance(value, (list, tuple)):
src_code += f"std::vector<{cxx_type_set[dtype]}> {op_var}_const_value{count} = {{{', '.join(map(str, value))}}};\n"
src_code += f"builder::Op {op_var}_const{count} = builder::Const<{cxx_type_set[dtype]}>(hlir_builder, {op_var}_const_value{count}, builder::Type({{{len(value)}}}, {type_set[dtype]}));\n"

return src_code, f"{op_var}_const{count}"

@staticmethod
def make_type(op_var, dtype, shape=[1], count=0):
Expand All @@ -605,7 +606,7 @@ def make_type(op_var, dtype, shape=[1], count=0):
@staticmethod
# TODO mul + add scaled_y should handle in conversion
def Add(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Add({x}, {y});"
return src_code

Expand All @@ -617,21 +618,21 @@ def Convert(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Div(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Div({x}, {y});"
return src_code

@staticmethod
def Sub(op_var, shape, dtype, x, y, **kwargs_list):
src_code_x, x = EnflameOverrides.make_const_if_scalar(op_var, x, dtype)
src_code_y, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code_x, x = EnflameOverrides.make_const(op_var, x, dtype)
src_code_y, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code = src_code_x + src_code_y
src_code += f"builder::Op {op_var} = builder::Sub({x}, {y});"
return src_code

@staticmethod
def Mul(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Mul({x}, {y});"
return src_code

Expand Down Expand Up @@ -663,19 +664,19 @@ def Less(op_var, shape, dtype, x, y, **kwargs_list):

@staticmethod
def Equal(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::Equal({x}, {y});"
return src_code

@staticmethod
def LessEqual(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y)
src_code, y = EnflameOverrides.make_const(op_var, y)
src_code += f"builder::Op {op_var} = builder::LessEqual({x}, {y});"
return src_code

@staticmethod
def NotEqual(op_var, shape, dtype, data_type, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, data_type)
src_code += f"builder::Op {op_var} = builder::NotEqual({x}, {y});"
return src_code
Expand All @@ -690,7 +691,7 @@ def Neg(op_var, shape, dtype, x, **kwargs_list):

@staticmethod
def Pow(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(op_var, y, dtype)
src_code, y = EnflameOverrides.make_const(op_var, y, dtype)
src_code += f"builder::Op {op_var} = builder::Pow({x}, {y});"
return src_code

Expand All @@ -702,10 +703,22 @@ def Exp(op_var, shape, dtype, x, **kwargs_list):
def Sqrt(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sqrt({x});"

@staticmethod
def Sin(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sin({x});"

@staticmethod
def Cos(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Cos({x});"

@staticmethod
def Relu(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Relu({x});"

@staticmethod
def Erf(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Erf({x});"

@staticmethod
def Sigmoid(op_var, shape, dtype, x, **kwargs_list):
return f"builder::Op {op_var} = builder::Sigmoid({x});"
Expand Down Expand Up @@ -802,14 +815,14 @@ def Expand(op_var, shape, dtype, x, new_shape, **kwargs_list):

@staticmethod
def Squeeze(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, torch.int64)
src_code += f"builder::Op {op_var} = builder::Squeeze({x}, {y});"
return src_code

@staticmethod
def Unsqueeze(op_var, shape, dtype, x, y, **kwargs_list):
src_code, y = EnflameOverrides.make_const_if_scalar(
src_code, y = EnflameOverrides.make_const(
op_var, y, torch.int64)
src_code += f"builder::Op {op_var} = builder::Unsqueeze({x}, {y});"
return src_code
Expand Down Expand Up @@ -847,9 +860,9 @@ def SliceInDim(op_var, shape, dtype, x, dim, start, end, step, **kwargs_list):

@staticmethod
def SliceScatter(op_var, shape, dtype, x, y, dim, start, end, step, **kwargs_list):
src_code_index, op_start_index = EnflameOverrides.make_const_if_scalar(
src_code_index, op_start_index = EnflameOverrides.make_const(
op_var, 0, torch.int64, 0)
src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const_if_scalar(
src_code_index_dim, op_start_index_dim = EnflameOverrides.make_const(
op_var, start, torch.int64, 1)
src_code = src_code_index + src_code_index_dim
src_code += f"builder::Op {op_var} = builder::DynamicUpdateSlice({x}, {y}, {{{', '.join([op_start_index_dim if i == dim else op_start_index for i in range(len(shape))])}}});"
Expand Down
20 changes: 20 additions & 0 deletions dicp/dicp/vendor/TopsGraph/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,10 +148,30 @@ def Rsqrt(self, *args, **kwargs):
def Exp(self, *args, **kwargs):
return self.get_proxy(tops_op.Exp, args, kwargs)

@register_conversion(aten.sin)
def Sin(self, *args, **kwargs):
return self.get_proxy(tops_op.Sin, args, kwargs)

@register_conversion(aten.cos)
def Cos(self, *args, **kwargs):
return self.get_proxy(tops_op.Cos, args, kwargs)

@register_conversion(aten.relu)
def Relu(self, *args, **kwargs):
return self.get_proxy(tops_op.Relu, args, kwargs)

@register_conversion(aten.erf)
def Erf(self, *args, **kwargs):
return self.get_proxy(tops_op.Erf, args, kwargs)

@register_conversion(aten.split.Tensor)
def Split(self, a, size, dim=0, **kwargs):
in_shape = a.node.meta["val"].shape
dim = dim % len(in_shape)
sections = (in_shape[dim] + size - 1) // size
splits = (self.get_proxy(tops_op.SliceInDim, (a, dim, i * size, min((i + 1) * size, in_shape[dim]), 1)) for i in range(sections))
return self.get_proxy(tops_op.MakeTuple, tuple(splits))

@register_conversion(aten.sum)
def ReduceSum(self, a, *args, **kwargs):
in_dtype = a.node.meta["val"].dtype
Expand Down
29 changes: 23 additions & 6 deletions dicp/dicp/vendor/TopsGraph/tops_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,13 +142,34 @@ def __init__(self, a):
self.torch_op = aten.exp


class Sin(Operator):
def __init__(self, a):
super().__init__("Sin")
self.a = a
self.torch_op = aten.sin


class Cos(Operator):
def __init__(self, a):
super().__init__("Cos")
self.a = a
self.torch_op = aten.cos


class Relu(Operator):
def __init__(self, a):
super().__init__("Relu")
self.a = a
self.torch_op = aten.relu


class Erf(Operator):
def __init__(self, a):
super().__init__("Erf")
self.a = a
self.torch_op = aten.erf


class ReduceSum(Operator):
def __init__(self, *args, **kwargs):
super().__init__("ReduceSum")
Expand Down Expand Up @@ -742,12 +763,8 @@ def __init__(self, a, b):
super().__init__("MakeTuple")
self.torch_op = torch.empty_like

def __call__(self, a, b):
if hasattr(a, 'meta'):
a = a.meta['val']
if hasattr(b, 'meta'):
b = b.meta['val']
return a, b
def __call__(self, *args):
return (arg.meta["val"] if hasattr(arg, "meta") else arg for arg in args)


class XlaGather(Operator):
Expand Down
40 changes: 40 additions & 0 deletions dicp/test/op/test_cos.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.cos.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestCos():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_cos(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True)
40 changes: 40 additions & 0 deletions dicp/test/op/test_erf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.erf.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestErf():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_erf(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), atol=1e-04, equal_nan=True)
40 changes: 40 additions & 0 deletions dicp/test/op/test_sin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import pytest
from ..common.utils import (
torch,
dynamo,
parse_args,
compile_model,
get_device,
Size,
update_dynamo_config,
)


class OpModule(torch.nn.Module):
def forward(self, a):
res_default = torch.ops.aten.sin.default(a)
return res_default


model = OpModule()
args = parse_args()
compiled_model = compile_model(model, args.backend, args.dynamic)


class TestSin():
@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("sizes", [Size((5,), (5, 3)), Size((3, 5), (5, 3)), Size((2, 3, 4), (2, 4))])
@pytest.mark.parametrize("compiled_model", compiled_model)
def test_torch_sin(self, sizes, dtype, compiled_model):
device = get_device()
size = sizes.dynamic if compiled_model.dynamic else sizes.static
input1 = torch.randn(size, dtype=dtype)

dicp_input1 = input1.to(device)

output = model(input1)
dynamo.reset()
update_dynamo_config(compiled_model.dynamic)
dicp_output = compiled_model.model(dicp_input1)

assert torch.allclose(output, dicp_output.cpu(), equal_nan=True)
Loading