Skip to content
1 change: 1 addition & 0 deletions src/llmcompressor/args/dataset_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ class DatasetArguments(CustomDatasetArguments):
"_prepare_fsmt_decoder_inputs",
"_prepare_4d_causal_attention_mask_with_cache_position",
"_update_linear_attn_mask",
"project_per_layer_inputs",
],
metadata={
"help": "List of functions to ignore during tracing, either "
Expand Down
2 changes: 1 addition & 1 deletion src/llmcompressor/args/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def parse_args(

# raise depreciation warnings
if dataset_args.remove_columns is not None:
logger.warn(
logger.warning(
"`remove_columns` argument is depreciated. When tokenizing datasets, all "
"columns which are invalid inputs the tokenizer will be removed",
DeprecationWarning,
Expand Down
40 changes: 34 additions & 6 deletions src/llmcompressor/pipelines/sequential/ast_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,15 @@
import linecache
import sys
import textwrap
import traceback
from typing import List

import torch

from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper
from llmcompressor.utils import patch_attr

__all__ = ["autowrap_forwards"]
__all__ = ["autowrap_forwards", "append_autowrap_source_on_fail"]


@contextlib.contextmanager
Expand Down Expand Up @@ -58,22 +59,49 @@ def autowrap_forward(module: torch.nn.Module, ignore: List[str]):
# autowrap untraceable code
auto_wrapper = AutoWrapper(namespace, ignore)
tree = auto_wrapper.auto_wrap(tree)
source = ast.unparse(tree)

# compile new forward function from autowrapped code
filename = f"{module.__class__.__name__}_{hash(module)}_autowrapped"
code = compile(tree, filename=filename, mode="exec")
filename = f"<Autowrapped {module.__class__.__name__} {id(module)}>"
code = compile(source, filename=filename, mode="exec")
exec(code, namespace) # ensure ns of functions is the same ns as torch.fx.wrap

# enable better tracebacks if autowrapped code fails
source_str = ast.unparse(tree)
linecache.cache[filename] = (
len(source_str),
len(source),
None,
[line + "\n" for line in source_str.splitlines()],
[line + "\n" for line in source.splitlines()],
filename,
)

# patch forward with autowrapped forward
new_forward = namespace["forward"].__get__(module)
with patch_attr(module, "forward", new_forward):
yield


@contextlib.contextmanager
def append_autowrap_source_on_fail():
try:
yield
except Exception as exception:
_exc_type, _exc_value, exc_tb = sys.exc_info()
tb_list = traceback.extract_tb(exc_tb)

for frame in reversed(tb_list):
if "Autowrapped" in frame.filename:
source_lines = linecache.getlines(frame.filename)
lineno = frame.lineno

# annotate failing line
source_lines = [
("> " if i + 1 == lineno else " ") + line
for i, line in enumerate(source_lines)
]

message = f"{exception}\n\n"
message += f"\n--- {frame.filename}:{lineno} ---\n"
message += "".join(source_lines)
raise RuntimeError(message) from exception

raise exception
16 changes: 13 additions & 3 deletions src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,13 @@ def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
:param node: function definition whose decorators will be stripped
:return: function definition without decorators
"""
node.decorator_list = []
node.decorator_list = [
decorator_name
for decorator_name in node.decorator_list
if isinstance(decorator_name, ast.Name)
and decorator_name.id in ("can_return_tuple",) # modifies func signature
]

if node.name == "forward":
for arg in node.args.args:
self._local_names.add(arg.arg)
Expand Down Expand Up @@ -104,6 +110,11 @@ def visit_If(self, node: ast.If) -> Union[ast.If, ast.Assign]:
try:
value = bool(self._eval_expr(node.test))

# force a wrap if any assignments occur within the if statement
for expr in ast.walk(node):
if isinstance(expr, ast.NamedExpr):
raise Exception("If statement contains assignment")

except Exception:
return self._wrap_if_possible(node)

Expand Down Expand Up @@ -165,8 +176,7 @@ def _can_wrap(self, node: ast.AST) -> bool:
without its original context. In the future, we can add more checks for module
calls (see `visit_If`)
"""
analyzer = ControlFlowAnalyzer()
return analyzer.is_valid(node)
return ControlFlowAnalyzer().is_valid(node)

def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Call]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,13 @@ def visit_Assign(self, node: ast.Assign):
for target in node.targets:
self.visit(target)

def visit_NamedExpr(self, node: ast.NamedExpr):
# Visit the right side of the assignment first
self.visit(node.value)

# Now visit the left side of the assignment
self.visit(node.target)

def visit_If(self, node: ast.If):
self.visit(node.test)

Expand Down
43 changes: 22 additions & 21 deletions src/llmcompressor/pipelines/sequential/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import inspect
from collections import deque
from dataclasses import dataclass
from types import FunctionType, MethodType
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple

import torch
Expand All @@ -26,7 +27,7 @@
from llmcompressor.utils.helpers import calibration_forward_context, patch_attr
from llmcompressor.utils.pytorch.module import get_no_split_params

from .ast_helpers import autowrap_forwards
from .ast_helpers import append_autowrap_source_on_fail, autowrap_forwards

if TYPE_CHECKING:
from llmcompressor.args.dataset_arguments import DatasetArguments
Expand Down Expand Up @@ -69,15 +70,8 @@ def forward(self, *args, **kwargs) -> Dict[str, Any]:

forward_fn = self._code.globals.get("forward")

try:
outputs = forward_fn(*args, **kwargs)
except Exception as exception:
raise RuntimeError(
"Raised an exception during execution of the following code:\n"
f"```\n{add_line_numbers(self._code.src)}\n```"
) from exception

return outputs
with append_autowrap_source_on_fail():
return forward_fn(*args, **kwargs)


def trace_subgraphs(
Expand Down Expand Up @@ -118,19 +112,26 @@ def trace_subgraphs(

# autowrap forwards
stack.enter_context(autowrap_forwards(ancestors, ignore))
stack.enter_context(patch_attr(type(model), "forward", model.forward.__func__))

graph = GraphModule(
model,
tracer.trace(
# avoid bug where pytorch cannot handle wrapped root functions
unwrapped = inspect.unwrap(model.forward).__get__(model)
stack.enter_context(patch_attr(model, "forward", unwrapped))
stack.enter_context(patch_attr(type(model), "forward", unwrapped.__func__))
assert isinstance(model.forward, MethodType)
assert isinstance(type(model).forward, FunctionType)

with append_autowrap_source_on_fail():
graph = GraphModule(
model,
dummy_inputs=sample_input,
concrete_args=concrete_args,
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
# bug in trace throws an error for variadic
# args and kwargs in function signature
),
)
tracer.trace(
model,
dummy_inputs=sample_input,
concrete_args=concrete_args,
complete_concrete_args_with_inputs_not_in_dummy_inputs=False,
# bug in trace throws an error for variadic
# args and kwargs in function signature
),
)

# copy metadata
graph.config = model.config
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# flake8: noqa
import ast
import textwrap
from types import SimpleNamespace
Expand All @@ -21,13 +22,14 @@ def check_wrapping(

wrapped_lines = ast.unparse(wrapped).splitlines()
output_lines = textwrap.dedent(output).splitlines()[1:]
lines = ("\n".join(wrapped_lines), "\n".join(output_lines))

assert len(wrapped_lines) == len(output_lines)
assert len(wrapped_lines) == len(output_lines), lines
for wrapped_line, output_line in zip(wrapped_lines, output_lines):
if "# skip" in output:
continue

assert wrapped_line == output_line
assert wrapped_line == output_line, lines


def test_static_if():
Expand Down Expand Up @@ -189,3 +191,24 @@ def forward(a, *b, c=5, **d):
() = wrapped_0(a, b, c, d)
"""
check_wrapping(source, output)


def test_walrus():
"""Checks for handling variadic names created via function def"""

source = """
def forward():
if (x := (1 + 2)):
pass
"""
output = """
@torch.fx.wrap
def wrapped_0():
if (x := (1 + 2)):
pass
return (x,)

def forward():
(x,) = wrapped_0() # skip: some envs use "(x,)" -> "x,"
"""
check_wrapping(source, output)
9 changes: 9 additions & 0 deletions tests/llmcompressor/transformers/tracing/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from transformers import (
AutoModelForCausalLM,
Gemma3ForConditionalGeneration,
Gemma3nForConditionalGeneration,
Idefics3ForConditionalGeneration,
Llama4ForConditionalGeneration,
LlavaForConditionalGeneration,
Expand Down Expand Up @@ -49,6 +50,7 @@
"text",
[],
),
("google/gemma-3n-E2B-it", AutoModelForCausalLM, None, "text", ["timm"]),
("unsloth/DeepSeek-R1-0528-BF16", AutoModelForCausalLM, None, "text", []),
# --- vision ---
(
Expand Down Expand Up @@ -122,6 +124,13 @@
"vision",
[],
),
(
"google/gemma-3n-E2B-it",
Gemma3nForConditionalGeneration,
None,
"vision",
["timm"],
),
# --- audio ---
(
"openai/whisper-large-v3",
Expand Down
Loading