2828aten = torch .ops .aten
2929
3030
31- def get_no_random_sign_vector () -> torch .Tensor :
31+ def get_no_random_sign_vector (device : int ) -> torch .Tensor :
3232 """Non-random sign vector for Hadamard transform."""
33- return torch .tensor ([1 ], dtype = torch .float32 , device = "cuda" )
33+ return torch .tensor ([1 ], dtype = torch .float32 , device = device )
3434
3535
3636def get_sign_from_vector (vector : torch .Tensor ) -> int :
@@ -45,7 +45,7 @@ def get_sign_from_vector(vector: torch.Tensor) -> int:
4545 return mask .item ()
4646
4747
48- def get_wgrad_sign_vector () -> torch .Tensor :
48+ def get_wgrad_sign_vector (device : int ) -> torch .Tensor :
4949 """Hard-coded random signs for Hadamard transform.
5050
5151 https://xkcd.com/221/
@@ -54,11 +54,11 @@ def get_wgrad_sign_vector() -> torch.Tensor:
5454 return torch .tensor (
5555 [1 , 1 , 1 , - 1 , 1 , - 1 , - 1 , - 1 , - 1 , - 1 , - 1 , 1 , - 1 , 1 , - 1 , - 1 ],
5656 dtype = torch .float32 ,
57- device = "cuda" ,
57+ device = device ,
5858 )
5959
6060
61- def get_hadamard_matrix (hadamard_dimension : int ) -> torch .Tensor :
61+ def get_hadamard_matrix (hadamard_dimension : int , device : int ) -> torch .Tensor :
6262 """Construct a 16x16 Hadamard matrix."""
6363 assert hadamard_dimension == 16 , "Only hadamard dimension 16 is supported."
6464 hadamard_scale = 1 / math .sqrt (hadamard_dimension )
@@ -83,30 +83,30 @@ def get_hadamard_matrix(hadamard_dimension: int) -> torch.Tensor:
8383 [1 , - 1 , - 1 , 1 , - 1 , 1 , 1 , - 1 , - 1 , 1 , 1 , - 1 , 1 , - 1 , - 1 , 1 ],
8484 ],
8585 dtype = torch .float32 ,
86- device = "cuda" ,
86+ device = device ,
8787 )
8888 * hadamard_scale
8989 )
9090
9191
9292@functools .lru_cache (maxsize = None )
93- def get_rht_matrix (with_random_sign_mask : bool ) -> torch .Tensor :
93+ def get_rht_matrix (with_random_sign_mask : bool , device : int ) -> torch .Tensor :
9494 """Construct matrix used in random Hadamard transform."""
9595 hadamard_dimension = 16
9696 if with_random_sign_mask :
97- signs = get_wgrad_sign_vector ()
97+ signs = get_wgrad_sign_vector (device = device )
9898 else :
99- signs = get_no_random_sign_vector ()
100- sign_matrix = signs * torch .eye (hadamard_dimension , dtype = torch .float32 , device = "cuda" )
101- rht_matrix = sign_matrix @ get_hadamard_matrix (hadamard_dimension )
99+ signs = get_no_random_sign_vector (device = device )
100+ sign_matrix = signs * torch .eye (hadamard_dimension , dtype = torch .float32 , device = device )
101+ rht_matrix = sign_matrix @ get_hadamard_matrix (hadamard_dimension , device = device )
102102 return rht_matrix .to (dtype = torch .bfloat16 )
103103
104104
105105@functools .lru_cache (maxsize = None )
106- def get_random_sign_mask_for_rht (with_random_sign_mask : bool ) -> int :
106+ def get_random_sign_mask_for_rht (with_random_sign_mask : bool , device : int ) -> int :
107107 """Sign mask for random Hadamard transform."""
108108 if with_random_sign_mask :
109- return get_sign_from_vector (get_wgrad_sign_vector ())
109+ return get_sign_from_vector (get_wgrad_sign_vector (device = device ))
110110 return 0
111111
112112
@@ -152,8 +152,10 @@ def __init__(
152152 self .amax_reduction_group = amax_reduction_group
153153 self .with_2d_quantization = with_2d_quantization
154154 self .stochastic_rounding = stochastic_rounding
155- self .rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht (with_random_sign_mask )
156- self .rht_matrix = get_rht_matrix (with_random_sign_mask )
155+ self .rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht (
156+ with_random_sign_mask , torch .cuda .current_device ()
157+ )
158+ self .rht_matrix = get_rht_matrix (with_random_sign_mask , torch .cuda .current_device ())
157159
158160 def update_quantized (
159161 self ,
0 commit comments