Skip to content

Commit

Permalink
Transform columnar data for interactive (#1042)
Browse files Browse the repository at this point in the history
Co-authored-by: maximlt <[email protected]>
  • Loading branch information
hoxbro and maximlt authored Mar 16, 2023
1 parent 0e0add8 commit 4e9e59b
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 15 deletions.
12 changes: 3 additions & 9 deletions hvplot/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
is_streamz, is_ibis, is_xarray, is_xarray_dataarray, process_crs,
process_intake, process_xarray, check_library, is_geodataframe,
process_derived_datetime_xarray, process_derived_datetime_pandas,
_convert_col_names_to_str,
)
from .utilities import hvplot_extension

Expand Down Expand Up @@ -655,12 +656,6 @@ def _process_crs(self, data, crs):
"'{}' must be either a valid crs or an reference to "
"a `data.attr` containing a valid crs.".format(crs))

def _transform_columnar_data(self, data):
renamed = {c: str(c) for c in data.columns if not isinstance(c, str)}
if renamed:
data = data.rename(columns=renamed)
return data

def _process_data(self, kind, data, x, y, by, groupby, row, col,
use_dask, persist, backlog, label, group_label,
value_label, hover_cols, attr_labels, transforms,
Expand All @@ -681,8 +676,7 @@ def _process_data(self, kind, data, x, y, by, groupby, row, col,
# update the `_dataset` property (of the hv object its __call__ method
# returns) with a hv Dataset created from the source data, which
# is done for optimizating some operations in HoloViews.
if hasattr(data, 'columns'):
data = self._transform_columnar_data(data)
data = _convert_col_names_to_str(data)

self.source_data = data

Expand Down Expand Up @@ -1571,7 +1565,7 @@ def _process_chart_args(self, data, x, y, single_y=False, categories=None):
if data is None:
data = self.data
elif not self.gridded_data:
data = self._transform_columnar_data(data)
data = _convert_col_names_to_str(data)

x = self._process_chart_x(data, x, y, single_y, categories=categories)
y = self._process_chart_y(data, x, y, single_y)
Expand Down
9 changes: 6 additions & 3 deletions hvplot/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,10 @@
from panel.widgets.base import Widget

from .converter import HoloViewsConverter
from .util import _flatten, is_tabular, is_xarray, is_xarray_dataarray
from .util import (
_flatten, is_tabular, is_xarray, is_xarray_dataarray,
_convert_col_names_to_str,
)


def _find_widgets(op):
Expand Down Expand Up @@ -272,7 +275,7 @@ def __init__(self, obj, transform=None, fn=None, plot=False, depth=0,
self._inherit_kwargs = inherit_kwargs
self._max_rows = max_rows
self._kwargs = kwargs
ds = hv.Dataset(self._obj)
ds = hv.Dataset(_convert_col_names_to_str(self._obj))
if _current is not None:
self._current_ = _current
else:
Expand Down Expand Up @@ -682,7 +685,7 @@ def eval(self):
"""
if self._dirty:
obj = self._obj
ds = hv.Dataset(obj)
ds = hv.Dataset(_convert_col_names_to_str(obj))
transform = self._transform
if ds.interface.datatype == 'xarray' and is_xarray_dataarray(obj):
transform = transform.clone(obj.name)
Expand Down
4 changes: 2 additions & 2 deletions hvplot/plotting/scatter_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from ..backend_transforms import _transfer_opts_cur_backend
from ..converter import HoloViewsConverter
from ..util import with_hv_extension
from ..util import with_hv_extension, _convert_col_names_to_str


@with_hv_extension
Expand Down Expand Up @@ -79,7 +79,7 @@ def scatter_matrix(data, c=None, chart='scatter', diagonal='hist',
:func:`pandas.plotting.scatter_matrix` : Equivalent pandas function.
"""

data = _hv.Dataset(data)
data = _hv.Dataset(_convert_col_names_to_str(data))
supported = list(HoloViewsConverter._kind_mapping)
if diagonal not in supported:
raise ValueError('diagonal type must be one of: %s, found %s' %
Expand Down
15 changes: 15 additions & 0 deletions hvplot/tests/testinteractive.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,3 +1381,18 @@ def piped(df, msg):
df.interactive.pipe(piped, msg="1").pipe(piped, msg="2")

assert len(msgs) == 3


def test_interactive_accept_non_str_columnar_data():
df = pd.DataFrame(np.random.random((10, 2)))
assert all(not isinstance(col, str) for col in df.columns)
dfi = Interactive(df)

w = pn.widgets.FloatSlider(start=0, end=1, step=0.05)

# Column names converted as string so can no longer use dfi[1]
dfi = dfi['1'] + w.param.value

w.value = 0.5

pytest.approx(dfi.eval().sum(), (df[1] + 0.5).sum())
12 changes: 11 additions & 1 deletion hvplot/tests/testutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,10 @@

from unittest import TestCase, SkipTest

from hvplot.util import check_crs, is_list_like, process_crs, process_xarray
from hvplot.util import (
check_crs, is_list_like, process_crs, process_xarray,
_convert_col_names_to_str,
)


class TestProcessXarray(TestCase):
Expand Down Expand Up @@ -314,3 +317,10 @@ def test_is_list_like():
assert is_list_like(pd.Series(['a', 'b']))
assert is_list_like(pd.Index(['a', 'b']))
assert is_list_like(np.array(['a', 'b']))


def test_convert_col_names_to_str():
df = pd.DataFrame(np.random.random((10, 2)))
assert all(not isinstance(col, str) for col in df.columns)
df = _convert_col_names_to_str(df)
assert all(isinstance(col, str) for col in df.columns)
21 changes: 21 additions & 0 deletions hvplot/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

import sys

from collections.abc import Hashable

from functools import wraps
from packaging.version import Version
from types import FunctionType
Expand Down Expand Up @@ -548,3 +550,22 @@ def _flatten(line):
yield from _flatten(element)
else:
yield element


def _convert_col_names_to_str(data):
"""
Convert column names to string.
"""
# There's no generic way to rename columns across tabular object types.
# `columns` could refer to anything else on the object, e.g. a dim
# on an xarray DataArray. So this may need to be stricter.
if not hasattr(data, 'columns') or not hasattr(data, 'rename'):
return data
renamed = {
c: str(c)
for c in data.columns
if not isinstance(c, str) and isinstance(c, Hashable)
}
if renamed:
data = data.rename(columns=renamed)
return data

0 comments on commit 4e9e59b

Please sign in to comment.