@@ -27,12 +27,48 @@ def __init__(self, target: tvm.target.Target):
27
27
def transform_module (self , mod : IRModule , _ctx : tvm .transform .PassContext ) -> IRModule :
28
28
"""Entrypoint"""
29
29
mod = mod .clone ()
30
- mod ["apply_logit_bias_inplace" ] = _get_apply_logit_bias_inplace (self .target )
31
- mod ["apply_penalty_inplace" ] = _get_apply_penalty_inplace (self .target )
32
- mod ["apply_bitmask_inplace" ] = _get_apply_bitmask_inplace (self .target )
30
+ if str (self .target .kind ) == "llvm" :
31
+ mod ["apply_logit_bias_inplace" ] = _get_apply_logit_bias_inplace_cpu (self .target )
32
+ mod ["apply_penalty_inplace" ] = _get_apply_penalty_inplace_cpu (self .target )
33
+ mod ["apply_bitmask_inplace" ] = _get_apply_bitmask_inplace_cpu (self .target )
34
+ else :
35
+ mod ["apply_logit_bias_inplace" ] = _get_apply_logit_bias_inplace (self .target )
36
+ mod ["apply_penalty_inplace" ] = _get_apply_penalty_inplace (self .target )
37
+ mod ["apply_bitmask_inplace" ] = _get_apply_bitmask_inplace (self .target )
33
38
return mod
34
39
35
40
41
+ def _get_apply_logit_bias_inplace_cpu (target : tvm .target .Target ):
42
+ @T .prim_func
43
+ def _apply_logit_bias_inplace (
44
+ var_logits : T .handle ,
45
+ var_pos2seq_id : T .handle ,
46
+ var_token_ids : T .handle ,
47
+ var_logit_bias : T .handle ,
48
+ ) -> None :
49
+ """Function that applies logit bias in place."""
50
+ T .func_attr (
51
+ {
52
+ "global_symbol" : "apply_logit_bias_inplace" ,
53
+ "tir.noalias" : True ,
54
+ "tir.is_scheduled" : True ,
55
+ }
56
+ )
57
+ batch_size = T .int32 (is_size_var = True )
58
+ vocab_size = T .int32 (is_size_var = True )
59
+ num_token = T .int32 (is_size_var = True )
60
+ logits = T .match_buffer (var_logits , (batch_size , vocab_size ), "float32" )
61
+ # seq_ids
62
+ pos2seq_id = T .match_buffer (var_pos2seq_id , (num_token ,), "int32" )
63
+ token_ids = T .match_buffer (var_token_ids , (num_token ,), "int32" )
64
+ logit_bias = T .match_buffer (var_logit_bias , (num_token ,), "float32" )
65
+
66
+ for i in range (num_token ):
67
+ logits [pos2seq_id [i ], token_ids [i ]] += logit_bias [i ]
68
+
69
+ return _apply_logit_bias_inplace
70
+
71
+
36
72
def _get_apply_logit_bias_inplace (target : tvm .target .Target ):
37
73
tx = 1024 # default
38
74
max_num_threads_per_block = get_max_num_threads_per_block (target )
@@ -74,6 +110,50 @@ def _apply_logit_bias_inplace(
74
110
return _apply_logit_bias_inplace
75
111
76
112
113
+ def _get_apply_penalty_inplace_cpu (target : tvm .target .Target ):
114
+ @T .prim_func
115
+ def _apply_penalty_inplace ( # pylint: disable=too-many-arguments,too-many-locals
116
+ var_logits : T .handle ,
117
+ var_seq_ids : T .handle ,
118
+ var_pos2seq_id : T .handle ,
119
+ var_token_ids : T .handle ,
120
+ var_token_cnt : T .handle ,
121
+ var_penalties : T .handle ,
122
+ ) -> None :
123
+ """Function that applies penalties in place."""
124
+ T .func_attr (
125
+ {
126
+ "global_symbol" : "apply_penalty_inplace" ,
127
+ "tir.noalias" : True ,
128
+ "tir.is_scheduled" : True ,
129
+ }
130
+ )
131
+ batch_size = T .int32 (is_size_var = True )
132
+ vocab_size = T .int32 (is_size_var = True )
133
+ num_token = T .int32 (is_size_var = True )
134
+ num_seq = T .int32 (is_size_var = True )
135
+ logits = T .match_buffer (var_logits , (batch_size , vocab_size ), "float32" )
136
+ seq_ids = T .match_buffer (var_seq_ids , (num_seq ,), "int32" )
137
+ pos2seq_id = T .match_buffer (var_pos2seq_id , (num_token ,), "int32" )
138
+ token_ids = T .match_buffer (var_token_ids , (num_token ,), "int32" )
139
+ token_cnt = T .match_buffer (var_token_cnt , (num_token ,), "int32" )
140
+ penalties = T .match_buffer (var_penalties , (num_seq , 3 ), "float32" )
141
+
142
+ for token in T .serial (num_token ):
143
+ with T .block ("block" ):
144
+ vp = T .axis .spatial (num_token , token )
145
+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] -= (
146
+ penalties [pos2seq_id [vp ], 0 ] + token_cnt [vp ] * penalties [pos2seq_id [vp ], 1 ]
147
+ )
148
+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] = T .if_then_else (
149
+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] < 0 ,
150
+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] * penalties [pos2seq_id [vp ], 2 ],
151
+ logits [seq_ids [pos2seq_id [vp ]], token_ids [vp ]] / penalties [pos2seq_id [vp ], 2 ],
152
+ )
153
+
154
+ return _apply_penalty_inplace
155
+
156
+
77
157
def _get_apply_penalty_inplace (target : tvm .target .Target ):
78
158
tx = 1024 # default
79
159
max_num_threads_per_block = get_max_num_threads_per_block (target )
@@ -129,6 +209,42 @@ def _apply_penalty_inplace( # pylint: disable=too-many-arguments,too-many-local
129
209
return _apply_penalty_inplace
130
210
131
211
212
+ def _get_apply_bitmask_inplace_cpu (target : tvm .target .Target ):
213
+ @T .prim_func
214
+ def _apply_bitmask_inplace (
215
+ var_logits : T .handle ,
216
+ var_seq_ids : T .handle ,
217
+ var_bitmask : T .handle ,
218
+ ) -> None :
219
+ """Function that applies vocabulary masking in place."""
220
+ T .func_attr (
221
+ {
222
+ "global_symbol" : "apply_bitmask_inplace" ,
223
+ "tir.noalias" : True ,
224
+ "tir.is_scheduled" : True ,
225
+ }
226
+ )
227
+ batch_size = T .int32 (is_size_var = True )
228
+ vocab_size = T .int32 (is_size_var = True )
229
+ num_seq = T .int32 (is_size_var = True )
230
+ logits = T .match_buffer (var_logits , (batch_size , vocab_size ), "float32" )
231
+ seq_ids = T .match_buffer (var_seq_ids , (num_seq ,), "int32" )
232
+ bitmask = T .match_buffer (var_bitmask , (batch_size , (vocab_size + 31 ) // 32 ), "int32" )
233
+
234
+ for token in T .serial (num_seq * vocab_size ):
235
+ with T .block ("block" ):
236
+ vs = T .axis .spatial (num_seq , (token ) // vocab_size )
237
+ vv = T .axis .spatial (vocab_size , (token ) % vocab_size )
238
+
239
+ logits [seq_ids [vs ], vv ] = T .if_then_else (
240
+ (bitmask [seq_ids [vs ], vv // 32 ] >> (vv % 32 )) & 1 == 1 ,
241
+ logits [seq_ids [vs ], vv ],
242
+ T .min_value ("float32" ),
243
+ )
244
+
245
+ return _apply_bitmask_inplace
246
+
247
+
132
248
def _get_apply_bitmask_inplace (target : tvm .target .Target ):
133
249
tx = 1024 # default
134
250
max_num_threads_per_block = get_max_num_threads_per_block (target )
0 commit comments