Skip to content

Commit 93d5791

Browse files
committed
Add further hipgraph support
1 parent 35fac8e commit 93d5791

File tree

4 files changed

+80
-46
lines changed

4 files changed

+80
-46
lines changed

byte_infer_perf/llm_perf/backends/GPU/gpu_mp_engine.py

+46-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pathlib
66
from multiprocessing import Queue
77
from typing import List
8-
8+
import time
99
import torch
1010
import torch.nn as nn
1111
import torch.distributed as dist
@@ -92,16 +92,30 @@ def build_inputs(self, forward_inputs):
9292
).cuda()
9393

9494
is_context = forward_inputs["is_context"]
95+
96+
batch_offset = forward_inputs["cache_batch_offset"]
97+
9598
if is_context:
9699
forward_inputs["full_attention_mask"] = get_context_masks(
97100
forward_inputs["input_ids"],
98101
forward_inputs["attention_mask"]
99102
)
103+
slot_offset = torch.tensor([forward_inputs["valid_slot_ids"][0] * batch_offset],
104+
device = forward_inputs["position_ids"].device,
105+
dtype = forward_inputs["position_ids"].dtype).unsqueeze(1)
100106
else:
107+
bsz = forward_inputs["input_ids"].shape[0]
101108
forward_inputs["full_attention_mask"] = get_decode_masks(
102109
forward_inputs["input_ids"],
103110
forward_inputs["all_kv_len"]
104111
)
112+
forward_inputs["seq_lens"] = torch.tensor( [x + y for x, y in zip(forward_inputs["all_q_len"], forward_inputs["all_kv_len"])],
113+
dtype=torch.int,
114+
device=forward_inputs["position_ids"].device)
115+
slot_offset = torch.arange(0, bsz * batch_offset, batch_offset,
116+
device = forward_inputs["position_ids"].device,
117+
dtype = forward_inputs["position_ids"].dtype).unsqueeze(1)
118+
forward_inputs["slot_mapping"] = forward_inputs["position_ids"] + slot_offset
105119
return forward_inputs
106120

107121

@@ -192,8 +206,8 @@ def mp_forward(self, *args):
192206
self._input_queues.put(args, block=True)
193207

194208
# wait for one subprocess send result back to main process
209+
#for _ in range(self.world_size):
195210
output_dict = self._output_queues.get(block=True)
196-
197211
return output_dict
198212

199213
# ROCM_HIPGRAPH modify
@@ -240,22 +254,43 @@ def signal_handler(signum, frame):
240254
logger.info(f"{local_rank}/{world_size} rank is ready")
241255

242256
graph = torch.cuda.CUDAGraph()
243-
257+
s = torch.cuda.Stream()
244258
# model process loop
245259
while True:
246260
(
247261
forward_inputs,
248262
) = input_queue.get(block=True)
249263

264+
if 'replay' not in forward_inputs:
265+
forward_inputs["cache_batch_offset"] = model.cache_batch_offset
266+
inputs_dict = self.build_inputs(forward_inputs)
250267
# this is the capture phase of graph
251268
if 'capture' in forward_inputs:
252-
graph.reset() # reset cuda graph each time
253-
inputs_dict = self.build_inputs(forward_inputs)
269+
graph = torch.cuda.CUDAGraph() # reset cuda graph each time
270+
254271
# model.forward(inputs_dict)
272+
_NUM_WARMUP_ITERS=2
273+
with torch.cuda.stream(s):
274+
for _ in range(_NUM_WARMUP_ITERS):
275+
logits = model.forward(inputs_dict)
276+
277+
torch.cuda.current_stream().wait_stream(s)
255278
torch.cuda.synchronize()
279+
# graph.enable_debug_mode()
256280
with torch.cuda.graph(graph):
257-
model.forward(inputs_dict)
281+
logits = model.forward(inputs_dict)
282+
258283
torch.cuda.synchronize()
284+
#graph.debug_dump(f"/src/cuda_graph{str(local_rank)}.dot")
285+
output_dict = dict()
286+
#inputs_dict["input_ids"][0][0] = 128
287+
# graph.replay()
288+
# torch.cuda.synchronize()
289+
290+
output_dict["duration_ms"] = 0
291+
# TP realization: rank0 send result back to main process
292+
if local_rank == 0:
293+
output_queue.put(output_dict)
259294
continue
260295

261296
log = forward_inputs.get("log", False)
@@ -267,13 +302,13 @@ def signal_handler(signum, frame):
267302
workspace_dir.mkdir(exist_ok=True, parents=True)
268303
forward_inputs["log_file"] = open(workspace_dir / "run.log", "w")
269304

