Skip to content

Commit cbbd2cc

Browse files
authored
Add files via upload
1 parent 09a4fb9 commit cbbd2cc

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

mergekit/merge_methods/generalized_task_arithmetic.py

+17-5
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
6464
ConfigParameterDef(name="random_mask", required=False, default_value= 0.0),
6565
ConfigParameterDef(name="random_mask_seed", required=False, default_value= None),
6666
]
67+
if self.sparsification_method == SparsificationMethod.magnitude_outliers:
68+
res.append(
69+
ConfigParameterDef(
70+
name="gamma",
71+
default_value=0.01,
72+
)
73+
)
6774
return res
6875

6976
def make_task(
@@ -83,15 +90,15 @@ def make_task(
8390
normalize=parameters["normalize"],
8491
rescale=parameters["rescale"],
8592
swapping=parameters["swapping"],
86-
out_tensor_name=output_weight.name,
93+
weight_info=output_weight,
8794
)
8895

8996

9097
class GTATask(Task[torch.Tensor]):
9198
method: GeneralizedTaskArithmeticMerge
9299
tensors: GatherTensors
93100
base_model: ModelReference
94-
out_tensor_name: str
101+
weight_info: WeightInfo
95102
tensor_parameters: ImmutableMap[ModelReference, Any]
96103
int8_mask: bool
97104
normalize: bool
@@ -111,7 +118,7 @@ def execute(
111118
) -> torch.Tensor:
112119
# collect task vectors
113120
tvs, base = get_task_vectors(
114-
self.out_tensor_name,
121+
self.weight_info,
115122
self.base_model,
116123
tensors,
117124
tensor_parameters=self.tensor_parameters.data,
@@ -123,11 +130,15 @@ def execute(
123130
# sparsify
124131
if self.method.sparsification_method:
125132
for tv_info in tvs:
133+
kwargs = {}
134+
if "gamma" in tv_info:
135+
kwargs["gamma"] = tv_info["gamma"]
126136
tv_info["delta"] = sparsify(
127137
tv_info["delta"],
128138
density=tv_info["density"],
129139
method=self.method.sparsification_method,
130140
rescale=self.rescale,
141+
**kwargs,
131142
)
132143

133144
deltas = torch.stack([tv["delta"] for tv in tvs], dim=0)
@@ -218,14 +229,15 @@ def rand_mask(base, x, percent, seed=None):
218229

219230

220231
def get_task_vectors(
221-
parameter_name: str,
232+
weight_info: WeightInfo,
222233
base_model: ModelReference,
223234
tensors: ImmutableMap[ModelReference, torch.Tensor],
224235
tensor_parameters: ImmutableMap[ModelReference, ImmutableMap[str, Any]],
225236
swapping: bool,
226237
) -> Tuple[List[Dict[str, Any]], torch.Tensor]:
227238
keys = list(tensors.keys())
228239
base = tensors[base_model]
240+
parameter_name = weight_info.name
229241

230242
res = []
231243
for model in keys:
@@ -235,7 +247,7 @@ def get_task_vectors(
235247
x = tensors[model].to(base.dtype)
236248

237249
if x.shape != base.shape:
238-
if "lm_head" in parameter_name or "embed_tokens" in parameter_name:
250+
if weight_info.is_embed:
239251
x = x[: base.shape[0], : base.shape[1]]
240252
logging.warning(f"Using submatrix of {model}:{parameter_name}")
241253
else:

0 commit comments

Comments
 (0)