Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Remove redundant next_power_of_two() from math.mojo and optimize existing function in bit.mojo #4278

Open
wants to merge 15 commits into
base: main
Choose a base branch
from
Open
173 changes: 173 additions & 0 deletions mojo/stdlib/benchmarks/bit/bench_bit.mojo
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
# ===----------------------------------------------------------------------=== #
# Copyright (c) 2025, Modular Inc. All rights reserved.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions:
# https://llvm.org/LICENSE.txt
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ===----------------------------------------------------------------------=== #
# RUN: %mojo-no-debug %s -t
# NOTE: to test changes on the current branch using run-benchmarks.sh, remove
# the -t flag. Remember to replace it again before pushing any code.

from benchmark import Bench, BenchConfig, Bencher, BenchId, Unit, keep, run
from bit import bit_width, count_leading_zeros
from collections import Dict
from random import random_ui64, seed
from sys import bitwidthof
from sys.intrinsics import unlikely, likely


# ===-----------------------------------------------------------------------===#
# Benchmarks
# ===-----------------------------------------------------------------------===#

# ===-----------------------------------------------------------------------===#
# next_power_of_two
# ===-----------------------------------------------------------------------===#


fn next_power_of_two_int_v1(val: Int) -> Int:
if val <= 1:
return 1

if val.is_power_of_two():
return val

return 1 << bit_width(val - 1)


fn next_power_of_two_int_v2(val: Int) -> Int:
if val <= 1:
return 1

return 1 << (bitwidthof[Int]() - count_leading_zeros(val - 1))


fn next_power_of_two_int_v3(val: Int) -> Int:
var v = Scalar[DType.index](val)
return Int(
(v <= 1)
.select(1, 1 << (bitwidthof[Int]() - count_leading_zeros(v - 1)))
.__index__()
)


fn next_power_of_two_int_v4(val: Int) -> Int:
return 1 << (
(bitwidthof[Int]() - count_leading_zeros(val - 1))
& -Int(likely(val > 1))
)


fn next_power_of_two_uint_v1(val: UInt) -> UInt:
if unlikely(val == 0):
return 1

return 1 << (bitwidthof[UInt]() - count_leading_zeros(val - 1))


fn next_power_of_two_uint_v2(val: UInt) -> UInt:
var v = Scalar[DType.index](val)
return UInt(
(v == 0)
.select(1, 1 << (bitwidthof[UInt]() - count_leading_zeros(v - 1)))
.__index__()
)


fn next_power_of_two_uint_v3(val: UInt) -> UInt:
return 1 << (
bitwidthof[UInt]() - count_leading_zeros(val - UInt(likely(val > 0)))
)


fn next_power_of_two_uint_v4(val: UInt) -> UInt:
return 1 << (
bitwidthof[UInt]()
- count_leading_zeros((val | UInt(unlikely(val == 0))) - 1)
)


fn _build_list[start: Int, stop: Int]() -> List[Int]:
var values = List[Int](capacity=10_000)
for _ in range(10_000):
values.append(Int(random_ui64(start, stop)))
return values^


alias width = bitwidthof[Int]()
var int_values = _build_list[-(2 ** (width - 1)), 2 ** (width - 1) - 1]()
var uint_values = _build_list[0, 2**width - 1]()


@parameter
fn bench_next_power_of_two_int[func: fn (Int) -> Int](mut b: Bencher) raises:
@always_inline
@parameter
fn call_fn() raises:
for _ in range(10_000):
for i in range(len(uint_values)):
var result = func(uint_values.unsafe_get(i))
keep(result)

b.iter[call_fn]()


@parameter
fn bench_next_power_of_two_uint[func: fn (UInt) -> UInt](mut b: Bencher) raises:
@always_inline
@parameter
fn call_fn() raises:
for _ in range(10_000):
for i in range(len(uint_values)):
var result = func(uint_values.unsafe_get(i))
keep(result)

b.iter[call_fn]()


# ===-----------------------------------------------------------------------===#
# Benchmark Main
# ===-----------------------------------------------------------------------===#
def main():
seed()
var m = Bench(BenchConfig(num_repetitions=10))
m.bench_function[bench_next_power_of_two_int[next_power_of_two_int_v1]](
BenchId("bench_next_power_of_two_int_v1")
)
m.bench_function[bench_next_power_of_two_int[next_power_of_two_int_v2]](
BenchId("bench_next_power_of_two_int_v2")
)
m.bench_function[bench_next_power_of_two_int[next_power_of_two_int_v3]](
BenchId("bench_next_power_of_two_int_v3")
)
m.bench_function[bench_next_power_of_two_int[next_power_of_two_int_v4]](
BenchId("bench_next_power_of_two_int_v4")
)
m.bench_function[bench_next_power_of_two_uint[next_power_of_two_uint_v1]](
BenchId("bench_next_power_of_two_uint_v1")
)
m.bench_function[bench_next_power_of_two_uint[next_power_of_two_uint_v2]](
BenchId("bench_next_power_of_two_uint_v2")
)
m.bench_function[bench_next_power_of_two_uint[next_power_of_two_uint_v3]](
BenchId("bench_next_power_of_two_uint_v3")
)
m.bench_function[bench_next_power_of_two_uint[next_power_of_two_uint_v4]](
BenchId("bench_next_power_of_two_uint_v4")
)

