diff --git a/optuna/visualization/_rank.py b/optuna/visualization/_rank.py index 15f74ea6ae..d7e4597af4 100644 --- a/optuna/visualization/_rank.py +++ b/optuna/visualization/_rank.py @@ -208,17 +208,49 @@ def _get_rank_subplot_info( def _get_axis_info(trials: list[FrozenTrial], param_name: str) -> _AxisInfo: - values: list[str | float | None] + # Single-pass extraction of required statistics + values: list[str | float | None] = [] + min_value = None + max_value = None + unique_values = set() + has_none = False + is_numerical = _is_numerical(trials, param_name) if is_numerical: - values = [t.params.get(param_name) for t in trials] + for t in trials: + v = t.params.get(param_name) + values.append(v) + if v is not None: + if min_value is None or v < min_value: + min_value = v + if max_value is None or v > max_value: + max_value = v else: - values = [ - str(t.params.get(param_name)) if param_name in t.params else None for t in trials - ] + for t in trials: + if param_name in t.params: + v = str(t.params[param_name]) + else: + v = None + values.append(v) + unique_values.add(v) + if v is None: + has_none = True + + # For categorical, compute min/max indices after + # (keeping values as category labels) - min_value = min([v for v in values if v is not None]) - max_value = max([v for v in values if v is not None]) + if is_numerical: + # At this point, min_value and max_value are computed in previous loop + # Ensure filtered for None + if min_value is None or max_value is None: + # This raises ValueError like min/max with an empty sequence + raise ValueError("No non-None values in values for numerical parameter.") + else: + # unique_values already collected, has_none flags if None is present + unique_values_count = len(unique_values) + span = unique_values_count - 1 + if has_none: + span -= 1 if _is_log_scale(trials, param_name): min_value = float(min_value) @@ -239,13 +271,11 @@ def _get_axis_info(trials: list[FrozenTrial], param_name: str) -> _AxisInfo: is_cat = False else: - unique_values = set(values) - span = len(unique_values) - 1 - if None in unique_values: - span -= 1 padding = span * PADDING_RATIO - min_value = -padding - max_value = span + padding + min_axis = -padding + max_axis = span + padding + min_value = min_axis + max_value = max_axis is_log = False is_cat = True