Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 43 additions & 13 deletions optuna/visualization/_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,49 @@ def _get_rank_subplot_info(


def _get_axis_info(trials: list[FrozenTrial], param_name: str) -> _AxisInfo:
values: list[str | float | None]
# Single-pass extraction of required statistics
values: list[str | float | None] = []
min_value = None
max_value = None
unique_values = set()
has_none = False

is_numerical = _is_numerical(trials, param_name)
if is_numerical:
values = [t.params.get(param_name) for t in trials]
for t in trials:
v = t.params.get(param_name)
values.append(v)
if v is not None:
if min_value is None or v < min_value:
min_value = v
if max_value is None or v > max_value:
max_value = v
else:
values = [
str(t.params.get(param_name)) if param_name in t.params else None for t in trials
]
for t in trials:
if param_name in t.params:
v = str(t.params[param_name])
else:
v = None
values.append(v)
unique_values.add(v)
if v is None:
has_none = True

# For categorical, compute min/max indices after
# (keeping values as category labels)

min_value = min([v for v in values if v is not None])
max_value = max([v for v in values if v is not None])
if is_numerical:
# At this point, min_value and max_value are computed in previous loop
# Ensure filtered for None
if min_value is None or max_value is None:
# This raises ValueError like min/max with an empty sequence
raise ValueError("No non-None values in values for numerical parameter.")
else:
# unique_values already collected, has_none flags if None is present
unique_values_count = len(unique_values)
span = unique_values_count - 1
if has_none:
span -= 1

if _is_log_scale(trials, param_name):
min_value = float(min_value)
Expand All @@ -239,13 +271,11 @@ def _get_axis_info(trials: list[FrozenTrial], param_name: str) -> _AxisInfo:
is_cat = False

else:
unique_values = set(values)
span = len(unique_values) - 1
if None in unique_values:
span -= 1
padding = span * PADDING_RATIO
min_value = -padding
max_value = span + padding
min_axis = -padding
max_axis = span + padding
min_value = min_axis
max_value = max_axis
is_log = False
is_cat = True

Expand Down