Skip to content

Commit

Permalink
Fix GPU-CPU tensor manipulation. Small performance boost (#178)
Browse files Browse the repository at this point in the history
* transfer mask to gpu

* add cpu mask

* slice tensor on gpu side

* add device suffix for clarity

* create mask on cpu once

* transfer tensors to device

* typo fix
  • Loading branch information
vvchernov authored Jan 30, 2024
1 parent 998357b commit e1bd866
Showing 1 changed file with 13 additions and 7 deletions.
20 changes: 13 additions & 7 deletions serve/mlc_serve/model/model_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,19 @@ def _is_safe_to_sample(prob_like):
logits = torch.from_dlpack(logits)
num_seq = len(sampling_params)

mask_random = torch.tensor(
mask_random_cpu = torch.tensor(
[p.sampling_type == SamplingType.RANDOM for p in sampling_params],
dtype=torch.bool,
)
mask_greedy = torch.logical_not(mask_random)
mask_greedy_cpu = torch.logical_not(mask_random_cpu)
if logits.device == torch.device("cpu"):
mask_random_dvc = mask_random_cpu
mask_greedy_dvc = mask_greedy_cpu
else: # gpu
mask_random_dvc = mask_random_cpu.to(logits.device)
mask_greedy_dvc = mask_greedy_cpu.to(logits.device)

logits_greedy = logits[mask_greedy]
logits_greedy = logits[mask_greedy_dvc]

if logits_greedy.shape[0] > 0:
res_greedy = torch.argmax(logits_greedy, -1).cpu().numpy()
Expand Down Expand Up @@ -140,7 +146,7 @@ def _is_safe_to_sample(prob_like):
.to(device=logits.device)
)

logits_random = logits[mask_random]
logits_random = logits[mask_random_dvc]

if divide_by_temperature:
t = torch.tensor(temperatures, dtype=logits.dtype, device=logits.device)
Expand All @@ -155,17 +161,17 @@ def _is_safe_to_sample(prob_like):
torch.cuda.nvtx.range_pop()
return None

res_random = torch.multinomial(probs, 1, True).cpu().numpy()[:, 0]
res_random = torch.multinomial(probs, 1, True)[:, 0].cpu().numpy()

if logits_random.shape[0] == num_seq:
torch.cuda.nvtx.range_pop()
return res_random

res = np.empty((num_seq,), dtype=np.int32)
res[mask_random] = res_random
res[mask_random_cpu] = res_random

if logits_greedy.shape[0] > 0:
res[mask_greedy] = res_greedy
res[mask_greedy_cpu] = res_greedy

torch.cuda.nvtx.range_pop()
return res
Expand Down

0 comments on commit e1bd866

Please sign in to comment.