Skip to content

Commit d6d6dad

Browse files
authored
fix triton multiple gpus and some other issues (#539)
1 parent 88e6e3b commit d6d6dad

File tree

5 files changed

+55
-52
lines changed

5 files changed

+55
-52
lines changed

auto_round/inference/auto_quantizer.py

+2
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,8 @@ def from_dict(cls, config_dict, return_unused_kwargs=False, **kwargs):
281281

282282
if "auto-round" not in quant_method:
283283
config_dict["packing_format"] = f"auto_round:{quant_method}"
284+
285+
284286
return super().from_dict(config_dict, return_unused_kwargs=return_unused_kwargs, **kwargs)
285287

286288

auto_round/inference/convert_model.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -383,11 +383,12 @@ def _import_exllamav2_kernels():
383383
"""Attempts to import ExLlamaV2 kernels for performance optimization."""
384384
try:
385385
from exllamav2_kernels import gemm_half_q_half, make_q_matrix # pylint: disable=E0611, E0401
386-
except ImportError:
387-
raise ImportError(
386+
except:
387+
logger.warning_once(
388388
"AutoGPTQ ExLlamaV2 has not been installed, Please install it using the following command: "
389389
"`pip install git+https://github.com/AutoGPTQ/AutoGPTQ.git@b8b4127`"
390390
)
391+
logger.warning_once("try to fallback to other autogptq backends for now")
391392

392393

393394
def _create_quant_layer(layer, layer_backend, config, in_features, out_features):
@@ -520,19 +521,19 @@ def convert_hf_model(model: nn.Module, target_device="cpu"):
520521
else:
521522
backend = "auto"
522523

523-
524524
##target_backend could be None
525525
_, backend = parse_target_device_and_backend(backend)
526526

527-
if hasattr(quantization_config, "packing_format"): # pragma: no cover
527+
if hasattr(quantization_config,
528+
"packing_format") and "auto-round" in quantization_config.quant_method: # pragma: no cover
528529
packing_format = quantization_config.packing_format
529530
elif 'gptq' in quantization_config.quant_method: # pragma: no cover
530531
packing_format = "auto_gptq"
531532
elif "awq" in quantization_config.quant_method:
532533
packing_format = "auto_awq"
533534
else: # pragma: no cover
534535
packing_format = "auto_gptq"
535-
logger.warning("Quantization backend must be specified. Set it to 'auto_gptq' by default.")
536+
logger.warning("quantization backend must be specified. Set it to 'auto_gptq' by default.")
536537
if packing_format == "auto":
537538
packing_format = "auto_gptq"
538539

auto_round/version.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,4 @@
1414
"""Intel® auto-round: An open-source Python library
1515
supporting popular model weight only compression based on signround."""
1616

17-
__version__ = "0.5.0"
17+
__version__ = "0.5.1"

auto_round_extension/cuda/triton_utils/dequant.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,29 @@ def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None, input_dtype=torc
123123
"""
124124
Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8
125125
"""
126-
127-
num_groups = scales.shape[0]
128-
outfeatures = scales.shape[1]
129-
infeatures = g_idx.shape[0]
130-
131-
out = torch.empty((infeatures, outfeatures), device="cuda", dtype=input_dtype)
132-
numels = out.numel()
133-
maxq = 2 ** bits - 1 if maxq is None else maxq
134-
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
135-
136-
dequant_kernel_248[grid](
137-
g_idx,
138-
scales,
139-
qweight,
140-
qzeros,
141-
out,
142-
numels,
143-
maxq=maxq,
144-
bits=bits,
145-
outfeatures=outfeatures,
146-
num_groups=num_groups,
147-
)
148-
return out
126+
with torch.cuda.device(qweight.device):
127+
num_groups = scales.shape[0]
128+
outfeatures = scales.shape[1]
129+
infeatures = g_idx.shape[0]
130+
131+
out = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=input_dtype)
132+
numels = out.numel()
133+
maxq = 2 ** bits - 1 if maxq is None else maxq
134+
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
135+
136+
dequant_kernel_248[grid](
137+
g_idx,
138+
scales,
139+
qweight,
140+
qzeros,
141+
out,
142+
numels,
143+
maxq=maxq,
144+
bits=bits,
145+
outfeatures=outfeatures,
146+
num_groups=num_groups,
147+
)
148+
return out
149149

150150

151151
def quant_matmul_248(

auto_round_extension/cuda/triton_utils_zp/dequant.py

+23-23
Original file line numberDiff line numberDiff line change
@@ -123,29 +123,29 @@ def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None, input_dtype=torc
123123
"""
124124
Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8
125125
"""
126-
127-
num_groups = scales.shape[0]
128-
outfeatures = scales.shape[1]
129-
infeatures = g_idx.shape[0]
130-
131-
out = torch.empty((infeatures, outfeatures), device="cuda", dtype=input_dtype)
132-
numels = out.numel()
133-
maxq = 2 ** bits - 1 if maxq is None else maxq
134-
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
135-
136-
dequant_kernel_248[grid](
137-
g_idx,
138-
scales,
139-
qweight,
140-
qzeros,
141-
out,
142-
numels,
143-
maxq=maxq,
144-
bits=bits,
145-
outfeatures=outfeatures,
146-
num_groups=num_groups,
147-
)
148-
return out
126+
with torch.cuda.device(qweight.device):
127+
num_groups = scales.shape[0]
128+
outfeatures = scales.shape[1]
129+
infeatures = g_idx.shape[0]
130+
131+
out = torch.empty((infeatures, outfeatures), device=qweight.device, dtype=input_dtype)
132+
numels = out.numel()
133+
maxq = 2 ** bits - 1 if maxq is None else maxq
134+
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
135+
136+
dequant_kernel_248[grid](
137+
g_idx,
138+
scales,
139+
qweight,
140+
qzeros,
141+
out,
142+
numels,
143+
maxq=maxq,
144+
bits=bits,
145+
outfeatures=outfeatures,
146+
num_groups=num_groups,
147+
)
148+
return out
149149

150150

151151
def quant_matmul_248(

0 commit comments

Comments
 (0)