Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tracing option to quantize bench #3651

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions fbgemm_gpu/experimental/gemm/test/fp8_gemm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,23 @@ def _test_quantize_fp8_row(
# Apply sparsification if specified.
zero_start_index_M = None
if use_jagged:
# View input as [G, M, K] where G is the number of groups.
grouped_input = input_a.view(
-1, input_a.shape[-2], input_a.shape[-1]
)
m_vals = torch.randint(
0, input_a.shape[-1] + 1, (input_a.shape[:-1])
0, grouped_input.shape[1] + 1, (grouped_input.shape[0],)
)
mask = torch.arange(input_a.shape[-1]).expand(
input_a.shape[:-1] + (input_a.shape[-1],)
mask = torch.arange(grouped_input.shape[-2]).expand(
(grouped_input.shape[0], grouped_input.shape[1])
) >= m_vals.unsqueeze(-1)
# Set corresponding values to 0.
input_a[mask] = 0.0
grouped_input[mask] = 0.0
# Generate nonzero tensor in same layout as input.
zero_start_index_M = torch.count_nonzero(input_a, dim=-1)
zero_start_index_M = torch.count_nonzero(
torch.sum(grouped_input, dim=-1), dim=-1
)

a_fp8, a_scale = quantize_fp8_row(
input_a,
scale_ub=scale_ub,
Expand Down
34 changes: 14 additions & 20 deletions fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2316,7 +2316,6 @@ def _kernel_quantize_fp8_row(
stride_ok,
stride_zb,
stride_zm,
stride_zn,
TL_FP8_DTYPE: tl.constexpr,
MAX_FP8: tl.constexpr,
EPS: tl.constexpr,
Expand Down Expand Up @@ -2354,7 +2353,6 @@ def _kernel_quantize_fp8_row(
stride_ok (int): Stride of k dimension of output.
stride_zb (int): Stride of b dimension of jagged index.
stride_zm (int): Stride of m dimension of jagged index.
stride_zn (int): Stride of n dimension of jagged index.
TL_FP8_DTYPE (tl.dtype): Target fp8 datatype.
MAX_FP8 (float): Maxmimum expressible value for FP8.
EPS (float): Epsilon value for numerical stability.
Expand All @@ -2380,24 +2378,22 @@ def _kernel_quantize_fp8_row(
+ (pid % (M * N)) % N * stride_on
)

if JAGGED:
z_offset_base = (
pid // (M * N) * stride_zb
+ (pid % (M * N)) // N * stride_zm
+ (pid % (M * N)) % N * stride_zn
)
row_size = tl.load(zero_start_index_M + z_offset_base)
else:
row_size = K
K_in = K

blocks = tl.cdiv(row_size, BLOCK_SIZE)
if JAGGED:
z_offset_base = pid // (M * N) * stride_zb + (pid % (M * N)) // N * stride_zm
group_rows = tl.load(zero_start_index_M + z_offset_base)
current_row = pid % N
# If this row is empty, dont process any of it.
if current_row >= group_rows:
K_in = 0

# Calculate max.
cur_max = 0.0
for _k in range(0, blocks):
for _k in range(0, tl.cdiv(K_in, BLOCK_SIZE)):
a = tl.load(
A + a_offset_base + n_offset * stride_ak,
mask=n_offset < row_size,
mask=n_offset < K_in,
other=0.0,
)
tile_max = tl.max(tl.abs(a))
Expand All @@ -2418,15 +2414,14 @@ def _kernel_quantize_fp8_row(
for _k in range(0, tl.cdiv(K, BLOCK_SIZE)):
a = tl.load(
A + a_offset_base + n_offset * stride_ak,
mask=n_offset < row_size,
mask=n_offset < K_in,
other=0.0,
)
a_fp8 = a * a_scale
# Clamp A to fp8 range to make sure there's no overflow.
# This is required for AMD. Nvidia's default saturation
# handles it, but it's nice to have anyway.
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8)
a_fp8.to(TL_FP8_DTYPE)
a_fp8 = tl.clamp(a_fp8, -MAX_FP8, MAX_FP8).to(TL_FP8_DTYPE)
tl.store(
A_fp8 + a_fp8_offset_base + n_offset * stride_ok,
a_fp8,
Expand Down Expand Up @@ -2481,7 +2476,6 @@ def triton_quantize_fp8_row(
a_fp8.stride(3),
zero_start_index_M.stride(0) if zero_start_index_M is not None else None,
zero_start_index_M.stride(1) if zero_start_index_M is not None else None,
zero_start_index_M.stride(2) if zero_start_index_M is not None else None,
TL_FP8_DTYPE=tl_dtype,
MAX_FP8=max_fp8,
EPS=eps,
Expand Down Expand Up @@ -2527,8 +2521,8 @@ def quantize_fp8_row(
while a.dim() < 4:
a = a.unsqueeze(0)
if zero_start_index_M is not None:
while zero_start_index_M.dim() < 3:
zero_start_index_M = zero_start_index_M.unsqueeze(0)
# There should be one value of zero_start_index_M per NxK matrix.
zero_start_index_M = zero_start_index_M.view(a.shape[0], a.shape[1])
a_fp8, a_scale = triton_quantize_fp8_row(a, scale_ub, zero_start_index_M)
return a_fp8.view(a_shape), a_scale.view(a_shape[:-1])
# else use pytorch implementation.
Expand Down
80 changes: 52 additions & 28 deletions fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
import seaborn as sns
import torch

try:
from accelerators.pytorch.lib.utils.torch_profiler import profiler_or_nullcontext
except ImportError:
from contextlib import nullcontext

class profiler_or_nullcontext(nullcontext):
def __init__(self, *args, **kwargs):
super().__init__()


from .quantize_ops import get_quantize_ops, QuantizeOpBase


Expand Down Expand Up @@ -96,6 +106,7 @@ def benchmark_grouped(
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
trace: bool = False,
) -> Dict[str, Any]:
num_groups = len(m)
# Create input tensors.
Expand Down Expand Up @@ -127,7 +138,8 @@ def benchmark_grouped(
# Also check if the operator is supported.
if kernel_requested and quantize_op.supported:
# Get the quantized tensors for this operator.
quantized_vals = quantize_op.quantize(A, B)
preprocessed_args = quantize_op.preprocess(A, B)
quantized_vals = quantize_op.quantize(*preprocessed_args)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Some kernels may pad output, just take the first m values of each row.
Expand All @@ -142,20 +154,21 @@ def benchmark_grouped(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
A,
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
tflops = 0
Expand Down Expand Up @@ -197,6 +210,7 @@ def benchmark(
bench_quantize: bool = False,
use_rotating_buffer_bench: bool = False,
use_cuda_graph: bool = True,
trace: bool = False,
) -> Dict[str, Any]:
# Create input tensors.
if b > 1:
Expand All @@ -218,8 +232,10 @@ def benchmark(
)
# Also check if the operator is supported.
if kernel_requested and quantize_op.supported:
# Preprocess data if needed.
preprocessed_args = quantize_op.preprocess(A, B)
# Get the quantized tensors for this operator.
quantized_vals = quantize_op.quantize(A, B)
quantized_vals = quantize_op.quantize(*preprocessed_args)
# Compute the output given quantized values.
output = quantize_op.compute(*quantized_vals)
# Compare the quantize op output to reference as a sanity check.
Expand All @@ -228,20 +244,21 @@ def benchmark(
# Now perform benchmark.
if bench_quantize:
# Benchmark both quantize and compute.
ms_runtime = quantize_op.benchmark(
A,
B,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*preprocessed_args,
bench_quantize=True,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
else:
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)
with profiler_or_nullcontext(enabled=trace, with_stack=True):
ms_runtime = quantize_op.benchmark(
*quantized_vals,
bench_quantize=False,
use_rotating_buffer_bench=use_rotating_buffer_bench,
use_cuda_graph=use_cuda_graph,
)

# Print out results for this op.
tflops = 2 * b * m * n * k / (ms_runtime / 1e3) / 1e12
Expand Down Expand Up @@ -369,6 +386,7 @@ def main(args: Any):
args.bench_quantize,
args.use_rotating_buffer_bench,
not args.no_cuda_graph,
args.trace,
)
benchmark_results.append(quantize_measurements)
if args.export_csv:
Expand Down Expand Up @@ -459,6 +477,12 @@ def invoke_main() -> None:
action="store_true",
help="If set, benchmark using fixed shapes relevant to ldm workloads.",
)
parser.add_argument(
"--trace",
default=False,
action="store_true",
help="If set, produce a performance trace of the benchmark.",
)

args = parser.parse_args()
main(args)
Loading
Loading