@@ -64,6 +64,13 @@ def tensor_parameters(self) -> List[ConfigParameterDef]:
64
64
ConfigParameterDef (name = "random_mask" , required = False , default_value = 0.0 ),
65
65
ConfigParameterDef (name = "random_mask_seed" , required = False , default_value = None ),
66
66
]
67
+ if self .sparsification_method == SparsificationMethod .magnitude_outliers :
68
+ res .append (
69
+ ConfigParameterDef (
70
+ name = "gamma" ,
71
+ default_value = 0.01 ,
72
+ )
73
+ )
67
74
return res
68
75
69
76
def make_task (
@@ -83,15 +90,15 @@ def make_task(
83
90
normalize = parameters ["normalize" ],
84
91
rescale = parameters ["rescale" ],
85
92
swapping = parameters ["swapping" ],
86
- out_tensor_name = output_weight . name ,
93
+ weight_info = output_weight ,
87
94
)
88
95
89
96
90
97
class GTATask (Task [torch .Tensor ]):
91
98
method : GeneralizedTaskArithmeticMerge
92
99
tensors : GatherTensors
93
100
base_model : ModelReference
94
- out_tensor_name : str
101
+ weight_info : WeightInfo
95
102
tensor_parameters : ImmutableMap [ModelReference , Any ]
96
103
int8_mask : bool
97
104
normalize : bool
@@ -111,7 +118,7 @@ def execute(
111
118
) -> torch .Tensor :
112
119
# collect task vectors
113
120
tvs , base = get_task_vectors (
114
- self .out_tensor_name ,
121
+ self .weight_info ,
115
122
self .base_model ,
116
123
tensors ,
117
124
tensor_parameters = self .tensor_parameters .data ,
@@ -123,11 +130,15 @@ def execute(
123
130
# sparsify
124
131
if self .method .sparsification_method :
125
132
for tv_info in tvs :
133
+ kwargs = {}
134
+ if "gamma" in tv_info :
135
+ kwargs ["gamma" ] = tv_info ["gamma" ]
126
136
tv_info ["delta" ] = sparsify (
127
137
tv_info ["delta" ],
128
138
density = tv_info ["density" ],
129
139
method = self .method .sparsification_method ,
130
140
rescale = self .rescale ,
141
+ ** kwargs ,
131
142
)
132
143
133
144
deltas = torch .stack ([tv ["delta" ] for tv in tvs ], dim = 0 )
@@ -218,14 +229,15 @@ def rand_mask(base, x, percent, seed=None):
218
229
219
230
220
231
def get_task_vectors (
221
- parameter_name : str ,
232
+ weight_info : WeightInfo ,
222
233
base_model : ModelReference ,
223
234
tensors : ImmutableMap [ModelReference , torch .Tensor ],
224
235
tensor_parameters : ImmutableMap [ModelReference , ImmutableMap [str , Any ]],
225
236
swapping : bool ,
226
237
) -> Tuple [List [Dict [str , Any ]], torch .Tensor ]:
227
238
keys = list (tensors .keys ())
228
239
base = tensors [base_model ]
240
+ parameter_name = weight_info .name
229
241
230
242
res = []
231
243
for model in keys :
@@ -235,7 +247,7 @@ def get_task_vectors(
235
247
x = tensors [model ].to (base .dtype )
236
248
237
249
if x .shape != base .shape :
238
- if "lm_head" in parameter_name or "embed_tokens" in parameter_name :
250
+ if weight_info . is_embed :
239
251
x = x [: base .shape [0 ], : base .shape [1 ]]
240
252
logging .warning (f"Using submatrix of { model } :{ parameter_name } " )
241
253
else :
0 commit comments