From 16948c1e3be99d414b931420f44948eb1e700d0e Mon Sep 17 00:00:00 2001 From: Suhaani Agarwal Date: Wed, 26 Feb 2025 02:04:04 +0530 Subject: [PATCH 1/3] Initial work on autocorrelation plot refactor --- src/arviz_plots/plot_collection.py | 20 +- src/arviz_plots/plots/autocorrplot.py | 254 ++++++++++++++++++++++++++ 2 files changed, 273 insertions(+), 1 deletion(-) create mode 100644 src/arviz_plots/plots/autocorrplot.py diff --git a/src/arviz_plots/plot_collection.py b/src/arviz_plots/plot_collection.py index d857506a..08766a38 100644 --- a/src/arviz_plots/plot_collection.py +++ b/src/arviz_plots/plot_collection.py @@ -3,6 +3,7 @@ import warnings from importlib import import_module +import matplotlib.pyplot as plt import numpy as np import xarray as xr from arviz_base import rcParams @@ -873,9 +874,26 @@ def allocate_artist(self, fun_label, data, all_loop_dims, artist_dims=None): coords={dim: data[dim] for dim in inherited_dims}, ) + # def get_target(self, var_name, selection): + # """Get the target that corresponds to the given variable and selection.""" + # return subset_ds(self.get_viz(var_name), "plot", selection) + def get_target(self, var_name, selection): """Get the target that corresponds to the given variable and selection.""" - return subset_ds(self.get_viz(var_name), "plot", selection) + selection_key = tuple(sorted(selection.items())) + target = self.viz[var_name].get(selection_key, None) + # Ensuring the target is always a Matplotlib Axes object + if target is None: + print(f"Warning: Selection {selection_key} not found in viz[{var_name}]") + fig, target = plt.subplots() # Creating new figure + plt.close(fig) # Closing figure to prevent display issues + + if not isinstance(target, plt.Axes): + print("Warning: Target is not an Axes object. Creating a new figure.") + fig, target = plt.subplots() + plt.close(fig) + + return target def get_aes_kwargs(self, aes, var_name, selection): """Get the aesthetic mappings for the given variable and selection as a dictionary. diff --git a/src/arviz_plots/plots/autocorrplot.py b/src/arviz_plots/plots/autocorrplot.py new file mode 100644 index 00000000..5c25cfc1 --- /dev/null +++ b/src/arviz_plots/plots/autocorrplot.py @@ -0,0 +1,254 @@ +"""Autocorrelation plot code.""" + +from importlib import import_module + +import numpy as np +import xarray as xr +from arviz_base import convert_to_dataset, rcParams +from arviz_base.labels import BaseLabeller +from arviz_stats.base.core import _CoreBase + +from arviz_plots.plot_collection import PlotCollection, process_facet_dims +from arviz_plots.plots.utils import filter_aes, process_group_variables_coords +from arviz_plots.visuals import labelled_title, line_x, remove_axis + + +def plot_autocorr( + dt, + var_names=None, + filter_vars=None, + group="posterior", + coords=None, + max_lag=None, + combined=False, + sample_dims=None, + plot_collection=None, + backend=None, + labeller=None, + aes_map=None, + plot_kwargs=None, + pc_kwargs=None, +): + """Generate autocorrelation plots for the given dataset. + + Parameters + ---------- + dt : xarray.Dataset + The dataset containing the variables to plot. + var_names : list of str, optional + Names of variables to include in the plot. + filter_vars : str, optional + Filter to apply to variable names. + group : str, default="posterior" + The group in the dataset to use for plotting. + coords : dict, optional + Coordinates to subset the dataset. + max_lag : int, optional + Maximum lag to compute autocorrelation for. + combined : bool, default=False + Whether to combine chains when computing autocorrelation. + sample_dims : list of str, optional + Dimensions to treat as sample dimensions. + plot_collection : PlotCollection, optional + Existing plot collection to use. + backend : str, optional + Backend to use for plotting. + labeller : BaseLabeller, optional + Labeller to use for plot labels. + aes_map : dict, optional + Mapping of aesthetics to variables. + plot_kwargs : dict, optional + Additional keyword arguments for the plot. + pc_kwargs : dict, optional + Additional keyword arguments for the plot collection. + + Returns + ------- + plot_collection : PlotCollection + The plot collection containing the autocorrelation plots. + """ + dt = convert_to_dataset(dt, group="posterior") + + if sample_dims is None: + sample_dims = rcParams["data.sample_dims"] + if isinstance(sample_dims, str): + sample_dims = [sample_dims] + if plot_kwargs is None: + plot_kwargs = {} + if pc_kwargs is None: + pc_kwargs = {} + + # Default max lag to 100 or max length of chain + if max_lag is None: + max_lag = 100 + + distribution = process_group_variables_coords( + dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords + ) + + # Convert xarray.Dataset to numpy array + distribution_array = distribution.to_array().values + + # Calculate lags up to max_lag + lags = np.arange(max_lag) + + # Calculate autocorrelation from arviz_stats autocorr computation + core_base = _CoreBase() + acf_data = core_base.autocorr(distribution_array, axis=-1)[..., :max_lag] + + dims = list(distribution.dims) + ["lag"] + + # Ensure correct dimension sizes + coords = {} + for dim in distribution.dims: + expected_size = acf_data.shape[dims.index(dim)] + if dim in distribution.coords and len(distribution.coords[dim]) == expected_size: + coords[dim] = distribution.coords[dim] + else: + coords[dim] = np.arange(expected_size) + + # Add lag coordinate + coords["lag"] = np.arange(acf_data.shape[-1]) + + # Adjust the 'lag' coordinate to match the size of the 'lag' dimension in acf_data + lag_size = acf_data.shape[-1] # Size of the 'lag' dimension in acf_data + coords["lag"] = np.arange(lag_size) # Update 'lag' coordinate to match the size + + # Handle the case where the dimension sizes have changed + for dim in distribution.dims: + if dim in coords and acf_data.shape[dims.index(dim)] != len(coords[dim]): + # Adjust the dimension coordinate to match the size of the dimension in acf_data + coords[dim] = np.arange(acf_data.shape[dims.index(dim)]) + + acf_data = xr.DataArray( + acf_data, + dims=dims, + coords=coords, + ) + + if backend is None: + if plot_collection is None: + backend = rcParams["plot.backend"] + else: + backend = plot_collection.backend + plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") + + if plot_collection is None: + # Convert DataArray to Dataset for compatibility with process_facet_dims + pc_data = acf_data.to_dataset(name="autocorr") + if "column" not in pc_data.dims: + pc_data = pc_data.expand_dims(column=["autocorr"]) + print(pc_data) + + # Set default columns and rows for faceting + pc_kwargs.setdefault("cols", ["__variable__"]) + pc_kwargs.setdefault("rows", list(set(pc_data.dims) - {"__variable__", "lag", "chain"})) + + # Calculate the number of plots + n_plots, plots_per_var = process_facet_dims(pc_data, pc_kwargs["cols"]) + + # Set up figure size + figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) + if figsize is None: + col_wrap = pc_kwargs.get("col_wrap", 4) + if n_plots <= col_wrap: + n_rows, n_cols = 1, n_plots + else: + div_mod = divmod(n_plots, col_wrap) + n_rows = div_mod[0] + (div_mod[1] != 0) + n_cols = col_wrap + figsize = plot_bknd.scale_fig_size( + figsize, + rows=n_rows, + cols=n_cols, + ) + + pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() + pc_kwargs["plot_grid_kws"]["figsize"] = figsize + pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() + pc_kwargs["aes"].setdefault("y", ["lag"]) + + if not combined and "chain" in distribution.dims: + pc_kwargs["aes"].setdefault("color", ["chain"]) + + plot_collection = PlotCollection.grid( + pc_data, + backend=backend, + **pc_kwargs, + ) + + if aes_map is None: + aes_map = {} + aes_map.setdefault("line", plot_collection.aes_set) + + if labeller is None: + labeller = BaseLabeller() + + # Convert acf_data to Dataset for compatibility with plot_collection.map + acf_dataset = acf_data.to_dataset(name="autocorr") + print("acf data is : ", acf_data) + + # Plot autocorrelation lines + line_kwargs = plot_kwargs.get("line", {}).copy() + line_kwargs.setdefault("linewidth", 1.5) # Use 'linewidth' instead of 'width' + line_kwargs.setdefault("color", "#1f77b4") + print("line_kwargs:", line_kwargs) + print("aes_map:", aes_map) + + acf_dims, acf_aes, acf_ignore = filter_aes(plot_collection, aes_map, "line", sample_dims) + + print(acf_dataset) + print(acf_ignore) + + plot_collection.map( + line_x, + "autocorr", + data=acf_dataset, + ignore_aes=acf_ignore, + **line_kwargs, + ) + + # Add reference line at 0 + ref_line_kwargs = plot_kwargs.get("reference_line", {}).copy() + ref_line_kwargs.setdefault("color", "gray") + ref_line_kwargs.setdefault("linewidth", 1) # Use 'linewidth' instead of 'width' + ref_line_kwargs.setdefault("linestyle", "--") + + # Create zero line as a proper DataArray + zero_line = xr.DataArray(np.zeros(max_lag), dims=["lag"], coords={"lag": np.arange(max_lag)}) + + # Convert to dataset with the same variable name as the main data + zero_line = zero_line.to_dataset(name="autocorr") + + if "column" not in zero_line.dims: + zero_line = zero_line.expand_dims(column=["autocorr"]) + + plot_collection.map( + line_x, + "reference_line", + data=zero_line, + ignore_aes=plot_collection.aes_set, + **ref_line_kwargs, + ) + + # Add titles for each plot + title_kwargs = plot_kwargs.get("title", {}).copy() + if title_kwargs is not False: + _, title_aes, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) + if "color" not in title_aes: + title_kwargs.setdefault("color", "black") + plot_collection.map( + labelled_title, + "title", + ignore_aes=title_ignore, + subset_info=True, + labeller=labeller, + **title_kwargs, + ) + + if plot_kwargs.get("remove_axis", True) is not False: + plot_collection.map( + remove_axis, store_artist=False, axis="y", ignore_aes=plot_collection.aes_set + ) + + return plot_collection From 749fbca748393c3316c6470966d571cc746a2580 Mon Sep 17 00:00:00 2001 From: Suhaani Agarwal Date: Wed, 12 Mar 2025 15:58:19 +0530 Subject: [PATCH 2/3] progress --- src/arviz_plots/plots/autocorrplot.py | 180 +++++++++----------------- 1 file changed, 63 insertions(+), 117 deletions(-) diff --git a/src/arviz_plots/plots/autocorrplot.py b/src/arviz_plots/plots/autocorrplot.py index 5c25cfc1..dfbe8022 100644 --- a/src/arviz_plots/plots/autocorrplot.py +++ b/src/arviz_plots/plots/autocorrplot.py @@ -6,10 +6,9 @@ import xarray as xr from arviz_base import convert_to_dataset, rcParams from arviz_base.labels import BaseLabeller -from arviz_stats.base.core import _CoreBase -from arviz_plots.plot_collection import PlotCollection, process_facet_dims -from arviz_plots.plots.utils import filter_aes, process_group_variables_coords +from arviz_plots.plot_collection import PlotCollection +from arviz_plots.plots.utils import process_group_variables_coords from arviz_plots.visuals import labelled_title, line_x, remove_axis @@ -29,45 +28,9 @@ def plot_autocorr( plot_kwargs=None, pc_kwargs=None, ): - """Generate autocorrelation plots for the given dataset. - - Parameters - ---------- - dt : xarray.Dataset - The dataset containing the variables to plot. - var_names : list of str, optional - Names of variables to include in the plot. - filter_vars : str, optional - Filter to apply to variable names. - group : str, default="posterior" - The group in the dataset to use for plotting. - coords : dict, optional - Coordinates to subset the dataset. - max_lag : int, optional - Maximum lag to compute autocorrelation for. - combined : bool, default=False - Whether to combine chains when computing autocorrelation. - sample_dims : list of str, optional - Dimensions to treat as sample dimensions. - plot_collection : PlotCollection, optional - Existing plot collection to use. - backend : str, optional - Backend to use for plotting. - labeller : BaseLabeller, optional - Labeller to use for plot labels. - aes_map : dict, optional - Mapping of aesthetics to variables. - plot_kwargs : dict, optional - Additional keyword arguments for the plot. - pc_kwargs : dict, optional - Additional keyword arguments for the plot collection. - - Returns - ------- - plot_collection : PlotCollection - The plot collection containing the autocorrelation plots. - """ - dt = convert_to_dataset(dt, group="posterior") + """Generate autocorrelation plots for the given dataset.""" + dt = convert_to_dataset(dt, group=group) + print("Input dataset:", dt) if sample_dims is None: sample_dims = rcParams["data.sample_dims"] @@ -85,46 +48,50 @@ def plot_autocorr( distribution = process_group_variables_coords( dt, group=group, var_names=var_names, filter_vars=filter_vars, coords=coords ) - - # Convert xarray.Dataset to numpy array - distribution_array = distribution.to_array().values - - # Calculate lags up to max_lag - lags = np.arange(max_lag) - - # Calculate autocorrelation from arviz_stats autocorr computation - core_base = _CoreBase() - acf_data = core_base.autocorr(distribution_array, axis=-1)[..., :max_lag] - - dims = list(distribution.dims) + ["lag"] - - # Ensure correct dimension sizes - coords = {} - for dim in distribution.dims: - expected_size = acf_data.shape[dims.index(dim)] - if dim in distribution.coords and len(distribution.coords[dim]) == expected_size: - coords[dim] = distribution.coords[dim] + print("Processed distribution:", distribution) + + # Compute autocorrelation for each variable and chain + acf_data = [] + for var in distribution.data_vars: + var_data = distribution[var] + print(f"Processing variable: {var}") + if "chain" in var_data.dims and not combined: + for chain in var_data.chain.values: + chain_data = var_data.sel(chain=chain) + print(f"Processing chain: {chain}") + # Ensure sample_dims are valid for the current data + valid_sample_dims = [dim for dim in sample_dims if dim in chain_data.dims] + if not valid_sample_dims: + raise ValueError( + f"None of the sample_dims {sample_dims} present in data for {var}" + ) + # Compute autocorrelation + acf = chain_data.azstats.autocorr(dims=valid_sample_dims) + print(f"Autocorrelation result for {var}, chain {chain}: {acf}") + # Add chain and variable as coordinates + acf = acf.assign_coords({"chain": chain, "variable": var}) + acf_data.append(acf) else: - coords[dim] = np.arange(expected_size) - - # Add lag coordinate - coords["lag"] = np.arange(acf_data.shape[-1]) - - # Adjust the 'lag' coordinate to match the size of the 'lag' dimension in acf_data - lag_size = acf_data.shape[-1] # Size of the 'lag' dimension in acf_data - coords["lag"] = np.arange(lag_size) # Update 'lag' coordinate to match the size - - # Handle the case where the dimension sizes have changed - for dim in distribution.dims: - if dim in coords and acf_data.shape[dims.index(dim)] != len(coords[dim]): - # Adjust the dimension coordinate to match the size of the dimension in acf_data - coords[dim] = np.arange(acf_data.shape[dims.index(dim)]) - - acf_data = xr.DataArray( - acf_data, - dims=dims, - coords=coords, - ) + # Ensure sample_dims are valid for the current data + valid_sample_dims = [dim for dim in sample_dims if dim in var_data.dims] + if not valid_sample_dims: + raise ValueError(f"None of the sample_dims {sample_dims} present in data for {var}") + # Compute autocorrelation + acf = var_data.azstats.autocorr(dims=valid_sample_dims) + print(f"Autocorrelation result for {var}: {acf}") + # Add variable as a coordinate + acf = acf.assign_coords({"variable": var}) + acf_data.append(acf) + + # Combine all autocorrelation results into a single DataArray + acf_data = xr.concat(acf_data, dim="variable") + print("Combined acf_data:", acf_data) + print("Shape of acf_data:", acf_data.shape) + + # Convert acf_data to Dataset with the correct variable name + acf_dataset = acf_data.to_dataset(name="autocorr") + print("acf_dataset:", acf_dataset) + print("Variables in acf_dataset:", list(acf_dataset.data_vars)) # Should include 'autocorr' if backend is None: if plot_collection is None: @@ -134,18 +101,16 @@ def plot_autocorr( plot_bknd = import_module(f".backend.{backend}", package="arviz_plots") if plot_collection is None: - # Convert DataArray to Dataset for compatibility with process_facet_dims - pc_data = acf_data.to_dataset(name="autocorr") - if "column" not in pc_data.dims: - pc_data = pc_data.expand_dims(column=["autocorr"]) - print(pc_data) - - # Set default columns and rows for faceting + # Set up faceting + pc_kwargs.setdefault("col_wrap", 4) pc_kwargs.setdefault("cols", ["__variable__"]) - pc_kwargs.setdefault("rows", list(set(pc_data.dims) - {"__variable__", "lag", "chain"})) + pc_kwargs.setdefault("rows", ["chain"] if "chain" in acf_dataset.dims else []) # Calculate the number of plots - n_plots, plots_per_var = process_facet_dims(pc_data, pc_kwargs["cols"]) + n_plots = len(acf_dataset.variable) * ( + len(acf_dataset.chain) if "chain" in acf_dataset.dims else 1 + ) + print("Number of plots:", n_plots) # Set up figure size figsize = pc_kwargs.get("plot_grid_kws", {}).get("figsize", None) @@ -166,13 +131,14 @@ def plot_autocorr( pc_kwargs["plot_grid_kws"] = pc_kwargs.get("plot_grid_kws", {}).copy() pc_kwargs["plot_grid_kws"]["figsize"] = figsize pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy() - pc_kwargs["aes"].setdefault("y", ["lag"]) + pc_kwargs["aes"].setdefault("x", ["lag"]) # lag is on the x-axis + pc_kwargs["aes"].setdefault("y", ["autocorr"]) # autocorr is on the y-axis - if not combined and "chain" in distribution.dims: + if not combined and "chain" in acf_dataset.dims: pc_kwargs["aes"].setdefault("color", ["chain"]) plot_collection = PlotCollection.grid( - pc_data, + acf_dataset, backend=backend, **pc_kwargs, ) @@ -184,45 +150,28 @@ def plot_autocorr( if labeller is None: labeller = BaseLabeller() - # Convert acf_data to Dataset for compatibility with plot_collection.map - acf_dataset = acf_data.to_dataset(name="autocorr") - print("acf data is : ", acf_data) - # Plot autocorrelation lines line_kwargs = plot_kwargs.get("line", {}).copy() - line_kwargs.setdefault("linewidth", 1.5) # Use 'linewidth' instead of 'width' + line_kwargs.setdefault("linewidth", 1.5) line_kwargs.setdefault("color", "#1f77b4") - print("line_kwargs:", line_kwargs) - print("aes_map:", aes_map) - - acf_dims, acf_aes, acf_ignore = filter_aes(plot_collection, aes_map, "line", sample_dims) - - print(acf_dataset) - print(acf_ignore) plot_collection.map( line_x, "autocorr", data=acf_dataset, - ignore_aes=acf_ignore, + ignore_aes=plot_collection.aes_set - {"x", "y", "color"}, **line_kwargs, ) # Add reference line at 0 ref_line_kwargs = plot_kwargs.get("reference_line", {}).copy() ref_line_kwargs.setdefault("color", "gray") - ref_line_kwargs.setdefault("linewidth", 1) # Use 'linewidth' instead of 'width' + ref_line_kwargs.setdefault("linewidth", 1) ref_line_kwargs.setdefault("linestyle", "--") - # Create zero line as a proper DataArray zero_line = xr.DataArray(np.zeros(max_lag), dims=["lag"], coords={"lag": np.arange(max_lag)}) - - # Convert to dataset with the same variable name as the main data zero_line = zero_line.to_dataset(name="autocorr") - if "column" not in zero_line.dims: - zero_line = zero_line.expand_dims(column=["autocorr"]) - plot_collection.map( line_x, "reference_line", @@ -234,13 +183,10 @@ def plot_autocorr( # Add titles for each plot title_kwargs = plot_kwargs.get("title", {}).copy() if title_kwargs is not False: - _, title_aes, title_ignore = filter_aes(plot_collection, aes_map, "title", sample_dims) - if "color" not in title_aes: - title_kwargs.setdefault("color", "black") plot_collection.map( labelled_title, "title", - ignore_aes=title_ignore, + ignore_aes=plot_collection.aes_set, subset_info=True, labeller=labeller, **title_kwargs, From f4928f6744670c9b0052d5da543eda5121a1eff2 Mon Sep 17 00:00:00 2001 From: Suhaani Agarwal Date: Wed, 12 Mar 2025 16:09:46 +0530 Subject: [PATCH 3/3] removed changes from plot_collection --- src/arviz_plots/plot_collection.py | 23 ++++------------------- 1 file changed, 4 insertions(+), 19 deletions(-) diff --git a/src/arviz_plots/plot_collection.py b/src/arviz_plots/plot_collection.py index 08766a38..2a447848 100644 --- a/src/arviz_plots/plot_collection.py +++ b/src/arviz_plots/plot_collection.py @@ -3,7 +3,6 @@ import warnings from importlib import import_module -import matplotlib.pyplot as plt import numpy as np import xarray as xr from arviz_base import rcParams @@ -538,6 +537,7 @@ def generate_aes_dt(self, aes=None, **kwargs): total_aes_vals, {aes_key: aes_vals}, ) + print("aes values : ", aes_vals) aes_da = xr.DataArray( np.array(aes_vals).reshape(aes_shape), dims=dims, @@ -845,6 +845,8 @@ def grid( coords={dim: da[dim] for dim in dims}, ) viz_dt = xr.DataTree.from_dict(viz_dict) + print("viz_dt is : ", viz_dt) + print("grid function output is : ", cls(data, viz_dt, backend=backend, **kwargs)) return cls(data, viz_dt, backend=backend, **kwargs) def update_aes(self, ignore_aes=frozenset(), coords=None): @@ -874,26 +876,9 @@ def allocate_artist(self, fun_label, data, all_loop_dims, artist_dims=None): coords={dim: data[dim] for dim in inherited_dims}, ) - # def get_target(self, var_name, selection): - # """Get the target that corresponds to the given variable and selection.""" - # return subset_ds(self.get_viz(var_name), "plot", selection) - def get_target(self, var_name, selection): """Get the target that corresponds to the given variable and selection.""" - selection_key = tuple(sorted(selection.items())) - target = self.viz[var_name].get(selection_key, None) - # Ensuring the target is always a Matplotlib Axes object - if target is None: - print(f"Warning: Selection {selection_key} not found in viz[{var_name}]") - fig, target = plt.subplots() # Creating new figure - plt.close(fig) # Closing figure to prevent display issues - - if not isinstance(target, plt.Axes): - print("Warning: Target is not an Axes object. Creating a new figure.") - fig, target = plt.subplots() - plt.close(fig) - - return target + return subset_ds(self.get_viz(var_name), "plot", selection) def get_aes_kwargs(self, aes, var_name, selection): """Get the aesthetic mappings for the given variable and selection as a dictionary.