diff --git a/mlir/include/mlir/Dialect/Quant/IR/Quant.h b/mlir/include/mlir/Dialect/Quant/IR/Quant.h index 11a969a3ee51..fdfac0990ed0 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/Quant.h +++ b/mlir/include/mlir/Dialect/Quant/IR/Quant.h @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// @@ -25,6 +27,7 @@ namespace mlir { namespace quant { class QuantizedType; +class BlockFloatQuantizedType; class UniformQuantizedType; class UniformQuantizedPerAxisType; diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td index bd9cdf823822..d2931620c69e 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// // @@ -81,6 +83,14 @@ def UniformQuantizedPerAxisType: DialectType<(type }]; } +def BlockFloatQuantizedType: DialectType<(type + WithGetter<"static_cast($_attrType.getBlockMode())", VarInt>:$blockMode, + VarInt:$axis +)> { + let cBuilder = "get<$_resultType>(context, " + " static_cast(blockMode), axis)"; +} + /// This enum contains marker codes used to indicate which attribute is /// currently being decoded, and how it should be decoded. The order of these /// codes should generally be unchanged, as any changes will inevitably break @@ -93,7 +103,8 @@ def QuantDialectTypes : DialectTypes<"Quant"> { AnyQuantizedTypeWithExpressedType, CalibratedQuantizedType, UniformQuantizedType, - UniformQuantizedPerAxisType + UniformQuantizedPerAxisType, + BlockFloatQuantizedType ]; } diff --git a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h index 43440ba623b9..448f518f4a2e 100644 --- a/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h +++ b/mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// @@ -23,6 +25,7 @@ namespace detail { struct QuantizedTypeStorage; struct AnyQuantizedTypeStorage; +struct BlockFloatQuantizedTypeStorage; struct UniformQuantizedTypeStorage; struct UniformQuantizedPerAxisTypeStorage; struct CalibratedQuantizedTypeStorage; @@ -224,6 +227,94 @@ class AnyQuantizedType int64_t storageTypeMax); }; +/// Represents block floating point quantization where multiple elements share +/// data along a particular axis (e.g. BFP16). The concrete block format +/// determines the implied storage characteristics and is not exposed in the IR. +/// This class is experimental and may be subject to change. +/// Design decisions: +/// - The base class requires an integral storage type. For +/// block-quantized/packed types, the required storage with depends on the +/// number of elements. For example, a single BFP16 element requires 16 bits to +/// be represented, but a block of 8 BFP16 elements can be packed into 9 bits +/// per element on average (72 bits total). The storage type for +/// an element from BlockFloatQuantizedType is the "packed" type +/// divided by the number of packed elements, so for BFP16 i9. +/// -- As accessing properties like min/max storage values and integral width +/// depend on the block size, these methods are overridden to return errors. +/// - The expressed type is not stored yet, this may change if there is a use +/// for it. +/// - The axis is signed to match MLIR convention, but enforced to be +/// non-negative. +class BlockFloatQuantizedType + : public Type::TypeBase { +public: + using Base::Base; + using Base::getChecked; + + static constexpr StringLiteral name = "quant.block_float"; + + // MX6 refers to the MicoExponent format, not to the OCP MicroScaling format + // with the same name. + enum class BlockMode : uint32_t { BFP16 = 0, MX6 = 1, MAX_VALUE = MX6 }; + + static std::optional parseBlockMode(StringRef name); + static StringRef getBlockModeName(BlockMode blockMode); + + static BlockFloatQuantizedType get(MLIRContext *ctx, BlockMode blockMode, + int32_t axis); + static BlockFloatQuantizedType + getChecked(function_ref emitError, MLIRContext *ctx, + BlockMode blockMode, int32_t axis); + + static LogicalResult + verifyInvariants(function_ref emitError, + uint32_t blockModeRaw, int32_t axis, unsigned flags, + Type storageType, Type expressedType, int64_t storageTypeMin, + int64_t storageTypeMax); + + Type getStorageType() const { + assert(false && + "BlockFloatQuantizedType does not have a direct storage type"); + return QuantizedType::getStorageType(); + } + + int64_t getStorageTypeMin() const { + assert(false && + "BlockFloatQuantizedType does not have a direct storage type"); + return QuantizedType::getStorageTypeMin(); + } + + int64_t getStorageTypeMax() const { + assert(false && + "BlockFloatQuantizedType does not have a direct storage type"); + return QuantizedType::getStorageTypeMax(); + } + + bool hasStorageTypeBounds() const { + assert(false && + "BlockFloatQuantizedType does not have a direct storage type"); + return QuantizedType::hasStorageTypeBounds(); + } + + unsigned getStorageTypeIntegralWidth() const { + assert(false && + "BlockFloatQuantizedType does not have a direct storage type"); + return QuantizedType::getStorageTypeIntegralWidth(); + } + + BlockMode getBlockMode() const; + int32_t getAxis() const; + + /// Number of elements in a block + unsigned getBlockSize() const; + /// Average number of bits used to represent each element in the block + unsigned getAverageBitsPerElement() const; + /// Returns the size in bits required to represent a single, not + /// blocked/packed element. + unsigned getSingleElementStorageSize() const; +}; + /// Represents a family of uniform, quantized types. /// /// Each instance of this type expresses a mapping between real values (most diff --git a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp index c584903f3a15..2dc7203370e4 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantOps.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantOps.cpp @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// @@ -37,30 +39,62 @@ namespace { LogicalResult verifyPerAxisQuantization(Operation *op, QuantizedType quantizedType, Type containerType) { - auto quantizedPerAxisType = dyn_cast(quantizedType); + auto quantizedPerAxisType = + dyn_cast(quantizedType); if (!quantizedPerAxisType) return success(); - auto tensorType = dyn_cast(containerType); - if (!tensorType) + auto shapedType = dyn_cast(containerType); + if (!shapedType) return op->emitError("scalar types may not use per-axis quantization"); - if (!tensorType.hasRank()) + if (!shapedType.hasRank()) return success(); int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension(); - if (quantizedDimension >= tensorType.getRank()) + if (quantizedDimension >= shapedType.getRank()) return op->emitError("quantized dimension must be less than tensor rank"); - int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension); + int64_t quantizedDimensionSize = shapedType.getDimSize(quantizedDimension); if (quantizedDimensionSize != ShapedType::kDynamic && - quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size()) + quantizedDimensionSize != + (int64_t)quantizedPerAxisType.getScales().size()) return op->emitError( "quantized dimension size does not match number of scales"); return success(); } +// Verify the integrity of block float quantization information, if present. +// +// - quantizedType +// Any quantized type. Any quantized type with no block float quantization is +// ignored. +// +// - containerType +// Original input or result type of the operation using the provided quantized +// type. Used to ensure that the quantized type appears within a tensor and +// that the tensor is compatible with block float quantization information. +// +LogicalResult verifyBlockFloatQuantization(Operation *op, + QuantizedType quantizedType, + Type containerType) { + auto blockModeType = dyn_cast(quantizedType); + if (!blockModeType) + return success(); + + auto shapedType = dyn_cast(containerType); + if (!shapedType) + return op->emitError("scalar types may not use block float quantization"); + if (!shapedType.hasRank()) + return success(); + // We could also check that the tensor is a multiple of the block size, but + // that requires that all padding is visible in MLIR + if (blockModeType.getAxis() >= shapedType.getRank()) + return op->emitError("block axis must be less than tensor rank"); + return success(); +} + // Common verification logic for 'quant.dcast' and 'quant.qcast' ops. // // - quantizedType @@ -76,12 +110,18 @@ LogicalResult verifyPerAxisQuantization(Operation *op, // LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, FloatType floatType, Type containerType) { - if (quantizedType.getExpressedType() != floatType) + if (!isa(quantizedType) && + quantizedType.getExpressedType() != floatType) return op->emitError( "expressed type in quantized type expected to match float type"); - // Veriy integrity of per-axis quantization information, if present. - return verifyPerAxisQuantization(op, quantizedType, containerType); + if (failed(verifyPerAxisQuantization(op, quantizedType, containerType))) + return failure(); + + if (failed(verifyBlockFloatQuantization(op, quantizedType, containerType))) + return failure(); + + return success(); } } // namespace @@ -92,8 +132,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType, //===----------------------------------------------------------------------===// void QuantDialect::initialize() { - addTypes(); + addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc" @@ -167,6 +207,9 @@ QuantizedType QuantizeCastOp::getQuantizedType() { LogicalResult StorageCastOp::verify() { auto quantizedType = getQuantizedType(); + if (isa(quantizedType)) + return getOperation()->emitError( + "storage cast not supported for block float quantized types"); auto integerType = getIntegerType(); if (quantizedType.getStorageType() != integerType) return emitError( diff --git a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp index 7c0d36964865..ab8e00e14b94 100644 --- a/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp +++ b/mlir/lib/Dialect/Quant/IR/QuantTypes.cpp @@ -3,12 +3,14 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "TypeDetail.h" #include "mlir/Dialect/Quant/IR/Quant.h" -#include "mlir/Dialect/Quant/IR/QuantTypes.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/MLIRContext.h" @@ -34,7 +36,31 @@ double getMaxScale(Type expressedType) { return APFloat::getLargest(floatType.getFloatSemantics()).convertToDouble(); } -} // namespace +struct BlockFloatQuantizedTypeConfig { + StringRef mode; + unsigned singleElementBitWidth; + unsigned averageBitsPerElement; + unsigned blockSize; +}; + +const BlockFloatQuantizedTypeConfig & +getBlockFloatQuantizedTypeConfig(BlockFloatQuantizedType::BlockMode blockMode) { + switch (blockMode) { + case BlockFloatQuantizedType::BlockMode::BFP16: { + static constexpr BlockFloatQuantizedTypeConfig config = { + StringLiteral("BFP16"), 16, 9, 8}; + return config; + } + case BlockFloatQuantizedType::BlockMode::MX6: { + static constexpr BlockFloatQuantizedTypeConfig config = { + StringLiteral("MX6"), 13, 6, 16}; + return config; + } + } + llvm_unreachable("unknown block quantized type"); +} + +} // namespace unsigned QuantizedType::getFlags() const { return static_cast(impl)->flags; @@ -335,6 +361,108 @@ int64_t UniformQuantizedType::getZeroPoint() const { return getImpl()->zeroPoint; } +struct BlockFloatQuantizedParams { + unsigned flags; + Type storageType; + int64_t storageTypeMin; + int64_t storageTypeMax; +}; + +static BlockFloatQuantizedParams +getBlockFloatQuantizedParams(MLIRContext *ctx, + BlockFloatQuantizedType::BlockMode blockMode) { + const BlockFloatQuantizedTypeConfig &config = + getBlockFloatQuantizedTypeConfig(blockMode); + const unsigned flags = 0; + const bool isSigned = + false; // this does not really make sense for block + // types, just fix to unsigned to make the base class happy + Type storageType = IntegerType::get(ctx, config.averageBitsPerElement); + int64_t storageTypeMin = QuantizedType::getDefaultMinimumForInteger( + isSigned, config.averageBitsPerElement); + int64_t storageTypeMax = QuantizedType::getDefaultMaximumForInteger( + isSigned, config.averageBitsPerElement); + return {flags, storageType, storageTypeMin, storageTypeMax}; +} + +BlockFloatQuantizedType BlockFloatQuantizedType::get(MLIRContext *ctx, + BlockMode blockMode, + int32_t axis) { + const BlockFloatQuantizedParams params = + getBlockFloatQuantizedParams(ctx, blockMode); + return Base::get(ctx, static_cast(blockMode), axis, params.flags, + params.storageType, /*expressedType*/ Type(), + params.storageTypeMin, params.storageTypeMax); +} + +BlockFloatQuantizedType BlockFloatQuantizedType::getChecked( + function_ref emitError, MLIRContext *ctx, + BlockMode blockMode, int32_t axis) { + const BlockFloatQuantizedParams params = + getBlockFloatQuantizedParams(ctx, blockMode); + return Base::getChecked(emitError, ctx, static_cast(blockMode), + axis, params.flags, params.storageType, + /*expressedType*/ Type(), params.storageTypeMin, + params.storageTypeMax); +} + +LogicalResult BlockFloatQuantizedType::verifyInvariants( + function_ref emitError, uint32_t blockModeRaw, + int32_t axis, unsigned flags, Type storageType, Type expressedType, + int64_t storageTypeMin, int64_t storageTypeMax) { + // storage type, expressed type, storageTypeMin, storageTypeMax and flags can + // be derived from blockMode. But inside this function we do not have access + // to ctx to construct those values, so we need to pass them in to be able + // to forward them to the base class verification. + // MLIR requires the TypeStorage KeyTy and verifyInvariants to share the same + // signature, forcing us to also pass these derivable args to the TypeStorage. + if (failed(QuantizedType::verifyInvariants(emitError, flags, storageType, + expressedType, storageTypeMin, + storageTypeMax))) + return failure(); + + if (blockModeRaw > + static_cast(BlockFloatQuantizedType::BlockMode::MAX_VALUE)) + return emitError() << "invalid block mode: " << blockModeRaw; + + if (axis < 0) + return emitError() << "axis must be non-negative"; + + return success(); +} + +int32_t BlockFloatQuantizedType::getAxis() const { return getImpl()->axis; } + +BlockFloatQuantizedType::BlockMode +BlockFloatQuantizedType::getBlockMode() const { + return static_cast(getImpl()->blockType); +} + +unsigned BlockFloatQuantizedType::getAverageBitsPerElement() const { + return getBlockFloatQuantizedTypeConfig(getBlockMode()).averageBitsPerElement; +} + +unsigned BlockFloatQuantizedType::getSingleElementStorageSize() const { + return getBlockFloatQuantizedTypeConfig(getBlockMode()).singleElementBitWidth; +} + +unsigned BlockFloatQuantizedType::getBlockSize() const { + return getBlockFloatQuantizedTypeConfig(getBlockMode()).blockSize; +} + +std::optional +BlockFloatQuantizedType::parseBlockMode(StringRef name) { + if (name == "BFP16") + return BlockMode::BFP16; + if (name == "MX6") + return BlockMode::MX6; + return std::nullopt; +} + +StringRef BlockFloatQuantizedType::getBlockModeName(BlockMode blockMode) { + return getBlockFloatQuantizedTypeConfig(blockMode).mode; +} + UniformQuantizedPerAxisType UniformQuantizedPerAxisType::get( unsigned flags, Type storageType, Type expressedType, ArrayRef scales, ArrayRef zeroPoints, diff --git a/mlir/lib/Dialect/Quant/IR/TypeDetail.h b/mlir/lib/Dialect/Quant/IR/TypeDetail.h index ef098811927c..a1a16c9b60b5 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeDetail.h +++ b/mlir/lib/Dialect/Quant/IR/TypeDetail.h @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// @@ -92,6 +94,59 @@ struct AnyQuantizedTypeStorage : public QuantizedTypeStorage { static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } }; +struct BlockFloatQuantizedTypeStorage : public QuantizedTypeStorage { + struct KeyTy { + KeyTy(uint32_t blockType, int32_t axis, unsigned flags, Type storageType, + Type expressedType, int64_t storageTypeMin, int64_t storageTypeMax) + : blockType(blockType), axis(axis), flags(flags), + storageType(storageType), expressedType(expressedType), + storageTypeMin(storageTypeMin), storageTypeMax(storageTypeMax) {} + uint32_t blockType; + int32_t axis; + + // "inherited" members from QuantizedTypeStorage. These are derivable from + // the blockType and do not contribute to the hash/comparison + unsigned flags; + Type storageType; + Type expressedType; + int64_t storageTypeMin; + int64_t storageTypeMax; + + template + static bool genericIsEqual(const T &lhs, const U &rhs) { + return lhs.axis == rhs.axis && lhs.blockType == rhs.blockType; + } + + bool operator==(const KeyTy &other) const { + return genericIsEqual(*this, other); + } + + unsigned getHashValue() const { + return llvm::hash_combine(axis, blockType); + } + }; + + BlockFloatQuantizedTypeStorage(const KeyTy &key) + : QuantizedTypeStorage(key.flags, key.storageType, key.expressedType, + key.storageTypeMin, key.storageTypeMax), + blockType(key.blockType), axis(key.axis) {} + + bool operator==(const KeyTy &key) const { + return KeyTy::genericIsEqual(*this, key); + } + + static BlockFloatQuantizedTypeStorage * + construct(TypeStorageAllocator &allocator, const KeyTy &key) { + return new (allocator.allocate()) + BlockFloatQuantizedTypeStorage(key); + } + + static unsigned hashKey(const KeyTy &key) { return key.getHashValue(); } + + uint32_t blockType; + int32_t axis; +}; + struct UniformQuantizedTypeStorage : public QuantizedTypeStorage { struct KeyTy { KeyTy(unsigned flags, Type storageType, Type expressedType, double scale, diff --git a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp index 851763d8942e..f3aff653dcee 100644 --- a/mlir/lib/Dialect/Quant/IR/TypeParser.cpp +++ b/mlir/lib/Dialect/Quant/IR/TypeParser.cpp @@ -3,6 +3,8 @@ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // //===----------------------------------------------------------------------===// @@ -316,6 +318,38 @@ static Type parseCalibratedType(DialectAsmParser &parser) { return parser.getChecked(expressedType, min, max); } +static Type parseBlockFloatQuantizedType(DialectAsmParser &parser) { + if (parser.parseLess()) + return nullptr; + + if (parser.parseKeyword("mode") || parser.parseEqual()) + return nullptr; + + StringRef name; + if (failed(parser.parseKeyword(&name))) + return nullptr; + + const auto blockMode = BlockFloatQuantizedType::parseBlockMode(name); + if (!blockMode) { + parser.emitError(parser.getNameLoc()) + << "unknown block quantized mode " << name; + return nullptr; + } + + if (parser.parseComma() || parser.parseKeyword("axis") || parser.parseEqual()) + return nullptr; + + int32_t axis; + if (parser.parseInteger(axis)) + return nullptr; + + if (parser.parseGreater()) + return nullptr; + + return parser.getChecked( + parser.getBuilder().getContext(), *blockMode, axis); +} + /// Parse a type registered to this dialect. Type QuantDialect::parseType(DialectAsmParser &parser) const { // All types start with an identifier that we switch on. @@ -329,6 +363,8 @@ Type QuantDialect::parseType(DialectAsmParser &parser) const { return parseAnyType(parser); if (typeNameSpelling == "calibrated") return parseCalibratedType(parser); + if (typeNameSpelling == "block_float") + return parseBlockFloatQuantizedType(parser); parser.emitError(parser.getNameLoc(), "unknown quantized type " + typeNameSpelling); @@ -413,6 +449,13 @@ static void printCalibratedQuantizedType(CalibratedQuantizedType type, out << ">"; } +static void printBlockFloatQuantizedType(BlockFloatQuantizedType type, + DialectAsmPrinter &out) { + out << "block_float"; +} + /// Print a type registered to this dialect. void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { if (auto anyType = llvm::dyn_cast(type)) @@ -423,6 +466,8 @@ void QuantDialect::printType(Type type, DialectAsmPrinter &os) const { printUniformQuantizedPerAxisType(perAxisType, os); else if (auto calibratedType = llvm::dyn_cast(type)) printCalibratedQuantizedType(calibratedType, os); + else if (auto blockType = llvm::dyn_cast(type)) + printBlockFloatQuantizedType(blockType, os); else llvm_unreachable("Unhandled quantized type"); } diff --git a/mlir/test/Dialect/Quant/Bytecode/types.mlir b/mlir/test/Dialect/Quant/Bytecode/types.mlir index 359a58557087..9ea62b33ddee 100644 --- a/mlir/test/Dialect/Quant/Bytecode/types.mlir +++ b/mlir/test/Dialect/Quant/Bytecode/types.mlir @@ -1,3 +1,5 @@ +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates // RUN: mlir-opt -emit-bytecode %s | mlir-opt | FileCheck %s //===----------------------------------------------------------------------===// @@ -64,3 +66,18 @@ module @parseUniformPerAxisMixed attributes { bytecode.test = !quant.uniform } {} +//===----------------------------------------------------------------------===// +// BlockFloatQuantized +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: parseBlockFloatQuantized +module @parseBlockFloatQuantized attributes { + // CHECK: bytecode.test = !quant.block_float + bytecode.test = !quant.block_float +} {} + +// CHECK-LABEL: parseBlockFloatQuantizedWithExpressed +module @parseBlockFloatQuantizedWithExpressed attributes { + // CHECK: bytecode.test = !quant.block_float + bytecode.test = !quant.block_float +} {} diff --git a/mlir/test/Dialect/Quant/invalid.mlir b/mlir/test/Dialect/Quant/invalid.mlir index ba3a8e312d96..0006a3c94e10 100644 --- a/mlir/test/Dialect/Quant/invalid.mlir +++ b/mlir/test/Dialect/Quant/invalid.mlir @@ -1,3 +1,6 @@ +// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its +// affiliates + // RUN: mlir-opt %s -split-input-file -verify-diagnostics func.func @dcast_invalid_input(%arg0: f32) { @@ -256,3 +259,39 @@ func.func @scast_per_axis_invalid_rank(%arg0: tensor<2x3x4xi8>) { return } +// ----- + +// expected-error@+1 {{unknown block quantized mode foo}} +!unknown_block = !quant.block_float + +// ----- + +// expected-error@+1 {{axis must be non-negative}} +!negative_axis = !quant.block_float + +// ----- + +!block_scalar = !quant.block_float +func.func @block_quant_scalar(%arg0: !block_scalar) { + // expected-error@+1 {{scalar types may not use block float quantization}} + %0 = quant.dcast %arg0 : !block_scalar to f32 + return +} + +// ----- + +!block_axis = !quant.block_float +func.func @block_quant_axis(%arg0: tensor<2x!block_axis>) { + // expected-error@+1 {{block axis must be less than tensor rank}} + %0 = quant.dcast %arg0 : tensor<2x!block_axis> to tensor<2xf32> + return +} + +// ----- + +!block_axis = !quant.block_float +func.func @block_quant_storage_cast(%arg0: tensor<2x!block_axis>) { + // expected-error@+1 {{storage cast not supported for block float quantized types}} + %0 = quant.scast %arg0 : tensor<2x!block_axis> to tensor<2xi9> + return +}