Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions conda-envs/environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
dependencies:
# Base dependencies
- arviz>=0.13.0
- arviz-base
- blas
- cachetools>=4.2.1
- cloudpickle
Expand Down
1 change: 1 addition & 0 deletions conda-envs/environment-docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ channels:
dependencies:
# Base dependencies
- arviz>=0.13.0
- arviz-base
- cachetools>=4.2.1
- cloudpickle
- numpy>=1.25.0
Expand Down
37 changes: 23 additions & 14 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@
import xarray

from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from arviz.data.base import CoordSpec, DimSpec, requires
from arviz_base import dict_to_dataset
from pytensor.graph import ancestors
from pytensor.tensor.sharedvar import SharedVariable
from rich.progress import Console
Expand Down Expand Up @@ -305,14 +306,14 @@ def posterior_to_xarray(self):
return (
dict_to_dataset(
data,
library=pymc,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
),
dict_to_dataset(
data_warmup,
library=pymc,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
attrs=self.attrs,
Expand Down Expand Up @@ -347,14 +348,14 @@ def sample_stats_to_xarray(self):
return (
dict_to_dataset(
data,
library=pymc,
inference_library=pymc,
dims=None,
coords=self.coords,
attrs=self.attrs,
),
dict_to_dataset(
data_warmup,
library=pymc,
inference_library=pymc,
dims=None,
coords=self.coords,
attrs=self.attrs,
Expand All @@ -367,7 +368,11 @@ def posterior_predictive_to_xarray(self):
data = self.posterior_predictive
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
data,
inference_library=pymc,
coords=self.coords,
dims=dims,
sample_dims=self.sample_dims,
)

@requires(["predictions"])
Expand All @@ -376,7 +381,11 @@ def predictions_to_xarray(self):
data = self.predictions
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
return dict_to_dataset(
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
data,
inference_library=pymc,
coords=self.coords,
dims=dims,
sample_dims=self.sample_dims,
)

def priors_to_xarray(self):
Expand All @@ -399,7 +408,7 @@ def priors_to_xarray(self):
if var_names is None
else dict_to_dataset_drop_incompatible_coords(
{k: np.expand_dims(self.prior[k], 0) for k in var_names},
library=pymc,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
)
Expand All @@ -414,10 +423,10 @@ def observed_data_to_xarray(self):
return None
return dict_to_dataset(
self.observations,
library=pymc,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
default_dims=[],
sample_dims=[],
)

@requires("model")
Expand All @@ -429,10 +438,10 @@ def constant_data_to_xarray(self):

xarray_dataset = dict_to_dataset(
constant_data,
library=pymc,
inference_library=pymc,
coords=self.coords,
dims=self.dims,
default_dims=[],
sample_dims=[],
)

# provisional handling of scalars in constant
Expand Down Expand Up @@ -707,9 +716,9 @@ def apply_function_over_dataset(

return dict_to_dataset(
out_trace,
library=pymc,
inference_library=pymc,
dims=dims,
coords=coords,
default_dims=list(sample_dims),
sample_dims=list(sample_dims),
skip_event_dims=True,
)
4 changes: 3 additions & 1 deletion pymc/smc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,9 @@ def _save_sample_stats(
sample_stats = dict_to_dataset(
sample_stats_dict,
attrs=sample_settings_dict,
library=pymc,
inference_library=pymc,
sample_dims=["chain"],
check_conventions=False,
)

ikwargs: dict[str, Any] = {"model": model}
Expand Down
1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# This file is auto-generated by scripts/generate_pip_deps_from_conda.py, do not modify.
# See that file for comments about the need/usage of each dependency.

arviz-base
arviz>=0.13.0
cachetools>=4.2.1
cloudpickle
Expand Down