270-
271-
inputs_dict = self.build_inputs(forward_inputs)
272305
start_time = time.perf_counter_ns()
273306

274-
# output_dict = model.forward(inputs_dict)
275-
graph.replay()
276-
307+
if 'replay' in forward_inputs:
308+
with torch.cuda.stream(s):
309+
graph.replay()
310+
else:
311+
output_dict = model.forward(inputs_dict)
277312
torch.cuda.synchronize()
278313
end_time = time.perf_counter_ns()
279314
duration_ms = round((end_time - start_time) / 1e6, 3)

byte_infer_perf/llm_perf/backends/ROCM/model_impl/modeling_mixtral.py

+14-19
Original file line numberDiff line numberDiff line change
@@ -875,7 +875,7 @@ def forward(
875875
None,
876876
1.0,
877877
1.0,
878-
).contiguous()
878+
)
879879

880880
attn_output = attn_output.view(bsz, q_len, self.num_heads * self.head_dim)
881881
attn_output = self.o_proj(attn_output)
@@ -1234,31 +1234,24 @@ def forward(
12341234
**kwargs,
12351235
) -> Union[Tuple, MoeModelOutputWithPast]:
12361236
residual = None
1237-
bsz = input_ids.shape[0]
1238-
is_context = kwargs.get("is_context")
1239-
valid_slot_ids = kwargs.get("valid_slot_ids")
1240-
batch_offset = kwargs.get("cache_batch_offset")
1241-
if is_context:
1242-
slot_offset = torch.tensor([valid_slot_ids[0] * batch_offset],
1243-
device = position_ids.device,
1244-
dtype = position_ids.dtype).unsqueeze(1)
1245-
else:
1246-
slot_offset = torch.arange(0, bsz * batch_offset, batch_offset,
1247-
device = position_ids.device,
1248-
dtype = position_ids.dtype).unsqueeze(1)
1249-
kwargs["slot_mapping"] = position_ids + slot_offset
1237+
12501238
if kwargs.pop("override_hidden_states", False):
12511239
random_seed = kwargs.pop("random_seed", None)
12521240
layer_index = kwargs.pop("fixed_layer_index", -1)
12531241
layer_index = layer_index % len(self.layers)
1254-
12551242
# create random input ids on cpu and copy to device
1243+
<<<<<<< HEAD
12561244
if random_seed is not None:
12571245
# RuntimeError: Cannot call CUDAGeneratorImpl::set_current_seed during CUDA graph capture.
12581246
torch.manual_seed(random_seed)
12591247
random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device)
12601248

12611249
hidden_states = self.embed_tokens(random_input_ids)
1250+
=======
1251+
#torch.manual_seed(random_seed)
1252+
#random_input_ids = torch.randint(10, self.vocab_size, input_ids.shape, dtype=torch.int64, device="cpu").to(input_ids.device)
1253+
hidden_states = self.embed_tokens(input_ids)
1254+
>>>>>>> 4494d72 (Add further hipgraph support)
12621255

