diff --git a/src/tabpfn/classifier.py b/src/tabpfn/classifier.py index e8f26898..790cea2e 100644 --- a/src/tabpfn/classifier.py +++ b/src/tabpfn/classifier.py @@ -49,6 +49,7 @@ default_classifier_preprocessor_configs, ) from tabpfn.utils import ( + _create_time_usage_tracker, _fix_dtypes, _get_embeddings, _get_ordinal_encoder, @@ -386,6 +387,11 @@ def fit(self, X: XType, y: YType) -> Self: """ static_seed, rng = infer_random_state(self.random_state) + # Track time usage + TimeUsageTracker = _create_time_usage_tracker() + time_tracker = TimeUsageTracker() + time_tracker.start("TabPFNClassifier fit") + # Load the model and config self.model_, self.config_, _ = initialize_tabpfn_model( model_path=self.model_path, @@ -425,6 +431,7 @@ def fit(self, X: XType, y: YType) -> Self: max_num_samples=self.interface_config_.MAX_NUMBER_OF_SAMPLES, max_num_features=self.interface_config_.MAX_NUMBER_OF_FEATURES, ignore_pretraining_limits=self.ignore_pretraining_limits, + device=self.device_, ) if feature_names_in is not None: self.feature_names_in_ = feature_names_in @@ -513,6 +520,8 @@ def fit(self, X: XType, y: YType) -> Self: use_autocast_=self.use_autocast_, ) + time_tracker.stop("TabPFNClassifier fit") + return self def predict(self, X: XType) -> np.ndarray: @@ -540,6 +549,11 @@ def predict_proba(self, X: XType) -> np.ndarray: """ check_is_fitted(self) + # Track time usage + TimeUsageTracker = _create_time_usage_tracker() + time_tracker = TimeUsageTracker() + time_tracker.start("TabPFNClassifier predict_proba") + X = validate_X_predict(X, self) X = _fix_dtypes(X, cat_indices=self.categorical_features_indices) @@ -586,6 +600,7 @@ def predict_proba(self, X: XType) -> np.ndarray: output = np.around(output, decimals=SKLEARN_16_DECIMAL_PRECISION) output = np.where(output < PROBABILITY_EPSILON_ROUND_ZERO, 0.0, output) + time_tracker.stop("TabPFNClassifier predict_proba") # Normalize to guarantee proba sum to 1, required due to precision issues and # going from torch to numpy return output / output.sum(axis=1, keepdims=True) # type: ignore diff --git a/src/tabpfn/constants.py b/src/tabpfn/constants.py index 6d59fc7c..80b4fae5 100644 --- a/src/tabpfn/constants.py +++ b/src/tabpfn/constants.py @@ -25,6 +25,9 @@ NA_PLACEHOLDER = "__MISSING__" +X_SIZE_WARNING_CPU = 1000 # samples +TIME_WARNING_CPU = 10 * 60 # seconds + SKLEARN_16_DECIMAL_PRECISION = 16 PROBABILITY_EPSILON_ROUND_ZERO = 1e-3 REGRESSION_NAN_BORDER_LIMIT_UPPER = 1e3 diff --git a/src/tabpfn/model/time_tracker.py b/src/tabpfn/model/time_tracker.py new file mode 100644 index 00000000..90b00213 --- /dev/null +++ b/src/tabpfn/model/time_tracker.py @@ -0,0 +1,112 @@ +# Copyright (c) Prior Labs GmbH 2025. + +from __future__ import annotations + +import threading +import time +import warnings + +from tabpfn.constants import TIME_WARNING_CPU + + +class TimeUsageTracker: + """Tracks time usage for code segments. + + Supports starting, stopping, and resetting segments. It will + warn if any segment runs longer than the threshold defined by + `TIME_WARNING_CPU` (default: 10 minutes). Note that resetting an + active segment issues a warning. + + Attributes: + active_segments (dict[str, float]): + Maps segment labels to their start times. + + completed_segments (list[tuple[str, float]]): + Records completed segments with their elapsed times. + + _timers (dict[str, threading.Timer]): + Timers that trigger warnings when a segment exceeds the threshold. + + Raises: + RuntimeError: If attempting to start a segment that's + already active or stop one that isn't active. + """ + + def __init__(self) -> None: + self.active_segments: dict[str, float] = {} + self.completed_segments: list[tuple[str, float]] = [] + self._timers: dict[str, threading.Timer] = {} + + def _warn_if_over_limit(self, label: str, start_time: float) -> None: + # Verify the segment is still active and its start time hasn't changed. + if label in self.active_segments and self.active_segments[label] == start_time: + warnings.warn( + f"{label} is taking > 10 minutes to run. " + "Use GPU for faster processing, or if unavailable, " + "try tabpfn-client API https://github.com/PriorLabs/tabpfn-client", + UserWarning, + stacklevel=2, + ) + + def _start_timer(self, label: str, start_time: float) -> None: + # Schedule a timer to check the segment after the limit + timer = threading.Timer( + TIME_WARNING_CPU, self._warn_if_over_limit, args=(label, start_time) + ) + timer.daemon = True # Ensure timer doesn't block program exit + timer.start() + self._timers[label] = timer + + def start(self, label: str) -> None: + """Start a new segment with the given label. + + Raises: + RuntimeError: If a segment with this label is already active. + """ + if label in self.active_segments: + raise RuntimeError(f"Segment '{label}' is already active.") + start_time = time.time() + self.active_segments[label] = start_time + self._start_timer(label, start_time) + + def stop(self, label: str) -> None: + """Stop the segment with the given label. + + Raises: + RuntimeError: If no active segment exists with the given label. + """ + if label not in self.active_segments: + raise RuntimeError(f"No active segment found with label '{label}'.") + + start_time = self.active_segments.pop(label) + # Cancel the timer for this segment, if it exists + if label in self._timers: + timer = self._timers.pop(label) + timer.cancel() + self.completed_segments.append((label, time.time() - start_time)) + + def reset(self, label: str) -> None: + """Reset an active segment with the given label. + + Issues a warning if resetting an active segment. + """ + if label in self.active_segments: + warnings.warn(f"Resetting {label} time segment.", stacklevel=2) + # Cancel the current timer, if it exists + if label in self._timers: + self._timers[label].cancel() + + # Restart the timer with the new start time + start_time = time.time() + self.active_segments[label] = start_time + self._start_timer(label, start_time) + + def total_time(self, labels: list[str] | None = None) -> float: + """Return the total elapsed time for segments with the given labels. + If no labels are provided, returns the total for all completed segments. + """ + if labels is None: + return sum(duration for _, duration in self.completed_segments) + return sum( + duration for seg, duration in self.completed_segments if seg in labels + ) diff --git a/src/tabpfn/regressor.py b/src/tabpfn/regressor.py index 71410b0d..7af7026f 100644 --- a/src/tabpfn/regressor.py +++ b/src/tabpfn/regressor.py @@ -51,6 +51,7 @@ default_regressor_preprocessor_configs, ) from tabpfn.utils import ( + _create_time_usage_tracker, _fix_dtypes, _get_embeddings, _get_ordinal_encoder, @@ -414,6 +415,11 @@ def fit(self, X: XType, y: YType) -> Self: """ static_seed, rng = infer_random_state(self.random_state) + # Track time usage + TimeUsageTracker = _create_time_usage_tracker() + time_tracker = TimeUsageTracker() + time_tracker.start("TabPFNRegressor fit") + # Load the model and config self.model_, self.config_, self.bardist_ = initialize_tabpfn_model( model_path=self.model_path, @@ -453,6 +459,7 @@ def fit(self, X: XType, y: YType) -> Self: max_num_samples=self.interface_config_.MAX_NUMBER_OF_SAMPLES, max_num_features=self.interface_config_.MAX_NUMBER_OF_FEATURES, ignore_pretraining_limits=self.ignore_pretraining_limits, + device=self.device_, ) assert isinstance(X, np.ndarray) @@ -540,6 +547,8 @@ def fit(self, X: XType, y: YType) -> Self: use_autocast_=self.use_autocast_, ) + time_tracker.stop("TabPFNRegressor fit") + return self @overload @@ -624,6 +633,11 @@ def predict( # noqa: C901, PLR0912 """ check_is_fitted(self) + # Track time usage + TimeUsageTracker = _create_time_usage_tracker() + time_tracker = TimeUsageTracker() + time_tracker.start("TabPFNRegressor predict") + X = validate_X_predict(X, self) X = _fix_dtypes(X, cat_indices=self.categorical_features_indices) X = _process_text_na_dataframe(X, ord_encoder=self.preprocessor_) # type: ignore @@ -707,6 +721,8 @@ def predict( # noqa: C901, PLR0912 logits = logits.float() logits = logits.cpu() + time_tracker.stop("TabPFNRegressor predict") + # Determine and return intended output type logit_to_output = partial( _logits_to_output, diff --git a/src/tabpfn/utils.py b/src/tabpfn/utils.py index 47971251..7ce3bf77 100644 --- a/src/tabpfn/utils.py +++ b/src/tabpfn/utils.py @@ -28,10 +28,12 @@ NA_PLACEHOLDER, REGRESSION_NAN_BORDER_LIMIT_LOWER, REGRESSION_NAN_BORDER_LIMIT_UPPER, + X_SIZE_WARNING_CPU, ) from tabpfn.misc._sklearn_compat import check_array, validate_data from tabpfn.model.bar_distribution import FullSupportBarDistribution from tabpfn.model.loading import download_model, load_model +from tabpfn.model.time_tracker import TimeUsageTracker if TYPE_CHECKING: from sklearn.base import TransformerMixin @@ -528,6 +530,7 @@ def validate_Xy_fit( max_num_samples: int, ensure_y_numeric: bool = False, ignore_pretraining_limits: bool = False, + device: torch.device | None = None, ) -> tuple[np.ndarray, np.ndarray, npt.NDArray[Any] | None, int]: """Validate the input data for fitting.""" # Calls `validate_data()` with specification @@ -577,6 +580,15 @@ def validate_Xy_fit( stacklevel=2, ) + if device == torch.device("cpu") and X.shape[0] > X_SIZE_WARNING_CPU: + warnings.warn( + "Training on CPU with >1000 samples maybe slow." + " Use GPU for faster processing, or if unavailable," + " try tabpfn-client API https://github.com/PriorLabs/tabpfn-client ", + UserWarning, + stacklevel=2, + ) + if is_classifier(estimator): check_classification_targets(y) # Annoyingly, the `ensure_all_finite` above only applies to `X` and @@ -917,3 +929,8 @@ class _MEMORYSTATUSEX(ctypes.Structure): except (AttributeError, OSError): # Fall back if not on Windows or if the function fails return 0.0 + + +def _create_time_usage_tracker() -> type[TimeUsageTracker]: + """Creates and returns the TimeUsageTracker class.""" + return TimeUsageTracker