Skip to content

Commit

Permalink
Clean Ruff settings
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 2, 2024
1 parent b41376d commit 5fef108
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 10 deletions.
4 changes: 0 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ module = ['networkx', 'optax', 'optax.contrib', 'array_api_compat']
ignore_missing_imports = true

[tool.ruff]
target-version = 'py310'
line-length = 100

[tool.ruff.lint]
Expand Down Expand Up @@ -258,9 +257,6 @@ ignore = [
'TID252', # Relative imports from parent modules are banned.
]

[tool.ruff.lint.flake8-annotations]
mypy-init-return = true

[tool.ruff.lint.flake8-errmsg]
max-string-length = 40

Expand Down
2 changes: 1 addition & 1 deletion tests/test_flax_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
@pytest.mark.skip
def test_dataclass_module() -> None:
class SomeModule(nnx.Module):
def __init__(self, epsilon: Array):
def __init__(self, epsilon: Array) -> None:
super().__init__()
self.epsilon = epsilon

Expand Down
2 changes: 1 addition & 1 deletion tjax/_src/rng.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


class RngStream:
def __init__(self, key: KeyArray, count: JaxIntegralArray | None = None):
def __init__(self, key: KeyArray, count: JaxIntegralArray | None = None) -> None:
super().__init__()
if count is None:
count = jnp.zeros((), dtype=jnp.uint32)
Expand Down
12 changes: 8 additions & 4 deletions tjax/_src/shims.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ class custom_vjp(Generic[P, R_co]): # noqa: N801
def __init__(self,
func: Callable[P, R_co],
*,
static_argnums: tuple[int, ...] = ()):
static_argnums: tuple[int, ...] = ()
) -> None:
super().__init__()
static_argnums = tuple(sorted(static_argnums))
self.vjp = jax.custom_vjp(func, nondiff_argnums=static_argnums)
Expand Down Expand Up @@ -72,7 +73,8 @@ class custom_vjp_method(Generic[U, P, R_co]): # noqa: N801
def __init__(self,
func: Callable[Concatenate[U, P], R_co],
*,
static_argnums: tuple[int, ...] = ()):
static_argnums: tuple[int, ...] = ()
) -> None:
super().__init__()
static_argnums = tuple(sorted(static_argnums))
self.vjp = jax.custom_vjp(func, nondiff_argnums=static_argnums)
Expand Down Expand Up @@ -117,7 +119,8 @@ class custom_jvp(Generic[P, R_co]): # noqa: N801
def __init__(self,
func: Callable[P, R_co],
*,
nondiff_argnums: tuple[int, ...] = ()):
nondiff_argnums: tuple[int, ...] = ()
) -> None:
super().__init__()
nondiff_argnums = tuple(sorted(nondiff_argnums))
self.jvp = jax.custom_jvp(func, nondiff_argnums=nondiff_argnums)
Expand Down Expand Up @@ -147,7 +150,8 @@ class custom_jvp_method(Generic[U, P, R_co]): # noqa: N801
def __init__(self,
func: Callable[Concatenate[U, P], R_co],
*,
nondiff_argnums: tuple[int, ...] = ()):
nondiff_argnums: tuple[int, ...] = ()
) -> None:
super().__init__()
nondiff_argnums = tuple(sorted(nondiff_argnums))
self.jvp = jax.custom_jvp(func, nondiff_argnums=nondiff_argnums)
Expand Down

0 comments on commit 5fef108

Please sign in to comment.