11import copy
2- from typing import Any , Optional
32
43import mlc
54import pytest
5+ from mlc .testing .dataclasses import PyClassForTest
66
77
8- @mlc .py_class
9- class PyClassForTest (mlc .PyClass ):
10- bool_ : bool
11- i64 : int
12- f64 : float
13- raw_ptr : mlc .Ptr
14- dtype : mlc .DataType
15- device : mlc .Device
16- any : Any
17- func : mlc .Func
18- ulist : list [Any ]
19- udict : dict
20- str_ : str
21- ###
22- list_any : list [Any ]
23- list_list_int : list [list [int ]]
24- dict_any_any : dict [Any , Any ]
25- dict_str_any : dict [str , Any ]
26- dict_any_str : dict [Any , str ]
27- dict_str_list_int : dict [str , list [int ]]
28- ###
29- opt_bool : Optional [bool ]
30- opt_i64 : Optional [int ]
31- opt_f64 : Optional [float ]
32- opt_raw_ptr : Optional [mlc .Ptr ]
33- opt_dtype : Optional [mlc .DataType ]
34- opt_device : Optional [mlc .Device ]
35- opt_func : Optional [mlc .Func ]
36- opt_ulist : Optional [list ]
37- opt_udict : Optional [dict [Any , Any ]]
38- opt_str : Optional [str ]
39- ###
40- opt_list_any : Optional [list [Any ]]
41- opt_list_list_int : Optional [list [list [int ]]]
42- opt_dict_any_any : Optional [dict ]
43- opt_dict_str_any : Optional [dict [str , Any ]]
44- opt_dict_any_str : Optional [dict [Any , str ]]
45- opt_dict_str_list_int : Optional [dict [str , list [int ]]]
8+ @mlc .py_class (init = False )
9+ class CustomInit (mlc .PyClass ):
10+ a : int
11+ b : str
4612
47- def i64_plus_one (self ) -> int :
48- return self .i64 + 1
13+ def __init__ (self , * , b : str , a : int ) -> None :
14+ self .a = a
15+ self .b = b
16+
17+
18+ @pytest .fixture
19+ def test_obj () -> CustomInit :
20+ return CustomInit (a = 1 , b = "hello" )
4921
5022
5123@pytest .fixture
5224def mlc_class_for_test () -> PyClassForTest :
5325 return PyClassForTest (
5426 bool_ = True ,
27+ i8 = 8 ,
28+ i16 = 16 ,
29+ i32 = 32 ,
5530 i64 = 64 ,
31+ f32 = 2 ,
5632 f64 = 2.5 ,
5733 raw_ptr = mlc .Ptr (0xDEADBEEF ),
5834 dtype = "float8" ,
@@ -62,6 +38,7 @@ def mlc_class_for_test() -> PyClassForTest:
6238 ulist = [1 , 2.0 , "three" , lambda : 4 ],
6339 udict = {"1" : 1 , "2" : 2.0 , "3" : "three" , "4" : lambda : 4 },
6440 str_ = "world" ,
41+ str_readonly = "world" ,
6542 ###
6643 list_any = [1 , 2.0 , "three" , lambda : 4 ],
6744 list_list_int = [[1 , 2 , 3 ], [4 , 5 , 6 ]],
@@ -95,7 +72,11 @@ def test_copy_shallow(mlc_class_for_test: PyClassForTest) -> None:
9572 dst = copy .copy (src )
9673 assert src != dst
9774 assert src .bool_ == dst .bool_
75+ assert src .i8 == dst .i8
76+ assert src .i16 == dst .i16
77+ assert src .i32 == dst .i32
9878 assert src .i64 == dst .i64
79+ assert src .f32 == dst .f32
9980 assert src .f64 == dst .f64
10081 assert src .raw_ptr .value == dst .raw_ptr .value
10182 assert src .dtype == dst .dtype
@@ -133,7 +114,12 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None:
133114 src = mlc_class_for_test
134115 dst = copy .deepcopy (src )
135116 assert src != dst
117+ assert src .bool_ == dst .bool_
118+ assert src .i8 == dst .i8
119+ assert src .i16 == dst .i16
120+ assert src .i32 == dst .i32
136121 assert src .i64 == dst .i64
122+ assert src .f32 == dst .f32
137123 assert src .f64 == dst .f64
138124 assert src .raw_ptr .value == dst .raw_ptr .value
139125 assert src .dtype == dst .dtype
@@ -268,3 +254,19 @@ def test_copy_deep(mlc_class_for_test: PyClassForTest) -> None:
268254 and tuple (src .opt_dict_str_list_int ["1" ]) == tuple (dst .opt_dict_str_list_int ["1" ]) # type: ignore[index]
269255 and tuple (src .opt_dict_str_list_int ["2" ]) == tuple (dst .opt_dict_str_list_int ["2" ]) # type: ignore[index]
270256 )
257+
258+
259+ def test_copy_shallow_dataclass (test_obj : CustomInit ) -> None :
260+ src = test_obj
261+ dst = copy .copy (src )
262+ assert src != dst
263+ assert src .a == dst .a
264+ assert src .b == dst .b
265+
266+
267+ def test_copy_deep_dataclass (test_obj : CustomInit ) -> None :
268+ src = test_obj
269+ dst = copy .deepcopy (src )
270+ assert src != dst
271+ assert src .a == dst .a
272+ assert src .b == dst .b
0 commit comments