results = Dict[String, (Float64, Int)]()
for info in m.info_vec:
n = info[].name
time = info[].result.mean("ms")
avg, amnt = results.get(n, (Float64(0), 0))
results[n] = ((avg * amnt + time) / (amnt + 1), amnt + 1)
print("")
for k_v in results.items():
print(k_v[].key, k_v[].value[0], sep=",")
35 changes: 27 additions & 8 deletions mojo/stdlib/src/bit/bit.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ from bit import count_leading_zeros

from sys import llvm_intrinsic, sizeof
from sys.info import bitwidthof
from utils._select import _select_register_value as select

# ===-----------------------------------------------------------------------===#
# count_leading_zeros
Expand Down Expand Up @@ -375,21 +376,39 @@ fn next_power_of_two(val: Int) -> Int:
"""Computes the smallest power of 2 that is greater than or equal to the
input value. Any integral value less than or equal to 1 will be ceiled to 1.

This operation is called `bit_ceil()` in C++.

Args:
val: The input value.

Returns:
The smallest power of 2 that is greater than or equal to the input value.
The smallest power of 2 that is greater than or equal to the input
value.

Notes:
This operation is called `bit_ceil()` in C++.
"""
if val <= 1:
return 1
return select(
val <= 1, 1, 1 << (bitwidthof[Int]() - count_leading_zeros(val - 1))
)


@always_inline
fn next_power_of_two(val: UInt) -> UInt:
"""Computes the smallest power of 2 that is greater than or equal to the
input value. Any integral value less than or equal to 1 will be ceiled to 1.

if val.is_power_of_two():
return val
Args:
val: The input value.

return 1 << bit_width(val - 1)
Returns:
The smallest power of 2 that is greater than or equal to the input
value.

Notes:
This operation is called `bit_ceil()` in C++.
"""
return select(
val == 0, 1, 1 << (bitwidthof[UInt]() - count_leading_zeros(val - 1))
)


@always_inline
Expand Down
1 change: 0 additions & 1 deletion mojo/stdlib/src/math/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@ from .math import (
log10,
logb,
modf,
next_power_of_two,
recip,
remainder,
scalb,
Expand Down
20 changes: 0 additions & 20 deletions mojo/stdlib/src/math/math.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -2357,26 +2357,6 @@ fn clamp(
return val.clamp(lower_bound, upper_bound)


# ===----------------------------------------------------------------------=== #
# next_power_of_two
# ===----------------------------------------------------------------------=== #


fn next_power_of_two(n: Int) -> Int:
"""Computes the next power of two greater than or equal to the input.

Args:
n: The input value.

Returns:
The next power of two greater than or equal to the input.
"""
if n <= 1:
return 1

return 1 << (bitwidthof[Int]() - count_leading_zeros(n - 1))


# ===----------------------------------------------------------------------=== #
# utilities
# ===----------------------------------------------------------------------=== #
Expand Down
25 changes: 18 additions & 7 deletions mojo/stdlib/test/bit/test_bit.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -314,13 +314,24 @@ def test_bit_width_simd():


def test_next_power_of_two():
assert_equal(next_power_of_two(-(2**59)), 1)
assert_equal(next_power_of_two(-2), 1)
assert_equal(next_power_of_two(1), 1)
assert_equal(next_power_of_two(2), 2)
assert_equal(next_power_of_two(4), 4)
assert_equal(next_power_of_two(5), 8)
assert_equal(next_power_of_two(2**59 - 3), 2**59)
# test for Int
assert_equal(next_power_of_two(Int(-(2**59))), 1)
assert_equal(next_power_of_two(Int(-2)), 1)
assert_equal(next_power_of_two(Int(-1)), 1)
assert_equal(next_power_of_two(Int(0)), 1)
assert_equal(next_power_of_two(Int(1)), 1)
assert_equal(next_power_of_two(Int(2)), 2)
assert_equal(next_power_of_two(Int(4)), 4)
assert_equal(next_power_of_two(Int(5)), 8)
assert_equal(next_power_of_two(Int(2**59 - 3)), 2**59)

# test for UInt
assert_equal(next_power_of_two(UInt(0)), 1)
assert_equal(next_power_of_two(UInt(1)), 1)
assert_equal(next_power_of_two(UInt(2)), 2)
assert_equal(next_power_of_two(UInt(4)), 4)
assert_equal(next_power_of_two(UInt(5)), 8)
assert_equal(next_power_of_two(UInt(2**59 - 3)), 2**59)


def test_next_power_of_two_simd():
Expand Down