Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add size-time warning on cpu #247

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions src/tabpfn/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
default_classifier_preprocessor_configs,
)
from tabpfn.utils import (
_create_time_usage_tracker,
_fix_dtypes,
_get_embeddings,
_get_ordinal_encoder,
Expand Down Expand Up @@ -386,6 +387,11 @@ def fit(self, X: XType, y: YType) -> Self:
"""
static_seed, rng = infer_random_state(self.random_state)

# Track time usage
TimeUsageTracker = _create_time_usage_tracker()
time_tracker = TimeUsageTracker()
time_tracker.start("TabPFNClassifier fit")

# Load the model and config
self.model_, self.config_, _ = initialize_tabpfn_model(
model_path=self.model_path,
Expand Down Expand Up @@ -425,6 +431,7 @@ def fit(self, X: XType, y: YType) -> Self:
max_num_samples=self.interface_config_.MAX_NUMBER_OF_SAMPLES,
max_num_features=self.interface_config_.MAX_NUMBER_OF_FEATURES,
ignore_pretraining_limits=self.ignore_pretraining_limits,
device=self.device_,
)
if feature_names_in is not None:
self.feature_names_in_ = feature_names_in
Expand Down Expand Up @@ -513,6 +520,8 @@ def fit(self, X: XType, y: YType) -> Self:
use_autocast_=self.use_autocast_,
)

time_tracker.stop("TabPFNClassifier fit")

return self

def predict(self, X: XType) -> np.ndarray:
Expand Down Expand Up @@ -540,6 +549,11 @@ def predict_proba(self, X: XType) -> np.ndarray:
"""
check_is_fitted(self)

# Track time usage
TimeUsageTracker = _create_time_usage_tracker()
time_tracker = TimeUsageTracker()
time_tracker.start("TabPFNClassifier predict_proba")

X = validate_X_predict(X, self)
X = _fix_dtypes(X, cat_indices=self.categorical_features_indices)

Expand Down Expand Up @@ -586,6 +600,7 @@ def predict_proba(self, X: XType) -> np.ndarray:
output = np.around(output, decimals=SKLEARN_16_DECIMAL_PRECISION)
output = np.where(output < PROBABILITY_EPSILON_ROUND_ZERO, 0.0, output)

time_tracker.stop("TabPFNClassifier predict_proba")
# Normalize to guarantee proba sum to 1, required due to precision issues and
# going from torch to numpy
return output / output.sum(axis=1, keepdims=True) # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions src/tabpfn/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

NA_PLACEHOLDER = "__MISSING__"

X_SIZE_WARNING_CPU = 1000 # samples
TIME_WARNING_CPU = 10 * 60 # seconds

SKLEARN_16_DECIMAL_PRECISION = 16
PROBABILITY_EPSILON_ROUND_ZERO = 1e-3
REGRESSION_NAN_BORDER_LIMIT_UPPER = 1e3
Expand Down
112 changes: 112 additions & 0 deletions src/tabpfn/model/time_tracker.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright (c) Prior Labs GmbH 2025.

from __future__ import annotations

import threading
import time
import warnings

from tabpfn.constants import TIME_WARNING_CPU


class TimeUsageTracker:
"""Tracks time usage for code segments.

Supports starting, stopping, and resetting segments. It will
warn if any segment runs longer than the threshold defined by
`TIME_WARNING_CPU` (default: 10 minutes). Note that resetting an
active segment issues a warning.

Attributes:
active_segments (dict[str, float]):
Maps segment labels to their start times.

completed_segments (list[tuple[str, float]]):
Records completed segments with their elapsed times.

_timers (dict[str, threading.Timer]):
Timers that trigger warnings when a segment exceeds the threshold.

Raises:
RuntimeError: If attempting to start a segment that's
already active or stop one that isn't active.
"""

def __init__(self) -> None:
self.active_segments: dict[str, float] = {}
self.completed_segments: list[tuple[str, float]] = []
self._timers: dict[str, threading.Timer] = {}

def _warn_if_over_limit(self, label: str, start_time: float) -> None:
# Verify the segment is still active and its start time hasn't changed.
if label in self.active_segments and self.active_segments[label] == start_time:
warnings.warn(
f"{label} is taking > 10 minutes to run. "
"Use GPU for faster processing, or if unavailable, "
"try tabpfn-client API https://github.com/PriorLabs/tabpfn-client",
UserWarning,
stacklevel=2,
)

def _start_timer(self, label: str, start_time: float) -> None:
# Schedule a timer to check the segment after the limit
timer = threading.Timer(
TIME_WARNING_CPU, self._warn_if_over_limit, args=(label, start_time)
)
timer.daemon = True # Ensure timer doesn't block program exit
timer.start()
self._timers[label] = timer

def start(self, label: str) -> None:
"""Start a new segment with the given label.

Raises:
RuntimeError: If a segment with this label is already active.
"""
if label in self.active_segments:
raise RuntimeError(f"Segment '{label}' is already active.")
start_time = time.time()
self.active_segments[label] = start_time
self._start_timer(label, start_time)

def stop(self, label: str) -> None:
"""Stop the segment with the given label.

Raises:
RuntimeError: If no active segment exists with the given label.
"""
if label not in self.active_segments:
raise RuntimeError(f"No active segment found with label '{label}'.")

start_time = self.active_segments.pop(label)
# Cancel the timer for this segment, if it exists
if label in self._timers:
timer = self._timers.pop(label)
timer.cancel()
self.completed_segments.append((label, time.time() - start_time))

def reset(self, label: str) -> None:
"""Reset an active segment with the given label.

Issues a warning if resetting an active segment.
"""
if label in self.active_segments:
warnings.warn(f"Resetting {label} time segment.", stacklevel=2)
# Cancel the current timer, if it exists
if label in self._timers:
self._timers[label].cancel()

# Restart the timer with the new start time
start_time = time.time()
self.active_segments[label] = start_time
self._start_timer(label, start_time)

def total_time(self, labels: list[str] | None = None) -> float:
"""Return the total elapsed time for segments with the given labels.
If no labels are provided, returns the total for all completed segments.
"""
if labels is None:
return sum(duration for _, duration in self.completed_segments)
return sum(
duration for seg, duration in self.completed_segments if seg in labels
)
16 changes: 16 additions & 0 deletions src/tabpfn/regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
default_regressor_preprocessor_configs,
)
from tabpfn.utils import (
_create_time_usage_tracker,
_fix_dtypes,
_get_embeddings,
_get_ordinal_encoder,
Expand Down Expand Up @@ -414,6 +415,11 @@ def fit(self, X: XType, y: YType) -> Self:
"""
static_seed, rng = infer_random_state(self.random_state)

# Track time usage
TimeUsageTracker = _create_time_usage_tracker()
time_tracker = TimeUsageTracker()
time_tracker.start("TabPFNRegressor fit")

# Load the model and config
self.model_, self.config_, self.bardist_ = initialize_tabpfn_model(
model_path=self.model_path,
Expand Down Expand Up @@ -453,6 +459,7 @@ def fit(self, X: XType, y: YType) -> Self:
max_num_samples=self.interface_config_.MAX_NUMBER_OF_SAMPLES,
max_num_features=self.interface_config_.MAX_NUMBER_OF_FEATURES,
ignore_pretraining_limits=self.ignore_pretraining_limits,
device=self.device_,
)
assert isinstance(X, np.ndarray)

Expand Down Expand Up @@ -540,6 +547,8 @@ def fit(self, X: XType, y: YType) -> Self:
use_autocast_=self.use_autocast_,
)

time_tracker.stop("TabPFNRegressor fit")

return self

@overload
Expand Down Expand Up @@ -624,6 +633,11 @@ def predict( # noqa: C901, PLR0912
"""
check_is_fitted(self)

# Track time usage
TimeUsageTracker = _create_time_usage_tracker()
time_tracker = TimeUsageTracker()
time_tracker.start("TabPFNRegressor predict")

X = validate_X_predict(X, self)
X = _fix_dtypes(X, cat_indices=self.categorical_features_indices)
X = _process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore
Expand Down Expand Up @@ -707,6 +721,8 @@ def predict( # noqa: C901, PLR0912
logits = logits.float()
logits = logits.cpu()

time_tracker.stop("TabPFNRegressor predict")

# Determine and return intended output type
logit_to_output = partial(
_logits_to_output,
Expand Down
17 changes: 17 additions & 0 deletions src/tabpfn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
NA_PLACEHOLDER,
REGRESSION_NAN_BORDER_LIMIT_LOWER,
REGRESSION_NAN_BORDER_LIMIT_UPPER,
X_SIZE_WARNING_CPU,
)
from tabpfn.misc._sklearn_compat import check_array, validate_data
from tabpfn.model.bar_distribution import FullSupportBarDistribution
from tabpfn.model.loading import download_model, load_model
from tabpfn.model.time_tracker import TimeUsageTracker

if TYPE_CHECKING:
from sklearn.base import TransformerMixin
Expand Down Expand Up @@ -528,6 +530,7 @@ def validate_Xy_fit(
max_num_samples: int,
ensure_y_numeric: bool = False,
ignore_pretraining_limits: bool = False,
device: torch.device | None = None,
) -> tuple[np.ndarray, np.ndarray, npt.NDArray[Any] | None, int]:
"""Validate the input data for fitting."""
# Calls `validate_data()` with specification
Expand Down Expand Up @@ -577,6 +580,15 @@ def validate_Xy_fit(
stacklevel=2,
)

if device == torch.device("cpu") and X.shape[0] > X_SIZE_WARNING_CPU:
warnings.warn(
"Training on CPU with >1000 samples maybe slow."
" Use GPU for faster processing, or if unavailable,"
" try tabpfn-client API https://github.com/PriorLabs/tabpfn-client ",
UserWarning,
stacklevel=2,
)

if is_classifier(estimator):
check_classification_targets(y)
# Annoyingly, the `ensure_all_finite` above only applies to `X` and
Expand Down Expand Up @@ -917,3 +929,8 @@ class _MEMORYSTATUSEX(ctypes.Structure):
except (AttributeError, OSError):
# Fall back if not on Windows or if the function fails
return 0.0


def _create_time_usage_tracker() -> type[TimeUsageTracker]:
"""Creates and returns the TimeUsageTracker class."""
return TimeUsageTracker
Loading