Skip to content

Commit 7d4b52a

Browse files
Merge pull request #1694 from bitsandbytes-foundation/cherrypick-v046-1
Cherry picks for v0.46.1 patch release
2 parents d35b170 + 39d7505 commit 7d4b52a

File tree

11 files changed

+115
-52
lines changed

11 files changed

+115
-52
lines changed

.github/scripts/build-cuda.sh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,14 @@ if [[ -v cuda_targets ]]; then
1111
elif [ "${build_arch}" = "aarch64" ]; then
1212
build_capability="75;80;90"
1313

14-
# CUDA 12.8: Add sm100
15-
[[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;90;100"
14+
# CUDA 12.8+: Add sm100/sm120
15+
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120"
1616
else
1717
# By default, target Maxwell through Hopper.
1818
build_capability="50;52;60;61;70;75;80;86;89;90"
1919

20-
# CUDA 12.8: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
21-
[[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;86;89;90;100;120"
20+
# CUDA 12.8+: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
21+
[[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;86;89;90;100;120"
2222
fi
2323

2424
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja

.github/workflows/python-package.yml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,17 @@ jobs:
7272
- os: windows-latest
7373
arch: x86_64
7474
cuda_version:
75-
["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"]
75+
["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1", "12.9.1"]
7676
runs-on: ${{ matrix.os }}
7777
steps:
7878
- uses: actions/checkout@v4
7979
# Windows: We install Cuda on the agent (slow)
80-
- uses: Jimver/[email protected].22
80+
- uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24
8181
if: startsWith(matrix.os, 'windows')
8282
id: cuda-toolkit
8383
with:
84-
cuda: ${{ matrix.cuda_version }}
84+
# Temporary: Use CUDA 12.9.0 for Windows until 12.9.1 is supported with this action.
85+
cuda: ${{ matrix.cuda_version == '12.9.1' && '12.9.0' || matrix.cuda_version }}
8586
method: "network"
8687
sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]'
8788
linux-local-args: '["--toolkit"]'

.github/workflows/tests.yml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,22 +49,23 @@ jobs:
4949
build-cuda:
5050
strategy:
5151
matrix:
52-
cuda_version: ["11.8.0", "12.6.3", "12.8.1"]
53-
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025]
52+
cuda_version: ["11.8.0", "12.6.3", "12.8.1", "12.9.1"]
53+
os: [ubuntu-22.04, ubuntu-22.04-arm]
5454
include:
5555
- os: ubuntu-22.04
5656
arch: x86_64
5757
- os: ubuntu-22.04-arm
5858
arch: aarch64
5959
- os: windows-2025
6060
arch: x86_64
61+
cuda_version: "11.8.0"
6162
runs-on: ${{ matrix.os }}
6263

6364
steps:
6465
- uses: actions/checkout@v4
6566

6667
- name: Install CUDA Toolkit
67-
uses: Jimver/[email protected].23
68+
uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24
6869
if: startsWith(matrix.os, 'windows')
6970
id: cuda-toolkit
7071
with:
@@ -193,7 +194,7 @@ jobs:
193194
os: [ubuntu-22.04, windows-2025]
194195
arch: [x86_64]
195196
gpu: [T4, L40S]
196-
cuda_version: ["11.8.0", "12.6.3", "12.8.1"]
197+
cuda_version: ["11.8.0", "12.6.3", "12.8.1", "12.9.1"]
197198
include:
198199
- cuda_version: "11.8.0"
199200
torch_version: "2.2.2"
@@ -204,6 +205,9 @@ jobs:
204205
- cuda_version: "12.8.1"
205206
torch_version: "2.7.0"
206207
pypi_index: "https://download.pytorch.org/whl/cu128"
208+
- cuda_version: "12.9.1"
209+
torch_version: "2.8.0"
210+
pypi_index: "https://download.pytorch.org/whl/nightly/cu129"
207211

208212

209213
# Linux L40S runners
@@ -236,12 +240,14 @@ jobs:
236240
gpu: T4
237241
runner: CUDA-Windows-x64
238242
cuda_version: "11.8.0"
239-
torch_version: "2.7.0"
243+
torch_version: "2.7.1" # Note: this is the last PyTorch release supporting CUDA 11.8.
240244
pypi_index: "https://download.pytorch.org/whl/cu118"
241245

242246
exclude:
243247
# Our current T4 Windows runner has a driver too old (471.11)
244248
# and cannot support CUDA 12+. Skip for now.
249+
- os: windows-2025
250+
cuda_version: "12.9.1"
245251
- os: windows-2025
246252
cuda_version: "12.8.1"
247253
- os: windows-2025
@@ -273,7 +279,7 @@ jobs:
273279

274280
- name: Install dependencies
275281
run: |
276-
pip install torch==${{ matrix.torch_version }} --index-url ${{ matrix.pypi_index }}
282+
pip install --pre torch~=${{ matrix.torch_version }}.dev0 --index-url ${{ matrix.pypi_index }}
277283
pip install -e ".[test]"
278284
pip install pytest-cov
279285

MANIFEST.in

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
include CMakeLists.txt
2+
graft csrc
3+
graft include

bitsandbytes/nn/modules.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,6 @@ def from_prequantized(
290290

291291
return self
292292

293-
@classmethod
294-
def __torch_function__(cls, func, types, args=(), kwargs=None):
295-
if kwargs is None:
296-
kwargs = {}
297-
with torch._C.DisableTorchFunctionSubclass():
298-
return func(*args, **kwargs)
299-
300293
def _quantize(self, device):
301294
w = self.data.contiguous().to(device)
302295
w_4bit, quant_state = bnb.functional.quantize_4bit(
@@ -353,6 +346,7 @@ def to(self, *args, **kwargs):
353346
compress_statistics=self.compress_statistics,
354347
quant_type=self.quant_type,
355348
quant_storage=self.quant_storage,
349+
bnb_quantized=self.bnb_quantized,
356350
)
357351

358352
return new_param

bitsandbytes/optim/adam.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55

6-
76
from bitsandbytes.optim.optimizer import Optimizer2State
87

98

@@ -100,8 +99,10 @@ def __init__(
10099
The weight decay value for the optimizer.
101100
amsgrad (`bool`, defaults to `False`):
102101
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102+
Note: This parameter is not supported in Adam8bit and must be False.
103103
optim_bits (`int`, defaults to 32):
104104
The number of bits of the optimizer state.
105+
Note: This parameter is not used in Adam8bit as it always uses 8-bit optimization.
105106
args (`object`, defaults to `None`):
106107
An object with additional arguments.
107108
min_8bit_size (`int`, defaults to 4096):
@@ -113,14 +114,23 @@ def __init__(
113114
is_paged (`bool`, defaults to `False`):
114115
Whether the optimizer is a paged optimizer or not.
115116
"""
117+
# Validate unsupported parameters
118+
if amsgrad:
119+
raise ValueError("Adam8bit does not support amsgrad=True")
120+
121+
if optim_bits != 32:
122+
# We allow the default value of 32 to maintain compatibility with the function signature,
123+
# but any other value is invalid since Adam8bit always uses 8-bit optimization
124+
raise ValueError("Adam8bit only supports optim_bits=32 (default value for compatibility)")
125+
116126
super().__init__(
117127
"adam",
118128
params,
119129
lr,
120130
betas,
121131
eps,
122132
weight_decay,
123-
8,
133+
8, # Hardcoded to 8 bits
124134
args,
125135
min_8bit_size,
126136
percentile_clipping,
@@ -283,8 +293,10 @@ def __init__(
283293
The weight decay value for the optimizer.
284294
amsgrad (`bool`, defaults to `False`):
285295
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
296+
Note: This parameter is not supported in PagedAdam8bit and must be False.
286297
optim_bits (`int`, defaults to 32):
287298
The number of bits of the optimizer state.
299+
Note: This parameter is not used in PagedAdam8bit as it always uses 8-bit optimization.
288300
args (`object`, defaults to `None`):
289301
An object with additional arguments.
290302
min_8bit_size (`int`, defaults to 4096):
@@ -296,14 +308,23 @@ def __init__(
296308
is_paged (`bool`, defaults to `False`):
297309
Whether the optimizer is a paged optimizer or not.
298310
"""
311+
# Validate unsupported parameters
312+
if amsgrad:
313+
raise ValueError("PagedAdam8bit does not support amsgrad=True")
314+
315+
if optim_bits != 32:
316+
# We allow the default value of 32 to maintain compatibility with the function signature,
317+
# but any other value is invalid since PagedAdam8bit always uses 8-bit optimization
318+
raise ValueError("PagedAdam8bit only supports optim_bits=32 (default value for compatibility)")
319+
299320
super().__init__(
300321
"adam",
301322
params,
302323
lr,
303324
betas,
304325
eps,
305326
weight_decay,
306-
8,
327+
8, # Hardcoded to 8 bits
307328
args,
308329
min_8bit_size,
309330
percentile_clipping,

bitsandbytes/optim/adamw.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
56
from bitsandbytes.optim.optimizer import Optimizer2State
67

78

@@ -25,7 +26,7 @@ def __init__(
2526
Base AdamW optimizer.
2627
2728
Arguments:
28-
params (`torch.tensor`):
29+
params (`torch.Tensor`):
2930
The input parameters to optimize.
3031
lr (`float`, defaults to 1e-3):
3132
The learning rate.
@@ -86,7 +87,7 @@ def __init__(
8687
8-bit AdamW optimizer.
8788
8889
Arguments:
89-
params (`torch.tensor`):
90+
params (`torch.Tensor`):
9091
The input parameters to optimize.
9192
lr (`float`, defaults to 1e-3):
9293
The learning rate.
@@ -98,8 +99,10 @@ def __init__(
9899
The weight decay value for the optimizer.
99100
amsgrad (`bool`, defaults to `False`):
100101
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
102+
Note: This parameter is not supported in AdamW8bit and must be False.
101103
optim_bits (`int`, defaults to 32):
102104
The number of bits of the optimizer state.
105+
Note: This parameter is not used in AdamW8bit as it always uses 8-bit optimization.
103106
args (`object`, defaults to `None`):
104107
An object with additional arguments.
105108
min_8bit_size (`int`, defaults to 4096):
@@ -111,14 +114,23 @@ def __init__(
111114
is_paged (`bool`, defaults to `False`):
112115
Whether the optimizer is a paged optimizer or not.
113116
"""
117+
# Validate unsupported parameters
118+
if amsgrad:
119+
raise ValueError("AdamW8bit does not support amsgrad=True")
120+
121+
if optim_bits != 32:
122+
# We allow the default value of 32 to maintain compatibility with the function signature,
123+
# but any other value is invalid since AdamW8bit always uses 8-bit optimization
124+
raise ValueError("AdamW8bit only supports optim_bits=32 (default value for compatibility)")
125+
114126
super().__init__(
115127
"adam",
116128
params,
117129
lr,
118130
betas,
119131
eps,
120132
weight_decay,
121-
8,
133+
8, # Hardcoded to 8 bits
122134
args,
123135
min_8bit_size,
124136
percentile_clipping,
@@ -147,7 +159,7 @@ def __init__(
147159
32-bit AdamW optimizer.
148160
149161
Arguments:
150-
params (`torch.tensor`):
162+
params (`torch.Tensor`):
151163
The input parameters to optimize.
152164
lr (`float`, defaults to 1e-3):
153165
The learning rate.
@@ -207,7 +219,7 @@ def __init__(
207219
Paged AdamW optimizer.
208220
209221
Arguments:
210-
params (`torch.tensor`):
222+
params (`torch.Tensor`):
211223
The input parameters to optimize.
212224
lr (`float`, defaults to 1e-3):
213225
The learning rate.
@@ -229,8 +241,6 @@ def __init__(
229241
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
230242
block_wise (`bool`, defaults to `True`):
231243
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
232-
is_paged (`bool`, defaults to `False`):
233-
Whether the optimizer is a paged optimizer or not.
234244
"""
235245
super().__init__(
236246
"adam",
@@ -267,7 +277,7 @@ def __init__(
267277
Paged 8-bit AdamW optimizer.
268278
269279
Arguments:
270-
params (`torch.tensor`):
280+
params (`torch.Tensor`):
271281
The input parameters to optimize.
272282
lr (`float`, defaults to 1e-3):
273283
The learning rate.
@@ -279,8 +289,10 @@ def __init__(
279289
The weight decay value for the optimizer.
280290
amsgrad (`bool`, defaults to `False`):
281291
Whether to use the [AMSGrad](https://hf.co/papers/1904.09237) variant of Adam that uses the maximum of past squared gradients instead.
292+
Note: This parameter is not supported in PagedAdamW8bit and must be False.
282293
optim_bits (`int`, defaults to 32):
283294
The number of bits of the optimizer state.
295+
Note: This parameter is not used in PagedAdamW8bit as it always uses 8-bit optimization.
284296
args (`object`, defaults to `None`):
285297
An object with additional arguments.
286298
min_8bit_size (`int`, defaults to 4096):
@@ -289,17 +301,24 @@ def __init__(
289301
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
290302
block_wise (`bool`, defaults to `True`):
291303
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
292-
is_paged (`bool`, defaults to `False`):
293-
Whether the optimizer is a paged optimizer or not.
294304
"""
305+
# Validate unsupported parameters
306+
if amsgrad:
307+
raise ValueError("PagedAdamW8bit does not support amsgrad=True")
308+
309+
if optim_bits != 32:
310+
# We allow the default value of 32 to maintain compatibility with the function signature,
311+
# but any other value is invalid since PagedAdamW8bit always uses 8-bit optimization
312+
raise ValueError("PagedAdamW8bit only supports optim_bits=32 (default value for compatibility)")
313+
295314
super().__init__(
296315
"adam",
297316
params,
298317
lr,
299318
betas,
300319
eps,
301320
weight_decay,
302-
8,
321+
8, # Hardcoded to 8 bits
303322
args,
304323
min_8bit_size,
305324
percentile_clipping,
@@ -327,7 +346,7 @@ def __init__(
327346
Paged 32-bit AdamW optimizer.
328347
329348
Arguments:
330-
params (`torch.tensor`):
349+
params (`torch.Tensor`):
331350
The input parameters to optimize.
332351
lr (`float`, defaults to 1e-3):
333352
The learning rate.
@@ -349,8 +368,6 @@ def __init__(
349368
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
350369
block_wise (`bool`, defaults to `True`):
351370
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
352-
is_paged (`bool`, defaults to `False`):
353-
Whether the optimizer is a paged optimizer or not.
354371
"""
355372
super().__init__(
356373
"adam",

0 commit comments

Comments
 (0)