Skip to content

Commit f53071b

Browse files
authored
[Transform] Support separating v and u transforms of quip (#1782)
## Purpose ## * Support `rotation` argument with `QuIPModifier` which allows users to select V and U rotations separately. * According to research, the default can be changed to V after a basic recovery evaluation is done ## Changes ## * Add `rotations` argument * Updated docstrings and comments to remove references to spinquant ## Testing ## * Tested model coherence with V and U separately --------- Signed-off-by: Kyle Sayers <[email protected]>
1 parent 727513c commit f53071b

File tree

2 files changed

+60
-46
lines changed

2 files changed

+60
-46
lines changed

examples/transform/quip_example.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@
1818
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
1919

2020
# Configure the quantization algorithm to run.
21-
# * apply spinquant transforms to model in order to make quantization easier
21+
# * apply quip transforms to model in order to make quantization easier
2222
# * quantize the weights to 4 bit with a group size 128
2323
recipe = [
24-
QuIPModifier(targets="Linear", transform_type="random-hadamard"),
24+
QuIPModifier(rotations=["v", "u"], transform_type="random-hadamard"),
2525
QuantizationModifier(targets="Linear", scheme="W4A16", ignore=["lm_head"]),
2626
]
2727

@@ -35,7 +35,7 @@
3535
input_ids = tokenizer("Hello my name is", return_tensors="pt").input_ids.to(
3636
model.device
3737
)
38-
output = model.generate(input_ids, max_new_tokens=100)
38+
output = model.generate(input_ids, max_new_tokens=50)
3939
print(tokenizer.decode(output[0]))
4040
print("==========================================\n\n")
4141

src/llmcompressor/modifiers/transform/quip/base.py

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -32,18 +32,18 @@ class QuIPModifier(Modifier):
3232
3333
Lifecycle:
3434
- on_initialize
35-
- infer SpinQuantMappings & NormMappings
36-
- as needed, create transform schemes for R1, R2, R3, & R4
35+
- as needed, create transform schemes for V (input) and U (output)
3736
- on_start
38-
- normalize embeddings
39-
- fuse norm layers into subsequent Linear layers
4037
- apply TransformConfig
4138
- fuse transforms into weights for mergeable transforms
4239
- add hooks for online transforms
4340
- on sequential epoch end
4441
- on_end
4542
- on_finalize
4643
44+
:param rotations: which rotation schemes to apply to the model. Including `"v"` will
45+
rotate the input side of weights, and including `"u"` will rotate the output
46+
side of weights (note that v does not require u and vice-versa)
4747
:param transform_type: The type of transform to apply to the model.
4848
`"hadamard"` has the least performance cost but only supports sizes which are
4949
powers of power of two.
@@ -58,6 +58,7 @@ class QuIPModifier(Modifier):
5858
:param transform_config: Optional transform config for overriding provided arguments
5959
""" # noqa: E501
6060

61+
rotations: List[Literal["v", "u"]] = Field(default_factory=lambda: ["v", "u"])
6162
transform_type: Literal["hadamard", "random-hadamard", "random-matrix"] = Field(
6263
default="random-hadamard"
6364
)
@@ -77,6 +78,12 @@ def validate_not_implemented(cls, value, info: ValidationInfo):
7778
raise NotImplementedError(f"{info.field_name} is not supported right now")
7879
return value
7980

81+
@field_validator("rotations", mode="before")
82+
def validate_lowercase_list(cls, value):
83+
if isinstance(value, list):
84+
value = [v.lower() if isinstance(v, str) else v for v in value]
85+
return value
86+
8087
def on_initialize(self, state: State, **kwargs) -> bool:
8188
if self.transform_config is not None:
8289
return True
@@ -111,45 +118,52 @@ def on_finalize(self, state: State, **kwargs) -> bool:
111118
return True
112119

113120
def _create_config(self) -> TransformConfig:
114-
return TransformConfig(
115-
config_groups={
116-
"v": TransformScheme(
117-
type=self.transform_type,
118-
apply=[
119-
TransformArgs(
120-
targets=self.targets,
121-
location="input", # non-mergable
122-
ignore=self.ignore,
123-
),
124-
TransformArgs(
125-
targets=self.targets,
126-
location="weight_input",
127-
inverse=True,
128-
ignore=self.ignore,
129-
),
130-
],
131-
randomize=self.randomize,
132-
requires_grad=self.learnable,
133-
precision=self.precision,
121+
config_groups = dict()
122+
if "v" in self.rotations:
123+
config_groups["v"] = self._create_v_scheme()
124+
if "u" in self.rotations:
125+
config_groups["u"] = self._create_u_scheme()
126+
127+
return TransformConfig(config_groups=config_groups)
128+
129+
def _create_v_scheme(self) -> TransformScheme:
130+
return TransformScheme(
131+
type=self.transform_type,
132+
apply=[
133+
TransformArgs(
134+
targets=self.targets,
135+
location="input", # non-mergable
136+
ignore=self.ignore,
137+
),
138+
TransformArgs(
139+
targets=self.targets,
140+
location="weight_input",
141+
inverse=True,
142+
ignore=self.ignore,
143+
),
144+
],
145+
randomize=self.randomize,
146+
requires_grad=self.learnable,
147+
precision=self.precision,
148+
)
149+
150+
def _create_u_scheme(self) -> TransformScheme:
151+
return TransformScheme(
152+
type=self.transform_type,
153+
apply=[
154+
TransformArgs(
155+
targets=self.targets,
156+
location="weight_output",
157+
ignore=self.ignore,
134158
),
135-
"u": TransformScheme(
136-
type=self.transform_type,
137-
apply=[
138-
TransformArgs(
139-
targets=self.targets,
140-
location="weight_output",
141-
ignore=self.ignore,
142-
),
143-
TransformArgs(
144-
targets=self.targets,
145-
location="output", # non-mergable
146-
inverse=True,
147-
ignore=self.ignore,
148-
),
149-
],
150-
randomize=self.randomize,
151-
requires_grad=self.learnable,
152-
precision=self.precision,
159+
TransformArgs(
160+
targets=self.targets,
161+
location="output", # non-mergable
162+
inverse=True,
163+
ignore=self.ignore,
153164
),
154-
}
165+
],
166+
randomize=self.randomize,
167+
requires_grad=self.learnable,
168+
precision=self.precision,
155169
)

0 commit comments

Comments
 (0)