-
Notifications
You must be signed in to change notification settings - Fork 497
Feature fast cast-only mxfp8 #2062
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
00fbe4f
to
7295b7d
Compare
Steps to reproduce performance numbers
ncu --section=MemoryWorkloadAnalysis --section=SpeedOfLight --clock-control=none --nvtx --nvtx-include="Update Quantized/" --nvtx-include="reference kernel/" python quantize.py # quantize.py
import os
os.environ["TORCHINDUCTOR_CACHE_DIR"] = "/tmp"
os.environ["TRITON_CACHE_DIR"] = "/tmp"
os.environ["TORCHINDUCTOR_FX_GRAPH_CACHE"] = "0"
if "USER" not in os.environ:
os.environ["USER"] = "you"
import torch
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer
import transformer_engine_torch as tex
import argparse
torch.cuda.manual_seed(233376)
def run(args):
check_consistency = args.check_consistency
direction = {"rowwise": not args.no_rowwise, "colwise": not args.no_colwise}
src_dtype = args.src_dtype
dst_dtype = args.dst_dtype
size_h, size_w = args.size_h, args.size_w
msg_candidates = {"TrueTrue": "rowwise and colwise",
"TrueFalse": "rowwise",
"FalseTrue": "colwise",
"FalseFalse": None}
msg = msg_candidates[f"{direction['rowwise']}{direction['colwise']}"]
if msg is None:
raise ValueError(f"Invalid direction: {direction}")
print("=" * 120)
print(f"checking {msg}, "
f"src_dtype: {src_dtype}, dst_dtype: {dst_dtype}, size_h: {size_h}, size_w: {size_w}")
print("=" * 120)
with torch.cuda.nvtx.range("Ctor"):
quantizer = MXFP8Quantizer(
fp8_dtype=dst_dtype,
rowwise=direction["rowwise"],
columnwise=direction["colwise"],
)
with torch.cuda.nvtx.range("Create Input"):
bf16_tensor = torch.randn(size_h, size_w, dtype=src_dtype, device="cuda")
# bf16_tensor = torch.arange(size_h * size_w, dtype=src_dtype, device="cuda").reshape(size_h, size_w)
# # Print every element in bf16_tensor
# print("Elements of bf16_tensor:")
# for i in range(bf16_tensor.shape[0]):
# print("row: ", i, end=": ")
# for j in range(bf16_tensor.shape[1]):
# print(f"{bf16_tensor[i, j].item():.4f}\t", end="")
# print()
# amax = torch.abs(bf16_tensor).amax(axis=0, keepdim=True)
# print(amax)
if check_consistency:
with torch.cuda.nvtx.range("reference"):
fp8_tensor_ref = quantizer.make_empty(
bf16_tensor.shape,
dtype=bf16_tensor.dtype,
device=bf16_tensor.device,
)
with torch.cuda.nvtx.range("reference kernel"):
quantizer.update_quantized(bf16_tensor, fp8_tensor_ref)
with torch.cuda.nvtx.range("Make Empty"):
fp8_tensor = quantizer.make_empty(
bf16_tensor.shape,
dtype=bf16_tensor.dtype,
device=bf16_tensor.device,
)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
if check_consistency:
os.environ["ENABLE_CAST_ONLY"] = "1"
start.record()
with torch.cuda.nvtx.range("Update Quantized"):
quantizer.update_quantized(bf16_tensor, fp8_tensor)
end.record()
torch.cuda.synchronize()
ms = start.elapsed_time(end)
io_bytes = size_h * size_w * 2
io_bytes += size_h * size_w * 1
io_bytes += size_h * (size_w // 32) * 1
print(f"Io Bytes: {io_bytes / 1e6} MB")
print(f"Duration: {ms} ms")
print(f"Bandwidth: {(io_bytes * 1e-9) / (ms * 1e-3)} GB/s")
# print(fp8_tensor)
if check_consistency:
# print(fp8_tensor_ref)
if direction["rowwise"]:
torch.testing.assert_close(fp8_tensor._rowwise_data, fp8_tensor_ref._rowwise_data)
print("rowwise data passed")
# print(fp8_tensor._rowwise_scale_inv.shape)
# for i in range(fp8_tensor._rowwise_scale_inv.shape[0]):
# print(f"row: {i}", end=": ")
# for j in range(fp8_tensor._rowwise_scale_inv.shape[1]):
# print(f"{fp8_tensor._rowwise_scale_inv[i, j].item():d},", end="")
# print("")
# print("-------------ref tensor-------------------")
# for i in range(fp8_tensor_ref._rowwise_scale_inv.shape[0]):
# print(f"row: {i}", end=": ")
# for j in range(fp8_tensor_ref._rowwise_scale_inv.shape[1]):
# print(f"{fp8_tensor_ref._rowwise_scale_inv[i, j].item():d},", end="")
# print("")
torch.testing.assert_close(fp8_tensor._rowwise_scale_inv, fp8_tensor_ref._rowwise_scale_inv)
print("rowwise scale_inv passed")
if direction["colwise"]:
torch.testing.assert_close(fp8_tensor._columnwise_data, fp8_tensor_ref._columnwise_data)
print("colwise data passed")
torch.testing.assert_close(fp8_tensor._columnwise_scale_inv, fp8_tensor_ref._columnwise_scale_inv)
print("colwise scale_inv passed")
torch.testing.assert_close(fp8_tensor, fp8_tensor_ref)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Profile MXFP8 quantization")
parser.add_argument("--no_rowwise", action="store_true", default=False, help="Enable rowwise quantization")
parser.add_argument("--no_colwise", action="store_true", default=False, help="Enable colwise quantization")
parser.add_argument("--src_dtype", type=str, default="bfloat16", choices=["bfloat16", "float16"], help="Source dtype")
parser.add_argument("--dst_dtype", type=str, default="kFloat8E4M3", choices=["kFloat8E4M3", "kFloat8E5M2"], help="Destination dtype")
parser.add_argument("--size_h", type=int, default=4096, help="Input tensor height")
parser.add_argument("--size_w", type=int, default=7168, help="Input tensor width")
parser.add_argument("--check_consistency", action="store_true", default=True, help="Check consistency")
args = parser.parse_args()
if args.src_dtype == "bfloat16":
src_dtype = torch.bfloat16
args.src_dtype = src_dtype
elif args.src_dtype == "float16":
src_dtype = torch.float16
args.src_dtype = src_dtype
elif args.src_dtype == "float32":
src_dtype = torch.float32
args.src_dtype = src_dtype
else:
raise ValueError(f"Unsupported src_dtype: {args.src_dtype}")
if args.dst_dtype == "kFloat8E4M3":
dst_dtype = tex.DType.kFloat8E4M3
args.dst_dtype = dst_dtype
elif args.dst_dtype == "kFloat8E5M2":
dst_dtype = tex.DType.kFloat8E5M2
args.dst_dtype = dst_dtype
else:
raise ValueError(f"Unsupported dst_dtype: {args.dst_dtype}")
run(args) |
623f960
to
0089f01
Compare
Signed-off-by: Jianbing Dong <[email protected]> Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]> Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]> Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]> Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]> Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
Signed-off-by: Jianbing Dong <[email protected]>
0089f01
to
c0ae662
Compare
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Solid kernel optimizations, the perf comes right when we need it.
|
||
break; | ||
} | ||
case ScalingType::COLWISE: { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please modify the logic such that the column-wise scaling is not skipped and the original implementation is used as fall back?
|
||
#include "state_counter.cuh" | ||
#include "swizzle.cuh" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please include "ptx.cuh", as the code isn't compiling
|
||
namespace { | ||
#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) | ||
template <typename IType, typename OType> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it be possible to use the members of the original Quantized_Limits class and derive their bit representation in compile time using, e.g. bit_cast, to ensure the limits defined in one place?
coords.x = block_coords.x + iter_n * CastTraits::blockIterDimN; | ||
|
||
if (coords.y < rows && coords.x < cols) { | ||
int32_t offset = coords.y * cols + coords.x; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please use size_t data type for offsets, especially when global memory is addressed, as int32_t overflows when tensors are really large. (We fixed this problem just recently, as some TE kernels used 4-byte ints previously for offsets)
smem_alignment + smem_rowwise_scale + smem_colwise_reduce); | ||
}; | ||
|
||
#define ALIGN_TO(x, align) (((x) + (align) - 1) & ~((align) - 1)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace the macro with the device function.
|
||
extern __shared__ char smem[]; | ||
char *smemAligned = | ||
reinterpret_cast<char *>(ALIGN_TO((intptr_t)smem, CastTraits::smem_alignment)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please replace C-style casting with static_cast<>.
float2 amaxs; | ||
#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ | ||
(__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) | ||
asm volatile("redux.sync.max.abs.f32 %0, %1, 0xFFFFFFFF;" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please wrap the inline PTX into device functions with some reasonable names (e.g., which resemble the operation, like redux_sync_max_abs_f32() here), and move them to ptx.cuh
: "=f"(amaxs.y) | ||
: "f"(values.y)); | ||
#else | ||
asm volatile( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here, and in other places with "raw" inline ptx.
extern __shared__ char smem[]; | ||
char *smemAligned = | ||
reinterpret_cast<char *>(ALIGN_TO((intptr_t)smem, CastTraits::smem_alignment)); | ||
typename CastTraits::IType *sInput = reinterpret_cast<typename CastTraits::IType *>(smemAligned); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: as long types are used few times in the kernel, and it may better to use type aliasing to improve readability of the code. E.g.:
using IType = typename CastTraits::IType;
using OType = typename CastTraits::rowOutputUnitType;
using ColwiseReduceType = typename CastTraits::ColwiseReduceDataType;
Description
This pull request involves efficient implementations for mxfp8 quantize on
casting only
cases. It can increase the casting performance from 5%~ 20%.It supports:
BF16
orFP16
as inputsE5M2
orE4M3
as outputssm_100
rowwise
orrow- & col-wise
Performance gain:




Type of change
Changes
Please list the changes introduced in this PR:
ENABLE_CAST_ONLY
to select optimized kernel. If optimized kernel doesn't support provided inputs, it will fallback to original kernels, automatically.ENABLE_CAST_ONLY
is not set or is set to0
, then original kernels will be used.Checklist: