Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/Quant/IR/Quant.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//

Expand All @@ -25,6 +27,7 @@ namespace mlir {
namespace quant {

class QuantizedType;
class BlockFloatQuantizedType;
class UniformQuantizedType;
class UniformQuantizedPerAxisType;

Expand Down
13 changes: 12 additions & 1 deletion mlir/include/mlir/Dialect/Quant/IR/QuantDialectBytecode.td
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//
//
Expand Down Expand Up @@ -81,6 +83,14 @@ def UniformQuantizedPerAxisType: DialectType<(type
}];
}

def BlockFloatQuantizedType: DialectType<(type
WithGetter<"static_cast<uint32_t>($_attrType.getBlockMode())", VarInt>:$blockMode,
VarInt:$axis
)> {
let cBuilder = "get<$_resultType>(context, "
" static_cast<BlockFloatQuantizedType::BlockMode>(blockMode), axis)";
}

/// This enum contains marker codes used to indicate which attribute is
/// currently being decoded, and how it should be decoded. The order of these
/// codes should generally be unchanged, as any changes will inevitably break
Expand All @@ -93,7 +103,8 @@ def QuantDialectTypes : DialectTypes<"Quant"> {
AnyQuantizedTypeWithExpressedType,
CalibratedQuantizedType,
UniformQuantizedType,
UniformQuantizedPerAxisType
UniformQuantizedPerAxisType,
BlockFloatQuantizedType
];
}

Expand Down
91 changes: 91 additions & 0 deletions mlir/include/mlir/Dialect/Quant/IR/QuantTypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//

Expand All @@ -23,6 +25,7 @@ namespace detail {

struct QuantizedTypeStorage;
struct AnyQuantizedTypeStorage;
struct BlockFloatQuantizedTypeStorage;
struct UniformQuantizedTypeStorage;
struct UniformQuantizedPerAxisTypeStorage;
struct CalibratedQuantizedTypeStorage;
Expand Down Expand Up @@ -224,6 +227,94 @@ class AnyQuantizedType
int64_t storageTypeMax);
};

/// Represents block floating point quantization where multiple elements share
/// data along a particular axis (e.g. BFP16). The concrete block format
/// determines the implied storage characteristics and is not exposed in the IR.
/// This class is experimental and may be subject to change.
/// Design decisions:
/// - The base class requires an integral storage type. For
/// block-quantized/packed types, the required storage with depends on the
/// number of elements. For example, a single BFP16 element requires 16 bits to
/// be represented, but a block of 8 BFP16 elements can be packed into 9 bits
/// per element on average (72 bits total). The storage type for
/// an element from BlockFloatQuantizedType is the "packed" type
/// divided by the number of packed elements, so for BFP16 i9.
/// -- As accessing properties like min/max storage values and integral width
/// depend on the block size, these methods are overridden to return errors.
/// - The expressed type is not stored yet, this may change if there is a use
/// for it.
/// - The axis is signed to match MLIR convention, but enforced to be
/// non-negative.
class BlockFloatQuantizedType
: public Type::TypeBase<BlockFloatQuantizedType, QuantizedType,
detail::BlockFloatQuantizedTypeStorage> {
public:
using Base::Base;
using Base::getChecked;

static constexpr StringLiteral name = "quant.block_float";

// MX6 refers to the MicoExponent format, not to the OCP MicroScaling format
// with the same name.
enum class BlockMode : uint32_t { BFP16 = 0, MX6 = 1, MAX_VALUE = MX6 };

static std::optional<BlockMode> parseBlockMode(StringRef name);
static StringRef getBlockModeName(BlockMode blockMode);

static BlockFloatQuantizedType get(MLIRContext *ctx, BlockMode blockMode,
int32_t axis);
static BlockFloatQuantizedType
getChecked(function_ref<InFlightDiagnostic()> emitError, MLIRContext *ctx,
BlockMode blockMode, int32_t axis);

static LogicalResult
verifyInvariants(function_ref<InFlightDiagnostic()> emitError,
uint32_t blockModeRaw, int32_t axis, unsigned flags,
Type storageType, Type expressedType, int64_t storageTypeMin,
int64_t storageTypeMax);

Type getStorageType() const {
assert(false &&
"BlockFloatQuantizedType does not have a direct storage type");
return QuantizedType::getStorageType();
}

int64_t getStorageTypeMin() const {
assert(false &&
"BlockFloatQuantizedType does not have a direct storage type");
return QuantizedType::getStorageTypeMin();
}

int64_t getStorageTypeMax() const {
assert(false &&
"BlockFloatQuantizedType does not have a direct storage type");
return QuantizedType::getStorageTypeMax();
}

bool hasStorageTypeBounds() const {
assert(false &&
"BlockFloatQuantizedType does not have a direct storage type");
return QuantizedType::hasStorageTypeBounds();
}

unsigned getStorageTypeIntegralWidth() const {
assert(false &&
"BlockFloatQuantizedType does not have a direct storage type");
return QuantizedType::getStorageTypeIntegralWidth();
}

BlockMode getBlockMode() const;
int32_t getAxis() const;

/// Number of elements in a block
unsigned getBlockSize() const;
/// Average number of bits used to represent each element in the block
unsigned getAverageBitsPerElement() const;
/// Returns the size in bits required to represent a single, not
/// blocked/packed element.
unsigned getSingleElementStorageSize() const;
};

