Skip to content

Commit 60f934f

Browse files
committed
Resolving inconsistency between attention/attention_bias
1 parent a65d3d4 commit 60f934f

File tree

2 files changed

+3
-17
lines changed

2 files changed

+3
-17
lines changed

include/tvm/relax/attrs/nn.h

+2
Original file line numberDiff line numberDiff line change
@@ -546,11 +546,13 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
546546

547547
/*! \brief Attributes used in Attention operator */
548548
struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
549+
Optional<Expr> bias;
549550
Optional<FloatImm> scale;
550551
Optional<String> causal_mask;
551552
Optional<IntImm> window_size;
552553

553554
TVM_DECLARE_ATTRS(AttentionAttrs, "relax.attrs.AttentionAttrs") {
555+
TVM_ATTR_FIELD(bias).describe("The input bias tensor.");
554556
TVM_ATTR_FIELD(scale).describe(
555557
"The custom scale applied before the softmax. The default value is 1 / sqrt(head_dim).");
556558
TVM_ATTR_FIELD(causal_mask)

src/relax/op/nn/attention.cc

+1-17
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,8 @@ Expr attention(Expr query, Expr key, Expr value, Optional<Expr> bias, Optional<F
3434
attrs->scale = scale;
3535
attrs->causal_mask = causal_mask;
3636
attrs->window_size = window_size;
37+
attrs->bias = bias;
3738

38-
if (bias) {
39-
return Call(Op::Get("relax.nn.attention_bias"),
40-
{std::move(query), std::move(key), std::move(value), std::move(bias.value())},
41-
Attrs(attrs), {});
42-
}
4339
return Call(Op::Get("relax.nn.attention"), {std::move(query), std::move(key), std::move(value)},
4440
Attrs(attrs), {});
4541
}
@@ -152,18 +148,6 @@ TVM_REGISTER_OP("relax.nn.attention")
152148
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention)
153149
.set_attr<Bool>("FPurity", Bool(true));
154150

155-
TVM_REGISTER_OP("relax.nn.attention_bias")
156-
.set_attrs_type<AttentionAttrs>()
157-
.set_num_inputs(4)
158-
.add_argument("query", "Tensor", "The input queries tensor.")
159-
.add_argument("key", "Tensor", "The input keys tensor.")
160-
.add_argument("value", "Tensor", "The input values tensor.")
161-
.add_argument("bias", "Tensor", "The input bias tensor.")
162-
.set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy", MixedPrecisionPolicyKind::kAlways)
163-
.set_attr<FInferMixedPrecision>("FInferMixedPrecision", InferMixedPrecisionAttention)
164-
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAttention)
165-
.set_attr<Bool>("FPurity", Bool(true));
166-
167151
TVM_REGISTER_OP("relax.nn.attention_var_len")
168152
.set_attrs_type<AttentionAttrs>()
169153
.set_num_inputs(7)

0 commit comments

Comments
 (0)