1
- from typing import Union
1
+ from typing import Any , Dict , Union
2
2
3
3
import torch
4
4
from neuronx_distributed .operators .argmax import argmax as nxd_argmax
5
5
from neuronx_distributed .operators .topk import topk as nxd_topk
6
6
from neuronx_distributed .parallel_layers import parallel_state
7
7
from torch_neuronx .xla_impl .ops import xla_hlo_call
8
8
9
- from neuronx_distributed_inference .models .config import NeuronConfig
9
+ from neuronx_distributed_inference .models .config import NeuronConfig , OnDeviceSamplingConfig
10
10
11
11
12
12
@xla_hlo_call
@@ -18,6 +18,62 @@ def rand_like(tensor):
18
18
return dtype [shape ].Rng (minimum , maximum , distribution = 1 ) # Uniform distribution
19
19
20
20
21
+ def validate_sampling_params (
22
+ params : torch .Tensor , on_device_sampling_config : Union [Dict [str , Any ], OnDeviceSamplingConfig ]
23
+ ) -> None :
24
+ """
25
+ Validates sampling parameters for language models.
26
+
27
+ Args:
28
+ params (torch.Tensor): Tensor of shape (batch_size, 3) containing sampling parameters
29
+ in the order: top-k, top-p, temperature.
30
+ on_device_sampling_config
31
+
32
+ Raises:
33
+ ValueError: If any of the parameters are invalid.
34
+ """
35
+ if params .shape [1 ] != 3 :
36
+ raise ValueError (f"Expected tensor of shape (batch_size, 3), but got { params .shape } " )
37
+
38
+ # autocast params tensor to float32
39
+ params = params .to (torch .float32 )
40
+
41
+ # Unpack parameters
42
+ top_k , top_p , temperature = params [:, 0 ], params [:, 1 ], params [:, 2 ]
43
+
44
+ if isinstance (on_device_sampling_config , OnDeviceSamplingConfig ):
45
+ global_top_k = on_device_sampling_config .global_topk
46
+ else :
47
+ global_top_k = on_device_sampling_config ["global_topk" ]
48
+
49
+ # Validate top-k value range
50
+ valid_top_k = (top_k == - 1 ) | ((top_k > 0 ) & (top_k <= global_top_k ))
51
+ if not torch .all (valid_top_k ):
52
+ raise ValueError (
53
+ f"Invalid top-k values found. top-k must be -1 or greater than 0 but less than or equal to { global_top_k = } . Found { top_k = } ."
54
+ )
55
+
56
+ # checks if top-k values can be represented as integers
57
+ if not torch .equal (top_k , top_k .floor ()):
58
+ raise ValueError (
59
+ f"Invalid top-k values found. top-k values should be able to be represented as integer values, but found decimal parts. Found { top_k = } ."
60
+ )
61
+
62
+ # Validate top-p
63
+ valid_top_p = (top_p > 0.0 ) & (top_p <= 1.0 )
64
+ if not torch .all (valid_top_p ):
65
+ raise ValueError (
66
+ f"Invalid top-p values found. top-p must be in the range (0.0, 1.0]. Found { top_p = } ."
67
+ )
68
+
69
+ # Validate temperature
70
+ valid_temp = temperature > 0.0
71
+ if not torch .all (valid_temp ):
72
+ raise ValueError (
73
+ f"Invalid temperature values found. Temperature must be strictly greater than 0.0. Found { temperature = } ."
74
+ )
75
+
76
+
21
77
def prepare_sampling_params (batch_size , top_k = [1 ], top_p = [1.0 ], temperature = [1.0 ]):
22
78
top_k = prepare_tensor (top_k )
23
79
top_p = prepare_tensor (top_p )
0 commit comments