Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[#1196] implementing rename tree_scalar_mul to tree_scale #1198

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions optax/tree_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from optax.tree_utils._state_utils import tree_map_params
from optax.tree_utils._state_utils import tree_set
from optax.tree_utils._tree_math import tree_add
from optax.tree_utils._tree_math import tree_add_scalar_mul
from optax.tree_utils._tree_math import tree_add_scale
from optax.tree_utils._tree_math import tree_batch_shape
from optax.tree_utils._tree_math import tree_bias_correction
from optax.tree_utils._tree_math import tree_clip
Expand All @@ -41,7 +41,7 @@
from optax.tree_utils._tree_math import tree_mul
from optax.tree_utils._tree_math import tree_ones_like
from optax.tree_utils._tree_math import tree_real
from optax.tree_utils._tree_math import tree_scalar_mul
from optax.tree_utils._tree_math import tree_scale
from optax.tree_utils._tree_math import tree_sub
from optax.tree_utils._tree_math import tree_sum
from optax.tree_utils._tree_math import tree_update_infinity_moment
Expand Down
4 changes: 2 additions & 2 deletions optax/tree_utils/_tree_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def tree_div(tree_x: Any, tree_y: Any) -> Any:
return jax.tree.map(operator.truediv, tree_x, tree_y)


def tree_scalar_mul(
def tree_scale(
scalar: Union[float, jax.Array],
tree: Any,
) -> Any:
Expand All @@ -99,7 +99,7 @@ def tree_scalar_mul(
return jax.tree.map(lambda x: scalar * x, tree)


def tree_add_scalar_mul(
def tree_add_scale(
tree_x: Any, scalar: Union[float, jax.Array], tree_y: Any
) -> Any:
r"""Add two trees, where the second tree is scaled by a scalar.
Expand Down
22 changes: 11 additions & 11 deletions optax/tree_utils/_tree_math_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,21 +101,21 @@ def test_tree_div(self):
got = tu.tree_div(self.tree_a, self.tree_b)
chex.assert_trees_all_close(expected, got)

def test_tree_scalar_mul(self):
def test_tree_scalar(self):
expected = 0.5 * self.array_a
got = tu.tree_scalar_mul(0.5, self.array_a)
got = tu.tree_scalar(0.5, self.array_a)
np.testing.assert_array_almost_equal(expected, got)

expected = (0.5 * self.tree_a[0], 0.5 * self.tree_a[1])
got = tu.tree_scalar_mul(0.5, self.tree_a)
got = tu.tree_scalar(0.5, self.tree_a)
chex.assert_trees_all_close(expected, got)

def test_tree_add_scalar_mul(self):
def test_tree_add_scale(self):
expected = (
self.tree_a[0] + 0.5 * self.tree_b[0],
self.tree_a[1] + 0.5 * self.tree_b[1],
)
got = tu.tree_add_scalar_mul(self.tree_a, 0.5, self.tree_b)
got = tu.tree_add_scale(self.tree_a, 0.5, self.tree_b)
chex.assert_trees_all_close(expected, got)

def test_tree_vdot(self):
Expand Down Expand Up @@ -223,19 +223,19 @@ def test_tree_ones_like(self):
def test_add_multiple_trees(self):
"""Test adding more than 2 trees with tree_add."""
trees = [self.tree_a_dict_jax, self.tree_a_dict_jax, self.tree_a_dict_jax]
expected = tu.tree_scalar_mul(3.0, self.tree_a_dict_jax)
expected = tu.tree_scalar(3.0, self.tree_a_dict_jax)
got = tu.tree_add(*trees)
chex.assert_trees_all_close(expected, got)

def test_tree_clip(self):
"""Clip the tree to range [min_value, max_value]."""
expected = tu.tree_scalar_mul(0.5, self.tree_a_dict_jax)
expected = tu.tree_scalar(0.5, self.tree_a_dict_jax)
got = tu.tree_clip(self.tree_a_dict_jax, min_value=0, max_value=0.5)
chex.assert_trees_all_close(expected, got)
expected = tu.tree_scalar_mul(0.5, self.tree_a_dict_jax)
expected = tu.tree_scalar(0.5, self.tree_a_dict_jax)
got = tu.tree_clip(self.tree_a_dict_jax, min_value=None, max_value=0.5)
chex.assert_trees_all_close(expected, got)
expected = tu.tree_scalar_mul(2.0, self.tree_a_dict_jax)
expected = tu.tree_scalar(2.0, self.tree_a_dict_jax)
got = tu.tree_clip(self.tree_a_dict_jax, min_value=2.0, max_value=None)
chex.assert_trees_all_close(expected, got)

Expand Down Expand Up @@ -287,8 +287,8 @@ def test_empty_tree_reduce(self):

@parameterized.named_parameters(
{
'testcase_name': 'tree_add_scalar_mul',
'operation': lambda m: tu.tree_add_scalar_mul(None, 1, m),
'testcase_name': 'tree_add_scale',
'operation': lambda m: tu.tree_add_scale(None, 1, m),
},
{
'testcase_name': 'tree_update_moment',
Expand Down
Loading