Skip to content

Commit 80ceffa

Browse files
committed
add rotation args
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 6af0778 commit 80ceffa

File tree

2 files changed

+57
-43
lines changed

2 files changed

+57
-43
lines changed

examples/transform/quip_example.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,14 @@
1414
# NOTE: because the datafree pipeline is being used in this
1515
# example, you can use additional GPUs to support larger models
1616
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
17-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, torch_dtype="auto")
17+
model = AutoModelForCausalLM.from_pretrained(MODEL_ID, dtype="auto")
1818
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1919

2020
# Configure the quantization algorithm to run.
21-
# * apply spinquant transforms to model in order to make quantization easier
21+
# * apply quip transforms to model in order to make quantization easier
2222
# * quantize the weights to 4 bit with a group size 128
2323
recipe = [
24-
QuIPModifier(targets="Linear", transform_type="random-hadamard"),
24+
QuIPModifier(rotations=["v", "u"], transform_type="random-hadamard"),
2525
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2626
]
2727

@@ -35,7 +35,7 @@
3535
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
3636
model.device
3737
)
38-
output = model.generate(input_ids, max_new_tokens=100)
38+
output = model.generate(input_ids, max_new_tokens=50)
3939
print(tokenizer.decode(output[0]))
4040
print("==========================================\n\n")
4141

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 53 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class QuIPModifier(Modifier):
5858
:param transform_config: Optional transform config for overriding provided arguments
5959
""" # noqa: E501
6060

61+
rotations: List[Literal["v", "u"]] = Field(default_factory=lambda: ["v", "u"])
6162
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
6263
default="random-hadamard"
6364
)
@@ -77,6 +78,12 @@ def validate_not_implemented(cls, value, info: ValidationInfo):
7778
raise NotImplementedError(f"{info.field_name} is not supported right now")
7879
return value
7980

81+
@field_validator("rotations", mode="before")
82+
def validate_lowercase_list(cls, value):
83+
if isinstance(value, list):
84+
value = [v.lower() if isinstance(v, str) else v for v in value]
85+
return value
86+
8087
def on_initialize(self, state: State, **kwargs) -> bool:
8188
if self.transform_config is not None:
8289
return True
@@ -111,45 +118,52 @@ def on_finalize(self, state: State, **kwargs) -> bool:
111118
return True
112119

113120
def _create_config(self) -> TransformConfig:
114-
return TransformConfig(
115-
config_groups={
116-
"v": TransformScheme(
117-
type=self.transform_type,
118-
apply=[
119-
TransformArgs(
120-
targets=self.targets,
121-
location="input", # non-mergable
122-
ignore=self.ignore,
123-
),
124-
TransformArgs(
125-
targets=self.targets,
126-
location="weight_input",
127-
inverse=True,
128-
ignore=self.ignore,
129-
),
130-
],
131-
randomize=self.randomize,
132-
requires_grad=self.learnable,
133-
precision=self.precision,
121+
config_groups = dict()
122+
if "v" in self.rotations:
123+
config_groups["v"] = self._create_v_scheme()
124+
if "u" in self.rotations:
125+
config_groups["u"] = self._create_u_scheme()
126+
127+
return TransformConfig(config_groups=config_groups)
128+
129+
def _create_v_scheme(self) -> TransformScheme:
130+
return TransformScheme(
131+
type=self.transform_type,
132+
apply=[
133+
TransformArgs(
134+
targets=self.targets,
135+
location="input", # non-mergable
136+
ignore=self.ignore,
137+
),
138+
TransformArgs(
139+
targets=self.targets,
140+
location="weight_input",
141+
inverse=True,
142+
ignore=self.ignore,
143+
),
144+
],
145+
randomize=self.randomize,
146+
requires_grad=self.learnable,
147+
precision=self.precision,
148+
)
149+
150+
def _create_u_scheme(self) -> TransformScheme:
151+
return TransformScheme(
152+
type=self.transform_type,
153+
apply=[
154+
TransformArgs(
155+
targets=self.targets,
156+
location="weight_output",
157+
ignore=self.ignore,
134158
),
135-
"u": TransformScheme(
136-
type=self.transform_type,
137-
apply=[
138-
TransformArgs(
139-
targets=self.targets,
140-
location="weight_output",
141-
ignore=self.ignore,
142-
),
143-
TransformArgs(
144-
targets=self.targets,
145-
location="output", # non-mergable
146-
inverse=True,
147-
ignore=self.ignore,
148-
),
149-
],
150-
randomize=self.randomize,
151-
requires_grad=self.learnable,
152-
precision=self.precision,
159+
TransformArgs(
160+
targets=self.targets,
161+
location="output", # non-mergable
162+
inverse=True,
163+
ignore=self.ignore,
153164
),
154-
}
165+
],
166+
randomize=self.randomize,
167+
requires_grad=self.learnable,
168+
precision=self.precision,
155169
)

0 commit comments

Comments
 (0)