diff --git a/lmdeploy/pytorch/engine/logits_process.py b/lmdeploy/pytorch/engine/logits_process.py index f7ca9c511..c1d313056 100644 --- a/lmdeploy/pytorch/engine/logits_process.py +++ b/lmdeploy/pytorch/engine/logits_process.py @@ -361,6 +361,28 @@ async def __call__(self, all_ids: torch.LongTensor, def sampling(self, logits: torch.Tensor): """sampling.""" sampling_inputs = self.sampling_inputs + + def _softmax_scores(scores: torch.Tensor): + """softmax scores.""" + # if score has inf, replace it with max or min finite value, then do softmax + if torch.isinf(scores).any(): + dtype = scores.dtype + + if dtype in [torch.float16, torch.float32, torch.float64]: + max_finite_value = torch.finfo(dtype).max + min_finite_value = torch.finfo(dtype).min + elif dtype in [torch.int8, torch.int16, torch.int32, torch.int64]: + max_finite_value = torch.iinfo(dtype).max + min_finite_value = torch.iinfo(dtype).min + else: + raise TypeError("Unsupported data type") + + device = scores.device + + scores = torch.where(scores == float('inf'), torch.tensor(max_finite_value, dtype=dtype, device=device), scores) + scores = torch.where(scores == float('-inf'), torch.tensor(min_finite_value, dtype=dtype, device=device), scores) + softmax_scores = scores.softmax(dim=1) + return softmax_scores def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): """random sampling.""" @@ -383,7 +405,7 @@ def __random_sampling(scores: torch.Tensor, indices: torch.LongTensor): if min_p is not None: scores = _filter_minp_sorted_(scores, min_p) - softmax_scores = scores.softmax(1) + softmax_scores = _softmax_scores(scores) seeds = sampling_inputs.random_seeds offsets = sampling_inputs.random_offsets