Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/_quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ quartodoc:
- GT.cols_align
- GT.cols_width
- GT.cols_label
- GT.cols_label_with
- GT.cols_move
- GT.cols_move_to_start
- GT.cols_move_to_end
Expand Down
158 changes: 139 additions & 19 deletions great_tables/_boxhead.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from typing import Callable, TYPE_CHECKING

from ._locations import resolve_cols_c
from ._utils import _assert_list_is_subset
from ._utils import _assert_list_is_subset, _handle_units_syntax
from ._tbl_data import SelectExpr
from ._text import BaseText

if TYPE_CHECKING:
from ._types import GTSelf
from polars.selectors import _selector_proxy_

PlSelectExpr = _selector_proxy_


def cols_label(
Expand Down Expand Up @@ -114,8 +117,6 @@ def cols_label(
)
```
"""
from great_tables._helpers import UnitStr

cases = cases if cases is not None else {}
new_cases = cases | kwargs

Expand All @@ -132,26 +133,145 @@ def cols_label(
_assert_list_is_subset(mod_columns, set_list=column_names)

# Handle units syntax in labels (e.g., "Density ({{ppl / mi^2}})")
new_kwargs: dict[str, UnitStr | str | BaseText] = {}
new_kwargs = _handle_units_syntax(new_cases)

for k, v in new_cases.items():
if isinstance(v, str):
unitstr_v = UnitStr.from_str(v)
boxhead = self._boxhead._set_column_labels(new_kwargs)

if len(unitstr_v.units_str) == 1 and isinstance(unitstr_v.units_str[0], str):
new_kwargs[k] = unitstr_v.units_str[0]
else:
new_kwargs[k] = unitstr_v
return self._replace(_boxhead=boxhead)

elif isinstance(v, BaseText):
new_kwargs[k] = v

else:
raise ValueError(
"Column labels must be strings or BaseText objects. Use `md()` or `html()` for formatting."
)
def cols_label_with(
self: GTSelf,
columns: SelectExpr = None,
converter: Callable[[str], str] | PlSelectExpr | list[PlSelectExpr] | None = None,
) -> GTSelf:
"""
Relabel one or more columns using a function or a Polars expression.

boxhead = self._boxhead._set_column_labels(new_kwargs)
The `cols_label_with()` function allows for modification of column labels through a supplied
function. By default, the function will be invoked on all column labels but this can be limited
to a subset via the `columns=` parameter.

Alternatively, you can utilize the
[name](https://docs.pola.rs/api/python/stable/reference/expressions/name.html) attribute of
Polars expressions.

:::{.callout-warning}
If Polars expressions are utilized, the `columns=` parameter will be ignored, as **Great Tables**
can infer the original column labels from the expression.
:::

Parameters
----------
columns
The columns to target. Can either be a single column name or a series of column names
provided in a list.
converter
A function that takes a column label as input and returns a transformed label.
Alternatively, you can use a Polars expression or a list of Polars expressions to describe
the transformations.

Returns
-------
GT
The GT object is returned. This is the same object that the method is called on so that we
can facilitate method chaining.

Notes
-----
GT always selects columns using their name in the underlying data. This means that a column's
label is purely for final presentation.

Examples
--------
Let's use a subset of the `sp500` dataset to demonstrate how to convert all column labels to
uppercase using `str.upper()`.

```{python}
import polars as pl
from polars import selectors as cs

from great_tables import GT, md
from great_tables.data import sp500

sp500_mini = sp500.head()

GT(sp500_mini).cols_label_with(converter=str.upper)
```

One useful use case is using `md()`, provided by **Great Tables**, to format column labels.
For example, the following code demonstrates how to make the `date` and `adj_close` column labels
bold using markdown syntax.

```{python}
GT(sp500_mini).cols_label_with(["date", "adj_close"], lambda x: md(f"**{x}**"))
```

Now, let's see how to use Polars expressions to relabel a table when the underlying dataframe
comes from Polars. For instance, you can convert all column labels to uppercase using
`pl.all().name.to_uppercase()`.

```{python}
sp500_mini_pl = pl.from_pandas(sp500_mini)
GT(sp500_mini_pl).cols_label_with(converter=pl.all().name.to_uppercase())
```

Polars selectors are also supported. The following example demonstrates how to add a "str_"
prefix to string columns using `cs.string().name.prefix("str_")`.

```{python}
GT(sp500_mini_pl).cols_label_with(converter=cs.string().name.prefix("str_"))
```

Passing a list of Polars expressions is also supported. The following example shows how to
add a "str_" prefix to string columns using `cs.string().name.prefix("str_")`
and a "_num" suffix to numerical columns using `cs.numeric().name.suffix("_num")`.

```{python}
GT(sp500_mini_pl).cols_label_with(
converter=[cs.string().name.prefix("str_"), cs.numeric().name.suffix("_num")]
)
```

One final note: if a column is selected multiple times in different Polars expressions,
the last applied transformation takes precedence. For example, applying
`cs.all().name.to_uppercase()` followed by `cs.all().name.suffix("_all")`
will result in only the latter being used for relabeling.

```{python}
GT(sp500_mini_pl).cols_label_with(
converter=[cs.all().name.to_uppercase(), cs.all().name.suffix("_all")]
)
```

"""
if converter is None:
raise ValueError("Must provide the `converter=` parameter to use `cols_label_with()`.")

if isinstance(converter, Callable):
# Get the full list of column names for the data
column_names = self._boxhead._get_columns()

if isinstance(columns, str):
columns = [columns]
_assert_list_is_subset(columns, set_list=column_names)
elif columns is None:
columns = column_names

sel_cols = resolve_cols_c(data=self, expr=columns)

new_cases = {col: converter(col) for col in sel_cols}

else: # pl.col().expr.name.method() or selector.name.method() or [...]
frame = self._tbl_data
new_cases: dict[str, str] = {}
exprs = converter if isinstance(converter, list) else [converter]
for expr in exprs:
sel_cols: list[str] = frame.select(expr.meta.undo_aliases()).columns
new_cols: list[str] = frame.select(expr).columns
new_cases |= dict(zip(sel_cols, new_cols))

boxhead = self._boxhead._set_column_labels(new_cases)

return self._replace(_boxhead=boxhead)

Expand Down
24 changes: 24 additions & 0 deletions great_tables/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from types import ModuleType
from typing import TYPE_CHECKING, Any, Iterable, Iterator

from ._helpers import UnitStr
from ._tbl_data import _get_cell, _set_cell, get_column_names, n_rows
from ._text import BaseText, _process_text

Expand Down Expand Up @@ -285,3 +286,26 @@ def _get_visible_cells(data: TblData) -> list[tuple[str, int]]:

def is_valid_http_schema(url: str) -> bool:
return url.startswith("http://") or url.startswith("https://")


def _handle_units_syntax(cases: dict[str, str | BaseText]) -> dict[str, UnitStr | str | BaseText]:
# Handle units syntax in labels (e.g., "Density ({{ppl / mi^2}})")
kwargs: dict[str, UnitStr | str | BaseText] = {}

for k, v in cases.items():
if isinstance(v, str):
unitstr_v = UnitStr.from_str(v)

if len(unitstr_v.units_str) == 1 and isinstance(unitstr_v.units_str[0], str):
kwargs[k] = unitstr_v.units_str[0]
else:
kwargs[k] = unitstr_v

elif isinstance(v, BaseText):
kwargs[k] = v

else:
raise ValueError(
"Column labels must be strings or BaseText objects. Use `md()` or `html()` for formatting."
)
return kwargs
3 changes: 2 additions & 1 deletion great_tables/gt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

# Main gt imports ----
from ._body import body_reassemble
from ._boxhead import cols_align, cols_label
from ._boxhead import cols_align, cols_label, cols_label_with
from ._data_color import data_color
from ._export import as_latex, as_raw_html, save, show, write_raw_html
from ._formats import (
Expand Down Expand Up @@ -253,6 +253,7 @@ def __init__(
cols_align = cols_align
cols_width = cols_width
cols_label = cols_label
cols_label_with = cols_label_with
cols_move = cols_move
cols_move_to_start = cols_move_to_start
cols_move_to_end = cols_move_to_end
Expand Down
67 changes: 67 additions & 0 deletions tests/test__boxhead.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,73 @@ def test_cols_label_return_self_if_no_kwargs():
assert isinstance(unmodified_table, gt.GT)


def test_cols_label_with_relabel_columns():
# Create a table with default column labels
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
table = gt.GT(df)

# Relabel the columns
modified_table = table.cols_label_with(converter=str.lower)

# Check that the column labels have been updated
assert modified_table._boxhead._get_column_labels() == ["a", "b"]


@pytest.mark.parametrize(
"converter",
[
pl.col("my_col", "my_col2").name.to_uppercase(),
pl.col(["my_col", "my_col2"]).name.to_uppercase(),
cs.by_name("my_col", "my_col2").name.to_uppercase(),
cs.by_name(["my_col", "my_col2"]).name.to_uppercase(),
cs.starts_with("my").name.to_uppercase(),
[cs.by_name("my_col").name.to_uppercase(), cs.by_name("my_col2").name.to_uppercase()],
[pl.col("my_col").name.to_uppercase(), cs.by_name("my_col2").name.to_uppercase()],
[cs.all().name.suffix("_suffix"), cs.all().name.to_uppercase()], # test for last one wins
pl.col("my_col2", "my_col").name.to_uppercase(), # test for column positions
],
)
def test_cols_label_with_relabel_columns_polars(converter):
# Create a table with default column labels
df = pl.DataFrame({"my_col": [1, 2, 3], "my_col2": [4, 5, 6]})
table = gt.GT(df)

# Relabel the columns
modified_table = table.cols_label_with(converter=converter)

# Check that the column labels have been updated
assert modified_table._boxhead._get_column_labels() == ["MY_COL", "MY_COL2"]


def test_cols_label_with_relabel_columns_with_markdown():
# Create a table with default column labels
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
table = gt.GT(df)

# Relabel a column with a Markdown formatted label
modified_table = table.cols_label_with("A", lambda x: gt.md(f"**{x}**"))

# Check that the column label has been updated with Markdown formatting
modified_column_labels = modified_table._boxhead._get_column_labels()

assert modified_column_labels[0].text == "**A**"
assert modified_column_labels[1] == "B"


def test_cols_label_with_raises():
# Create a table with default column labels
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
table = gt.GT(df)

with pytest.raises(ValueError) as exc_info:
table.cols_label_with()

assert (
"Must provide the `converter=` parameter to use `cols_label_with()`."
in exc_info.value.args[0]
)


def test_cols_align_default():
df = pd.DataFrame({"A": [1, 2, 3], "B": [4, 5, 6]})
table = gt.GT(df)
Expand Down
21 changes: 20 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import pytest


from great_tables import GT, exibble
from great_tables import GT, exibble, md
from great_tables._tbl_data import is_na
from great_tables._utils import (
_assert_list_is_subset,
_assert_str_in_set,
_assert_str_list,
_assert_str_scalar,
_collapse_list_elements,
_handle_units_syntax,
_insert_into_list,
_match_arg,
_migrate_unformatted_to_output,
Expand Down Expand Up @@ -224,3 +225,21 @@ def test_migrate_unformatted_to_output_html():
)
def test_is_valid_http_schema(url: str):
assert is_valid_http_schema(url)


def test_handle_units_syntax():
from great_tables._text import BaseText

new_kwargs = _handle_units_syntax({"column_label_1": "abc", "column_label_2": md(text="xyz")})

assert all(isinstance(v, (str, BaseText)) for v in new_kwargs.values())


def test_handle_units_syntax_raises():
with pytest.raises(ValueError) as exc_info:
_handle_units_syntax({"column_label": 123})

assert (
"Column labels must be strings or BaseText objects. Use `md()` or `html()` for formatting."
in exc_info.value.args[0]
)