diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cd3e51d..2f5c354 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,7 +1,10 @@ repos: - - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.11.5 - hooks: - - id: ruff - args: ["--fix"] +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.13.1 + hooks: + # Run the linter. + - id: ruff-check + args: [ --fix ] + # Run the formatter. - id: ruff-format \ No newline at end of file diff --git a/posteriors/ekf/dense_fisher.py b/posteriors/ekf/dense_fisher.py index 02ef77c..9e41132 100644 --- a/posteriors/ekf/dense_fisher.py +++ b/posteriors/ekf/dense_fisher.py @@ -2,7 +2,7 @@ from functools import partial import torch from torch.func import grad_and_value -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from tensordict import TensorClass from posteriors.tree_utils import tree_size, tree_insert_ diff --git a/posteriors/laplace/dense_fisher.py b/posteriors/laplace/dense_fisher.py index e2f7220..3117fe1 100644 --- a/posteriors/laplace/dense_fisher.py +++ b/posteriors/laplace/dense_fisher.py @@ -2,7 +2,7 @@ from functools import partial import torch from optree import tree_map -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from tensordict import TensorClass from posteriors.types import TensorTree, Transform, LogProbFn from posteriors.tree_utils import tree_size, tree_insert_ diff --git a/posteriors/laplace/dense_ggn.py b/posteriors/laplace/dense_ggn.py index 1639345..9873a19 100644 --- a/posteriors/laplace/dense_ggn.py +++ b/posteriors/laplace/dense_ggn.py @@ -2,7 +2,7 @@ from typing import Any import torch from optree import tree_map -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from tensordict import TensorClass from posteriors.types import ( diff --git a/posteriors/laplace/dense_hessian.py b/posteriors/laplace/dense_hessian.py index 6c7d329..d6c8096 100644 --- a/posteriors/laplace/dense_hessian.py +++ b/posteriors/laplace/dense_hessian.py @@ -2,7 +2,7 @@ from functools import partial import torch from optree import tree_map -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from tensordict import TensorClass from posteriors.types import TensorTree, Transform, LogProbFn diff --git a/posteriors/sgmcmc/sgnht.py b/posteriors/sgmcmc/sgnht.py index 198c02d..30b43df 100644 --- a/posteriors/sgmcmc/sgnht.py +++ b/posteriors/sgmcmc/sgnht.py @@ -3,7 +3,7 @@ import torch from torch.func import grad_and_value from optree import tree_map -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from tensordict import TensorClass from posteriors.types import TensorTree, Transform, LogProbFn, Schedule from posteriors.tree_utils import flexi_tree_map, tree_insert_ diff --git a/posteriors/utils.py b/posteriors/utils.py index e229c1e..dea1f05 100644 --- a/posteriors/utils.py +++ b/posteriors/utils.py @@ -6,7 +6,7 @@ from torch.func import grad, jvp, vjp, functional_call, jacrev, jacfwd from torch.distributions import Normal from optree import tree_map, tree_reduce, tree_flatten, tree_leaves -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors.types import TensorTree, ForwardFn, Tensor from posteriors.tree_utils import tree_size diff --git a/posteriors/vi/dense.py b/posteriors/vi/dense.py index 72ba070..2c3c129 100644 --- a/posteriors/vi/dense.py +++ b/posteriors/vi/dense.py @@ -3,7 +3,7 @@ import torch from torch.func import grad_and_value, vmap from optree import tree_map -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel import torchopt from tensordict import TensorClass diff --git a/pyproject.toml b/pyproject.toml index db9ce1a..556c876 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", "License :: OSI Approved :: Apache Software License", ] -dependencies = ["torch>=2.0.0", "torchopt>=0.7.3", "optree>=0.10.0", "tensordict>=0.7.0"] +dependencies = ["torch>=2.0.0", "torchopt>=0.7.3", "optree>=0.17.0", "tensordict>=0.7.0"] [project.optional-dependencies] test = ["pre-commit", "pytest-cov", "pytest-xdist", "ruff"] diff --git a/tests/ekf/test_diag_fisher.py b/tests/ekf/test_diag_fisher.py index c34b937..5d80c3b 100644 --- a/tests/ekf/test_diag_fisher.py +++ b/tests/ekf/test_diag_fisher.py @@ -1,6 +1,6 @@ from functools import partial import torch -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors import ekf from tests.scenarios import get_multivariate_normal_log_prob from tests.utils import verify_inplace_update diff --git a/tests/ekf/utils.py b/tests/ekf/utils.py index eb3448d..4bdb667 100644 --- a/tests/ekf/utils.py +++ b/tests/ekf/utils.py @@ -1,6 +1,6 @@ from typing import Callable import torch -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from tests.scenarios import get_multivariate_normal_log_prob from posteriors.types import LogProbFn, Transform diff --git a/tests/laplace/test_dense_ggn.py b/tests/laplace/test_dense_ggn.py index 5772267..7a1ac72 100644 --- a/tests/laplace/test_dense_ggn.py +++ b/tests/laplace/test_dense_ggn.py @@ -3,7 +3,7 @@ from torch.distributions import Normal from torch.utils.data import DataLoader, TensorDataset from torch.func import functional_call -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors.laplace import dense_ggn from tests.utils import verify_inplace_update diff --git a/tests/laplace/test_dense_hessian.py b/tests/laplace/test_dense_hessian.py index 0856758..c3e9a54 100644 --- a/tests/laplace/test_dense_hessian.py +++ b/tests/laplace/test_dense_hessian.py @@ -2,7 +2,7 @@ from torch.distributions import Normal from torch.utils.data import DataLoader, TensorDataset from torch.func import functional_call, hessian -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors import tree_size, diag_normal_log_prob from posteriors.laplace import dense_hessian diff --git a/tests/laplace/test_diag_fisher.py b/tests/laplace/test_diag_fisher.py index dfd0bbd..9abc607 100644 --- a/tests/laplace/test_diag_fisher.py +++ b/tests/laplace/test_diag_fisher.py @@ -3,7 +3,7 @@ from torch.distributions import Normal from torch.utils.data import DataLoader, TensorDataset from torch.func import functional_call -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from optree import tree_map from posteriors.laplace import diag_fisher diff --git a/tests/laplace/test_diag_ggn.py b/tests/laplace/test_diag_ggn.py index d92ebf0..09531f3 100644 --- a/tests/laplace/test_diag_ggn.py +++ b/tests/laplace/test_diag_ggn.py @@ -4,7 +4,7 @@ from torch.utils.data import DataLoader, TensorDataset from torch.func import functional_call from optree import tree_map -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors.laplace import diag_ggn diff --git a/tests/scenarios.py b/tests/scenarios.py index f8e7cb8..ebfff43 100644 --- a/tests/scenarios.py +++ b/tests/scenarios.py @@ -1,7 +1,7 @@ import torch import torch.nn as nn from torch.distributions import MultivariateNormal -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors.types import LogProbFn diff --git a/tests/test_utils.py b/tests/test_utils.py index d039f83..414bb81 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,7 +2,7 @@ from functools import partial import torch from optree import tree_map, tree_flatten, tree_reduce -from optree.integration.torch import tree_ravel +from optree.integrations.torch import tree_ravel from posteriors import ( CatchAuxError,