/// Represents a family of uniform, quantized types.
///
/// Each instance of this type expresses a mapping between real values (most
Expand Down
67 changes: 55 additions & 12 deletions mlir/lib/Dialect/Quant/IR/QuantOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// Modifications (c) Copyright 2025 Advanced Micro Devices, Inc. or its
// affiliates
//
//===----------------------------------------------------------------------===//

Expand Down Expand Up @@ -37,30 +39,62 @@ namespace {
LogicalResult verifyPerAxisQuantization(Operation *op,
QuantizedType quantizedType,
Type containerType) {
auto quantizedPerAxisType = dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
auto quantizedPerAxisType =
dyn_cast<UniformQuantizedPerAxisType>(quantizedType);
if (!quantizedPerAxisType)
return success();

auto tensorType = dyn_cast<TensorType>(containerType);
if (!tensorType)
auto shapedType = dyn_cast<ShapedType>(containerType);
if (!shapedType)
return op->emitError("scalar types may not use per-axis quantization");

if (!tensorType.hasRank())
if (!shapedType.hasRank())
return success();

int64_t quantizedDimension = quantizedPerAxisType.getQuantizedDimension();
if (quantizedDimension >= tensorType.getRank())
if (quantizedDimension >= shapedType.getRank())
return op->emitError("quantized dimension must be less than tensor rank");

int64_t quantizedDimensionSize = tensorType.getDimSize(quantizedDimension);
int64_t quantizedDimensionSize = shapedType.getDimSize(quantizedDimension);
if (quantizedDimensionSize != ShapedType::kDynamic &&
quantizedDimensionSize != (int64_t)quantizedPerAxisType.getScales().size())
quantizedDimensionSize !=
(int64_t)quantizedPerAxisType.getScales().size())
return op->emitError(
"quantized dimension size does not match number of scales");

return success();
}

// Verify the integrity of block float quantization information, if present.
//
// - quantizedType
// Any quantized type. Any quantized type with no block float quantization is
// ignored.
//
// - containerType
// Original input or result type of the operation using the provided quantized
// type. Used to ensure that the quantized type appears within a tensor and
// that the tensor is compatible with block float quantization information.
//
LogicalResult verifyBlockFloatQuantization(Operation *op,
QuantizedType quantizedType,
Type containerType) {
auto blockModeType = dyn_cast<BlockFloatQuantizedType>(quantizedType);
if (!blockModeType)
return success();

auto shapedType = dyn_cast<ShapedType>(containerType);
if (!shapedType)
return op->emitError("scalar types may not use block float quantization");
if (!shapedType.hasRank())
return success();
// We could also check that the tensor is a multiple of the block size, but
// that requires that all padding is visible in MLIR
if (blockModeType.getAxis() >= shapedType.getRank())
return op->emitError("block axis must be less than tensor rank");
return success();
}

// Common verification logic for 'quant.dcast' and 'quant.qcast' ops.
//
// - quantizedType
Expand All @@ -76,12 +110,18 @@ LogicalResult verifyPerAxisQuantization(Operation *op,
//
LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
FloatType floatType, Type containerType) {
if (quantizedType.getExpressedType() != floatType)
if (!isa<BlockFloatQuantizedType>(quantizedType) &&
quantizedType.getExpressedType() != floatType)
return op->emitError(
"expressed type in quantized type expected to match float type");

// Veriy integrity of per-axis quantization information, if present.
return verifyPerAxisQuantization(op, quantizedType, containerType);
if (failed(verifyPerAxisQuantization(op, quantizedType, containerType)))
return failure();

if (failed(verifyBlockFloatQuantization(op, quantizedType, containerType)))
return failure();

return success();
}

} // namespace
Expand All @@ -92,8 +132,8 @@ LogicalResult verifyQuantizationOp(Operation *op, QuantizedType quantizedType,
//===----------------------------------------------------------------------===//

void QuantDialect::initialize() {
addTypes<AnyQuantizedType, CalibratedQuantizedType, UniformQuantizedType,
UniformQuantizedPerAxisType>();
addTypes<AnyQuantizedType, BlockFloatQuantizedType, CalibratedQuantizedType,
UniformQuantizedType, UniformQuantizedPerAxisType>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Quant/IR/QuantOps.cpp.inc"
Expand Down Expand Up @@ -167,6 +207,9 @@ QuantizedType QuantizeCastOp::getQuantizedType() {

LogicalResult StorageCastOp::verify() {
auto quantizedType = getQuantizedType();
if (isa<BlockFloatQuantizedType>(quantizedType))
return getOperation()->emitError(
"storage cast not supported for block float quantized types");
auto integerType = getIntegerType();
if (quantizedType.getStorageType() != integerType)
return emitError(
Expand Down
Loading