Skip to content

Commit

Permalink
Fix generation when prompt length == 1
Browse files Browse the repository at this point in the history
  • Loading branch information
tridao committed Dec 20, 2023
1 parent fc80874 commit cc01e2b
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions csrc/selective_scan/selective_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
TORCH_CHECK(B.is_cuda());
TORCH_CHECK(C.is_cuda());

TORCH_CHECK(u.stride(-1) == 1);
TORCH_CHECK(delta.stride(-1) == 1);
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);

const auto sizes = u.sizes();
const int batch_size = sizes[0];
Expand All @@ -268,28 +268,28 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
CHECK_SHAPE(B, dim, dstate);
} else {
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
TORCH_CHECK(B.stride(-1) == 1);
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
}
if (!is_variable_C) {
CHECK_SHAPE(C, dim, dstate);
} else {
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
TORCH_CHECK(C.stride(-1) == 1);
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
}

if (D_.has_value()) {
auto D = D_.value();
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(D.is_cuda());
TORCH_CHECK(D.stride(-1) == 1);
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
CHECK_SHAPE(D, dim);
}

if (delta_bias_.has_value()) {
auto delta_bias = delta_bias_.value();
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(delta_bias.is_cuda());
TORCH_CHECK(delta_bias.stride(-1) == 1);
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
CHECK_SHAPE(delta_bias, dim);
}

Expand All @@ -299,7 +299,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta,
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1);
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
CHECK_SHAPE(z, batch_size, dim, seqlen);
out_z = torch::empty_like(z);
}
Expand Down Expand Up @@ -368,9 +368,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
TORCH_CHECK(C.is_cuda());
TORCH_CHECK(dout.is_cuda());

TORCH_CHECK(u.stride(-1) == 1);
TORCH_CHECK(delta.stride(-1) == 1);
TORCH_CHECK(dout.stride(-1) == 1);
TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1);
TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1);
TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1);

const auto sizes = u.sizes();
const int batch_size = sizes[0];
Expand All @@ -388,29 +388,29 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
CHECK_SHAPE(B, dim, dstate);
} else {
CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2);
TORCH_CHECK(B.stride(-1) == 1);
TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1);
}
if (!is_variable_C) {
CHECK_SHAPE(C, dim, dstate);
} else {
CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2);
TORCH_CHECK(C.stride(-1) == 1);
TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1);
}
CHECK_SHAPE(dout, batch_size, dim, seqlen);

if (D_.has_value()) {
auto D = D_.value();
TORCH_CHECK(D.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(D.is_cuda());
TORCH_CHECK(D.stride(-1) == 1);
TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1);
CHECK_SHAPE(D, dim);
}

if (delta_bias_.has_value()) {
auto delta_bias = delta_bias_.value();
TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float);
TORCH_CHECK(delta_bias.is_cuda());
TORCH_CHECK(delta_bias.stride(-1) == 1);
TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1);
CHECK_SHAPE(delta_bias, dim);
}

Expand All @@ -420,21 +420,21 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta,
z = z_.value();
TORCH_CHECK(z.scalar_type() == input_type);
TORCH_CHECK(z.is_cuda());
TORCH_CHECK(z.stride(-1) == 1);
TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1);
CHECK_SHAPE(z, batch_size, dim, seqlen);

TORCH_CHECK(out_.has_value());
out = out_.value();
TORCH_CHECK(out.scalar_type() == input_type);
TORCH_CHECK(out.is_cuda());
TORCH_CHECK(out.stride(-1) == 1);
TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1);
CHECK_SHAPE(out, batch_size, dim, seqlen);

if (dz_.has_value()) {
dz = dz_.value();
TORCH_CHECK(dz.scalar_type() == input_type);
TORCH_CHECK(dz.is_cuda());
TORCH_CHECK(dz.stride(-1) == 1);
TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1);
CHECK_SHAPE(dz, batch_size, dim, seqlen);
} else {
dz = torch::empty_like(z);
Expand Down

0 comments on commit cc01e2b

Please sign in to comment.