@@ -103,8 +103,28 @@ def converts_to(source: Dtype, target: Dtype) -> bool:
103103 source = without_const (source )
104104 if isinstance (source , List ):
105105 return isinstance (target , List ) and converts_to (source .inner , target .inner )
106- if isinstance (source , Enum ):
107- return target == source or target == String ()
106+ if isinstance (source , Enum | String ):
107+ return (
108+ target == source
109+ or target == String ()
110+ or (
111+ type (target ) is String
112+ and source .max_length is not None
113+ and target .max_length > source .max_length
114+ )
115+ )
116+ if isinstance (source , Decimal ):
117+ return (
118+ target == source
119+ or target in FLOAT_SUBTYPES
120+ or target == Float ()
121+ or target == Decimal ()
122+ or (
123+ isinstance (target , Decimal )
124+ and target .scale >= source .scale
125+ and (target .precision - target .scale >= source .precision - source .scale )
126+ )
127+ )
108128 return target in IMPLICIT_CONVS [source ]
109129
110130
@@ -117,7 +137,7 @@ def to_python(dtype: Dtype):
117137 return float
118138 elif isinstance (dtype , List ):
119139 return list
120- elif isinstance (dtype , Enum ):
140+ elif isinstance (dtype , Enum | String ):
121141 return str
122142
123143 return {
@@ -187,13 +207,23 @@ def lca_type(dtypes: list[Dtype]) -> Dtype:
187207
188208 return List (lca_type ([dtype .inner for dtype in dtypes ]))
189209
190- if any (isinstance (dtype , Enum ) for dtype in dtypes ):
210+ if any (isinstance (dtype , Enum | String ) for dtype in dtypes ):
191211 if all (dtype == dtypes [0 ] for dtype in dtypes ):
192212 return copy .copy (dtypes [0 ])
193213 if all (isinstance (dtype , Enum | String ) for dtype in dtypes ):
194214 return String ()
195215 raise DataTypeError (f"incompatible types `{ ', ' .join (str (d ) for d in dtypes )} `" )
196216
217+ if any (isinstance (dtype , Decimal ) for dtype in dtypes ):
218+ if all (dtype == dtypes [0 ] for dtype in dtypes ):
219+ return copy .copy (dtypes [0 ])
220+ if all (isinstance (dtype , Decimal ) for dtype in dtypes ):
221+ precision_diff = max (dtype .precision - dtype .scale for dtype in dtypes )
222+ scale = max (dtype .scale for dtype in dtypes )
223+ precision = precision_diff + scale
224+ return Decimal (precision , scale )
225+ raise DataTypeError (f"incompatible types `{ ', ' .join (str (d ) for d in dtypes )} `" )
226+
197227 if not (
198228 common_ancestors := functools .reduce (
199229 operator .and_ ,
@@ -253,8 +283,12 @@ def is_subtype(dtype: Dtype) -> bool:
253283def implicit_conversions (dtype : Dtype ) -> list [Dtype ]:
254284 if isinstance (dtype , List ):
255285 return [List (inner ) for inner in implicit_conversions (dtype .inner )]
256- if isinstance (dtype , Enum ):
257- return [dtype , String ()]
286+ if isinstance (dtype , Enum | String ):
287+ return [String ()] + ([dtype ] if dtype .max_length is not None else [])
288+ if isinstance (dtype , Decimal ):
289+ return (
290+ list (FLOAT_SUBTYPES ) + [Float ()] + ([dtype ] if dtype != Decimal () else [])
291+ )
258292 return list (IMPLICIT_CONVS [dtype ].keys ())
259293
260294
@@ -303,8 +337,14 @@ def conversion_cost(dtype: Dtype, target: Dtype) -> tuple[int, int]:
303337 dtype = without_const (dtype )
304338 if isinstance (dtype , List ):
305339 return conversion_cost (dtype .inner , target .inner )
306- if isinstance (dtype , Enum ):
307- return (0 , 0 ) if dtype == target else (0 , 1 )
340+ if isinstance (dtype , Enum | String | Decimal ):
341+ return (
342+ (0 , 0 )
343+ if dtype == target
344+ else (0 , 1 )
345+ if type (dtype ) is type (target )
346+ else (0 , 2 )
347+ )
308348 return IMPLICIT_CONVS [dtype ][target ]
309349
310350
0 commit comments