Skip to content

Commit

Permalink
Fix boolean all reduce bug (#1355)
Browse files Browse the repository at this point in the history
  • Loading branch information
angeloskath authored Aug 24, 2024
1 parent 64bec4f commit 8081df7
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 3 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)

if(NOT MLX_VERSION)
set(MLX_VERSION 0.17.0)
set(MLX_VERSION 0.17.1)
endif()

# --------------------- Processor tests -------------------------
Expand Down
6 changes: 5 additions & 1 deletion mlx/backend/metal/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,7 +308,11 @@ void all_reduce_dispatch(
compute_encoder.dispatchThreads(grid_dims, group_dims);

// 2nd pass
compute_encoder->setComputePipelineState(kernel);
std::ostringstream kname_2nd_pass;
kname_2nd_pass << "all_reduce_" << op_name << type_to_name(intermediate);
auto kernel_2nd_pass =
get_reduce_kernel(d, kname_2nd_pass.str(), op_name, intermediate, out);
compute_encoder->setComputePipelineState(kernel_2nd_pass);
size_t intermediate_size = n_rows;
grid_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
group_dims = MTL::Size(threadgroup_2nd_pass, 1, 1);
Expand Down
7 changes: 7 additions & 0 deletions python/tests/test_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,13 @@ def test_edge_case(self):
z = np.array(x).sum((0, 2, 3))
self.assertTrue(np.all(z == y))

def test_sum_bool(self):
x = np.random.uniform(0, 1, size=(10, 10, 10)) > 0.5
y = mx.array(x)
npsum = x.sum().item()
mxsum = y.sum().item()
self.assertEqual(npsum, mxsum)


if __name__ == "__main__":
unittest.main(failfast=True)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def run(self) -> None:

setup(
name="mlx",
version=get_version("0.17.0"),
version=get_version("0.17.1"),
author="MLX Contributors",
author_email="[email protected]",
description="A framework for machine learning on Apple silicon.",
Expand Down

0 comments on commit 8081df7

Please sign in to comment.