Skip to content

Commit 8d20121

Browse files
authored
fix(dataclass): Create proper __init__ method for PyClass (#20)
1 parent a8017ce commit 8d20121

File tree

2 files changed

+57
-43
lines changed

2 files changed

+57
-43
lines changed

python/mlc/dataclasses/py_class.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,10 @@ def __new__(cls, *args: typing.Any, **kwargs: typing.Any) -> type[ClsType]:
141141
)
142142
setattr(type_cls, "_mlc_structure", struct)
143143

144-
# Step 5. Attach methods
144+
# Step 5. Add `__init__` method
145+
type_add_method(type_index, "__init__", _method_new(type_cls), 1) # static
146+
# Step 6. Attach methods
145147
fn: Callable[..., typing.Any]
146-
type_add_method(type_index, "__init__", type_cls, 1) # static
147148
if init:
148149
fn = method_init(super_type_cls, d_fields)
149150
attach_method(super_type_cls, type_cls, "__init__", fn, check_exists=True)
@@ -185,3 +186,14 @@ def method(self: ClsType) -> str:
185186
return f"{type_key}({', '.join(fields)})"
186187

187188
return method
189+
190+
191+
def _method_new(
192+
type_cls: type[ClsType],
193+
) -> Callable[..., ClsType]:
194+
def method(*args: typing.Any) -> ClsType:
195+
obj = type_cls.__new__(type_cls)
196+
obj._mlc_init(*args) # type: ignore[attr-defined]
197+
return obj
198+
199+
return method

tests/python/test_dataclasses_copy.py

Lines changed: 43 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,34 @@
11
import copy
2-
from typing import Any, Optional
32

43
import mlc
54
import pytest
5+
from mlc.testing.dataclasses import PyClassForTest
66

77

8-
@mlc.py_class
9-
class PyClassForTest(mlc.PyClass):
10-
bool_: bool
11-
i64: int
12-
f64: float
13-
raw_ptr: mlc.Ptr
14-
dtype: mlc.DataType
15-
device: mlc.Device
16-
any: Any
17-
func: mlc.Func
18-
ulist: list[Any]
19-
udict: dict
20-
str_: str
21-
###
22-
list_any: list[Any]
23-
list_list_int: list[list[int]]
24-
dict_any_any: dict[Any, Any]
25-
dict_str_any: dict[str, Any]
26-
dict_any_str: dict[Any, str]
27-
dict_str_list_int: dict[str, list[int]]
28-
###
29-
opt_bool: Optional[bool]
30-
opt_i64: Optional[int]
31-
opt_f64: Optional[float]
32-
opt_raw_ptr: Optional[mlc.Ptr]
33-
opt_dtype: Optional[mlc.DataType]
34-
opt_device: Optional[mlc.Device]
35-
opt_func: Optional[mlc.Func]
36-
opt_ulist: Optional[list]
37-
opt_udict: Optional[dict[Any, Any]]
38-
opt_str: Optional[str]
39-
###
40-
opt_list_any: Optional[list[Any]]
41-
opt_list_list_int: Optional[list[list[int]]]
42-
opt_dict_any_any: Optional[dict]
43-
opt_dict_str_any: Optional[dict[str, Any]]
44-
opt_dict_any_str: Optional[dict[Any, str]]
45-
opt_dict_str_list_int: Optional[dict[str, list[int]]]
8+
@mlc.py_class(init=False)
9+
class CustomInit(mlc.PyClass):
10+
a: int
11+
b: str
4612

47-
def i64_plus_one(self) -> int:
48-
return self.i64 + 1
13+
def __init__(self, *, b: str, a: int) -> None:
14+
self.a = a
15+
self.b = b
16+
17+
18+
@pytest.fixture
19+
def test_obj() -> CustomInit:
20+
return CustomInit(a=1, b="hello")
4921

5022

5123
@pytest.fixture
5224
def mlc_class_for_test() -> PyClassForTest:
5325
return PyClassForTest(
5426
bool_=True,
27+
i8=8,
28+
i16=16,
29+
i32=32,
5530
i64=64,
31+
f32=2,
5632
f64=2.5,
5733
raw_ptr=mlc.Ptr(0xDEADBEEF),
5834
dtype="float8",
@@ -62,6 +38,7 @@ def mlc_class_for_test() -> PyClassForTest:
6238
ulist=[1, 2.0, "three", lambda: 4],
6339
udict={"1": 1, "2": 2.0, "3": "three", "4": lambda: 4},
6440
str_="world",
41+
str_readonly="world",
6542
###
6643
list_any=[1, 2.0, "three", lambda: 4],
6744
list_list_int=[[1, 2, 3], [4, 5, 6]],
@@ -95,7 +72,11 @@ def test_copy_shallow(mlc_class_for_test: PyClassForTest) -> None:
9572
dst = copy.copy(src)
9673
assert src != dst
9774
assert src.bool_ == dst.bool_
75+
assert src.i8 == dst.i8
76+
assert src.i16 == dst.i16
77+
assert src.i32 == dst.i32
9878
assert src.i64 == dst.i64
79+
assert src.f32 == dst.f32
9980
assert src.f64 == dst.f64
10081
assert src.raw_ptr.value == dst.raw_ptr.value
10182
assert src.dtype == dst.dtype
@@ -133,7 +114,12 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None:
133114
src = mlc_class_for_test
134115
dst = copy.deepcopy(src)
135116
assert src != dst
117+
assert src.bool_ == dst.bool_
118+
assert src.i8 == dst.i8
119+
assert src.i16 == dst.i16
120+
assert src.i32 == dst.i32
136121
assert src.i64 == dst.i64
122+
assert src.f32 == dst.f32
137123
assert src.f64 == dst.f64
138124
assert src.raw_ptr.value == dst.raw_ptr.value
139125
assert src.dtype == dst.dtype
@@ -268,3 +254,19 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None:
268254
and tuple(src.opt_dict_str_list_int["1"]) == tuple(dst.opt_dict_str_list_int["1"]) # type: ignore[index]
269255
and tuple(src.opt_dict_str_list_int["2"]) == tuple(dst.opt_dict_str_list_int["2"]) # type: ignore[index]
270256
)
257+
258+
259+
def test_copy_shallow_dataclass(test_obj: CustomInit) -> None:
260+
src = test_obj
261+
dst = copy.copy(src)
262+
assert src != dst
263+
assert src.a == dst.a
264+
assert src.b == dst.b
265+
266+
267+
def test_copy_deep_dataclass(test_obj: CustomInit) -> None:
268+
src = test_obj
269+
dst = copy.deepcopy(src)
270+
assert src != dst
271+
assert src.a == dst.a
272+
assert src.b == dst.b

0 commit comments

Comments
 (0)