Skip to content

Commit

Permalink
Add 70b f16 tp1 benchmarking tests
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <[email protected]>
  • Loading branch information
aviator19941 committed Feb 26, 2025
1 parent d36cd5e commit 46ef7df
Showing 1 changed file with 147 additions and 29 deletions.
176 changes: 147 additions & 29 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def setUp(self):
f"--input=@{self.prefill_args_bs4_128_stride_32_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_decode_nondecomposed_args_f16 = [
self.iree_run_decode_nondecomposed_args_fp16 = [
"--function=decode_bs4",
f"--input=@{self.decode_args_bs4_128_stride_32_f16}/next_tokens.npy",
f"--input=@{self.decode_args_bs4_128_stride_32_f16}/seq_lens.npy",
Expand Down Expand Up @@ -198,7 +198,7 @@ def testBenchmark8B_f16_Non_Decomposed_Input_Len_128(self):
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_decode_nondecomposed_args_f16,
args=self.iree_run_decode_nondecomposed_args_fp16,
cwd=self.repo_root,
)

Expand Down Expand Up @@ -289,55 +289,101 @@ class BenchmarkLlama3_1_70B(BaseBenchmarkTest):
def setUp(self):
super().setUp()
# TODO: add numpy files to Azure and download from it
self.artifacts_dir = Path("/shark-dev/data/llama3.1/weights/70b")
self.artifacts_dir_2048 = Path("/shark-dev/70b")
self.irpa_path = self.artifacts_dir / "fp16/llama3.1_70b_f16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "f8/llama70b_fp8.irpa"
self.artifacts_dir = Path("/shark-dev/70b")
self.weights_dir = self.artifacts_dir / "instruct/weights"
self.irpa_path = self.weights_dir / "llama3.1_70b_instruct_fp16.irpa"
self.irpa_path_fp8 = self.artifacts_dir / "fp8/llama70b_fp8.irpa"
self.tensor_parallelism_size = 8
self.dir_path_70b = self.dir_path / "llama-70b"
self.temp_dir_70b = Path(self.dir_path_70b)
self.temp_dir_70b.mkdir(parents=True, exist_ok=True)
self.llama70b_f16_torch_sdpa_artifacts = ExportArtifacts(
self.llama70b_f16_torch_sdpa_artifacts_tp1 = ExportArtifacts(
irpa_path=str(self.irpa_path),
batch_size=4,
iree_hip_target="gfx942",
iree_hal_target_device="hip",
attention_kernel="torch",
tensor_parallelism_size=self.tensor_parallelism_size,
tensor_parallelism_size=1,
block_seq_stride=32,
)
self.llama70b_fp8_decomposed_artifacts = ExportArtifacts(
irpa_path=str(self.irpa_path_fp8),
self.llama70b_f16_torch_sdpa_artifacts_tp8 = ExportArtifacts(
irpa_path=str(self.irpa_path),
batch_size=4,
iree_hip_target="gfx942",
iree_hal_target_device="hip",
attention_kernel="decomposed",
attention_kernel="torch",
tensor_parallelism_size=self.tensor_parallelism_size,
block_seq_stride=32,
)
self.llama70b_fp8_torch_sdpa_artifacts = ExportArtifacts(
self.llama70b_fp8_torch_sdpa_artifacts_tp1 = ExportArtifacts(
irpa_path=str(self.irpa_path_fp8),
batch_size=4,
iree_hip_target="gfx942",
iree_hal_target_device="hip",
attention_kernel="torch",
tensor_parallelism_size=self.tensor_parallelism_size,
tensor_parallelism_size=1,
block_seq_stride=32,
)
self.prefill_args_bs4_128_stride_32_tp1_f16 = (
self.artifacts_dir / "prefill_args_bs4_128_stride_32"
)
self.decode_args_bs4_128_stride_32_tp1_f16 = (
self.artifacts_dir / "decode_args_bs4_128_stride_32"
)
self.prefill_args_bs4_2048_stride_32_tp1_f16 = (
self.artifacts_dir / "prefill_args_bs4_2048_stride_32"
)
self.decode_args_bs4_2048_stride_32_tp1_f16 = (
self.artifacts_dir / "decode_args_bs4_2048_stride_32"
)
self.prefill_args_bs4_128_stride_32_tp8_f16 = (
self.artifacts_dir / "prefill_args_bs4_128_stride_32_tp8"
)
self.decode_args_bs4_128_stride_32_tp8_f16 = (
self.artifacts_dir / "decode_args_bs4_128_stride_32_tp8"
)
self.prefill_args_bs4_2048_stride_32_tp8_f16 = (
self.artifacts_dir_2048 / "prefill_args_bs4_2048_stride_32_tp8"
self.artifacts_dir / "prefill_args_bs4_2048_stride_32_tp8"
)
self.decode_args_bs4_2048_stride_32_tp8_f16 = (
self.artifacts_dir_2048 / "decode_args_bs4_2048_stride_32_tp8"
self.artifacts_dir / "decode_args_bs4_2048_stride_32_tp8"
)
self.prefill_args_fp8 = self.artifacts_dir / "prefill_args_fp8"
self.decode_args_fp8 = self.artifacts_dir / "decode_args_fp8"
self.iree_run_prefill_nondecomposed_args_128_tp1_fp16 = [
"--function=prefill_bs4",
f"--input=@{self.prefill_args_bs4_128_stride_32_tp1_f16}/tokens.npy",
f"--input=@{self.prefill_args_bs4_128_stride_32_tp1_f16}/seq_lens.npy",
f"--input=@{self.prefill_args_bs4_128_stride_32_tp1_f16}/seq_block_ids.npy",
f"--input=@{self.prefill_args_bs4_128_stride_32_tp1_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_decode_nondecomposed_args_128_tp1_fp16 = [
"--function=decode_bs4",
f"--input=@{self.decode_args_bs4_128_stride_32_tp1_f16}/next_tokens.npy",
f"--input=@{self.decode_args_bs4_128_stride_32_tp1_f16}/seq_lens.npy",
f"--input=@{self.decode_args_bs4_128_stride_32_tp1_f16}/start_positions.npy",
f"--input=@{self.decode_args_bs4_128_stride_32_tp1_f16}/seq_block_ids.npy",
f"--input=@{self.decode_args_bs4_128_stride_32_tp1_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_prefill_nondecomposed_args_2048_tp1_fp16 = [
"--function=prefill_bs4",
f"--input=@{self.prefill_args_bs4_2048_stride_32_tp1_f16}/tokens.npy",
f"--input=@{self.prefill_args_bs4_2048_stride_32_tp1_f16}/seq_lens.npy",
f"--input=@{self.prefill_args_bs4_2048_stride_32_tp1_f16}/seq_block_ids.npy",
f"--input=@{self.prefill_args_bs4_2048_stride_32_tp1_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_decode_nondecomposed_args_2048_tp1_fp16 = [
"--function=decode_bs4",
f"--input=@{self.decode_args_bs4_2048_stride_32_tp1_f16}/next_tokens.npy",
f"--input=@{self.decode_args_bs4_2048_stride_32_tp1_f16}/seq_lens.npy",
f"--input=@{self.decode_args_bs4_2048_stride_32_tp1_f16}/start_positions.npy",
f"--input=@{self.decode_args_bs4_2048_stride_32_tp1_f16}/seq_block_ids.npy",
f"--input=@{self.decode_args_bs4_2048_stride_32_tp1_f16}/cs_f16.npy",
"--benchmark_repetitions=3",
]
self.iree_run_prefill_nondecomposed_args_128_tp8_fp16 = (
[
"--function=prefill_bs4",
Expand Down Expand Up @@ -418,6 +464,84 @@ def setUp(self):
"--benchmark_repetitions=3",
]

def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_128(self):
output_file_name = self.dir_path_70b / "f16_torch_128"
output_mlir = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".mlir", prefix=output_file_name
)
output_json = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".json", prefix=output_file_name
)
output_vmfb = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".vmfb", prefix=output_file_name
)
export_return_code = self.llama70b_f16_torch_sdpa_artifacts_tp1.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
)
self.llama70b_f16_torch_sdpa_artifacts_tp1.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_prefill_nondecomposed_args_128_tp1_fp16,
cwd=self.repo_root,
)
# benchmark decode
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_decode_nondecomposed_args_128_tp1_fp16,
cwd=self.repo_root,
)

