Skip to content

Commit

Permalink
Fix compilation with clang on ARM64 (Dao-AILab#1285)
Browse files Browse the repository at this point in the history
  • Loading branch information
sclarkson authored Dec 4, 2024
1 parent 0823cf7 commit 1feb711
Show file tree
Hide file tree
Showing 12 changed files with 26 additions and 52 deletions.
15 changes: 5 additions & 10 deletions csrc/flash_attn/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x round_mult
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

Expand Down Expand Up @@ -656,8 +655,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
const int seqlen_k_rounded = round_multiple(max_seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -898,8 +896,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x multipl
bool loop = true;

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q_rounded}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -1126,8 +1123,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads, x head_size
bool loop = true;

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({num_heads, total_q + 128 * batch_size}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -1363,8 +1359,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

Expand Down
3 changes: 1 addition & 2 deletions csrc/flash_attn_ck/mha_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,8 +309,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num
dv = torch::empty_like(v);
}

// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, seqlen_q}, opts.dtype(at::kFloat));
Expand Down
3 changes: 1 addition & 2 deletions csrc/flash_attn_ck/mha_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -234,8 +234,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
bool has_lse = true;
Expand Down
3 changes: 1 addition & 2 deletions csrc/flash_attn_ck/mha_fwd_kvcache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -399,8 +399,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_siz
const int head_size_8x = round_multiple(head_size_og, 8);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

Expand Down
3 changes: 1 addition & 2 deletions csrc/flash_attn_ck/mha_varlen_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,7 @@ mha_varlen_bwd(const at::Tensor &dout, // total_q x num_heads
dv = torch::empty_like(v);
}

// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_d = torch::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
Expand Down
3 changes: 1 addition & 2 deletions csrc/flash_attn_ck/mha_varlen_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_si
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
bool has_lse = true;
Expand Down
3 changes: 1 addition & 2 deletions csrc/ft_attention/ft_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ torch::Tensor single_query_attention(const torch::Tensor q,
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

torch::Tensor out = torch::empty_like(q);

Expand Down
9 changes: 3 additions & 6 deletions csrc/fused_dense_lib/fused_dense.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output,
CHECK_SHAPE(d_output, batch_size, out_features);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
at::cuda::CUDAGuard device_guard{input.device()};

// create output/workspace tensor
auto opts = input.options();
Expand Down Expand Up @@ -115,8 +114,7 @@ std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)input.get_device()};
at::cuda::CUDAGuard device_guard{input.device()};

// create output/workspace tensor
auto opts = input.options();
Expand Down Expand Up @@ -176,8 +174,7 @@ std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)weight.get_device()};
at::cuda::CUDAGuard device_guard{weight.device()};

// create output/workspace tensor
auto opts = weight.options();
Expand Down
12 changes: 4 additions & 8 deletions csrc/layer_norm/ln_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,7 @@ std::vector<at::Tensor> dropout_add_ln_fwd(const at::Tensor &x0, // Input:
TORCH_CHECK(epsilon >= 0.f);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
at::cuda::CUDAGuard device_guard{x0.device()};

auto opts = x0.options();

Expand Down Expand Up @@ -398,8 +397,7 @@ std::vector<at::Tensor> dropout_add_ln_bwd(const at::Tensor &dz, // BxSxhidd
TORCH_CHECK(gamma.numel() == cols);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)dz.get_device()};
at::cuda::CUDAGuard device_guard{dz.device()};

auto opts = x.options();

Expand Down Expand Up @@ -558,8 +556,7 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_fwd(
TORCH_CHECK(epsilon >= 0.f);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x0.get_device()};
at::cuda::CUDAGuard device_guard{x0.device()};

auto opts = x0.options();

Expand Down Expand Up @@ -744,8 +741,7 @@ std::vector<at::Tensor> dropout_add_ln_parallel_residual_bwd(
TORCH_CHECK(mu.sizes() == rsigma.sizes());

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)dz0.get_device()};
at::cuda::CUDAGuard device_guard{dz0.device()};

auto opts = x.options();

Expand Down
3 changes: 1 addition & 2 deletions csrc/rotary/rotary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,7 @@ void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
TORCH_CHECK(out1.sizes() == out2.sizes());

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x1.get_device()};
at::cuda::CUDAGuard device_guard{x1.device()};

apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
}
Expand Down
6 changes: 2 additions & 4 deletions csrc/xentropy/xentropy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -631,8 +631,7 @@ std::vector<Tensor> host_softmax_xentropy(
AT_ASSERTM(labels_.scalar_type() == ScalarType::Long,"Label type should be CUDA Long");

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)input_.get_device()};
at::cuda::CUDAGuard device_guard{input_.device()};

auto input = input_.contiguous();
Tensor max_log_sum_exp = at::empty_like(labels_, input.options().dtype(ScalarType::Float));
Expand Down Expand Up @@ -690,8 +689,7 @@ Tensor host_softmax_xentropy_backward(
bool inplace,
const int total_classes) {
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)grad_loss.get_device()};
at::cuda::CUDAGuard device_guard{grad_loss.device()};

const int64_t dim = 1;
Tensor gI = inplace ? logits_ : at::empty_like(logits_);
Expand Down
15 changes: 5 additions & 10 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -551,8 +551,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
if (is_causal) { window_size_right = 0; }

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

Expand Down Expand Up @@ -758,8 +757,7 @@ mha_varlen_fwd(at::Tensor &q, // total_q x num_heads x head_size, total_q := \s
if (is_causal) { window_size_right = 0; }

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
auto softmax_lse = torch::empty({num_heads, total_q}, opts.dtype(at::kFloat));
Expand Down Expand Up @@ -948,8 +946,7 @@ mha_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x head_si
}

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
// Need softmax_d to have seqlen_q_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
Expand Down Expand Up @@ -1168,8 +1165,7 @@ mha_varlen_bwd(const at::Tensor &dout, // batch_size x seqlen_q x num_heads, x
if (is_causal) { window_size_right = 0; }

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();
// Need softmax_d to have total_q_padded_rounded since we want its address to be aligned by 16/8 bytes for TMA / LDG.64
Expand Down Expand Up @@ -1393,8 +1389,7 @@ mha_fwd_kvcache(at::Tensor &q, // batch_size x seqlen_q x num_he
const int seqlen_k_rounded = round_multiple(seqlen_k, 128);

// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)q.get_device()};
at::cuda::CUDAGuard device_guard{q.device()};

auto opts = q.options();

Expand Down

0 comments on commit 1feb711

Please sign in to comment.