diff --git a/src/onnx_ir/tensor_adapters.py b/src/onnx_ir/tensor_adapters.py index 62ec402..9947309 100644 --- a/src/onnx_ir/tensor_adapters.py +++ b/src/onnx_ir/tensor_adapters.py @@ -72,6 +72,8 @@ def from_torch_dtype(dtype: torch.dtype) -> ir.DataType: torch.int32: ir.DataType.INT32, torch.int64: ir.DataType.INT64, torch.int8: ir.DataType.INT8, + torch.int4: ir.DataType.INT4, + torch.uint4: ir.DataType.UINT4, torch.uint8: ir.DataType.UINT8, torch.uint16: ir.DataType.UINT16, torch.uint32: ir.DataType.UINT32, @@ -108,6 +110,8 @@ def to_torch_dtype(dtype: ir.DataType) -> torch.dtype: ir.DataType.INT32: torch.int32, ir.DataType.INT64: torch.int64, ir.DataType.INT8: torch.int8, + ir.DataType.INT4: torch.int4, + ir.DataType.UINT4: torch.uint4, ir.DataType.UINT8: torch.uint8, ir.DataType.UINT16: torch.uint16, ir.DataType.UINT32: torch.uint32,