def testBenchmark70B_f16_TP1_Non_Decomposed_Input_Len_2048(self):
output_file_name = self.dir_path_70b / "f16_torch_2048"
output_mlir = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".mlir", prefix=output_file_name
)
output_json = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".json", prefix=output_file_name
)
output_vmfb = self.llama70b_f16_torch_sdpa_artifacts_tp1.create_file(
suffix=".vmfb", prefix=output_file_name
)
export_return_code = self.llama70b_f16_torch_sdpa_artifacts_tp1.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
)
self.llama70b_f16_torch_sdpa_artifacts_tp1.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_prefill_nondecomposed_args_2048_tp1_fp16,
cwd=self.repo_root,
)
# benchmark decode
self.llama70b_f16_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path,
args=self.iree_run_decode_nondecomposed_args_2048_tp1_fp16,
cwd=self.repo_root,
)

@pytest.mark.xfail(
reason="Benchmarking Error", strict=True, raises=IreeBenchmarkException
)
Expand Down Expand Up @@ -517,44 +641,38 @@ def testBenchmark70B_f16_TP8_Non_Decomposed_Input_Len_2048(self):
@pytest.mark.xfail(
reason="70b fp8 irpa does not exist", strict=True, raises=ExportMlirException
)
def testBenchmark70B_fp8_TP8_Non_Decomposed(self):
def testBenchmark70B_fp8_TP1_Non_Decomposed(self):
output_file_name = self.dir_path_70b / "fp8_torch"
output_mlir = self.llama70b_fp8_torch_sdpa_artifacts.create_file(
output_mlir = self.llama70b_fp8_torch_sdpa_artifacts_tp1.create_file(
suffix=".mlir", prefix=output_file_name
)
output_json = self.llama70b_fp8_torch_sdpa_artifacts.create_file(
output_json = self.llama70b_fp8_torch_sdpa_artifacts_tp1.create_file(
suffix=".json", prefix=output_file_name
)
output_vmfb = self.llama70b_fp8_torch_sdpa_artifacts.create_file(
output_vmfb = self.llama70b_fp8_torch_sdpa_artifacts_tp1.create_file(
suffix=".vmfb", prefix=output_file_name
)
output_shard_file_name = (
self.artifacts_dir
/ f"f8/tp8/llama3.1_70b_fp8_tp{self.tensor_parallelism_size}_parameters.irpa"
)
if output_shard_file_name.exists():
self.llama70b_fp8_torch_sdpa_artifacts.irpa_path = output_shard_file_name
export_return_code = self.llama70b_fp8_torch_sdpa_artifacts.export_to_mlir(
export_return_code = self.llama70b_fp8_torch_sdpa_artifacts_tp1.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
)
self.llama70b_fp8_torch_sdpa_artifacts.compile_to_vmfb(
self.llama70b_fp8_torch_sdpa_artifacts_tp1.compile_to_vmfb(
mlir_path=str(output_mlir),
vmfb_path=output_vmfb,
hal_dump_path=output_file_name,
cwd=self.repo_root,
args=self.compile_args,
)
# benchmark prefill
self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
self.llama70b_fp8_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path_fp8,
args=self.iree_run_prefill_args,
cwd=self.repo_root,
)
# benchmark decode
self.llama70b_fp8_torch_sdpa_artifacts.iree_benchmark_vmfb(
self.llama70b_fp8_torch_sdpa_artifacts_tp1.iree_benchmark_vmfb(
hip_device_id=self.iree_device,
vmfb_name=output_vmfb,
irpa_path=self.irpa_path_fp8,
Expand Down

0 comments on commit 46ef7df

Please sign in to comment.