12631256
for _ in self.layers:
12641257
layer_outputs, residual = self.layers[layer_index](
@@ -1271,7 +1264,7 @@ def forward(
12711264
output_router_logits=False,
12721265
use_cache=False,
12731266
**kwargs,
1274-
)
1267+
)
12751268
else:
12761269
hidden_states = self.embed_tokens(input_ids)
12771270
for decoder_layer in self.layers:
@@ -1289,9 +1282,8 @@ def forward(
12891282

12901283
hidden_states, _ = self.norm(hidden_states, residual)
12911284

1292-
return MoeModelOutputWithPast(
1293-
last_hidden_state=hidden_states
1294-
)
1285+
return hidden_states
1286+
12951287

12961288

12971289
class MixtralForCausalLM(MixtralPreTrainedModel):
@@ -1387,8 +1379,11 @@ def forward(
13871379

13881380
# print(f'{os.environ.get("LOCAL_RANK", "0")} {outputs=}')
13891381
hidden_states = outputs[0]
1382+
13901383
logits = self.lm_head(hidden_states)
13911384
logits = logits.float()
1385+
# if(os.environ.get("LOCAL_RANK") == "0"):
1386+
# print(f'>>>logits2 {logits.shape}, {logits} {logits.data_ptr()}')
13921387
# print(f'{os.environ.get("LOCAL_RANK", "0")}:{hidden_states.shape=}')
13931388
# print(f'{os.environ.get("LOCAL_RANK", "0")}:{hidden_states=}')
13941389
# print(f'{os.environ.get("LOCAL_RANK", "0")}:{logits.shape=}')

byte_infer_perf/llm_perf/backends/ROCM/model_impl/rocm_mixtral.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ def init_kvcache(self, dtype):
213213
while max_num_blocks * self.block_size < max_seq_len * max_batch_size:
214214
max_num_blocks += 4096
215215
self.max_num_blocks_per_seq = (max_seq_len + self.block_size - 1) // self.block_size
216+
self.cache_batch_offset = self.block_size * self.max_num_blocks_per_seq
216217
block_tables_lst: List[List[int]] = []
217218
for batch_idx in range(max_batch_size):
218219
block_start = self.max_num_blocks_per_seq * batch_idx
@@ -231,17 +232,14 @@ def init_kvcache(self, dtype):
231232
return block_tables, past_key_values
232233

233234
def forward(self, inputs : Dict[str, torch.Tensor]):
234-
inputs["cache_batch_offset"] = self.block_size * self.max_num_blocks_per_seq
235+
235236
model_outputs = self.transformer_model.forward(
236237
**inputs,
237238
past_key_values=(self.block_tables, self.kv_cache)
238239
)
239240

241+
240242
# context: [1, seq_len] --> [1, seq_len, vocab_size] or [1, 1, vocab_size]
241243
# decode: [max_batch_size, 1]
242-
logits = model_outputs.logits
243244

244-
output_dict = {
245-
"logits": logits
246-
}
247-
return output_dict
245+
return model_outputs.logits

byte_infer_perf/llm_perf/bench_model.py

+16-10
Original file line numberDiff line numberDiff line change
@@ -131,19 +131,12 @@ def update_template(mode, batch_size, seq_len):
131131
input_template = update_template("context", 1, 1024)
132132
is_graph = int(os.environ.get("ENABLE_GRAPH", "0"))
133133

134-
if is_graph:
135-
#ROCM_HIPGRAPH modify
136-
input_template['capture'] = 1
137-
engine.mp_forward(input_template)
138-
input_template.pop('capture')
139-
140134
start_time = time.perf_counter_ns()
141135
for _ in range(num_warm_iter):
142136
engine.mp_forward(input_template)
143137
duration_s = round((time.perf_counter_ns() - start_time) / 1e9, 3)
144138
logger.info(f"warmup cost: {duration_s}s")
145139

146-
147140
def results_to_csv(file_path, results):
148141
batch_size_set = sorted(results.keys())
149142
seq_len_set = set()
@@ -167,6 +160,8 @@ def results_to_csv(file_path, results):
167160

168161
log_results = []
169162
if xpu_config["perf_config"]["perf_context"]:
163+
print(f'>>> Beginning Context', flush=True)
164+
170165
batch_size_list = [1]
171166
seq_len_list = xpu_config["perf_config"]["seq_len_list"]
172167

@@ -187,7 +182,12 @@ def results_to_csv(file_path, results):
187182
test_iter = 0
188183
duration_ms = 0.
189184
while test_iter < total_test_iter:
190-
result = engine.mp_forward(input_template)
185+
if is_graph:
186+
input_template['replay'] = 1
187+
result = engine.mp_forward(input_template)
188+
input_template.pop('replay')
189+
else:
190+
result = engine.mp_forward(input_template)
191191
if start_iters > 0:
192192
start_iters -= 1
193193
continue
@@ -203,10 +203,11 @@ def results_to_csv(file_path, results):
203203
lines = workspace.joinpath("rank_0", "run.log").read_text().splitlines()
204204
log_results[-1] += f", {lines[0]}"
205205
print(log_results[-1])
206+
print(f'>>> End of sequence length', flush=True)
206207
results_to_csv(workspace.joinpath("context_perf.csv"), context_results)
207208

208-
209209
if xpu_config["perf_config"]["perf_decode"]:
210+
print(f'>>> Beginning Context', flush=True)
210211
batch_size_list = xpu_config["perf_config"]["batch_size_list"]
211212
seq_len_list = xpu_config["perf_config"]["seq_len_list"]
212213

@@ -230,7 +231,12 @@ def results_to_csv(file_path, results):
230231

231232
duration_ms = 0.
232233
while test_iter < total_test_iter:
233-
result = engine.mp_forward(input_template)
234+
if is_graph:
235+
input_template['replay'] = 1
236+
result = engine.mp_forward(input_template)
237+
input_template.pop('replay')
238+
else:
239+
result = engine.mp_forward(input_template)
234240
if start_iters > 0:
235241
start_iters -= 1
236242
continue

0 commit comments

Comments
 (0)