|
8 | 8 | from datatree import DataTree
|
9 | 9 |
|
10 | 10 |
|
| 11 | +def concat_model_dict(data): |
| 12 | + """Merge multiple Datasets into a single one along a new model dimension.""" |
| 13 | + if isinstance(data, dict): |
| 14 | + ds_list = data.values() |
| 15 | + if not all(isinstance(ds, xr.Dataset) for ds in ds_list): |
| 16 | + raise TypeError("Provided data must be a Dataset or dictionary of Datasets") |
| 17 | + data = xr.concat(ds_list, dim="model").assign_coords(model=list(data)) |
| 18 | + return data |
| 19 | + |
| 20 | + |
11 | 21 | def sel_subset(sel, present_dims):
|
12 | 22 | """Subset a dictionary of dim: coord values.
|
13 | 23 |
|
@@ -97,8 +107,15 @@ def _get_aes_dict_from_dt(aes_dt):
|
97 | 107 | an aes DataTree directly when initializating a PlotCollection object.
|
98 | 108 | This method is used to generate the more basic dictionary from the DataTree.
|
99 | 109 | """
|
| 110 | + child_list = list(aes_dt.children.values()) |
100 | 111 | aes = {}
|
101 |
| - for ds in aes_dt.children.values(): |
| 112 | + aes_in_all_vars = set.intersection(*[set(child.data_vars) for child in child_list]) |
| 113 | + aes = { |
| 114 | + aes_key: ["__variable__"] |
| 115 | + for aes_key in aes_in_all_vars |
| 116 | + if any(child[aes_key].item(0) != child_list[0][aes_key].item(0) for child in child_list) |
| 117 | + } |
| 118 | + for ds in child_list: |
102 | 119 | for aes_key, values in ds.items():
|
103 | 120 | if not values.dims:
|
104 | 121 | continue
|
@@ -211,7 +228,7 @@ def data(self):
|
211 | 228 | @data.setter
|
212 | 229 | def data(self, value):
|
213 | 230 | # might want/be possible to make some checks on the data before setting it
|
214 |
| - self._data = value |
| 231 | + self._data = concat_model_dict(value) |
215 | 232 |
|
216 | 233 | @property
|
217 | 234 | def aes_set(self):
|
@@ -301,31 +318,66 @@ def generate_aes_dt(self, aes=None, **kwargs):
|
301 | 318 | but it will always be possible to set their value manually.
|
302 | 319 | """
|
303 | 320 | if aes is None:
|
304 |
| - aes = {} |
| 321 | + aes = self._aes |
| 322 | + kwargs = self._kwargs |
305 | 323 | self._aes = aes
|
306 | 324 | self._kwargs = kwargs
|
307 |
| - self._aes_dt = DataTree() |
308 |
| - for var_name, da in self.data.items(): |
309 |
| - ds = xr.Dataset() |
310 |
| - for aes_key, dims in aes.items(): |
311 |
| - aes_vals = kwargs.get(aes_key, [None]) |
312 |
| - aes_dims = [dim for dim in dims if dim in da.dims] |
313 |
| - aes_raw_shape = [da.sizes[dim] for dim in aes_dims] |
314 |
| - if not aes_raw_shape: |
315 |
| - ds[aes_key] = aes_vals[0] |
316 |
| - continue |
317 |
| - n_aes = np.prod(aes_raw_shape) |
318 |
| - n_aes_vals = len(aes_vals) |
319 |
| - if n_aes_vals > n_aes: |
320 |
| - aes_vals = aes_vals[:n_aes] |
321 |
| - elif n_aes_vals < n_aes: |
322 |
| - aes_vals = np.tile(aes_vals, (n_aes // n_aes_vals) + 1)[:n_aes] |
323 |
| - ds[aes_key] = xr.DataArray( |
324 |
| - np.array(aes_vals).reshape(aes_raw_shape), |
325 |
| - dims=aes_dims, |
326 |
| - coords={dim: da.coords[dim] for dim in dims if dim in da.coords}, |
| 325 | + if not hasattr(self, "backend"): |
| 326 | + plot_bknd = import_module(".backend", package="arviz_plots") |
| 327 | + else: |
| 328 | + plot_bknd = import_module(f".backend.{self.backend}", package="arviz_plots") |
| 329 | + get_default_aes = plot_bknd.get_default_aes |
| 330 | + ds_dict = {var_name: xr.Dataset() for var_name in self.data.data_vars} |
| 331 | + for aes_key, dims in aes.items(): |
| 332 | + if "__variable__" in dims: |
| 333 | + total_aes_vals = int( |
| 334 | + np.sum( |
| 335 | + [ |
| 336 | + np.prod([size for dim, size in da.sizes.items() if dim in dims]) |
| 337 | + for da in self.data.values() |
| 338 | + ] |
| 339 | + ) |
| 340 | + ) |
| 341 | + aes_vals = get_default_aes(aes_key, total_aes_vals, kwargs) |
| 342 | + aes_cumulative = 0 |
| 343 | + for var_name, da in self.data.items(): |
| 344 | + ds = ds_dict[var_name] |
| 345 | + aes_dims = [dim for dim in dims if dim in da.dims] |
| 346 | + aes_raw_shape = [da.sizes[dim] for dim in aes_dims] |
| 347 | + if not aes_raw_shape: |
| 348 | + ds[aes_key] = np.asarray(aes_vals)[ |
| 349 | + aes_cumulative : aes_cumulative + 1 |
| 350 | + ].squeeze() |
| 351 | + aes_cumulative += 1 |
| 352 | + continue |
| 353 | + n_aes = np.prod(aes_raw_shape) |
| 354 | + ds[aes_key] = xr.DataArray( |
| 355 | + np.array(aes_vals[aes_cumulative : aes_cumulative + n_aes]).reshape( |
| 356 | + aes_raw_shape |
| 357 | + ), |
| 358 | + dims=aes_dims, |
| 359 | + coords={dim: da.coords[dim] for dim in dims if dim in da.coords}, |
| 360 | + ) |
| 361 | + aes_cumulative += n_aes |
| 362 | + else: |
| 363 | + total_aes_vals = int( |
| 364 | + np.prod([self.data.sizes[dim] for dim in self.data.dims if dim in dims]) |
327 | 365 | )
|
328 |
| - DataTree(name=var_name, parent=self._aes_dt, data=ds) |
| 366 | + aes_vals = get_default_aes(aes_key, total_aes_vals, kwargs) |
| 367 | + for var_name, da in self.data.items(): |
| 368 | + ds = ds_dict[var_name] |
| 369 | + aes_dims = [dim for dim in dims if dim in da.dims] |
| 370 | + aes_raw_shape = [da.sizes[dim] for dim in aes_dims] |
| 371 | + if not aes_raw_shape: |
| 372 | + ds[aes_key] = aes_vals[0] |
| 373 | + continue |
| 374 | + n_aes = np.prod(aes_raw_shape) |
| 375 | + ds[aes_key] = xr.DataArray( |
| 376 | + np.array(aes_vals[:n_aes]).reshape(aes_raw_shape), |
| 377 | + dims=aes_dims, |
| 378 | + coords={dim: da.coords[dim] for dim in dims if dim in da.coords}, |
| 379 | + ) |
| 380 | + self._aes_dt = DataTree.from_dict(ds_dict) |
329 | 381 |
|
330 | 382 | @property
|
331 | 383 | def base_loop_dims(self):
|
@@ -385,8 +437,7 @@ def wrap(
|
385 | 437 | plot_grid_kws = {}
|
386 | 438 | if backend is None:
|
387 | 439 | backend = rcParams["plot.backend"]
|
388 |
| - if isinstance(data, dict): |
389 |
| - data = xr.concat(data.values(), dim="model").assign_coords(model=list(data)) |
| 440 | + data = concat_model_dict(data) |
390 | 441 |
|
391 | 442 | n_plots, plots_per_var = _process_facet_dims(data, cols)
|
392 | 443 | if n_plots <= col_wrap:
|
@@ -501,8 +552,7 @@ def grid(
|
501 | 552 | repeated_dims = [col for col in cols if col in rows]
|
502 | 553 | if repeated_dims:
|
503 | 554 | raise ValueError("The same dimension can't be used for both cols and rows.")
|
504 |
| - if isinstance(data, dict): |
505 |
| - data = xr.concat(data.values(), dim="model").assign_coords(model=list(data)) |
| 555 | + data = concat_model_dict(data) |
506 | 556 |
|
507 | 557 | n_cols, cols_per_var = _process_facet_dims(data, cols)
|
508 | 558 | n_rows, rows_per_var = _process_facet_dims(data, rows)
|
|
0 commit comments