@@ -242,7 +242,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
242
242
assert e + p == total_bits - has_sign
243
243
# the exponent is biased to 2^(e-1) -1 == 0
244
244
evalues = []
245
- pvalues = []
246
245
for i , val in enumerate (range (- (2 ** (exponent_bits - has_sign )), 2 ** (exponent_bits - has_sign ), 1 )):
247
246
evalues .append (2 ** val )
248
247
@@ -1365,8 +1364,6 @@ def optimizer_update_8bit_blockwise(
1365
1364
gnorm_scale : float = 1.0 ,
1366
1365
skip_zeros = False ,
1367
1366
) -> None :
1368
- optim_func = None
1369
-
1370
1367
is_on_gpu ([p , g , state1 , state2 , qmap1 , qmap2 , absmax1 , absmax2 ])
1371
1368
1372
1369
torch .ops .bitsandbytes .optimizer_update_8bit_blockwise (
@@ -2116,7 +2113,7 @@ def spmm_coo(
2116
2113
assert cooA .values .numel () == nnz
2117
2114
assert cooA .cols == B .shape [0 ]
2118
2115
2119
- transposed_B = False if B .is_contiguous () else True
2116
+ transposed_B = not B .is_contiguous ()
2120
2117
2121
2118
ldb = B .stride ()[(1 if transposed_B else 0 )]
2122
2119
ldc = B .shape [1 ]
@@ -2165,12 +2162,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
2165
2162
assert cooA .values .numel () == nnz
2166
2163
assert cooA .cols == B .shape [0 ], f"{ cooA .cols } vs { B .shape } "
2167
2164
2168
- transposed_B = False if B .is_contiguous () else True
2169
-
2170
- ldb = B .stride ()[(1 if transposed_B else 0 )]
2171
- ldc = B .shape [1 ]
2172
-
2173
- values , counts = torch .unique (cooA .rowidx , return_counts = True )
2165
+ _ , counts = torch .unique (cooA .rowidx , return_counts = True )
2174
2166
offset = counts .cumsum (0 ).int ()
2175
2167
max_count , max_idx = torch .sort (counts , descending = True )
2176
2168
max_idx = max_idx .int ()
@@ -2190,11 +2182,8 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
2190
2182
cnnz_rows = ct .c_int32 (counts .numel ())
2191
2183
cnnz = ct .c_int32 (cooA .nnz )
2192
2184
crowsA = ct .c_int32 (cooA .rows )
2193
- ccolsA = ct .c_int32 (cooA .cols )
2194
2185
crowsB = ct .c_int32 (B .shape [1 ])
2195
2186
ccolsB = ct .c_int32 (B .shape [1 ])
2196
- cldb = ct .c_int32 (ldb )
2197
- cldc = ct .c_int32 (ldc )
2198
2187
2199
2188
with _cuda_device_of (B ):
2200
2189
is_on_gpu ([cooA .rowidx , cooA .colidx , cooA .values , B , out , dequant_stats ])
0 commit comments