@@ -34,12 +34,8 @@ Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, Optional<F
34
34
attrs->scale = scale;
35
35
attrs->causal_mask = causal_mask;
36
36
attrs->window_size = window_size;
37
+ attrs->bias = bias;
37
38
38
- if (bias) {
39
- return Call (Op::Get (" relax.nn.attention_bias" ),
40
- {std::move (query), std::move (key), std::move (value), std::move (bias.value ())},
41
- Attrs (attrs), {});
42
- }
43
39
return Call (Op::Get (" relax.nn.attention" ), {std::move (query), std::move (key), std::move (value)},
44
40
Attrs (attrs), {});
45
41
}
@@ -152,18 +148,6 @@ TVM_REGISTER_OP("relax.nn.attention")
152
148
.set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoAttention)
153
149
.set_attr<Bool>(" FPurity" , Bool(true ));
154
150
155
- TVM_REGISTER_OP (" relax.nn.attention_bias" )
156
- .set_attrs_type<AttentionAttrs>()
157
- .set_num_inputs(4 )
158
- .add_argument(" query" , " Tensor" , " The input queries tensor." )
159
- .add_argument(" key" , " Tensor" , " The input keys tensor." )
160
- .add_argument(" value" , " Tensor" , " The input values tensor." )
161
- .add_argument(" bias" , " Tensor" , " The input bias tensor." )
162
- .set_attr<TMixedPrecisionPolicy>(" TMixedPrecisionPolicy" , MixedPrecisionPolicyKind::kAlways )
163
- .set_attr<FInferMixedPrecision>(" FInferMixedPrecision" , InferMixedPrecisionAttention)
164
- .set_attr<FInferStructInfo>(" FInferStructInfo" , InferStructInfoAttention)
165
- .set_attr<Bool>(" FPurity" , Bool(true ));
166
-
167
151
TVM_REGISTER_OP (" relax.nn.attention_var_len" )
168
152
.set_attrs_type<AttentionAttrs>()
169
153
.set_num_inputs(7 )
0 commit comments