Skip to content

Commit

Permalink
Fix compile (#1501)
Browse files Browse the repository at this point in the history
* fix compile

* fix space
  • Loading branch information
awni authored Oct 18, 2024
1 parent 50d8bed commit 92d7cb7
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 14 deletions.
4 changes: 2 additions & 2 deletions mlx/backend/common/compiled_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@ void* compile(
source_file.close();

std::ostringstream build_command;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared "
<< source_file_path << " -o " << shared_lib_path;
build_command << "g++ -std=c++17 -O3 -Wall -fPIC -shared '"
<< source_file_path << "' -o '" << shared_lib_path << "'";
std::string build_command_str = build_command.str();
auto return_code = system(build_command_str.c_str());
if (return_code) {
Expand Down
21 changes: 13 additions & 8 deletions mlx/backend/metal/compiled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,10 @@ inline void build_kernel(
// a third grid dimension
os << " size_t index = pos.x + grid.x * size_t(pos.y);\n";
} else if (work_per_thread > 1) {
os << " constexpr int N = " << std::to_string(work_per_thread) << ";\n"
os << " constexpr int N_ = " << std::to_string(work_per_thread) << ";\n"
<< " int xshape = output_shape["
<< (dynamic_dims ? "ndim - 1" : std::to_string(ndim - 1)) << "];\n"
<< " size_t index = N * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
<< " size_t index = N_ * pos.x + xshape * (pos.y + size_t(grid.y) * pos.z);\n";
} else {
os << " size_t index = pos.x + grid.x * (pos.y + size_t(grid.y) * pos.z);\n";
}
Expand Down Expand Up @@ -141,11 +141,11 @@ inline void build_kernel(
<< "in_strides + " << offset << ");\n";
} else if (!dynamic_dims) {
int offset = i * ndim;
os << " size_t index_" << xname << " = N * pos.x * in_strides["
os << " size_t index_" << xname << " = N_ * pos.x * in_strides["
<< offset + ndim - 1 << "]"
<< " + pos.y * in_strides[" << offset + ndim - 2 << "];\n";
} else {
os << " size_t index_" << xname << " = N * pos.x * in_strides[ndim * "
os << " size_t index_" << xname << " = N_ * pos.x * in_strides[ndim * "
<< i << " + ndim - 1]"
<< " + pos.y * in_strides[ndim * " << i << " + ndim - 2];\n";
}
Expand All @@ -172,7 +172,7 @@ inline void build_kernel(

// Open per-thread loop
if (work_per_thread > 1) {
os << " for (int i = 0; i < N && (int(N * pos.x) + i) < xshape; ++i) {\n";
os << " for (int i = 0; i < N_ && (int(N_ * pos.x) + i) < xshape; ++i) {\n";
}

// Read non-contiguous inputs into tmps
Expand Down Expand Up @@ -434,10 +434,15 @@ void Compiled::eval_gpu(
int work_per_thread = ndim > 3 ? WORK_PER_THREAD : 1;
dim0 = (dim0 + work_per_thread - 1) / work_per_thread;
NS::UInteger thread_group_size = kernel->maxTotalThreadsPerThreadgroup();
if (thread_group_size != 1024) {
throw std::runtime_error("[Metal::binary] Must use 1024 sized block");
int pow2;
if (thread_group_size == 1024) {
pow2 = 10;
} else if (thread_group_size > 512) {
pow2 = 9;
} else {
throw std::runtime_error("[Metal::compiled] Must use > 512 sized block");
}
auto group_dims = get_block_dims(dim0, dim1, rest);
auto group_dims = get_block_dims(dim0, dim1, rest, pow2);
MTL::Size grid_dims = MTL::Size(dim0, dim1, rest);
compute_encoder.dispatchThreads(grid_dims, group_dims);
}
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/metal/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ std::string type_to_name(const array& a) {
return tname;
}

MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 /* = 10 */) {
int pows[3] = {0, 0, 0};
int sum = 0;
while (true) {
Expand All @@ -76,7 +76,7 @@ MTL::Size get_block_dims(int dim0, int dim1, int dim2) {
pows[2]++;
sum++;
}
if (sum == presum || sum == 10) {
if (sum == presum || sum == pow2) {
break;
}
}
Expand Down
4 changes: 2 additions & 2 deletions mlx/backend/metal/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ std::string type_to_name(const array& a);
// Compute the thread block dimensions which fit the given
// input dimensions.
// - The thread block dimensions will be powers of two
// - The thread block size will be less than 1024
MTL::Size get_block_dims(int dim0, int dim1, int dim2);
// - The thread block size will be less than 2^pow2
MTL::Size get_block_dims(int dim0, int dim1, int dim2, int pow2 = 10);

// Computes a 2D grid where each element is < UINT_MAX
// Assumes:
Expand Down
17 changes: 17 additions & 0 deletions python/tests/test_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -772,6 +772,23 @@ def fn(a, b):
print((out - expected).abs().max())
self.assertTrue(mx.allclose(out, expected))

def test_compile_many_inputs(self):
inputs = [mx.ones((2, 2, 2, 2)) for _ in range(20)]
inputs[0] = inputs[0].T

@mx.compile
def fun(*inputs):
x = inputs[0]
for y in inputs[1:10]:
x = x + y
a = inputs[10]
for b in inputs[11:]:
a = a + b
return x + a

out = fun(*inputs)
self.assertTrue(mx.allclose(out, mx.full((2, 2), 20)))


if __name__ == "__main__":
unittest.main()

0 comments on commit 92d7cb7

Please sign in to comment.