diff --git a/EXPERIMENTS.md b/EXPERIMENTS.md new file mode 100644 index 00000000..fddf4367 --- /dev/null +++ b/EXPERIMENTS.md @@ -0,0 +1,82 @@ +# Experiments Overview + +This repository currently contains four experiment categories. + +## 1. Prediction Demos + +Location: `examples/` + +Purpose: +- Verify that pretrained Kronos models can run end-to-end inference. +- Compare model sizes under the same inputs. + +Typical inputs: +- Local CSV K-line data +- Optional online A-share data from `akshare` + +Typical outputs: +- Forecast CSV +- Plot PNG +- Summary JSON + +Ground-truth availability: +- `prediction_example.py`: no future truth, visualization only +- `prediction_wo_vol_example.py`: no future truth, visualization only +- `prediction_batch_example.py`: no future truth, visualization only +- `prediction_cn_markets_day.py --mode eval`: includes future truth and error metrics + +## 2. Regression Tests + +Location: `tests/test_kronos_regression.py` + +Purpose: +- Ensure code changes do not silently drift model outputs. +- Check exact regression fixtures and MSE thresholds against fixed model revisions. + +Typical outputs: +- `pytest` pass/fail result +- Printed absolute/relative differences and MSE summary in the terminal + +## 3. Fine-Tuning Pipelines + +Locations: +- `finetune/` +- `finetune_csv/` + +Purpose: +- Adapt Kronos to domain-specific datasets. +- Support either Qlib-based A-share workflows or direct CSV-based training. + +Notes: +- The `finetune/` training scripts are currently written for CUDA + NCCL + DDP. +- The training scripts now treat `comet_ml` as optional. + +## 4. Qlib Backtesting + +Location: `finetune/qlib_test.py` + +Purpose: +- Convert model forecasts into cross-sectional stock-selection signals. +- Evaluate those signals with Qlib's `TopkDropoutStrategy`. + +Typical inputs: +- Processed `train/val/test` pickle datasets from `qlib_data_preprocess.py` +- A tokenizer/model pair, which can be local checkpoints or explicit Hugging Face IDs + +Typical outputs: +- `predictions.pkl` +- Per-signal curve CSV +- Per-signal metrics JSON +- Backtest summary CSV/JSON +- Backtest plot PNG + +Ground-truth availability: +- Yes, via historical replay using the test split + +## Recommended Validation Sequence + +1. Run `examples/` to verify inference and output formats. +2. Run `tests/test_kronos_regression.py` to detect output drift. +3. Start `webui` to verify the interactive inference path. +4. Run `finetune/qlib_test.py` with an explicit config file for smoke backtesting. +5. Compare `Kronos-small` vs `Kronos-base` on the same prediction and backtest setups. diff --git a/README.md b/README.md index faffa157..556f4adc 100644 --- a/README.md +++ b/README.md @@ -211,10 +211,14 @@ Running this script will generate a plot comparing the ground truth data against Forecast Example

-Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py). - - -## ๐Ÿ”ง Finetuning on Your Own Data (A-Share Market Example) +Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py). + +For a concise guide to all example scripts, recommended run order, Apple Silicon device selection, and optional `akshare` setup, see [`examples/README.md`](examples/README.md). + +For an overview of the repository's prediction demos, regression tests, finetuning pipelines, and Qlib backtesting workflows, see [`EXPERIMENTS.md`](EXPERIMENTS.md). + + +## ๐Ÿ”ง Finetuning on Your Own Data (A-Share Market Example) We provide a complete pipeline for finetuning Kronos on your own datasets. As an example, we demonstrate how to use [Qlib](https://github.com/microsoft/qlib) to prepare data from the Chinese A-share market and conduct a simple backtest. diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 00000000..9eb04e8d --- /dev/null +++ b/examples/README.md @@ -0,0 +1,62 @@ +# Example Scripts + +All scripts in this directory can be run from the repository root and save their outputs to `examples/outputs/`. + +## Recommended Order + +1. `python examples/prediction_example.py --model small` +2. `python examples/prediction_wo_vol_example.py --model small` +3. `python examples/prediction_batch_example.py --model small` +4. `python examples/prediction_cn_markets_day.py --model small --mode forward --symbol 000001` +5. `python examples/prediction_cn_markets_day.py --model small --mode eval --symbol 000001` + +## Model Selection + +The examples support: + +- `--model small` +- `--model base` + +This keeps the input data and evaluation mode fixed while allowing direct `Kronos-small` vs `Kronos-base` comparisons. + +## Script Semantics + +- `prediction_example.py` + Runs one OHLCV forecast on a single historical window and saves CSV, PNG, and summary JSON. +- `prediction_wo_vol_example.py` + Runs one forecast without `volume`/`amount` inputs and saves CSV, PNG, and summary JSON. +- `prediction_batch_example.py` + Demonstrates batch inference on multiple sequential time windows from the same instrument. + It is not a multi-asset cross-sectional example. + By default it saves CSV for every window and one representative PNG. +- `prediction_cn_markets_day.py` + Supports two explicit modes: + - `forward`: future prediction only, so there is no ground-truth window yet. + - `eval`: historical replay mode with known future truth and saved error metrics. + +## Outputs + +Each run saves structured outputs in `examples/outputs/`: + +- Forecast CSV files +- Plot PNG files +- Summary JSON files describing the model, device, time ranges, and output paths + +## Apple Silicon Devices + +The scripts auto-select a compute device in this order: + +1. `cuda:0` +2. `mps` +3. `cpu` + +On Apple Silicon Macs, `mps` is used when PyTorch reports that Metal is available. + +## Dependencies + +- Local CSV examples only need the root `requirements.txt`. +- `prediction_cn_markets_day.py` also requires `akshare` and network access: + +```bash +uv pip install --python .venv/bin/python akshare +``` diff --git a/examples/common.py b/examples/common.py new file mode 100644 index 00000000..988d0097 --- /dev/null +++ b/examples/common.py @@ -0,0 +1,82 @@ +import json +from pathlib import Path +import sys +from typing import Any + +import matplotlib +import pandas as pd +import torch + +matplotlib.use("Agg") + +EXAMPLES_ROOT = Path(__file__).resolve().parent +PROJECT_ROOT = EXAMPLES_ROOT.parent +DATA_ROOT = EXAMPLES_ROOT / "data" +OUTPUT_ROOT = EXAMPLES_ROOT / "outputs" +OUTPUT_ROOT.mkdir(exist_ok=True) + +MODEL_SPECS = { + "small": { + "name": "Kronos-small", + "model_id": "NeoQuasar/Kronos-small", + "tokenizer_id": "NeoQuasar/Kronos-Tokenizer-base", + "max_context": 512, + }, + "base": { + "name": "Kronos-base", + "model_id": "NeoQuasar/Kronos-base", + "tokenizer_id": "NeoQuasar/Kronos-Tokenizer-base", + "max_context": 512, + }, +} + +if str(PROJECT_ROOT) not in sys.path: + sys.path.insert(0, str(PROJECT_ROOT)) + + +def data_path(filename: str) -> Path: + return DATA_ROOT / filename + + +def output_path(filename: str) -> Path: + return OUTPUT_ROOT / filename + + +def detect_device() -> str: + if torch.cuda.is_available(): + return "cuda:0" + if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + return "mps" + return "cpu" + + +def get_model_spec(model_key: str) -> dict[str, Any]: + if model_key not in MODEL_SPECS: + raise ValueError(f"Unsupported model '{model_key}'. Expected one of: {sorted(MODEL_SPECS)}") + return MODEL_SPECS[model_key] + + +def save_json(path: Path, payload: dict[str, Any]) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(to_jsonable(payload), handle, indent=2, ensure_ascii=False) + + +def to_jsonable(value: Any) -> Any: + if isinstance(value, Path): + try: + return str(value.relative_to(PROJECT_ROOT)) + except ValueError: + return str(value) + if isinstance(value, pd.Timestamp): + return value.isoformat() + if isinstance(value, pd.Series): + return [to_jsonable(item) for item in value.tolist()] + if isinstance(value, pd.DataFrame): + return [{key: to_jsonable(item) for key, item in row.items()} for row in value.to_dict(orient="records")] + if isinstance(value, dict): + return {str(key): to_jsonable(item) for key, item in value.items()} + if isinstance(value, (list, tuple)): + return [to_jsonable(item) for item in value] + if hasattr(value, "item"): + return value.item() + return value diff --git a/examples/prediction_batch_example.py b/examples/prediction_batch_example.py index 29a74334..d77010ec 100644 --- a/examples/prediction_batch_example.py +++ b/examples/prediction_batch_example.py @@ -1,72 +1,176 @@ -import pandas as pd -import matplotlib.pyplot as plt -import sys -sys.path.append("../") -from model import Kronos, KronosTokenizer, KronosPredictor - - -def plot_prediction(kline_df, pred_df): - pred_df.index = kline_df.index[-pred_df.shape[0]:] - sr_close = kline_df['close'] - sr_pred_close = pred_df['close'] - sr_close.name = 'Ground Truth' - sr_pred_close.name = "Prediction" +import argparse +import time +from pathlib import Path - sr_volume = kline_df['volume'] - sr_pred_volume = pred_df['volume'] - sr_volume.name = 'Ground Truth' - sr_pred_volume.name = "Prediction" +import matplotlib.pyplot as plt +import pandas as pd - close_df = pd.concat([sr_close, sr_pred_close], axis=1) - volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) +from common import data_path, detect_device, get_model_spec, output_path, save_json +from model import Kronos, KronosPredictor, KronosTokenizer + +LOOKBACK = 400 +PRED_LEN = 120 +DEFAULT_BATCH_SIZE = 5 + + +def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame, plot_file: Path) -> None: + plot_df = pred_df.copy() + plot_df.index = kline_df.index[-plot_df.shape[0]:] + + close_df = pd.concat( + [ + kline_df["close"].rename("Ground Truth"), + plot_df["close"].rename("Prediction"), + ], + axis=1, + ) + volume_df = pd.concat( + [ + kline_df["volume"].rename("Ground Truth"), + plot_df["volume"].rename("Prediction"), + ], + axis=1, + ) fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) - - ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) - ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) - ax1.set_ylabel('Close Price', fontsize=14) - ax1.legend(loc='lower left', fontsize=12) + ax1.plot(close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5) + ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5) + ax1.set_ylabel("Close Price", fontsize=14) + ax1.legend(loc="lower left", fontsize=12) ax1.grid(True) - ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) - ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) - ax2.set_ylabel('Volume', fontsize=14) - ax2.legend(loc='upper left', fontsize=12) + ax2.plot(volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5) + ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5) + ax2.set_ylabel("Volume", fontsize=14) + ax2.legend(loc="upper left", fontsize=12) ax2.grid(True) plt.tight_layout() - plt.show() - - -# 1. Load Model and Tokenizer -tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/') -model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/") - -# 2. Instantiate Predictor -predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) - -# 3. Prepare Data -df = pd.read_csv("./data/XSHG_5min_600977.csv") -df['timestamps'] = pd.to_datetime(df['timestamps']) - -lookback = 400 -pred_len = 120 - -dfs = [] -xtsp = [] -ytsp = [] -for i in range(5): - idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']] - i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps'] - i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps'] - - dfs.append(idf) - xtsp.append(i_x_timestamp) - ytsp.append(i_y_timestamp) - -pred_df = predictor.predict_batch( - df_list=dfs, - x_timestamp_list=xtsp, - y_timestamp_list=ytsp, - pred_len=pred_len, -) + plt.savefig(plot_file, dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Run batch prediction on multiple time windows from the same instrument." + ) + parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.") + parser.add_argument( + "--batch-size", + type=int, + default=DEFAULT_BATCH_SIZE, + help="Number of sequential time windows to batch together.", + ) + parser.add_argument( + "--plot-series", + type=int, + default=0, + help="Which window index to render as the representative plot.", + ) + args = parser.parse_args() + + spec = get_model_spec(args.model) + device = detect_device() + data_file = data_path("XSHG_5min_600977.csv") + summary_file = output_path(f"prediction_batch_example_{args.model}_summary.json") + + print( + "Running batch prediction on multiple sequential windows from the same source file, " + "not on multiple instruments." + ) + print(f"Loading {spec['name']} on device: {device}") + start_time = time.perf_counter() + tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"]) + model = Kronos.from_pretrained(spec["model_id"]) + predictor = KronosPredictor(model, tokenizer, device=device, max_context=spec["max_context"]) + + df = pd.read_csv(data_file) + df["timestamps"] = pd.to_datetime(df["timestamps"]) + + df_list = [] + x_timestamp_list = [] + y_timestamp_list = [] + window_summaries = [] + + for idx in range(args.batch_size): + start = idx * LOOKBACK + end = start + LOOKBACK - 1 + pred_start = end + 1 + pred_end = pred_start + PRED_LEN - 1 + + df_list.append(df.loc[start:end, ["open", "high", "low", "close", "volume", "amount"]]) + x_timestamp_list.append(df.loc[start:end, "timestamps"]) + y_timestamp_list.append(df.loc[pred_start:pred_end, "timestamps"]) + window_summaries.append( + { + "series_index": idx, + "instrument_note": "Same instrument, different historical time window.", + "history_row_start": start, + "history_row_end": end, + "forecast_row_start": pred_start, + "forecast_row_end": pred_end, + "history_start": df.loc[start, "timestamps"], + "history_end": df.loc[end, "timestamps"], + "forecast_start": df.loc[pred_start, "timestamps"], + "forecast_end": df.loc[pred_end, "timestamps"], + } + ) + + pred_df_list = predictor.predict_batch( + df_list=df_list, + x_timestamp_list=x_timestamp_list, + y_timestamp_list=y_timestamp_list, + pred_len=PRED_LEN, + verbose=False, + ) + elapsed_seconds = time.perf_counter() - start_time + + print(f"Generated batch predictions: {len(pred_df_list)} windows") + representative_plot = None + + for idx, pred_df in enumerate(pred_df_list): + csv_file = output_path(f"prediction_batch_example_{args.model}_series_{idx}.csv") + pred_df.to_csv(csv_file, index_label="timestamps") + window_summaries[idx]["forecast_csv"] = csv_file + + print(f"Series {idx} forecast head:") + print(pred_df.head()) + print(f"Saved series {idx} CSV: {csv_file}") + + if idx == args.plot_series: + plot_file = output_path(f"prediction_batch_example_{args.model}_representative_series_{idx}.png") + kline_df = df.loc[idx * LOOKBACK : idx * LOOKBACK + LOOKBACK + PRED_LEN - 1] + plot_prediction(kline_df, pred_df, plot_file) + representative_plot = plot_file + window_summaries[idx]["representative_plot"] = plot_file + print(f"Saved representative plot for series {idx}: {plot_file}") + + save_json( + summary_file, + { + "experiment": "prediction_batch_example", + "description": ( + "Batch prediction on multiple sequential time windows from a single instrument. " + "This example demonstrates parallel inference, not cross-sectional multi-asset prediction." + ), + "model_key": args.model, + "model_name": spec["name"], + "model_id": spec["model_id"], + "tokenizer_id": spec["tokenizer_id"], + "device": device, + "elapsed_seconds": elapsed_seconds, + "lookback": LOOKBACK, + "pred_len": PRED_LEN, + "batch_size": args.batch_size, + "plot_series": args.plot_series, + "data_file": data_file, + "representative_plot": representative_plot, + "windows": window_summaries, + }, + ) + + print(f"Saved summary JSON: {summary_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/prediction_cn_markets_day.py b/examples/prediction_cn_markets_day.py index 93b28f9f..1dd50a80 100644 --- a/examples/prediction_cn_markets_day.py +++ b/examples/prediction_cn_markets_day.py @@ -2,170 +2,207 @@ """ prediction_cn_markets_day.py -Description: - Predicts future daily K-line (1D) data for A-share markets using Kronos model and akshare. - The script automatically downloads the latest historical data, cleans it, and runs model inference. - -Usage: - python prediction_cn_markets_day.py --symbol 000001 - -Arguments: - --symbol Stock code (e.g. 002594 for BYD, 000001 for SSE Index) - -Output: - - Saves the prediction results to ./outputs/pred__data.csv and ./outputs/pred__chart.png - - Logs and progress are printed to console - -Example: - bash> python prediction_cn_markets_day.py --symbol 000001 - python3 prediction_cn_markets_day.py --symbol 002594 +Two modes are supported: +1. forward: predict the next unknown window from the latest available history. +2. eval: replay a historical cutoff and compare the forecast against known future data. """ -import os import argparse import time -import pandas as pd -import akshare as ak +from pathlib import Path + import matplotlib.pyplot as plt -import sys -sys.path.append("../") -from model import Kronos, KronosTokenizer, KronosPredictor +import pandas as pd + +from common import OUTPUT_ROOT, detect_device, get_model_spec, save_json +from model import Kronos, KronosPredictor, KronosTokenizer -save_dir = "./outputs" -os.makedirs(save_dir, exist_ok=True) +try: + import akshare as ak +except ImportError as exc: + raise ImportError( + "akshare is required for prediction_cn_markets_day.py. Install it with " + "`uv pip install --python .venv/bin/python akshare`." + ) from exc -# Setting -TOKENIZER_PRETRAINED = "NeoQuasar/Kronos-Tokenizer-base" -MODEL_PRETRAINED = "NeoQuasar/Kronos-base" -DEVICE = "cpu" # "cuda:0" MAX_CONTEXT = 512 LOOKBACK = 400 PRED_LEN = 120 T = 1.0 TOP_P = 0.9 SAMPLE_COUNT = 1 +PLOT_COLUMNS = ["close"] +METRIC_COLUMNS = ["open", "high", "low", "close"] + def load_data(symbol: str) -> pd.DataFrame: - print(f"๐Ÿ“ฅ Fetching {symbol} daily data from akshare ...") + print(f"Fetching {symbol} daily data from akshare...") max_retries = 3 df = None - - # Retry mechanism for attempt in range(1, max_retries + 1): try: df = ak.stock_zh_a_hist(symbol=symbol, period="daily", adjust="") if df is not None and not df.empty: break - except Exception as e: - print(f"โš ๏ธ Attempt {attempt}/{max_retries} failed: {e}") + except Exception as exc: + print(f"Attempt {attempt}/{max_retries} failed: {exc}") time.sleep(1.5) - # If still empty after retries if df is None or df.empty: - print(f"โŒ Failed to fetch data for {symbol} after {max_retries} attempts. Exiting.") - sys.exit(1) - - df.rename(columns={ - "ๆ—ฅๆœŸ": "date", - "ๅผ€็›˜": "open", - "ๆ”ถ็›˜": "close", - "ๆœ€้ซ˜": "high", - "ๆœ€ไฝŽ": "low", - "ๆˆไบค้‡": "volume", - "ๆˆไบค้ข": "amount" - }, inplace=True) + raise RuntimeError(f"Failed to fetch data for {symbol} after {max_retries} attempts.") + + df.rename( + columns={ + "ๆ—ฅๆœŸ": "date", + "ๅผ€็›˜": "open", + "ๆ”ถ็›˜": "close", + "ๆœ€้ซ˜": "high", + "ๆœ€ไฝŽ": "low", + "ๆˆไบค้‡": "volume", + "ๆˆไบค้ข": "amount", + }, + inplace=True, + ) df["date"] = pd.to_datetime(df["date"]) df = df.sort_values("date").reset_index(drop=True) - # Convert numeric columns numeric_cols = ["open", "high", "low", "close", "volume", "amount"] for col in numeric_cols: df[col] = ( - df[col] - .astype(str) - .str.replace(",", "", regex=False) - .replace({"--": None, "": None}) + df[col].astype(str).str.replace(",", "", regex=False).replace({"--": None, "": None}) ) df[col] = pd.to_numeric(df[col], errors="coerce") - # Fix invalid open values open_bad = (df["open"] == 0) | (df["open"].isna()) if open_bad.any(): - print(f"โš ๏ธ Fixed {open_bad.sum()} invalid open values.") + print(f"Fixed {open_bad.sum()} invalid open values.") df.loc[open_bad, "open"] = df["close"].shift(1) - df["open"].fillna(df["close"], inplace=True) + df["open"] = df["open"].fillna(df["close"]) - # Fix missing amount if df["amount"].isna().all() or (df["amount"] == 0).all(): df["amount"] = df["close"] * df["volume"] - print(f"โœ… Data loaded: {len(df)} rows, range: {df['date'].min()} ~ {df['date'].max()}") - - print("Data Head:") + print(f"Loaded {len(df)} rows from {df['date'].min()} to {df['date'].max()}") print(df.head()) - return df -def prepare_inputs(df): - x_df = df.iloc[-LOOKBACK:][["open","high","low","close","volume","amount"]] - x_timestamp = df.iloc[-LOOKBACK:]["date"] - y_timestamp = pd.bdate_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN) - return x_df, pd.Series(x_timestamp), pd.Series(y_timestamp) +def resolve_eval_end_index(df: pd.DataFrame, as_of_date: str | None) -> int: + min_end_idx = LOOKBACK - 1 + max_end_idx = len(df) - PRED_LEN - 1 + if max_end_idx < min_end_idx: + raise ValueError("Not enough rows to run historical evaluation with the requested lookback and pred_len.") + + if as_of_date is None: + return max_end_idx -def apply_price_limits(pred_df, last_close, limit_rate=0.1): - print(f"๐Ÿ”’ Applying ยฑ{limit_rate*100:.0f}% price limit ...") + target = pd.Timestamp(as_of_date) + eligible = df.index[df["date"] <= target] + if len(eligible) == 0: + raise ValueError(f"No data is available on or before {target.date()}.") + + end_idx = int(eligible[-1]) + if end_idx < min_end_idx: + raise ValueError(f"Evaluation cutoff {df.loc[end_idx, 'date'].date()} is before the required lookback.") + if end_idx > max_end_idx: + raise ValueError( + f"Evaluation cutoff {df.loc[end_idx, 'date'].date()} leaves no future truth window of length {PRED_LEN}." + ) + return end_idx - # Ensure integer index - pred_df = pred_df.reset_index(drop=True) - # Ensure float64 dtype for safe assignment +def prepare_forward_inputs(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series, pd.Series]: + x_df = df.iloc[-LOOKBACK:][["open", "high", "low", "close", "volume", "amount"]] + x_timestamp = pd.Series(df.iloc[-LOOKBACK:]["date"]) + y_timestamp = pd.Series( + pd.bdate_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN) + ) + return x_df, x_timestamp, y_timestamp + + +def prepare_eval_inputs(df: pd.DataFrame, as_of_date: str | None) -> tuple[pd.DataFrame, pd.Series, pd.Series, pd.DataFrame, pd.Timestamp]: + end_idx = resolve_eval_end_index(df, as_of_date) + x_df = df.iloc[end_idx - LOOKBACK + 1 : end_idx + 1][["open", "high", "low", "close", "volume", "amount"]] + x_timestamp = pd.Series(df.iloc[end_idx - LOOKBACK + 1 : end_idx + 1]["date"]) + actual_df = df.iloc[end_idx + 1 : end_idx + 1 + PRED_LEN][["date", "open", "high", "low", "close", "volume", "amount"]].reset_index(drop=True) + y_timestamp = pd.Series(actual_df["date"]) + return x_df, x_timestamp, y_timestamp, actual_df, df.loc[end_idx, "date"] + + +def apply_price_limits(pred_df: pd.DataFrame, last_close: float, limit_rate: float = 0.1) -> pd.DataFrame: + clipped_df = pred_df.reset_index(drop=True).copy() cols = ["open", "high", "low", "close"] - pred_df[cols] = pred_df[cols].astype("float64") + clipped_df[cols] = clipped_df[cols].astype("float64") - for i in range(len(pred_df)): + for idx in range(len(clipped_df)): limit_up = last_close * (1 + limit_rate) limit_down = last_close * (1 - limit_rate) - for col in cols: - value = pred_df.at[i, col] + value = clipped_df.at[idx, col] if pd.notna(value): - clipped = max(min(value, limit_up), limit_down) - pred_df.at[i, col] = float(clipped) + clipped_df.at[idx, col] = float(max(min(value, limit_up), limit_down)) + last_close = float(clipped_df.at[idx, "close"]) - last_close = float(pred_df.at[i, "close"]) # ensure float type + return clipped_df - return pred_df +def compute_eval_metrics(pred_df: pd.DataFrame, actual_df: pd.DataFrame) -> dict[str, float]: + metrics: dict[str, float] = {} + for col in METRIC_COLUMNS: + diff = pred_df[col] - actual_df[col] + metrics[f"{col}_mae"] = float(diff.abs().mean()) + metrics[f"{col}_mse"] = float((diff ** 2).mean()) + return metrics -def plot_result(df_hist, df_pred, symbol): + +def plot_result( + df_hist: pd.DataFrame, + df_pred: pd.DataFrame, + symbol: str, + plot_path: Path, + actual_df: pd.DataFrame | None = None, +) -> None: plt.figure(figsize=(12, 6)) plt.plot(df_hist["date"], df_hist["close"], label="Historical", color="blue") plt.plot(df_pred["date"], df_pred["close"], label="Predicted", color="red", linestyle="--") + if actual_df is not None: + plt.plot(actual_df["date"], actual_df["close"], label="Actual Future", color="green", linestyle=":") plt.title(f"Kronos Prediction for {symbol}") plt.xlabel("Date") plt.ylabel("Close Price") plt.legend() plt.grid(True) plt.tight_layout() - plot_path = os.path.join(save_dir, f"pred_{symbol.replace('.', '_')}_chart.png") - plt.savefig(plot_path) + plt.savefig(plot_path, dpi=150) plt.close() - print(f"๐Ÿ“Š Chart saved: {plot_path}") -def predict_future(symbol): - print(f"๐Ÿš€ Loading Kronos tokenizer:{TOKENIZER_PRETRAINED} model:{MODEL_PRETRAINED} ...") - tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_PRETRAINED) - model = Kronos.from_pretrained(MODEL_PRETRAINED) - predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=MAX_CONTEXT) +def run_prediction(symbol: str, model_key: str, mode: str, as_of_date: str | None) -> None: + spec = get_model_spec(model_key) + device = detect_device() + prefix = f"pred_{symbol.replace('.', '_')}_{model_key}_{mode}" + csv_path = OUTPUT_ROOT / f"{prefix}_data.csv" + plot_path = OUTPUT_ROOT / f"{prefix}_chart.png" + summary_path = OUTPUT_ROOT / f"{prefix}_summary.json" + + print(f"Loading {spec['name']} on device: {device}") + start_time = time.perf_counter() + tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"]) + model = Kronos.from_pretrained(spec["model_id"]) + predictor = KronosPredictor(model, tokenizer, device=device, max_context=MAX_CONTEXT) df = load_data(symbol) - x_df, x_timestamp, y_timestamp = prepare_inputs(df) + actual_df = None + eval_cutoff = None - print("๐Ÿ”ฎ Generating predictions ...") + if mode == "forward": + x_df, x_timestamp, y_timestamp = prepare_forward_inputs(df) + history_df = df.copy() + else: + x_df, x_timestamp, y_timestamp, actual_df, eval_cutoff = prepare_eval_inputs(df, as_of_date) + history_df = df[df["date"] <= eval_cutoff].copy() pred_df = predictor.predict( df=x_df, @@ -175,34 +212,88 @@ def predict_future(symbol): T=T, top_p=TOP_P, sample_count=SAMPLE_COUNT, + verbose=False, ) + elapsed_seconds = time.perf_counter() - start_time pred_df["date"] = y_timestamp.values + pred_df = apply_price_limits(pred_df, float(x_df["close"].iloc[-1]), limit_rate=0.1) + + if actual_df is None: + output_df = pd.concat( + [ + history_df[["date", "open", "high", "low", "close", "volume", "amount"]], + pred_df[["date", "open", "high", "low", "close", "volume", "amount"]], + ] + ).reset_index(drop=True) + metrics = None + else: + comparison_df = pred_df[["date", "open", "high", "low", "close", "volume", "amount"]].copy() + for col in ["open", "high", "low", "close", "volume", "amount"]: + comparison_df[f"actual_{col}"] = actual_df[col] + output_df = comparison_df + metrics = compute_eval_metrics(pred_df, actual_df) + + output_df.to_csv(csv_path, index=False) + plot_result(history_df, pred_df, symbol, plot_path, actual_df=actual_df) + + save_json( + summary_path, + { + "experiment": "prediction_cn_markets_day", + "description": ( + "Forward mode predicts unseen future dates. Eval mode replays a historical cutoff and compares " + "the forecast against known future truth." + ), + "symbol": symbol, + "mode": mode, + "has_ground_truth": actual_df is not None, + "evaluation_cutoff": eval_cutoff, + "model_key": model_key, + "model_name": spec["name"], + "model_id": spec["model_id"], + "tokenizer_id": spec["tokenizer_id"], + "device": device, + "elapsed_seconds": elapsed_seconds, + "lookback": LOOKBACK, + "pred_len": PRED_LEN, + "history_start": x_timestamp.iloc[0], + "history_end": x_timestamp.iloc[-1], + "prediction_start": y_timestamp.iloc[0], + "prediction_end": y_timestamp.iloc[-1], + "output_csv": csv_path, + "plot_file": plot_path, + "metrics": metrics, + "prediction_head": pred_df.head(), + "actual_head": actual_df.head() if actual_df is not None else None, + }, + ) - # Apply ยฑ10% price limit - last_close = df["close"].iloc[-1] - pred_df = apply_price_limits(pred_df, last_close, limit_rate=0.1) - - # Merge historical and predicted data - df_out = pd.concat([ - df[["date", "open", "high", "low", "close", "volume", "amount"]], - pred_df[["date", "open", "high", "low", "close", "volume", "amount"]] - ]).reset_index(drop=True) - - # Save CSV - out_file = os.path.join(save_dir, f"pred_{symbol.replace('.', '_')}_data.csv") - df_out.to_csv(out_file, index=False) - print(f"โœ… Prediction completed and saved: {out_file}") - - # Plot - plot_result(df, pred_df, symbol) + print(f"Saved forecast CSV: {csv_path}") + print(f"Saved forecast plot: {plot_path}") + print(f"Saved summary JSON: {summary_path}") + if metrics is not None: + print("Evaluation metrics:") + for key, value in metrics.items(): + print(f" {key}: {value:.6f}") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Kronos stock prediction script") parser.add_argument("--symbol", type=str, default="000001", help="Stock code") + parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.") + parser.add_argument( + "--mode", + choices=["forward", "eval"], + default="forward", + help="forward predicts unknown future dates, eval compares against known future truth.", + ) + parser.add_argument( + "--as-of-date", + type=str, + default=None, + help="Historical cutoff date for eval mode, e.g. 2020-06-30. Defaults to the latest eligible date.", + ) args = parser.parse_args() - predict_future( - symbol=args.symbol, - ) + run_prediction(symbol=args.symbol, model_key=args.model, mode=args.mode, as_of_date=args.as_of_date) diff --git a/examples/prediction_example.py b/examples/prediction_example.py index 880f22b8..312ccae0 100644 --- a/examples/prediction_example.py +++ b/examples/prediction_example.py @@ -1,80 +1,127 @@ -import pandas as pd -import matplotlib.pyplot as plt -import sys -sys.path.append("../") -from model import Kronos, KronosTokenizer, KronosPredictor - - -def plot_prediction(kline_df, pred_df): - pred_df.index = kline_df.index[-pred_df.shape[0]:] - sr_close = kline_df['close'] - sr_pred_close = pred_df['close'] - sr_close.name = 'Ground Truth' - sr_pred_close.name = "Prediction" - - sr_volume = kline_df['volume'] - sr_pred_volume = pred_df['volume'] - sr_volume.name = 'Ground Truth' - sr_pred_volume.name = "Prediction" - - close_df = pd.concat([sr_close, sr_pred_close], axis=1) - volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1) - - fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) - - ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) - ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) - ax1.set_ylabel('Close Price', fontsize=14) - ax1.legend(loc='lower left', fontsize=12) - ax1.grid(True) - - ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) - ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5) - ax2.set_ylabel('Volume', fontsize=14) - ax2.legend(loc='upper left', fontsize=12) - ax2.grid(True) - - plt.tight_layout() - plt.show() - - -# 1. Load Model and Tokenizer -tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") -model = Kronos.from_pretrained("NeoQuasar/Kronos-small") - -# 2. Instantiate Predictor -predictor = KronosPredictor(model, tokenizer, max_context=512) - -# 3. Prepare Data -df = pd.read_csv("./data/XSHG_5min_600977.csv") -df['timestamps'] = pd.to_datetime(df['timestamps']) - -lookback = 400 -pred_len = 120 - -x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']] -x_timestamp = df.loc[:lookback-1, 'timestamps'] -y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] - -# 4. Make Prediction -pred_df = predictor.predict( - df=x_df, - x_timestamp=x_timestamp, - y_timestamp=y_timestamp, - pred_len=pred_len, - T=1.0, - top_p=0.9, - sample_count=1, - verbose=True -) - -# 5. Visualize Results -print("Forecasted Data Head:") -print(pred_df.head()) - -# Combine historical and forecasted data for plotting -kline_df = df.loc[:lookback+pred_len-1] - -# visualize -plot_prediction(kline_df, pred_df) - +import argparse +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +from common import data_path, detect_device, get_model_spec, output_path, save_json +from model import Kronos, KronosPredictor, KronosTokenizer + +LOOKBACK = 400 +PRED_LEN = 120 + + +def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame, plot_file: Path) -> None: + plot_df = pred_df.copy() + plot_df.index = kline_df.index[-plot_df.shape[0]:] + + close_df = pd.concat( + [ + kline_df["close"].rename("Ground Truth"), + plot_df["close"].rename("Prediction"), + ], + axis=1, + ) + volume_df = pd.concat( + [ + kline_df["volume"].rename("Ground Truth"), + plot_df["volume"].rename("Prediction"), + ], + axis=1, + ) + + fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True) + + ax1.plot(close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5) + ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5) + ax1.set_ylabel("Close Price", fontsize=14) + ax1.legend(loc="lower left", fontsize=12) + ax1.grid(True) + + ax2.plot(volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5) + ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5) + ax2.set_ylabel("Volume", fontsize=14) + ax2.legend(loc="upper left", fontsize=12) + ax2.grid(True) + + plt.tight_layout() + plt.savefig(plot_file, dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run a single-series Kronos prediction example.") + parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.") + args = parser.parse_args() + + spec = get_model_spec(args.model) + device = detect_device() + data_file = data_path("XSHG_5min_600977.csv") + csv_file = output_path(f"prediction_example_{args.model}_forecast.csv") + plot_file = output_path(f"prediction_example_{args.model}_plot.png") + summary_file = output_path(f"prediction_example_{args.model}_summary.json") + + print(f"Loading {spec['name']} on device: {device}") + start_time = time.perf_counter() + tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"]) + model = Kronos.from_pretrained(spec["model_id"]) + predictor = KronosPredictor(model, tokenizer, device=device, max_context=spec["max_context"]) + + df = pd.read_csv(data_file) + df["timestamps"] = pd.to_datetime(df["timestamps"]) + + x_df = df.loc[: LOOKBACK - 1, ["open", "high", "low", "close", "volume", "amount"]] + x_timestamp = df.loc[: LOOKBACK - 1, "timestamps"] + y_timestamp = df.loc[LOOKBACK : LOOKBACK + PRED_LEN - 1, "timestamps"] + + pred_df = predictor.predict( + df=x_df, + x_timestamp=x_timestamp, + y_timestamp=y_timestamp, + pred_len=PRED_LEN, + T=1.0, + top_p=0.9, + sample_count=1, + verbose=False, + ) + elapsed_seconds = time.perf_counter() - start_time + + print("Forecasted Data Head:") + print(pred_df.head()) + + pred_df.to_csv(csv_file, index_label="timestamps") + kline_df = df.loc[: LOOKBACK + PRED_LEN - 1] + plot_prediction(kline_df, pred_df, plot_file) + + save_json( + summary_file, + { + "experiment": "prediction_example", + "description": "Single-series OHLCV prediction on one historical window.", + "model_key": args.model, + "model_name": spec["name"], + "model_id": spec["model_id"], + "tokenizer_id": spec["tokenizer_id"], + "device": device, + "elapsed_seconds": elapsed_seconds, + "lookback": LOOKBACK, + "pred_len": PRED_LEN, + "input_window_start": x_timestamp.iloc[0], + "input_window_end": x_timestamp.iloc[-1], + "prediction_window_start": y_timestamp.iloc[0], + "prediction_window_end": y_timestamp.iloc[-1], + "data_file": data_file, + "forecast_csv": csv_file, + "plot_file": plot_file, + "prediction_head": pred_df.head(), + }, + ) + + print(f"Saved forecast CSV: {csv_file}") + print(f"Saved forecast plot: {plot_file}") + print(f"Saved summary JSON: {summary_file}") + + +if __name__ == "__main__": + main() diff --git a/examples/prediction_wo_vol_example.py b/examples/prediction_wo_vol_example.py index 3e5624e1..2202261c 100644 --- a/examples/prediction_wo_vol_example.py +++ b/examples/prediction_wo_vol_example.py @@ -1,68 +1,113 @@ -import pandas as pd -import matplotlib.pyplot as plt -import sys -sys.path.append("../") -from model import Kronos, KronosTokenizer, KronosPredictor - - -def plot_prediction(kline_df, pred_df): - pred_df.index = kline_df.index[-pred_df.shape[0]:] - sr_close = kline_df['close'] - sr_pred_close = pred_df['close'] - sr_close.name = 'Ground Truth' - sr_pred_close.name = "Prediction" - - close_df = pd.concat([sr_close, sr_pred_close], axis=1) - - fig, ax = plt.subplots(1, 1, figsize=(8, 4)) - - ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5) - ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5) - ax.set_ylabel('Close Price', fontsize=14) - ax.legend(loc='lower left', fontsize=12) - ax.grid(True) - - plt.tight_layout() - plt.show() - - -# 1. Load Model and Tokenizer -tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base") -model = Kronos.from_pretrained("NeoQuasar/Kronos-small") - -# 2. Instantiate Predictor -predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512) - -# 3. Prepare Data -df = pd.read_csv("./data/XSHG_5min_600977.csv") -df['timestamps'] = pd.to_datetime(df['timestamps']) - -lookback = 400 -pred_len = 120 - -x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']] -x_timestamp = df.loc[:lookback-1, 'timestamps'] -y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps'] - -# 4. Make Prediction -pred_df = predictor.predict( - df=x_df, - x_timestamp=x_timestamp, - y_timestamp=y_timestamp, - pred_len=pred_len, - T=1.0, - top_p=0.9, - sample_count=1, - verbose=True -) - -# 5. Visualize Results -print("Forecasted Data Head:") -print(pred_df.head()) - -# Combine historical and forecasted data for plotting -kline_df = df.loc[:lookback+pred_len-1] - -# visualize -plot_prediction(kline_df, pred_df) - +import argparse +import time +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd + +from common import data_path, detect_device, get_model_spec, output_path, save_json +from model import Kronos, KronosPredictor, KronosTokenizer + +LOOKBACK = 400 +PRED_LEN = 120 + + +def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame, plot_file: Path) -> None: + plot_df = pred_df.copy() + plot_df.index = kline_df.index[-plot_df.shape[0]:] + + close_df = pd.concat( + [ + kline_df["close"].rename("Ground Truth"), + plot_df["close"].rename("Prediction"), + ], + axis=1, + ) + + fig, ax = plt.subplots(1, 1, figsize=(8, 4)) + ax.plot(close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5) + ax.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5) + ax.set_ylabel("Close Price", fontsize=14) + ax.legend(loc="lower left", fontsize=12) + ax.grid(True) + + plt.tight_layout() + plt.savefig(plot_file, dpi=150) + plt.close(fig) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run an OHLC-only Kronos prediction example.") + parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.") + args = parser.parse_args() + + spec = get_model_spec(args.model) + device = detect_device() + data_file = data_path("XSHG_5min_600977.csv") + csv_file = output_path(f"prediction_wo_vol_example_{args.model}_forecast.csv") + plot_file = output_path(f"prediction_wo_vol_example_{args.model}_plot.png") + summary_file = output_path(f"prediction_wo_vol_example_{args.model}_summary.json") + + print(f"Loading {spec['name']} on device: {device}") + start_time = time.perf_counter() + tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"]) + model = Kronos.from_pretrained(spec["model_id"]) + predictor = KronosPredictor(model, tokenizer, device=device, max_context=spec["max_context"]) + + df = pd.read_csv(data_file) + df["timestamps"] = pd.to_datetime(df["timestamps"]) + + x_df = df.loc[: LOOKBACK - 1, ["open", "high", "low", "close"]] + x_timestamp = df.loc[: LOOKBACK - 1, "timestamps"] + y_timestamp = df.loc[LOOKBACK : LOOKBACK + PRED_LEN - 1, "timestamps"] + + pred_df = predictor.predict( + df=x_df, + x_timestamp=x_timestamp, + y_timestamp=y_timestamp, + pred_len=PRED_LEN, + T=1.0, + top_p=0.9, + sample_count=1, + verbose=False, + ) + elapsed_seconds = time.perf_counter() - start_time + + print("Forecasted Data Head:") + print(pred_df.head()) + + pred_df.to_csv(csv_file, index_label="timestamps") + kline_df = df.loc[: LOOKBACK + PRED_LEN - 1] + plot_prediction(kline_df, pred_df, plot_file) + + save_json( + summary_file, + { + "experiment": "prediction_wo_vol_example", + "description": "Single-series OHLC-only prediction. Volume and amount are not provided as model inputs.", + "model_key": args.model, + "model_name": spec["name"], + "model_id": spec["model_id"], + "tokenizer_id": spec["tokenizer_id"], + "device": device, + "elapsed_seconds": elapsed_seconds, + "lookback": LOOKBACK, + "pred_len": PRED_LEN, + "input_window_start": x_timestamp.iloc[0], + "input_window_end": x_timestamp.iloc[-1], + "prediction_window_start": y_timestamp.iloc[0], + "prediction_window_end": y_timestamp.iloc[-1], + "data_file": data_file, + "forecast_csv": csv_file, + "plot_file": plot_file, + "prediction_head": pred_df.head(), + }, + ) + + print(f"Saved forecast CSV: {csv_file}") + print(f"Saved forecast plot: {plot_file}") + print(f"Saved summary JSON: {summary_file}") + + +if __name__ == "__main__": + main() diff --git a/finetune/config.py b/finetune/config.py index 04cc3ee3..009a4ba6 100644 --- a/finetune/config.py +++ b/finetune/config.py @@ -1,131 +1,127 @@ -import os - -class Config: - """ - Configuration class for the entire project. - """ - - def __init__(self): - # ================================================================= - # Data & Feature Parameters - # ================================================================= - # TODO: Update this path to your Qlib data directory. - self.qlib_data_path = "~/.qlib/qlib_data/cn_data" - self.instrument = 'csi300' - - # Overall time range for data loading from Qlib. - self.dataset_begin_time = "2011-01-01" - self.dataset_end_time = '2025-06-05' - - # Sliding window parameters for creating samples. - self.lookback_window = 90 # Number of past time steps for input. - self.predict_window = 10 # Number of future time steps for prediction. - self.max_context = 512 # Maximum context length for the model. - - # Features to be used from the raw data. - self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt'] - # Time-based features to be generated. - self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month'] - - # ================================================================= - # Dataset Splitting & Paths - # ================================================================= - # Note: The validation/test set starts earlier than the training/validation set ends - # to account for the `lookback_window`. - self.train_time_range = ["2011-01-01", "2022-12-31"] - self.val_time_range = ["2022-09-01", "2024-06-30"] - self.test_time_range = ["2024-04-01", "2025-06-05"] - self.backtest_time_range = ["2024-07-01", "2025-06-05"] - - # TODO: Directory to save the processed, pickled datasets. - self.dataset_path = "./data/processed_datasets" - - # ================================================================= - # Training Hyperparameters - # ================================================================= - self.clip = 5.0 # Clipping value for normalized data to prevent outliers. - - self.epochs = 30 - self.log_interval = 100 # Log training status every N batches. - self.batch_size = 50 # Batch size per GPU. - - # Number of samples to draw for one "epoch" of training/validation. - # This is useful for large datasets where a true epoch is too long. - self.n_train_iter = 2000 * self.batch_size - self.n_val_iter = 400 * self.batch_size - - # Learning rates for different model components. - self.tokenizer_learning_rate = 2e-4 - self.predictor_learning_rate = 4e-5 - - # Gradient accumulation to simulate a larger batch size. - self.accumulation_steps = 1 - - # AdamW optimizer parameters. - self.adam_beta1 = 0.9 - self.adam_beta2 = 0.95 - self.adam_weight_decay = 0.1 - - # Miscellaneous - self.seed = 100 # Global random seed for reproducibility. - - # ================================================================= - # Experiment Logging & Saving - # ================================================================= - self.use_comet = True # Set to False if you don't want to use Comet ML - self.comet_config = { - # It is highly recommended to load secrets from environment variables - # for security purposes. Example: os.getenv("COMET_API_KEY") - "api_key": "YOUR_COMET_API_KEY", - "project_name": "Kronos-Finetune-Demo", - "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name - } - self.comet_tag = 'finetune_demo' - self.comet_name = 'finetune_demo' - - # Base directory for saving model checkpoints and results. - # Using a general 'outputs' directory is a common practice. - self.save_path = "./outputs/models" - self.tokenizer_save_folder_name = 'finetune_tokenizer_demo' - self.predictor_save_folder_name = 'finetune_predictor_demo' - self.backtest_save_folder_name = 'finetune_backtest_demo' - - # Path for backtesting results. - self.backtest_result_path = "./outputs/backtest_results" - - # ================================================================= - # Model & Checkpoint Paths - # ================================================================= - # TODO: Update these paths to your pretrained model locations. - # These can be local paths or Hugging Face Hub model identifiers. - self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base" - self.pretrained_predictor_path = "path/to/your/Kronos-small" - - # Paths to the fine-tuned models, derived from the save_path. - # These will be generated automatically during training. - self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model" - self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model" - - # ================================================================= - # Backtesting Parameters - # ================================================================= - self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio. - self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool. - self.backtest_hold_thresh = 5 # Minimum holding period for a stock. - self.inference_T = 0.6 - self.inference_top_p = 0.9 - self.inference_top_k = 0 - self.inference_sample_count = 5 - self.backtest_batch_size = 1000 - self.backtest_benchmark = self._set_benchmark(self.instrument) - - def _set_benchmark(self, instrument): - dt_benchmark = { - 'csi800': "SH000906", - 'csi1000': "SH000852", - 'csi300': "SH000300", - } - if instrument in dt_benchmark: - return dt_benchmark[instrument] - else: - raise ValueError(f"Benchmark not defined for instrument: {instrument}") +import copy +import json +from pathlib import Path + + +class Config: + """ + Configuration class for finetuning and Qlib experiments. + + Defaults preserve the repository's original intended experiment setup. + Use a config file or explicit overrides for local smoke tests and other + environment-specific runs. + """ + + DEFAULTS = { + "qlib_data_path": "~/.qlib/qlib_data/cn_data", + "instrument": "csi300", + "dataset_begin_time": "2011-01-01", + "dataset_end_time": "2025-06-05", + "lookback_window": 90, + "predict_window": 10, + "max_context": 512, + "feature_list": ["open", "high", "low", "close", "vol", "amt"], + "time_feature_list": ["minute", "hour", "weekday", "day", "month"], + "train_time_range": ["2011-01-01", "2022-12-31"], + "val_time_range": ["2022-09-01", "2024-06-30"], + "test_time_range": ["2024-04-01", "2025-06-05"], + "backtest_time_range": ["2024-07-01", "2025-06-05"], + "dataset_path": "./data/processed_datasets", + "clip": 5.0, + "epochs": 30, + "log_interval": 100, + "batch_size": 50, + "n_train_iter": 2000 * 50, + "n_val_iter": 400 * 50, + "tokenizer_learning_rate": 2e-4, + "predictor_learning_rate": 4e-5, + "accumulation_steps": 1, + "adam_beta1": 0.9, + "adam_beta2": 0.95, + "adam_weight_decay": 0.1, + "seed": 100, + "use_comet": True, + "comet_config": { + "api_key": "YOUR_COMET_API_KEY", + "project_name": "Kronos-Finetune-Demo", + "workspace": "your_comet_workspace", + }, + "comet_tag": "finetune_demo", + "comet_name": "finetune_demo", + "save_path": "./outputs/models", + "tokenizer_save_folder_name": "finetune_tokenizer_demo", + "predictor_save_folder_name": "finetune_predictor_demo", + "backtest_save_folder_name": "finetune_backtest_demo", + "backtest_result_path": "./outputs/backtest_results", + "pretrained_tokenizer_path": "path/to/your/Kronos-Tokenizer-base", + "pretrained_predictor_path": "path/to/your/Kronos-small", + "finetuned_tokenizer_path": "__AUTO__", + "finetuned_predictor_path": "__AUTO__", + "backtest_n_symbol_hold": 50, + "backtest_n_symbol_drop": 5, + "backtest_hold_thresh": 5, + "inference_T": 0.6, + "inference_top_p": 0.9, + "inference_top_k": 0, + "inference_sample_count": 5, + "backtest_batch_size": 1000, + } + + def __init__(self, config_file: str | None = None, overrides: dict | None = None): + defaults = copy.deepcopy(self.DEFAULTS) + for key, value in defaults.items(): + setattr(self, key, value) + self._explicit_finetuned_tokenizer_path = False + self._explicit_finetuned_predictor_path = False + + if config_file is not None: + self.apply_overrides(self._load_config_file(config_file)) + if overrides: + self.apply_overrides(overrides) + + self._refresh_derived_fields() + + @staticmethod + def _load_config_file(config_file: str) -> dict: + path = Path(config_file).expanduser() + with path.open("r", encoding="utf-8") as handle: + return json.load(handle) + + def apply_overrides(self, overrides: dict) -> None: + for key, value in overrides.items(): + if not hasattr(self, key): + raise AttributeError(f"Unknown config key: {key}") + if key == "finetuned_tokenizer_path": + self._explicit_finetuned_tokenizer_path = True + if key == "finetuned_predictor_path": + self._explicit_finetuned_predictor_path = True + setattr(self, key, copy.deepcopy(value)) + self._refresh_derived_fields() + + def _refresh_derived_fields(self) -> None: + default_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model" + default_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model" + + if not self._explicit_finetuned_tokenizer_path: + self.finetuned_tokenizer_path = default_tokenizer_path + if not self._explicit_finetuned_predictor_path: + self.finetuned_predictor_path = default_predictor_path + + self.backtest_benchmark = self._set_benchmark(self.instrument) + + def to_dict(self) -> dict: + return { + key: copy.deepcopy(value) + for key, value in self.__dict__.items() + if not key.startswith("_") + } + + def _set_benchmark(self, instrument): + dt_benchmark = { + "csi800": "SH000906", + "csi1000": "SH000852", + "csi300": "SH000300", + } + if instrument in dt_benchmark: + return dt_benchmark[instrument] + raise ValueError(f"Benchmark not defined for instrument: {instrument}") diff --git a/finetune/dataset.py b/finetune/dataset.py index f955ec12..641d4cd8 100644 --- a/finetune/dataset.py +++ b/finetune/dataset.py @@ -6,7 +6,7 @@ from config import Config -class QlibDataset(Dataset): +class QlibDataset(Dataset): """ A PyTorch Dataset for handling Qlib financial time series data. @@ -20,8 +20,8 @@ class QlibDataset(Dataset): ValueError: If `data_type` is not 'train' or 'val'. """ - def __init__(self, data_type: str = 'train'): - self.config = Config() + def __init__(self, data_type: str = 'train', config: Config | None = None): + self.config = config or Config() if data_type not in ['train', 'val']: raise ValueError("data_type must be 'train' or 'val'") self.data_type = data_type diff --git a/finetune/local_configs/qlib_smoke_backtest_base.example.json b/finetune/local_configs/qlib_smoke_backtest_base.example.json new file mode 100644 index 00000000..8ae26dcf --- /dev/null +++ b/finetune/local_configs/qlib_smoke_backtest_base.example.json @@ -0,0 +1,25 @@ +{ + "dataset_end_time": "2020-09-25", + "train_time_range": [ + "2011-01-01", + "2018-12-31" + ], + "val_time_range": [ + "2018-09-01", + "2020-03-31" + ], + "test_time_range": [ + "2020-04-01", + "2020-09-25" + ], + "backtest_time_range": [ + "2020-07-01", + "2020-09-24" + ], + "pretrained_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base", + "pretrained_predictor_path": "NeoQuasar/Kronos-base", + "finetuned_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base", + "finetuned_predictor_path": "NeoQuasar/Kronos-base", + "inference_sample_count": 1, + "backtest_save_folder_name": "qlib_smoke_backtest_base" +} diff --git a/finetune/local_configs/qlib_smoke_backtest_small.example.json b/finetune/local_configs/qlib_smoke_backtest_small.example.json new file mode 100644 index 00000000..763f6242 --- /dev/null +++ b/finetune/local_configs/qlib_smoke_backtest_small.example.json @@ -0,0 +1,25 @@ +{ + "dataset_end_time": "2020-09-25", + "train_time_range": [ + "2011-01-01", + "2018-12-31" + ], + "val_time_range": [ + "2018-09-01", + "2020-03-31" + ], + "test_time_range": [ + "2020-04-01", + "2020-09-25" + ], + "backtest_time_range": [ + "2020-07-01", + "2020-09-24" + ], + "pretrained_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base", + "pretrained_predictor_path": "NeoQuasar/Kronos-small", + "finetuned_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base", + "finetuned_predictor_path": "NeoQuasar/Kronos-small", + "inference_sample_count": 1, + "backtest_save_folder_name": "qlib_smoke_backtest_small" +} diff --git a/finetune/qlib_data_preprocess.py b/finetune/qlib_data_preprocess.py index 7fe6147b..5335a566 100644 --- a/finetune/qlib_data_preprocess.py +++ b/finetune/qlib_data_preprocess.py @@ -1,7 +1,8 @@ -import os -import pickle -import numpy as np -import pandas as pd +import os +import pickle +import argparse +import numpy as np +import pandas as pd import qlib from qlib.config import REG_CN from qlib.data import D @@ -11,16 +12,16 @@ from config import Config -class QlibDataPreprocessor: +class QlibDataPreprocessor: """ A class to handle the loading, processing, and splitting of Qlib financial data. """ - def __init__(self): - """Initializes the preprocessor with configuration and data fields.""" - self.config = Config() - self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap'] - self.data = {} # A dictionary to store processed data for each symbol. + def __init__(self, config_file: str | None = None): + """Initializes the preprocessor with configuration and data fields.""" + self.config = Config(config_file=config_file) + self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap'] + self.data = {} # A dictionary to store processed data for each symbol. def initialize_qlib(self): """Initializes the Qlib environment.""" @@ -121,10 +122,13 @@ def prepare_dataset(self): print("Datasets prepared and saved successfully.") -if __name__ == '__main__': - # This block allows the script to be run directly to perform data preprocessing. - preprocessor = QlibDataPreprocessor() - preprocessor.initialize_qlib() - preprocessor.load_qlib_data() - preprocessor.prepare_dataset() +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="Prepare Qlib train/val/test datasets.") + parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.") + args = parser.parse_args() + + preprocessor = QlibDataPreprocessor(config_file=args.config_file) + preprocessor.initialize_qlib() + preprocessor.load_qlib_data() + preprocessor.prepare_dataset() diff --git a/finetune/qlib_test.py b/finetune/qlib_test.py index 97aebe41..0702001e 100644 --- a/finetune/qlib_test.py +++ b/finetune/qlib_test.py @@ -1,362 +1,382 @@ -import os -import sys -import argparse -import pickle -from collections import defaultdict - -import numpy as np -import pandas as pd -import torch -from torch.utils.data import Dataset, DataLoader -from tqdm import trange, tqdm -from matplotlib import pyplot as plt - -import qlib -from qlib.config import REG_CN -from qlib.backtest import backtest, executor, CommonInfrastructure -from qlib.contrib.evaluate import risk_analysis -from qlib.contrib.strategy import TopkDropoutStrategy -from qlib.utils import flatten_dict -from qlib.utils.time import Freq - -# Ensure project root is in the Python path -sys.path.append("../") -from config import Config -from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference - - -# ================================================================================= -# 1. Data Loading and Processing for Inference -# ================================================================================= - -class QlibTestDataset(Dataset): - """ - PyTorch Dataset for handling Qlib test data, specifically for inference. - - This dataset iterates through all possible sliding windows sequentially. It also - yields metadata like symbol and timestamp, which are crucial for mapping - predictions back to the original time series. - """ - - def __init__(self, data: dict, config: Config): - self.data = data - self.config = config - self.window_size = config.lookback_window + config.predict_window - self.symbols = list(self.data.keys()) - self.feature_list = config.feature_list - self.time_feature_list = config.time_feature_list - self.indices = [] - - print("Preprocessing and building indices for test dataset...") - for symbol in self.symbols: - df = self.data[symbol].reset_index() - # Generate time features on-the-fly - df['minute'] = df['datetime'].dt.minute - df['hour'] = df['datetime'].dt.hour - df['weekday'] = df['datetime'].dt.weekday - df['day'] = df['datetime'].dt.day - df['month'] = df['datetime'].dt.month - self.data[symbol] = df # Store preprocessed dataframe - - num_samples = len(df) - self.window_size + 1 - if num_samples > 0: - for i in range(num_samples): - timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime'] - self.indices.append((symbol, i, timestamp)) - - def __len__(self) -> int: - return len(self.indices) - - def __getitem__(self, idx: int): - symbol, start_idx, timestamp = self.indices[idx] - df = self.data[symbol] - - context_end = start_idx + self.config.lookback_window - predict_end = context_end + self.config.predict_window - - context_df = df.iloc[start_idx:context_end] - predict_df = df.iloc[context_end:predict_end] - - x = context_df[self.feature_list].values.astype(np.float32) - x_stamp = context_df[self.time_feature_list].values.astype(np.float32) - y_stamp = predict_df[self.time_feature_list].values.astype(np.float32) - - # Instance-level normalization, consistent with training - x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) - x = (x - x_mean) / (x_std + 1e-5) - x = np.clip(x, -self.config.clip, self.config.clip) - - return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp - - -# ================================================================================= -# 2. Backtesting Logic -# ================================================================================= - -class QlibBacktest: - """ - A wrapper class for conducting backtesting experiments using Qlib. - """ - - def __init__(self, config: Config): - self.config = config - self.initialize_qlib() - - def initialize_qlib(self): - """Initializes the Qlib environment.""" - print("Initializing Qlib for backtesting...") - qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) - - def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame: - """ - Runs a single backtest for a given prediction signal. - - Args: - signal_series (pd.Series): A pandas Series with a MultiIndex - (instrument, datetime) and prediction scores. - Returns: - pd.DataFrame: A DataFrame containing the performance report. - """ - strategy = TopkDropoutStrategy( - topk=self.config.backtest_n_symbol_hold, - n_drop=self.config.backtest_n_symbol_drop, - hold_thresh=self.config.backtest_hold_thresh, - signal=signal_series, - ) - executor_config = { - "time_per_step": "day", - "generate_portfolio_metrics": True, - "delay_execution": True, - } - backtest_config = { - "start_time": self.config.backtest_time_range[0], - "end_time": self.config.backtest_time_range[1], - "account": 100_000_000, - "benchmark": self.config.backtest_benchmark, - "exchange_kwargs": { - "freq": "day", "limit_threshold": 0.095, "deal_price": "open", - "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5, - }, - "executor": executor.SimulatorExecutor(**executor_config), - } - - portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config) - analysis_freq = "{0}{1}".format(*Freq.parse("day")) - report, _ = portfolio_metric_dict.get(analysis_freq) - - # --- Analysis and Reporting --- - analysis = { - "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq), - "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq), - } - print("\n--- Backtest Analysis ---") - print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n') - print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n') - print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n') - - report_df = pd.DataFrame({ - "cum_bench": report["bench"].cumsum(), - "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(), - "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(), - }) - return report_df - - def run_and_plot_results(self, signals: dict[str, pd.DataFrame]): - """ - Runs backtests for multiple signals and plots the cumulative return curves. - - Args: - signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names - and values are prediction DataFrames. - """ - return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() - - for signal_name, pred_df in signals.items(): - print(f"\nBacktesting signal: {signal_name}...") - pred_series = pred_df.stack() - pred_series.index.names = ['datetime', 'instrument'] - pred_series = pred_series.swaplevel().sort_index() - report_df = self.run_single_backtest(pred_series) - - return_df[signal_name] = report_df['cum_return_w_cost'] - ex_return_df[signal_name] = report_df['cum_ex_return_w_cost'] - if 'return' not in bench_df: - bench_df['return'] = report_df['cum_bench'] - - # Plotting results - fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True) - return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True) - axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--') - axes[0].legend() - axes[0].set_ylabel("Cumulative Return") - - ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True) - axes[1].legend() - axes[1].set_xlabel("Date") - axes[1].set_ylabel("Cumulative Excess Return") - - plt.tight_layout() - plt.savefig("../figures/backtest_result_example.png", dpi=200) - plt.show() - - -# ================================================================================= -# 3. Inference Logic -# ================================================================================= - -def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]: - """Loads the fine-tuned tokenizer and predictor model.""" - device = torch.device(config['device']) - print(f"Loading models onto device: {device}...") - tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval() - model = Kronos.from_pretrained(config['model_path']).to(device).eval() - return tokenizer, model - - -def collate_fn_for_inference(batch): - """ - Custom collate function to handle batches containing Tensors, strings, and Timestamps. - - Args: - batch (list): A list of samples, where each sample is the tuple returned by - QlibTestDataset.__getitem__. - - Returns: - A single tuple containing the batched data. - """ - # Unzip the list of samples into separate lists for each data type - x, x_stamp, y_stamp, symbols, timestamps = zip(*batch) - - # Stack the tensors to create a batch - x_batch = torch.stack(x, dim=0) - x_stamp_batch = torch.stack(x_stamp, dim=0) - y_stamp_batch = torch.stack(y_stamp, dim=0) - - # Return the strings and timestamps as lists - return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps) - - -def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]: - """ - Runs inference on the test dataset to generate prediction signals. - - Args: - config (dict): A dictionary containing inference parameters. - test_data (dict): The raw test data loaded from a pickle file. - - Returns: - A dictionary where keys are signal types (e.g., 'mean', 'last') and - values are DataFrames of predictions (datetime index, symbol columns). - """ - tokenizer, model = load_models(config) - device = torch.device(config['device']) - - # Use the Dataset and DataLoader for efficient batching and processing - dataset = QlibTestDataset(data=test_data, config=Config()) - loader = DataLoader( - dataset, - batch_size=config['batch_size'] // config['sample_count'], - shuffle=False, - num_workers=os.cpu_count() // 2, - collate_fn=collate_fn_for_inference - ) - - results = defaultdict(list) - with torch.no_grad(): - for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"): - preds = auto_regressive_inference( - tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device), - max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'], - T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count'] - ) - # You can try commenting on this line to keep the history data - preds = preds[:, -config['pred_len']:, :] - - # The 'close' price is at index 3 in `feature_list` - last_day_close = x[:, -1, 3].numpy() - signals = { - 'last': preds[:, -1, 3] - last_day_close, - 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close, - 'max': np.max(preds[:, :, 3], axis=1) - last_day_close, - 'min': np.min(preds[:, :, 3], axis=1) - last_day_close, - } - - for i in range(len(symbols)): - for sig_type, sig_values in signals.items(): - results[sig_type].append((timestamps[i], symbols[i], sig_values[i])) - - print("Post-processing predictions into DataFrames...") - prediction_dfs = {} - for sig_type, records in results.items(): - df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score']) - pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score') - prediction_dfs[sig_type] = pivot_df.sort_index() - - return prediction_dfs - - -# ================================================================================= -# 4. Main Execution -# ================================================================================= - -def main(): - """Main function to set up config, run inference, and execute backtesting.""" - parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting") - parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')") - args = parser.parse_args() - - # --- 1. Configuration Setup --- - base_config = Config() - - # Create a dedicated dictionary for this run's configuration - run_config = { - 'device': args.device, - 'data_path': base_config.dataset_path, - 'result_save_path': base_config.backtest_result_path, - 'result_name': base_config.backtest_save_folder_name, - 'tokenizer_path': base_config.finetuned_tokenizer_path, - 'model_path': base_config.finetuned_predictor_path, - 'max_context': base_config.max_context, - 'pred_len': base_config.predict_window, - 'clip': base_config.clip, - 'T': base_config.inference_T, - 'top_k': base_config.inference_top_k, - 'top_p': base_config.inference_top_p, - 'sample_count': base_config.inference_sample_count, - 'batch_size': base_config.backtest_batch_size, - } - - print("--- Running with Configuration ---") - for key, val in run_config.items(): - print(f"{key:>20}: {val}") - print("-" * 35) - - # --- 2. Load Data --- - test_data_path = os.path.join(run_config['data_path'], "test_data.pkl") - print(f"Loading test data from {test_data_path}...") - with open(test_data_path, 'rb') as f: - test_data = pickle.load(f) - print(test_data) - # --- 3. Generate Predictions --- - model_preds = generate_predictions(run_config, test_data) - - # --- 4. Save Predictions --- - save_dir = os.path.join(run_config['result_save_path'], run_config['result_name']) - os.makedirs(save_dir, exist_ok=True) - predictions_file = os.path.join(save_dir, "predictions.pkl") - print(f"Saving prediction signals to {predictions_file}...") - with open(predictions_file, 'wb') as f: - pickle.dump(model_preds, f) - - # --- 5. Run Backtesting --- - with open(predictions_file, 'rb') as f: - model_preds = pickle.load(f) - - backtester = QlibBacktest(base_config) - backtester.run_and_plot_results(model_preds) - - -if __name__ == '__main__': - main() - - +import argparse +import json +import os +import pickle +import sys +from collections import defaultdict +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from matplotlib import pyplot as plt +from torch.utils.data import DataLoader, Dataset +from tqdm import tqdm + +import qlib +from qlib.backtest import backtest, executor +from qlib.config import REG_CN +from qlib.contrib.evaluate import risk_analysis +from qlib.contrib.strategy import TopkDropoutStrategy +from qlib.utils.time import Freq + +# Ensure project root is in the Python path +sys.path.append("../") +from config import Config +from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference + + +def save_json(path: Path, payload: dict) -> None: + with path.open("w", encoding="utf-8") as handle: + json.dump(payload, handle, indent=2, ensure_ascii=False) + + +def analysis_frame_to_dict(frame: pd.DataFrame) -> dict[str, float]: + if "risk" in frame.columns: + series = frame["risk"] + else: + series = frame.iloc[:, 0] + return {str(key): float(value) for key, value in series.items()} + + +class QlibTestDataset(Dataset): + """ + PyTorch Dataset for handling Qlib test data, specifically for inference. + """ + + def __init__(self, data: dict, config: Config): + self.data = data + self.config = config + self.window_size = config.lookback_window + config.predict_window + self.symbols = list(self.data.keys()) + self.feature_list = config.feature_list + self.time_feature_list = config.time_feature_list + self.indices = [] + + print("Preprocessing and building indices for test dataset...") + for symbol in self.symbols: + df = self.data[symbol].reset_index() + df["minute"] = df["datetime"].dt.minute + df["hour"] = df["datetime"].dt.hour + df["weekday"] = df["datetime"].dt.weekday + df["day"] = df["datetime"].dt.day + df["month"] = df["datetime"].dt.month + self.data[symbol] = df + + num_samples = len(df) - self.window_size + 1 + if num_samples > 0: + for idx in range(num_samples): + timestamp = df.iloc[idx + self.config.lookback_window - 1]["datetime"] + self.indices.append((symbol, idx, timestamp)) + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, idx: int): + symbol, start_idx, timestamp = self.indices[idx] + df = self.data[symbol] + + context_end = start_idx + self.config.lookback_window + predict_end = context_end + self.config.predict_window + + context_df = df.iloc[start_idx:context_end] + predict_df = df.iloc[context_end:predict_end] + + x = context_df[self.feature_list].values.astype(np.float32) + x_stamp = context_df[self.time_feature_list].values.astype(np.float32) + y_stamp = predict_df[self.time_feature_list].values.astype(np.float32) + + x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0) + x = (x - x_mean) / (x_std + 1e-5) + x = np.clip(x, -self.config.clip, self.config.clip) + + return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp + + +class QlibBacktest: + """ + A wrapper class for conducting Qlib backtests and saving structured results. + """ + + def __init__(self, config: Config, save_dir: Path, show_plot: bool): + self.config = config + self.save_dir = save_dir + self.show_plot = show_plot + self.initialize_qlib() + + def initialize_qlib(self): + print("Initializing Qlib for backtesting...") + qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN) + + def run_single_backtest(self, signal_series: pd.Series) -> tuple[pd.DataFrame, dict[str, dict[str, float]]]: + strategy = TopkDropoutStrategy( + topk=self.config.backtest_n_symbol_hold, + n_drop=self.config.backtest_n_symbol_drop, + hold_thresh=self.config.backtest_hold_thresh, + signal=signal_series, + ) + executor_config = { + "time_per_step": "day", + "generate_portfolio_metrics": True, + "delay_execution": True, + } + backtest_config = { + "start_time": self.config.backtest_time_range[0], + "end_time": self.config.backtest_time_range[1], + "account": 100_000_000, + "benchmark": self.config.backtest_benchmark, + "exchange_kwargs": { + "freq": "day", + "limit_threshold": 0.095, + "deal_price": "open", + "open_cost": 0.001, + "close_cost": 0.0015, + "min_cost": 5, + }, + "executor": executor.SimulatorExecutor(**executor_config), + } + + portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config) + analysis_freq = "{0}{1}".format(*Freq.parse("day")) + report, _ = portfolio_metric_dict.get(analysis_freq) + + benchmark_metrics = risk_analysis(report["bench"], freq=analysis_freq) + excess_without_cost = risk_analysis(report["return"] - report["bench"], freq=analysis_freq) + excess_with_cost = risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq) + + print("\n--- Backtest Analysis ---") + print("Benchmark Return:", benchmark_metrics, sep="\n") + print("\nExcess Return (w/o cost):", excess_without_cost, sep="\n") + print("\nExcess Return (w/ cost):", excess_with_cost, sep="\n") + + report_df = pd.DataFrame( + { + "cum_bench": report["bench"].cumsum(), + "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(), + "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(), + } + ) + analysis = { + "benchmark": analysis_frame_to_dict(benchmark_metrics), + "excess_return_without_cost": analysis_frame_to_dict(excess_without_cost), + "excess_return_with_cost": analysis_frame_to_dict(excess_with_cost), + } + return report_df, analysis + + def run_and_plot_results(self, signals: dict[str, pd.DataFrame]): + return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame() + summary_rows = [] + + for signal_name, pred_df in signals.items(): + print(f"\nBacktesting signal: {signal_name}...") + pred_series = pred_df.stack() + pred_series.index.names = ["datetime", "instrument"] + pred_series = pred_series.swaplevel().sort_index() + report_df, analysis = self.run_single_backtest(pred_series) + + report_df.to_csv(self.save_dir / f"{signal_name}_curve.csv", index_label="datetime") + save_json(self.save_dir / f"{signal_name}_metrics.json", analysis) + + return_df[signal_name] = report_df["cum_return_w_cost"] + ex_return_df[signal_name] = report_df["cum_ex_return_w_cost"] + if "return" not in bench_df: + bench_df["return"] = report_df["cum_bench"] + + summary_rows.append( + { + "signal": signal_name, + "benchmark_annualized_return": analysis["benchmark"]["annualized_return"], + "benchmark_information_ratio": analysis["benchmark"]["information_ratio"], + "benchmark_max_drawdown": analysis["benchmark"]["max_drawdown"], + "excess_annualized_return_without_cost": analysis["excess_return_without_cost"]["annualized_return"], + "excess_information_ratio_without_cost": analysis["excess_return_without_cost"]["information_ratio"], + "excess_max_drawdown_without_cost": analysis["excess_return_without_cost"]["max_drawdown"], + "excess_annualized_return_with_cost": analysis["excess_return_with_cost"]["annualized_return"], + "excess_information_ratio_with_cost": analysis["excess_return_with_cost"]["information_ratio"], + "excess_max_drawdown_with_cost": analysis["excess_return_with_cost"]["max_drawdown"], + } + ) + + summary_df = pd.DataFrame(summary_rows) + summary_df.to_csv(self.save_dir / "backtest_summary.csv", index=False) + save_json(self.save_dir / "backtest_summary.json", {"signals": summary_rows}) + + fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True) + return_df.plot(ax=axes[0], title="Cumulative Return with Cost", grid=True) + axes[0].plot(bench_df["return"], label=self.config.instrument.upper(), color="black", linestyle="--") + axes[0].legend() + axes[0].set_ylabel("Cumulative Return") + + ex_return_df.plot(ax=axes[1], title="Cumulative Excess Return with Cost", grid=True) + axes[1].legend() + axes[1].set_xlabel("Date") + axes[1].set_ylabel("Cumulative Excess Return") + + plt.tight_layout() + plot_path = self.save_dir / "backtest_curves.png" + plt.savefig(plot_path, dpi=200) + if self.show_plot: + plt.show() + else: + plt.close(fig) + + print(f"Saved backtest summary CSV: {self.save_dir / 'backtest_summary.csv'}") + print(f"Saved backtest plot: {plot_path}") + + +def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]: + device = torch.device(config["device"]) + print(f"Loading models onto device: {device}...") + tokenizer = KronosTokenizer.from_pretrained(config["tokenizer_path"]).to(device).eval() + model = Kronos.from_pretrained(config["model_path"]).to(device).eval() + return tokenizer, model + + +def collate_fn_for_inference(batch): + x, x_stamp, y_stamp, symbols, timestamps = zip(*batch) + x_batch = torch.stack(x, dim=0) + x_stamp_batch = torch.stack(x_stamp, dim=0) + y_stamp_batch = torch.stack(y_stamp, dim=0) + return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps) + + +def generate_predictions(run_config: dict, test_data: dict) -> dict[str, pd.DataFrame]: + tokenizer, model = load_models(run_config) + device = torch.device(run_config["device"]) + + dataset = QlibTestDataset(data=test_data, config=run_config["config_obj"]) + loader = DataLoader( + dataset, + batch_size=max(1, run_config["batch_size"] // run_config["sample_count"]), + shuffle=False, + num_workers=run_config["num_workers"], + collate_fn=collate_fn_for_inference, + ) + + results = defaultdict(list) + with torch.no_grad(): + for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"): + preds = auto_regressive_inference( + tokenizer, + model, + x.to(device), + x_stamp.to(device), + y_stamp.to(device), + max_context=run_config["max_context"], + pred_len=run_config["pred_len"], + clip=run_config["clip"], + T=run_config["T"], + top_k=run_config["top_k"], + top_p=run_config["top_p"], + sample_count=run_config["sample_count"], + ) + preds = preds[:, -run_config["pred_len"] :, :] + last_day_close = x[:, -1, 3].numpy() + signals = { + "last": preds[:, -1, 3] - last_day_close, + "mean": np.mean(preds[:, :, 3], axis=1) - last_day_close, + "max": np.max(preds[:, :, 3], axis=1) - last_day_close, + "min": np.min(preds[:, :, 3], axis=1) - last_day_close, + } + + for idx in range(len(symbols)): + for sig_type, sig_values in signals.items(): + results[sig_type].append((timestamps[idx], symbols[idx], sig_values[idx])) + + print("Post-processing predictions into DataFrames...") + prediction_dfs = {} + for sig_type, records in results.items(): + df = pd.DataFrame(records, columns=["datetime", "instrument", "score"]) + pivot_df = df.pivot_table(index="datetime", columns="instrument", values="score") + prediction_dfs[sig_type] = pivot_df.sort_index() + + return prediction_dfs + + +def main(): + parser = argparse.ArgumentParser(description="Run Kronos inference and Qlib backtesting.") + parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g. cuda:0, mps, cpu).") + parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.") + parser.add_argument("--tokenizer-path", type=str, default=None, help="Explicit tokenizer path or Hugging Face model ID.") + parser.add_argument("--model-path", type=str, default=None, help="Explicit predictor path or Hugging Face model ID.") + parser.add_argument("--run-name", type=str, default=None, help="Optional override for the backtest result folder.") + parser.add_argument("--save-only", action="store_true", help="Save plots to disk without opening a GUI window.") + parser.add_argument("--num-workers", type=int, default=0, help="DataLoader worker count for inference.") + args = parser.parse_args() + + overrides = {} + if args.run_name: + overrides["backtest_save_folder_name"] = args.run_name + if args.tokenizer_path: + overrides["finetuned_tokenizer_path"] = args.tokenizer_path + if args.model_path: + overrides["finetuned_predictor_path"] = args.model_path + + base_config = Config(config_file=args.config_file, overrides=overrides if overrides else None) + run_config = { + "config_obj": base_config, + "device": args.device, + "data_path": base_config.dataset_path, + "result_save_path": base_config.backtest_result_path, + "result_name": base_config.backtest_save_folder_name, + "tokenizer_path": base_config.finetuned_tokenizer_path, + "model_path": base_config.finetuned_predictor_path, + "max_context": base_config.max_context, + "pred_len": base_config.predict_window, + "clip": base_config.clip, + "T": base_config.inference_T, + "top_k": base_config.inference_top_k, + "top_p": base_config.inference_top_p, + "sample_count": base_config.inference_sample_count, + "batch_size": base_config.backtest_batch_size, + "num_workers": args.num_workers, + } + + print("--- Running with Configuration ---") + for key, val in run_config.items(): + if key == "config_obj": + continue + print(f"{key:>20}: {val}") + print("-" * 35) + + test_data_path = Path(run_config["data_path"]) / "test_data.pkl" + print(f"Loading test data from {test_data_path}...") + with test_data_path.open("rb") as handle: + test_data = pickle.load(handle) + non_empty_symbols = sum(1 for df in test_data.values() if not df.empty) + print(f"Loaded {len(test_data)} symbols from test data, {non_empty_symbols} non-empty.") + if non_empty_symbols == 0: + raise RuntimeError( + "test_data.pkl contains no non-empty symbols. Adjust the config ranges or local Qlib dataset coverage." + ) + + model_preds = generate_predictions(run_config, test_data) + + save_dir = Path(run_config["result_save_path"]) / run_config["result_name"] + save_dir.mkdir(parents=True, exist_ok=True) + predictions_file = save_dir / "predictions.pkl" + print(f"Saving prediction signals to {predictions_file}...") + with predictions_file.open("wb") as handle: + pickle.dump(model_preds, handle) + + save_json( + save_dir / "run_config.json", + { + "device": args.device, + "config_file": args.config_file, + "tokenizer_path": run_config["tokenizer_path"], + "model_path": run_config["model_path"], + "result_save_path": str(save_dir), + "save_only": args.save_only, + "num_workers": args.num_workers, + "config": base_config.to_dict(), + "non_empty_symbols": non_empty_symbols, + }, + ) + + backtester = QlibBacktest(base_config, save_dir=save_dir, show_plot=not args.save_only) + backtester.run_and_plot_results(model_preds) + + +if __name__ == "__main__": + main() diff --git a/finetune/train_predictor.py b/finetune/train_predictor.py index 1e425872..6eaf96ec 100644 --- a/finetune/train_predictor.py +++ b/finetune/train_predictor.py @@ -1,15 +1,19 @@ -import os -import sys -import json -import time -from time import gmtime, strftime -import torch.distributed as dist +import os +import sys +import json +import time +from time import gmtime, strftime +import argparse +import torch.distributed as dist import torch from torch.utils.data import DataLoader from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP -import comet_ml +try: + import comet_ml +except ImportError: + comet_ml = None # Ensure project root is in path sys.path.append('../') @@ -26,7 +30,7 @@ ) -def create_dataloaders(config: dict, rank: int, world_size: int): +def create_dataloaders(config: dict, rank: int, world_size: int): """ Creates and returns distributed dataloaders for training and validation. @@ -39,8 +43,9 @@ def create_dataloaders(config: dict, rank: int, world_size: int): tuple: (train_loader, val_loader, train_dataset, valid_dataset). """ print(f"[Rank {rank}] Creating distributed dataloaders...") - train_dataset = QlibDataset('train') - valid_dataset = QlibDataset('val') + dataset_config = Config(overrides=config) + train_dataset = QlibDataset('train', config=dataset_config) + valid_dataset = QlibDataset('val', config=dataset_config) print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) @@ -196,16 +201,18 @@ def main(config: dict): 'save_directory': save_dir, 'world_size': world_size, } - if config['use_comet']: - comet_logger = comet_ml.Experiment( - api_key=config['comet_config']['api_key'], - project_name=config['comet_config']['project_name'], - workspace=config['comet_config']['workspace'], - ) - comet_logger.add_tag(config['comet_tag']) - comet_logger.set_name(config['comet_name']) - comet_logger.log_parameters(config) - print("Comet Logger Initialized.") + if config['use_comet'] and comet_ml is not None: + comet_logger = comet_ml.Experiment( + api_key=config['comet_config']['api_key'], + project_name=config['comet_config']['project_name'], + workspace=config['comet_config']['workspace'], + ) + comet_logger.add_tag(config['comet_tag']) + comet_logger.set_name(config['comet_name']) + comet_logger.log_parameters(config) + print("Comet Logger Initialized.") + elif config['use_comet']: + print("comet_ml is not installed; continuing without experiment logging.") dist.barrier() @@ -235,10 +242,14 @@ def main(config: dict): cleanup_ddp() -if __name__ == '__main__': - # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py - if "WORLD_SIZE" not in os.environ: - raise RuntimeError("This script must be launched with `torchrun`.") - - config_instance = Config() - main(config_instance.__dict__) +if __name__ == '__main__': + # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py + if "WORLD_SIZE" not in os.environ: + raise RuntimeError("This script must be launched with `torchrun`.") + + parser = argparse.ArgumentParser(description="Train the Kronos predictor with torchrun.") + parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.") + args = parser.parse_args() + + config_instance = Config(config_file=args.config_file) + main(config_instance.to_dict()) diff --git a/finetune/train_tokenizer.py b/finetune/train_tokenizer.py index 2fe28cf1..2b22cff3 100644 --- a/finetune/train_tokenizer.py +++ b/finetune/train_tokenizer.py @@ -12,7 +12,10 @@ from torch.utils.data.distributed import DistributedSampler from torch.nn.parallel import DistributedDataParallel as DDP -import comet_ml +try: + import comet_ml +except ImportError: + comet_ml = None # Ensure project root is in path sys.path.append("../") @@ -29,7 +32,7 @@ ) -def create_dataloaders(config: dict, rank: int, world_size: int): +def create_dataloaders(config: dict, rank: int, world_size: int): """ Creates and returns distributed dataloaders for training and validation. @@ -42,8 +45,9 @@ def create_dataloaders(config: dict, rank: int, world_size: int): tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset). """ print(f"[Rank {rank}] Creating distributed dataloaders...") - train_dataset = QlibDataset('train') - valid_dataset = QlibDataset('val') + dataset_config = Config(overrides=config) + train_dataset = QlibDataset('train', config=dataset_config) + valid_dataset = QlibDataset('val', config=dataset_config) print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}") train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True) @@ -234,16 +238,18 @@ def main(config: dict): 'save_directory': save_dir, 'world_size': world_size, } - if config['use_comet']: - comet_logger = comet_ml.Experiment( - api_key=config['comet_config']['api_key'], - project_name=config['comet_config']['project_name'], - workspace=config['comet_config']['workspace'], - ) - comet_logger.add_tag(config['comet_tag']) - comet_logger.set_name(config['comet_name']) - comet_logger.log_parameters(config) - print("Comet Logger Initialized.") + if config['use_comet'] and comet_ml is not None: + comet_logger = comet_ml.Experiment( + api_key=config['comet_config']['api_key'], + project_name=config['comet_config']['project_name'], + workspace=config['comet_config']['workspace'], + ) + comet_logger.add_tag(config['comet_tag']) + comet_logger.set_name(config['comet_name']) + comet_logger.log_parameters(config) + print("Comet Logger Initialized.") + elif config['use_comet']: + print("comet_ml is not installed; continuing without experiment logging.") dist.barrier() # Ensure save directory is created before proceeding @@ -272,10 +278,14 @@ def main(config: dict): cleanup_ddp() -if __name__ == '__main__': - # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py - if "WORLD_SIZE" not in os.environ: - raise RuntimeError("This script must be launched with `torchrun`.") - - config_instance = Config() - main(config_instance.__dict__) +if __name__ == '__main__': + # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py + if "WORLD_SIZE" not in os.environ: + raise RuntimeError("This script must be launched with `torchrun`.") + + parser = argparse.ArgumentParser(description="Train the Kronos tokenizer with torchrun.") + parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.") + args = parser.parse_args() + + config_instance = Config(config_file=args.config_file) + main(config_instance.to_dict()) diff --git a/webui/README.md b/webui/README.md index 571d93eb..a3ac2243 100644 --- a/webui/README.md +++ b/webui/README.md @@ -14,6 +14,19 @@ Web user interface for Kronos financial prediction model, providing intuitive gr ## ๐Ÿš€ Quick Start +Install the Web UI dependencies from the `webui` directory before starting: + +```bash +pip install -r requirements.txt +``` + +If you use a shared environment from the repository root, install the root requirements first and then install the Web UI extras: + +```bash +pip install -r ../requirements.txt +pip install -r requirements.txt +``` + ### Method 1: Start with Python script ```bash cd webui @@ -112,7 +125,7 @@ The system automatically provides comparison analysis between prediction results ### Common Issues 1. **Port occupied**: Modify port number in app.py -2. **Missing dependencies**: Run `pip install -r requirements.txt` +2. **Missing dependencies**: Run `pip install -r requirements.txt` from the `webui` directory after installing the root project requirements 3. **Model loading failed**: Check network connection and model ID 4. **Data format error**: Ensure data column names and format are correct diff --git a/webui/app.py b/webui/app.py index d240a372..886e02d9 100644 --- a/webui/app.py +++ b/webui/app.py @@ -59,10 +59,17 @@ def load_data_files(): """Scan data directory and return available data files""" - data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data') data_files = [] - - if os.path.exists(data_dir): + + project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + candidate_dirs = [ + os.path.join(project_root, 'data'), + os.path.join(project_root, 'examples', 'data'), + ] + + for data_dir in candidate_dirs: + if not os.path.exists(data_dir): + continue for file in os.listdir(data_dir): if file.endswith(('.csv', '.feather')): file_path = os.path.join(data_dir, file) @@ -70,6 +77,7 @@ def load_data_files(): data_files.append({ 'name': file, 'path': file_path, + 'source_dir': os.path.relpath(data_dir, project_root), 'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB" })