From cc01e2bd5c4c67442689884ce996a41dcbf7545a Mon Sep 17 00:00:00 2001 From: Tri Dao Date: Tue, 19 Dec 2023 21:33:33 -0800 Subject: [PATCH] Fix generation when prompt length == 1 --- csrc/selective_scan/selective_scan.cpp | 34 +++++++++++++------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/csrc/selective_scan/selective_scan.cpp b/csrc/selective_scan/selective_scan.cpp index f51af402..cde867cd 100644 --- a/csrc/selective_scan/selective_scan.cpp +++ b/csrc/selective_scan/selective_scan.cpp @@ -249,8 +249,8 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, TORCH_CHECK(B.is_cuda()); TORCH_CHECK(C.is_cuda()); - TORCH_CHECK(u.stride(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1); + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); const auto sizes = u.sizes(); const int batch_size = sizes[0]; @@ -268,20 +268,20 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, CHECK_SHAPE(B, dim, dstate); } else { CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); } if (!is_variable_C) { CHECK_SHAPE(C, dim, dstate); } else { CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); } if (D_.has_value()) { auto D = D_.value(); TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); CHECK_SHAPE(D, dim); } @@ -289,7 +289,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, auto delta_bias = delta_bias_.value(); TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); CHECK_SHAPE(delta_bias, dim); } @@ -299,7 +299,7 @@ selective_scan_fwd(const at::Tensor &u, const at::Tensor &delta, z = z_.value(); TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); CHECK_SHAPE(z, batch_size, dim, seqlen); out_z = torch::empty_like(z); } @@ -368,9 +368,9 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, TORCH_CHECK(C.is_cuda()); TORCH_CHECK(dout.is_cuda()); - TORCH_CHECK(u.stride(-1) == 1); - TORCH_CHECK(delta.stride(-1) == 1); - TORCH_CHECK(dout.stride(-1) == 1); + TORCH_CHECK(u.stride(-1) == 1 || u.size(-1) == 1); + TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); + TORCH_CHECK(dout.stride(-1) == 1 || dout.size(-1) == 1); const auto sizes = u.sizes(); const int batch_size = sizes[0]; @@ -388,13 +388,13 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, CHECK_SHAPE(B, dim, dstate); } else { CHECK_SHAPE(B, batch_size, n_groups, dstate, !is_complex ? seqlen : seqlen * 2); - TORCH_CHECK(B.stride(-1) == 1); + TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); } if (!is_variable_C) { CHECK_SHAPE(C, dim, dstate); } else { CHECK_SHAPE(C, batch_size, n_groups, dstate, !is_complex ? seqlen: seqlen * 2); - TORCH_CHECK(C.stride(-1) == 1); + TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); } CHECK_SHAPE(dout, batch_size, dim, seqlen); @@ -402,7 +402,7 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, auto D = D_.value(); TORCH_CHECK(D.scalar_type() == at::ScalarType::Float); TORCH_CHECK(D.is_cuda()); - TORCH_CHECK(D.stride(-1) == 1); + TORCH_CHECK(D.stride(-1) == 1 || D.size(-1) == 1); CHECK_SHAPE(D, dim); } @@ -410,7 +410,7 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, auto delta_bias = delta_bias_.value(); TORCH_CHECK(delta_bias.scalar_type() == at::ScalarType::Float); TORCH_CHECK(delta_bias.is_cuda()); - TORCH_CHECK(delta_bias.stride(-1) == 1); + TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); CHECK_SHAPE(delta_bias, dim); } @@ -420,21 +420,21 @@ selective_scan_bwd(const at::Tensor &u, const at::Tensor &delta, z = z_.value(); TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.is_cuda()); - TORCH_CHECK(z.stride(-1) == 1); + TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); CHECK_SHAPE(z, batch_size, dim, seqlen); TORCH_CHECK(out_.has_value()); out = out_.value(); TORCH_CHECK(out.scalar_type() == input_type); TORCH_CHECK(out.is_cuda()); - TORCH_CHECK(out.stride(-1) == 1); + TORCH_CHECK(out.stride(-1) == 1 || out.size(-1) == 1); CHECK_SHAPE(out, batch_size, dim, seqlen); if (dz_.has_value()) { dz = dz_.value(); TORCH_CHECK(dz.scalar_type() == input_type); TORCH_CHECK(dz.is_cuda()); - TORCH_CHECK(dz.stride(-1) == 1); + TORCH_CHECK(dz.stride(-1) == 1 || dz.size(-1) == 1); CHECK_SHAPE(dz, batch_size, dim, seqlen); } else { dz = torch::empty_like(z);