Skip to content

Commit

Permalink
Gemm update (#1518)
Browse files Browse the repository at this point in the history
  • Loading branch information
jagrit06 authored Oct 31, 2024
1 parent 884af42 commit 960e3f0
Show file tree
Hide file tree
Showing 9 changed files with 699 additions and 194 deletions.
1 change: 1 addition & 0 deletions mlx/backend/metal/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,7 @@ Device::Device() {
auto pool = new_scoped_memory_pool();
device_ = load_device();
library_map_ = {{"mlx", load_library(device_)}};
arch_ = std::string(device_->architecture()->name()->utf8String());
}

Device::~Device() {
Expand Down
5 changes: 5 additions & 0 deletions mlx/backend/metal/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,10 @@ class Device {
return device_;
};

const std::string& get_architecture() {
return arch_;
}

void new_queue(int index);
MTL::CommandBuffer* get_command_buffer(int index);
int get_command_buffer_ops(int index);
Expand Down Expand Up @@ -228,6 +232,7 @@ class Device {
std::shared_mutex library_mtx_;
std::unordered_map<std::string, MTL::Library*> library_map_;
const MTL::ResidencySet* residency_set_{nullptr};
std::string arch_;
};

Device& device(mlx::core::Device);
Expand Down
4 changes: 3 additions & 1 deletion mlx/backend/metal/kernels/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ set(STEEL_HEADERS
steel/gemm/transforms.h
steel/gemm/kernels/steel_gemm_fused.h
steel/gemm/kernels/steel_gemm_masked.h
steel/gemm/kernels/steel_gemm_splitk.h)
steel/gemm/kernels/steel_gemm_splitk.h
steel/utils/type_traits.h
steel/utils/integral_constant.h)

if(NOT MLX_METAL_JIT)
build_kernel(arange arange.h)
Expand Down
20 changes: 10 additions & 10 deletions mlx/backend/metal/kernels/steel/conv/kernels/steel_conv_general.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ implicit_gemm_conv_2d_general(
// Store results to device memory
{
// Adjust for simdgroup and thread locatio
int offset_m = c_row + mma_op.sm + mma_op.tm;
int offset_n = c_col + mma_op.sn + mma_op.tn;
int offset_m = c_row + mma_op.sm;
int offset_n = c_col + mma_op.sn;
C += offset_n;

if (offset_n >= gemm_params->N)
Expand All @@ -169,17 +169,17 @@ implicit_gemm_conv_2d_general(
STEEL_PRAGMA_UNROLL
for (int j = 0; j < mma_t::TN; j++) {
// Get accumulated result and associated offset in C
thread const auto& accum =
mma_op.results[i * mma_t::TN + j].thread_elements();
thread const auto& accum = mma_op.Ctile.frag_at(i, j);
int offset = offset_cm + (j * mma_t::TN_stride);

// Apply epilogue and output C
if (j * mma_t::TN_stride < diff) {
C[offset] = Epilogue::apply(accum[0]);
}
constexpr short kelems = decltype(mma_op.Ctile)::kElemsPerFrag;

if (j * mma_t::TN_stride + 1 < diff) {
C[offset + 1] = Epilogue::apply(accum[1]);
// Apply epilogue and output C
STEEL_PRAGMA_UNROLL
for (short k = 0; k < kelems; k++) {
if ((j * mma_t::TN_stride + k) < diff) {
C[offset + k] = Epilogue::apply(accum[k]);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
instantiate_gemm(tt, true , true , iname, itype, oname, otype, bm, bn, bk, wm, wn)

#define instantiate_gemm_shapes_helper(iname, itype, oname, otype) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 32, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 64, 32, 16, 2, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 2, 2)
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 64, 16, 1, 2) \
instantiate_gemm_transpose_helper(iname, itype, oname, otype, 32, 32, 16, 2, 2)

instantiate_gemm_shapes_helper(float16, half, float16, half);
instantiate_gemm_shapes_helper(bfloat16, bfloat16_t, bfloat16, bfloat16_t);
Expand Down
Loading

0 comments on commit 960e3f0

Please sign in to comment.