diff --git a/EXPERIMENTS.md b/EXPERIMENTS.md
new file mode 100644
index 00000000..fddf4367
--- /dev/null
+++ b/EXPERIMENTS.md
@@ -0,0 +1,82 @@
+# Experiments Overview
+
+This repository currently contains four experiment categories.
+
+## 1. Prediction Demos
+
+Location: `examples/`
+
+Purpose:
+- Verify that pretrained Kronos models can run end-to-end inference.
+- Compare model sizes under the same inputs.
+
+Typical inputs:
+- Local CSV K-line data
+- Optional online A-share data from `akshare`
+
+Typical outputs:
+- Forecast CSV
+- Plot PNG
+- Summary JSON
+
+Ground-truth availability:
+- `prediction_example.py`: no future truth, visualization only
+- `prediction_wo_vol_example.py`: no future truth, visualization only
+- `prediction_batch_example.py`: no future truth, visualization only
+- `prediction_cn_markets_day.py --mode eval`: includes future truth and error metrics
+
+## 2. Regression Tests
+
+Location: `tests/test_kronos_regression.py`
+
+Purpose:
+- Ensure code changes do not silently drift model outputs.
+- Check exact regression fixtures and MSE thresholds against fixed model revisions.
+
+Typical outputs:
+- `pytest` pass/fail result
+- Printed absolute/relative differences and MSE summary in the terminal
+
+## 3. Fine-Tuning Pipelines
+
+Locations:
+- `finetune/`
+- `finetune_csv/`
+
+Purpose:
+- Adapt Kronos to domain-specific datasets.
+- Support either Qlib-based A-share workflows or direct CSV-based training.
+
+Notes:
+- The `finetune/` training scripts are currently written for CUDA + NCCL + DDP.
+- The training scripts now treat `comet_ml` as optional.
+
+## 4. Qlib Backtesting
+
+Location: `finetune/qlib_test.py`
+
+Purpose:
+- Convert model forecasts into cross-sectional stock-selection signals.
+- Evaluate those signals with Qlib's `TopkDropoutStrategy`.
+
+Typical inputs:
+- Processed `train/val/test` pickle datasets from `qlib_data_preprocess.py`
+- A tokenizer/model pair, which can be local checkpoints or explicit Hugging Face IDs
+
+Typical outputs:
+- `predictions.pkl`
+- Per-signal curve CSV
+- Per-signal metrics JSON
+- Backtest summary CSV/JSON
+- Backtest plot PNG
+
+Ground-truth availability:
+- Yes, via historical replay using the test split
+
+## Recommended Validation Sequence
+
+1. Run `examples/` to verify inference and output formats.
+2. Run `tests/test_kronos_regression.py` to detect output drift.
+3. Start `webui` to verify the interactive inference path.
+4. Run `finetune/qlib_test.py` with an explicit config file for smoke backtesting.
+5. Compare `Kronos-small` vs `Kronos-base` on the same prediction and backtest setups.
diff --git a/README.md b/README.md
index faffa157..556f4adc 100644
--- a/README.md
+++ b/README.md
@@ -211,10 +211,14 @@ Running this script will generate a plot comparing the ground truth data against
-Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py).
-
-
-## ๐ง Finetuning on Your Own Data (A-Share Market Example)
+Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py).
+
+For a concise guide to all example scripts, recommended run order, Apple Silicon device selection, and optional `akshare` setup, see [`examples/README.md`](examples/README.md).
+
+For an overview of the repository's prediction demos, regression tests, finetuning pipelines, and Qlib backtesting workflows, see [`EXPERIMENTS.md`](EXPERIMENTS.md).
+
+
+## ๐ง Finetuning on Your Own Data (A-Share Market Example)
We provide a complete pipeline for finetuning Kronos on your own datasets. As an example, we demonstrate how to use [Qlib](https://github.com/microsoft/qlib) to prepare data from the Chinese A-share market and conduct a simple backtest.
diff --git a/examples/README.md b/examples/README.md
new file mode 100644
index 00000000..9eb04e8d
--- /dev/null
+++ b/examples/README.md
@@ -0,0 +1,62 @@
+# Example Scripts
+
+All scripts in this directory can be run from the repository root and save their outputs to `examples/outputs/`.
+
+## Recommended Order
+
+1. `python examples/prediction_example.py --model small`
+2. `python examples/prediction_wo_vol_example.py --model small`
+3. `python examples/prediction_batch_example.py --model small`
+4. `python examples/prediction_cn_markets_day.py --model small --mode forward --symbol 000001`
+5. `python examples/prediction_cn_markets_day.py --model small --mode eval --symbol 000001`
+
+## Model Selection
+
+The examples support:
+
+- `--model small`
+- `--model base`
+
+This keeps the input data and evaluation mode fixed while allowing direct `Kronos-small` vs `Kronos-base` comparisons.
+
+## Script Semantics
+
+- `prediction_example.py`
+ Runs one OHLCV forecast on a single historical window and saves CSV, PNG, and summary JSON.
+- `prediction_wo_vol_example.py`
+ Runs one forecast without `volume`/`amount` inputs and saves CSV, PNG, and summary JSON.
+- `prediction_batch_example.py`
+ Demonstrates batch inference on multiple sequential time windows from the same instrument.
+ It is not a multi-asset cross-sectional example.
+ By default it saves CSV for every window and one representative PNG.
+- `prediction_cn_markets_day.py`
+ Supports two explicit modes:
+ - `forward`: future prediction only, so there is no ground-truth window yet.
+ - `eval`: historical replay mode with known future truth and saved error metrics.
+
+## Outputs
+
+Each run saves structured outputs in `examples/outputs/`:
+
+- Forecast CSV files
+- Plot PNG files
+- Summary JSON files describing the model, device, time ranges, and output paths
+
+## Apple Silicon Devices
+
+The scripts auto-select a compute device in this order:
+
+1. `cuda:0`
+2. `mps`
+3. `cpu`
+
+On Apple Silicon Macs, `mps` is used when PyTorch reports that Metal is available.
+
+## Dependencies
+
+- Local CSV examples only need the root `requirements.txt`.
+- `prediction_cn_markets_day.py` also requires `akshare` and network access:
+
+```bash
+uv pip install --python .venv/bin/python akshare
+```
diff --git a/examples/common.py b/examples/common.py
new file mode 100644
index 00000000..988d0097
--- /dev/null
+++ b/examples/common.py
@@ -0,0 +1,82 @@
+import json
+from pathlib import Path
+import sys
+from typing import Any
+
+import matplotlib
+import pandas as pd
+import torch
+
+matplotlib.use("Agg")
+
+EXAMPLES_ROOT = Path(__file__).resolve().parent
+PROJECT_ROOT = EXAMPLES_ROOT.parent
+DATA_ROOT = EXAMPLES_ROOT / "data"
+OUTPUT_ROOT = EXAMPLES_ROOT / "outputs"
+OUTPUT_ROOT.mkdir(exist_ok=True)
+
+MODEL_SPECS = {
+ "small": {
+ "name": "Kronos-small",
+ "model_id": "NeoQuasar/Kronos-small",
+ "tokenizer_id": "NeoQuasar/Kronos-Tokenizer-base",
+ "max_context": 512,
+ },
+ "base": {
+ "name": "Kronos-base",
+ "model_id": "NeoQuasar/Kronos-base",
+ "tokenizer_id": "NeoQuasar/Kronos-Tokenizer-base",
+ "max_context": 512,
+ },
+}
+
+if str(PROJECT_ROOT) not in sys.path:
+ sys.path.insert(0, str(PROJECT_ROOT))
+
+
+def data_path(filename: str) -> Path:
+ return DATA_ROOT / filename
+
+
+def output_path(filename: str) -> Path:
+ return OUTPUT_ROOT / filename
+
+
+def detect_device() -> str:
+ if torch.cuda.is_available():
+ return "cuda:0"
+ if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
+ return "mps"
+ return "cpu"
+
+
+def get_model_spec(model_key: str) -> dict[str, Any]:
+ if model_key not in MODEL_SPECS:
+ raise ValueError(f"Unsupported model '{model_key}'. Expected one of: {sorted(MODEL_SPECS)}")
+ return MODEL_SPECS[model_key]
+
+
+def save_json(path: Path, payload: dict[str, Any]) -> None:
+ with path.open("w", encoding="utf-8") as handle:
+ json.dump(to_jsonable(payload), handle, indent=2, ensure_ascii=False)
+
+
+def to_jsonable(value: Any) -> Any:
+ if isinstance(value, Path):
+ try:
+ return str(value.relative_to(PROJECT_ROOT))
+ except ValueError:
+ return str(value)
+ if isinstance(value, pd.Timestamp):
+ return value.isoformat()
+ if isinstance(value, pd.Series):
+ return [to_jsonable(item) for item in value.tolist()]
+ if isinstance(value, pd.DataFrame):
+ return [{key: to_jsonable(item) for key, item in row.items()} for row in value.to_dict(orient="records")]
+ if isinstance(value, dict):
+ return {str(key): to_jsonable(item) for key, item in value.items()}
+ if isinstance(value, (list, tuple)):
+ return [to_jsonable(item) for item in value]
+ if hasattr(value, "item"):
+ return value.item()
+ return value
diff --git a/examples/prediction_batch_example.py b/examples/prediction_batch_example.py
index 29a74334..d77010ec 100644
--- a/examples/prediction_batch_example.py
+++ b/examples/prediction_batch_example.py
@@ -1,72 +1,176 @@
-import pandas as pd
-import matplotlib.pyplot as plt
-import sys
-sys.path.append("../")
-from model import Kronos, KronosTokenizer, KronosPredictor
-
-
-def plot_prediction(kline_df, pred_df):
- pred_df.index = kline_df.index[-pred_df.shape[0]:]
- sr_close = kline_df['close']
- sr_pred_close = pred_df['close']
- sr_close.name = 'Ground Truth'
- sr_pred_close.name = "Prediction"
+import argparse
+import time
+from pathlib import Path
- sr_volume = kline_df['volume']
- sr_pred_volume = pred_df['volume']
- sr_volume.name = 'Ground Truth'
- sr_pred_volume.name = "Prediction"
+import matplotlib.pyplot as plt
+import pandas as pd
- close_df = pd.concat([sr_close, sr_pred_close], axis=1)
- volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
+from common import data_path, detect_device, get_model_spec, output_path, save_json
+from model import Kronos, KronosPredictor, KronosTokenizer
+
+LOOKBACK = 400
+PRED_LEN = 120
+DEFAULT_BATCH_SIZE = 5
+
+
+def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame, plot_file: Path) -> None:
+ plot_df = pred_df.copy()
+ plot_df.index = kline_df.index[-plot_df.shape[0]:]
+
+ close_df = pd.concat(
+ [
+ kline_df["close"].rename("Ground Truth"),
+ plot_df["close"].rename("Prediction"),
+ ],
+ axis=1,
+ )
+ volume_df = pd.concat(
+ [
+ kline_df["volume"].rename("Ground Truth"),
+ plot_df["volume"].rename("Prediction"),
+ ],
+ axis=1,
+ )
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
-
- ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
- ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
- ax1.set_ylabel('Close Price', fontsize=14)
- ax1.legend(loc='lower left', fontsize=12)
+ ax1.plot(close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5)
+ ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
+ ax1.set_ylabel("Close Price", fontsize=14)
+ ax1.legend(loc="lower left", fontsize=12)
ax1.grid(True)
- ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
- ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
- ax2.set_ylabel('Volume', fontsize=14)
- ax2.legend(loc='upper left', fontsize=12)
+ ax2.plot(volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5)
+ ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
+ ax2.set_ylabel("Volume", fontsize=14)
+ ax2.legend(loc="upper left", fontsize=12)
ax2.grid(True)
plt.tight_layout()
- plt.show()
-
-
-# 1. Load Model and Tokenizer
-tokenizer = KronosTokenizer.from_pretrained('/home/csc/huggingface/Kronos-Tokenizer-base/')
-model = Kronos.from_pretrained("/home/csc/huggingface/Kronos-base/")
-
-# 2. Instantiate Predictor
-predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
-
-# 3. Prepare Data
-df = pd.read_csv("./data/XSHG_5min_600977.csv")
-df['timestamps'] = pd.to_datetime(df['timestamps'])
-
-lookback = 400
-pred_len = 120
-
-dfs = []
-xtsp = []
-ytsp = []
-for i in range(5):
- idf = df.loc[(i*400):(i*400+lookback-1), ['open', 'high', 'low', 'close', 'volume', 'amount']]
- i_x_timestamp = df.loc[(i*400):(i*400+lookback-1), 'timestamps']
- i_y_timestamp = df.loc[(i*400+lookback):(i*400+lookback+pred_len-1), 'timestamps']
-
- dfs.append(idf)
- xtsp.append(i_x_timestamp)
- ytsp.append(i_y_timestamp)
-
-pred_df = predictor.predict_batch(
- df_list=dfs,
- x_timestamp_list=xtsp,
- y_timestamp_list=ytsp,
- pred_len=pred_len,
-)
+ plt.savefig(plot_file, dpi=150)
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Run batch prediction on multiple time windows from the same instrument."
+ )
+ parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.")
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=DEFAULT_BATCH_SIZE,
+ help="Number of sequential time windows to batch together.",
+ )
+ parser.add_argument(
+ "--plot-series",
+ type=int,
+ default=0,
+ help="Which window index to render as the representative plot.",
+ )
+ args = parser.parse_args()
+
+ spec = get_model_spec(args.model)
+ device = detect_device()
+ data_file = data_path("XSHG_5min_600977.csv")
+ summary_file = output_path(f"prediction_batch_example_{args.model}_summary.json")
+
+ print(
+ "Running batch prediction on multiple sequential windows from the same source file, "
+ "not on multiple instruments."
+ )
+ print(f"Loading {spec['name']} on device: {device}")
+ start_time = time.perf_counter()
+ tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"])
+ model = Kronos.from_pretrained(spec["model_id"])
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=spec["max_context"])
+
+ df = pd.read_csv(data_file)
+ df["timestamps"] = pd.to_datetime(df["timestamps"])
+
+ df_list = []
+ x_timestamp_list = []
+ y_timestamp_list = []
+ window_summaries = []
+
+ for idx in range(args.batch_size):
+ start = idx * LOOKBACK
+ end = start + LOOKBACK - 1
+ pred_start = end + 1
+ pred_end = pred_start + PRED_LEN - 1
+
+ df_list.append(df.loc[start:end, ["open", "high", "low", "close", "volume", "amount"]])
+ x_timestamp_list.append(df.loc[start:end, "timestamps"])
+ y_timestamp_list.append(df.loc[pred_start:pred_end, "timestamps"])
+ window_summaries.append(
+ {
+ "series_index": idx,
+ "instrument_note": "Same instrument, different historical time window.",
+ "history_row_start": start,
+ "history_row_end": end,
+ "forecast_row_start": pred_start,
+ "forecast_row_end": pred_end,
+ "history_start": df.loc[start, "timestamps"],
+ "history_end": df.loc[end, "timestamps"],
+ "forecast_start": df.loc[pred_start, "timestamps"],
+ "forecast_end": df.loc[pred_end, "timestamps"],
+ }
+ )
+
+ pred_df_list = predictor.predict_batch(
+ df_list=df_list,
+ x_timestamp_list=x_timestamp_list,
+ y_timestamp_list=y_timestamp_list,
+ pred_len=PRED_LEN,
+ verbose=False,
+ )
+ elapsed_seconds = time.perf_counter() - start_time
+
+ print(f"Generated batch predictions: {len(pred_df_list)} windows")
+ representative_plot = None
+
+ for idx, pred_df in enumerate(pred_df_list):
+ csv_file = output_path(f"prediction_batch_example_{args.model}_series_{idx}.csv")
+ pred_df.to_csv(csv_file, index_label="timestamps")
+ window_summaries[idx]["forecast_csv"] = csv_file
+
+ print(f"Series {idx} forecast head:")
+ print(pred_df.head())
+ print(f"Saved series {idx} CSV: {csv_file}")
+
+ if idx == args.plot_series:
+ plot_file = output_path(f"prediction_batch_example_{args.model}_representative_series_{idx}.png")
+ kline_df = df.loc[idx * LOOKBACK : idx * LOOKBACK + LOOKBACK + PRED_LEN - 1]
+ plot_prediction(kline_df, pred_df, plot_file)
+ representative_plot = plot_file
+ window_summaries[idx]["representative_plot"] = plot_file
+ print(f"Saved representative plot for series {idx}: {plot_file}")
+
+ save_json(
+ summary_file,
+ {
+ "experiment": "prediction_batch_example",
+ "description": (
+ "Batch prediction on multiple sequential time windows from a single instrument. "
+ "This example demonstrates parallel inference, not cross-sectional multi-asset prediction."
+ ),
+ "model_key": args.model,
+ "model_name": spec["name"],
+ "model_id": spec["model_id"],
+ "tokenizer_id": spec["tokenizer_id"],
+ "device": device,
+ "elapsed_seconds": elapsed_seconds,
+ "lookback": LOOKBACK,
+ "pred_len": PRED_LEN,
+ "batch_size": args.batch_size,
+ "plot_series": args.plot_series,
+ "data_file": data_file,
+ "representative_plot": representative_plot,
+ "windows": window_summaries,
+ },
+ )
+
+ print(f"Saved summary JSON: {summary_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/prediction_cn_markets_day.py b/examples/prediction_cn_markets_day.py
index 93b28f9f..1dd50a80 100644
--- a/examples/prediction_cn_markets_day.py
+++ b/examples/prediction_cn_markets_day.py
@@ -2,170 +2,207 @@
"""
prediction_cn_markets_day.py
-Description:
- Predicts future daily K-line (1D) data for A-share markets using Kronos model and akshare.
- The script automatically downloads the latest historical data, cleans it, and runs model inference.
-
-Usage:
- python prediction_cn_markets_day.py --symbol 000001
-
-Arguments:
- --symbol Stock code (e.g. 002594 for BYD, 000001 for SSE Index)
-
-Output:
- - Saves the prediction results to ./outputs/pred__data.csv and ./outputs/pred__chart.png
- - Logs and progress are printed to console
-
-Example:
- bash> python prediction_cn_markets_day.py --symbol 000001
- python3 prediction_cn_markets_day.py --symbol 002594
+Two modes are supported:
+1. forward: predict the next unknown window from the latest available history.
+2. eval: replay a historical cutoff and compare the forecast against known future data.
"""
-import os
import argparse
import time
-import pandas as pd
-import akshare as ak
+from pathlib import Path
+
import matplotlib.pyplot as plt
-import sys
-sys.path.append("../")
-from model import Kronos, KronosTokenizer, KronosPredictor
+import pandas as pd
+
+from common import OUTPUT_ROOT, detect_device, get_model_spec, save_json
+from model import Kronos, KronosPredictor, KronosTokenizer
-save_dir = "./outputs"
-os.makedirs(save_dir, exist_ok=True)
+try:
+ import akshare as ak
+except ImportError as exc:
+ raise ImportError(
+ "akshare is required for prediction_cn_markets_day.py. Install it with "
+ "`uv pip install --python .venv/bin/python akshare`."
+ ) from exc
-# Setting
-TOKENIZER_PRETRAINED = "NeoQuasar/Kronos-Tokenizer-base"
-MODEL_PRETRAINED = "NeoQuasar/Kronos-base"
-DEVICE = "cpu" # "cuda:0"
MAX_CONTEXT = 512
LOOKBACK = 400
PRED_LEN = 120
T = 1.0
TOP_P = 0.9
SAMPLE_COUNT = 1
+PLOT_COLUMNS = ["close"]
+METRIC_COLUMNS = ["open", "high", "low", "close"]
+
def load_data(symbol: str) -> pd.DataFrame:
- print(f"๐ฅ Fetching {symbol} daily data from akshare ...")
+ print(f"Fetching {symbol} daily data from akshare...")
max_retries = 3
df = None
-
- # Retry mechanism
for attempt in range(1, max_retries + 1):
try:
df = ak.stock_zh_a_hist(symbol=symbol, period="daily", adjust="")
if df is not None and not df.empty:
break
- except Exception as e:
- print(f"โ ๏ธ Attempt {attempt}/{max_retries} failed: {e}")
+ except Exception as exc:
+ print(f"Attempt {attempt}/{max_retries} failed: {exc}")
time.sleep(1.5)
- # If still empty after retries
if df is None or df.empty:
- print(f"โ Failed to fetch data for {symbol} after {max_retries} attempts. Exiting.")
- sys.exit(1)
-
- df.rename(columns={
- "ๆฅๆ": "date",
- "ๅผ็": "open",
- "ๆถ็": "close",
- "ๆ้ซ": "high",
- "ๆไฝ": "low",
- "ๆไบค้": "volume",
- "ๆไบค้ข": "amount"
- }, inplace=True)
+ raise RuntimeError(f"Failed to fetch data for {symbol} after {max_retries} attempts.")
+
+ df.rename(
+ columns={
+ "ๆฅๆ": "date",
+ "ๅผ็": "open",
+ "ๆถ็": "close",
+ "ๆ้ซ": "high",
+ "ๆไฝ": "low",
+ "ๆไบค้": "volume",
+ "ๆไบค้ข": "amount",
+ },
+ inplace=True,
+ )
df["date"] = pd.to_datetime(df["date"])
df = df.sort_values("date").reset_index(drop=True)
- # Convert numeric columns
numeric_cols = ["open", "high", "low", "close", "volume", "amount"]
for col in numeric_cols:
df[col] = (
- df[col]
- .astype(str)
- .str.replace(",", "", regex=False)
- .replace({"--": None, "": None})
+ df[col].astype(str).str.replace(",", "", regex=False).replace({"--": None, "": None})
)
df[col] = pd.to_numeric(df[col], errors="coerce")
- # Fix invalid open values
open_bad = (df["open"] == 0) | (df["open"].isna())
if open_bad.any():
- print(f"โ ๏ธ Fixed {open_bad.sum()} invalid open values.")
+ print(f"Fixed {open_bad.sum()} invalid open values.")
df.loc[open_bad, "open"] = df["close"].shift(1)
- df["open"].fillna(df["close"], inplace=True)
+ df["open"] = df["open"].fillna(df["close"])
- # Fix missing amount
if df["amount"].isna().all() or (df["amount"] == 0).all():
df["amount"] = df["close"] * df["volume"]
- print(f"โ
Data loaded: {len(df)} rows, range: {df['date'].min()} ~ {df['date'].max()}")
-
- print("Data Head:")
+ print(f"Loaded {len(df)} rows from {df['date'].min()} to {df['date'].max()}")
print(df.head())
-
return df
-def prepare_inputs(df):
- x_df = df.iloc[-LOOKBACK:][["open","high","low","close","volume","amount"]]
- x_timestamp = df.iloc[-LOOKBACK:]["date"]
- y_timestamp = pd.bdate_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN)
- return x_df, pd.Series(x_timestamp), pd.Series(y_timestamp)
+def resolve_eval_end_index(df: pd.DataFrame, as_of_date: str | None) -> int:
+ min_end_idx = LOOKBACK - 1
+ max_end_idx = len(df) - PRED_LEN - 1
+ if max_end_idx < min_end_idx:
+ raise ValueError("Not enough rows to run historical evaluation with the requested lookback and pred_len.")
+
+ if as_of_date is None:
+ return max_end_idx
-def apply_price_limits(pred_df, last_close, limit_rate=0.1):
- print(f"๐ Applying ยฑ{limit_rate*100:.0f}% price limit ...")
+ target = pd.Timestamp(as_of_date)
+ eligible = df.index[df["date"] <= target]
+ if len(eligible) == 0:
+ raise ValueError(f"No data is available on or before {target.date()}.")
+
+ end_idx = int(eligible[-1])
+ if end_idx < min_end_idx:
+ raise ValueError(f"Evaluation cutoff {df.loc[end_idx, 'date'].date()} is before the required lookback.")
+ if end_idx > max_end_idx:
+ raise ValueError(
+ f"Evaluation cutoff {df.loc[end_idx, 'date'].date()} leaves no future truth window of length {PRED_LEN}."
+ )
+ return end_idx
- # Ensure integer index
- pred_df = pred_df.reset_index(drop=True)
- # Ensure float64 dtype for safe assignment
+def prepare_forward_inputs(df: pd.DataFrame) -> tuple[pd.DataFrame, pd.Series, pd.Series]:
+ x_df = df.iloc[-LOOKBACK:][["open", "high", "low", "close", "volume", "amount"]]
+ x_timestamp = pd.Series(df.iloc[-LOOKBACK:]["date"])
+ y_timestamp = pd.Series(
+ pd.bdate_range(start=df["date"].iloc[-1] + pd.Timedelta(days=1), periods=PRED_LEN)
+ )
+ return x_df, x_timestamp, y_timestamp
+
+
+def prepare_eval_inputs(df: pd.DataFrame, as_of_date: str | None) -> tuple[pd.DataFrame, pd.Series, pd.Series, pd.DataFrame, pd.Timestamp]:
+ end_idx = resolve_eval_end_index(df, as_of_date)
+ x_df = df.iloc[end_idx - LOOKBACK + 1 : end_idx + 1][["open", "high", "low", "close", "volume", "amount"]]
+ x_timestamp = pd.Series(df.iloc[end_idx - LOOKBACK + 1 : end_idx + 1]["date"])
+ actual_df = df.iloc[end_idx + 1 : end_idx + 1 + PRED_LEN][["date", "open", "high", "low", "close", "volume", "amount"]].reset_index(drop=True)
+ y_timestamp = pd.Series(actual_df["date"])
+ return x_df, x_timestamp, y_timestamp, actual_df, df.loc[end_idx, "date"]
+
+
+def apply_price_limits(pred_df: pd.DataFrame, last_close: float, limit_rate: float = 0.1) -> pd.DataFrame:
+ clipped_df = pred_df.reset_index(drop=True).copy()
cols = ["open", "high", "low", "close"]
- pred_df[cols] = pred_df[cols].astype("float64")
+ clipped_df[cols] = clipped_df[cols].astype("float64")
- for i in range(len(pred_df)):
+ for idx in range(len(clipped_df)):
limit_up = last_close * (1 + limit_rate)
limit_down = last_close * (1 - limit_rate)
-
for col in cols:
- value = pred_df.at[i, col]
+ value = clipped_df.at[idx, col]
if pd.notna(value):
- clipped = max(min(value, limit_up), limit_down)
- pred_df.at[i, col] = float(clipped)
+ clipped_df.at[idx, col] = float(max(min(value, limit_up), limit_down))
+ last_close = float(clipped_df.at[idx, "close"])
- last_close = float(pred_df.at[i, "close"]) # ensure float type
+ return clipped_df
- return pred_df
+def compute_eval_metrics(pred_df: pd.DataFrame, actual_df: pd.DataFrame) -> dict[str, float]:
+ metrics: dict[str, float] = {}
+ for col in METRIC_COLUMNS:
+ diff = pred_df[col] - actual_df[col]
+ metrics[f"{col}_mae"] = float(diff.abs().mean())
+ metrics[f"{col}_mse"] = float((diff ** 2).mean())
+ return metrics
-def plot_result(df_hist, df_pred, symbol):
+
+def plot_result(
+ df_hist: pd.DataFrame,
+ df_pred: pd.DataFrame,
+ symbol: str,
+ plot_path: Path,
+ actual_df: pd.DataFrame | None = None,
+) -> None:
plt.figure(figsize=(12, 6))
plt.plot(df_hist["date"], df_hist["close"], label="Historical", color="blue")
plt.plot(df_pred["date"], df_pred["close"], label="Predicted", color="red", linestyle="--")
+ if actual_df is not None:
+ plt.plot(actual_df["date"], actual_df["close"], label="Actual Future", color="green", linestyle=":")
plt.title(f"Kronos Prediction for {symbol}")
plt.xlabel("Date")
plt.ylabel("Close Price")
plt.legend()
plt.grid(True)
plt.tight_layout()
- plot_path = os.path.join(save_dir, f"pred_{symbol.replace('.', '_')}_chart.png")
- plt.savefig(plot_path)
+ plt.savefig(plot_path, dpi=150)
plt.close()
- print(f"๐ Chart saved: {plot_path}")
-def predict_future(symbol):
- print(f"๐ Loading Kronos tokenizer:{TOKENIZER_PRETRAINED} model:{MODEL_PRETRAINED} ...")
- tokenizer = KronosTokenizer.from_pretrained(TOKENIZER_PRETRAINED)
- model = Kronos.from_pretrained(MODEL_PRETRAINED)
- predictor = KronosPredictor(model, tokenizer, device=DEVICE, max_context=MAX_CONTEXT)
+def run_prediction(symbol: str, model_key: str, mode: str, as_of_date: str | None) -> None:
+ spec = get_model_spec(model_key)
+ device = detect_device()
+ prefix = f"pred_{symbol.replace('.', '_')}_{model_key}_{mode}"
+ csv_path = OUTPUT_ROOT / f"{prefix}_data.csv"
+ plot_path = OUTPUT_ROOT / f"{prefix}_chart.png"
+ summary_path = OUTPUT_ROOT / f"{prefix}_summary.json"
+
+ print(f"Loading {spec['name']} on device: {device}")
+ start_time = time.perf_counter()
+ tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"])
+ model = Kronos.from_pretrained(spec["model_id"])
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=MAX_CONTEXT)
df = load_data(symbol)
- x_df, x_timestamp, y_timestamp = prepare_inputs(df)
+ actual_df = None
+ eval_cutoff = None
- print("๐ฎ Generating predictions ...")
+ if mode == "forward":
+ x_df, x_timestamp, y_timestamp = prepare_forward_inputs(df)
+ history_df = df.copy()
+ else:
+ x_df, x_timestamp, y_timestamp, actual_df, eval_cutoff = prepare_eval_inputs(df, as_of_date)
+ history_df = df[df["date"] <= eval_cutoff].copy()
pred_df = predictor.predict(
df=x_df,
@@ -175,34 +212,88 @@ def predict_future(symbol):
T=T,
top_p=TOP_P,
sample_count=SAMPLE_COUNT,
+ verbose=False,
)
+ elapsed_seconds = time.perf_counter() - start_time
pred_df["date"] = y_timestamp.values
+ pred_df = apply_price_limits(pred_df, float(x_df["close"].iloc[-1]), limit_rate=0.1)
+
+ if actual_df is None:
+ output_df = pd.concat(
+ [
+ history_df[["date", "open", "high", "low", "close", "volume", "amount"]],
+ pred_df[["date", "open", "high", "low", "close", "volume", "amount"]],
+ ]
+ ).reset_index(drop=True)
+ metrics = None
+ else:
+ comparison_df = pred_df[["date", "open", "high", "low", "close", "volume", "amount"]].copy()
+ for col in ["open", "high", "low", "close", "volume", "amount"]:
+ comparison_df[f"actual_{col}"] = actual_df[col]
+ output_df = comparison_df
+ metrics = compute_eval_metrics(pred_df, actual_df)
+
+ output_df.to_csv(csv_path, index=False)
+ plot_result(history_df, pred_df, symbol, plot_path, actual_df=actual_df)
+
+ save_json(
+ summary_path,
+ {
+ "experiment": "prediction_cn_markets_day",
+ "description": (
+ "Forward mode predicts unseen future dates. Eval mode replays a historical cutoff and compares "
+ "the forecast against known future truth."
+ ),
+ "symbol": symbol,
+ "mode": mode,
+ "has_ground_truth": actual_df is not None,
+ "evaluation_cutoff": eval_cutoff,
+ "model_key": model_key,
+ "model_name": spec["name"],
+ "model_id": spec["model_id"],
+ "tokenizer_id": spec["tokenizer_id"],
+ "device": device,
+ "elapsed_seconds": elapsed_seconds,
+ "lookback": LOOKBACK,
+ "pred_len": PRED_LEN,
+ "history_start": x_timestamp.iloc[0],
+ "history_end": x_timestamp.iloc[-1],
+ "prediction_start": y_timestamp.iloc[0],
+ "prediction_end": y_timestamp.iloc[-1],
+ "output_csv": csv_path,
+ "plot_file": plot_path,
+ "metrics": metrics,
+ "prediction_head": pred_df.head(),
+ "actual_head": actual_df.head() if actual_df is not None else None,
+ },
+ )
- # Apply ยฑ10% price limit
- last_close = df["close"].iloc[-1]
- pred_df = apply_price_limits(pred_df, last_close, limit_rate=0.1)
-
- # Merge historical and predicted data
- df_out = pd.concat([
- df[["date", "open", "high", "low", "close", "volume", "amount"]],
- pred_df[["date", "open", "high", "low", "close", "volume", "amount"]]
- ]).reset_index(drop=True)
-
- # Save CSV
- out_file = os.path.join(save_dir, f"pred_{symbol.replace('.', '_')}_data.csv")
- df_out.to_csv(out_file, index=False)
- print(f"โ
Prediction completed and saved: {out_file}")
-
- # Plot
- plot_result(df, pred_df, symbol)
+ print(f"Saved forecast CSV: {csv_path}")
+ print(f"Saved forecast plot: {plot_path}")
+ print(f"Saved summary JSON: {summary_path}")
+ if metrics is not None:
+ print("Evaluation metrics:")
+ for key, value in metrics.items():
+ print(f" {key}: {value:.6f}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Kronos stock prediction script")
parser.add_argument("--symbol", type=str, default="000001", help="Stock code")
+ parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.")
+ parser.add_argument(
+ "--mode",
+ choices=["forward", "eval"],
+ default="forward",
+ help="forward predicts unknown future dates, eval compares against known future truth.",
+ )
+ parser.add_argument(
+ "--as-of-date",
+ type=str,
+ default=None,
+ help="Historical cutoff date for eval mode, e.g. 2020-06-30. Defaults to the latest eligible date.",
+ )
args = parser.parse_args()
- predict_future(
- symbol=args.symbol,
- )
+ run_prediction(symbol=args.symbol, model_key=args.model, mode=args.mode, as_of_date=args.as_of_date)
diff --git a/examples/prediction_example.py b/examples/prediction_example.py
index 880f22b8..312ccae0 100644
--- a/examples/prediction_example.py
+++ b/examples/prediction_example.py
@@ -1,80 +1,127 @@
-import pandas as pd
-import matplotlib.pyplot as plt
-import sys
-sys.path.append("../")
-from model import Kronos, KronosTokenizer, KronosPredictor
-
-
-def plot_prediction(kline_df, pred_df):
- pred_df.index = kline_df.index[-pred_df.shape[0]:]
- sr_close = kline_df['close']
- sr_pred_close = pred_df['close']
- sr_close.name = 'Ground Truth'
- sr_pred_close.name = "Prediction"
-
- sr_volume = kline_df['volume']
- sr_pred_volume = pred_df['volume']
- sr_volume.name = 'Ground Truth'
- sr_pred_volume.name = "Prediction"
-
- close_df = pd.concat([sr_close, sr_pred_close], axis=1)
- volume_df = pd.concat([sr_volume, sr_pred_volume], axis=1)
-
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
-
- ax1.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
- ax1.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
- ax1.set_ylabel('Close Price', fontsize=14)
- ax1.legend(loc='lower left', fontsize=12)
- ax1.grid(True)
-
- ax2.plot(volume_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
- ax2.plot(volume_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
- ax2.set_ylabel('Volume', fontsize=14)
- ax2.legend(loc='upper left', fontsize=12)
- ax2.grid(True)
-
- plt.tight_layout()
- plt.show()
-
-
-# 1. Load Model and Tokenizer
-tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
-model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
-
-# 2. Instantiate Predictor
-predictor = KronosPredictor(model, tokenizer, max_context=512)
-
-# 3. Prepare Data
-df = pd.read_csv("./data/XSHG_5min_600977.csv")
-df['timestamps'] = pd.to_datetime(df['timestamps'])
-
-lookback = 400
-pred_len = 120
-
-x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close', 'volume', 'amount']]
-x_timestamp = df.loc[:lookback-1, 'timestamps']
-y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
-
-# 4. Make Prediction
-pred_df = predictor.predict(
- df=x_df,
- x_timestamp=x_timestamp,
- y_timestamp=y_timestamp,
- pred_len=pred_len,
- T=1.0,
- top_p=0.9,
- sample_count=1,
- verbose=True
-)
-
-# 5. Visualize Results
-print("Forecasted Data Head:")
-print(pred_df.head())
-
-# Combine historical and forecasted data for plotting
-kline_df = df.loc[:lookback+pred_len-1]
-
-# visualize
-plot_prediction(kline_df, pred_df)
-
+import argparse
+import time
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pandas as pd
+
+from common import data_path, detect_device, get_model_spec, output_path, save_json
+from model import Kronos, KronosPredictor, KronosTokenizer
+
+LOOKBACK = 400
+PRED_LEN = 120
+
+
+def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame, plot_file: Path) -> None:
+ plot_df = pred_df.copy()
+ plot_df.index = kline_df.index[-plot_df.shape[0]:]
+
+ close_df = pd.concat(
+ [
+ kline_df["close"].rename("Ground Truth"),
+ plot_df["close"].rename("Prediction"),
+ ],
+ axis=1,
+ )
+ volume_df = pd.concat(
+ [
+ kline_df["volume"].rename("Ground Truth"),
+ plot_df["volume"].rename("Prediction"),
+ ],
+ axis=1,
+ )
+
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6), sharex=True)
+
+ ax1.plot(close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5)
+ ax1.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
+ ax1.set_ylabel("Close Price", fontsize=14)
+ ax1.legend(loc="lower left", fontsize=12)
+ ax1.grid(True)
+
+ ax2.plot(volume_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5)
+ ax2.plot(volume_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
+ ax2.set_ylabel("Volume", fontsize=14)
+ ax2.legend(loc="upper left", fontsize=12)
+ ax2.grid(True)
+
+ plt.tight_layout()
+ plt.savefig(plot_file, dpi=150)
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run a single-series Kronos prediction example.")
+ parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.")
+ args = parser.parse_args()
+
+ spec = get_model_spec(args.model)
+ device = detect_device()
+ data_file = data_path("XSHG_5min_600977.csv")
+ csv_file = output_path(f"prediction_example_{args.model}_forecast.csv")
+ plot_file = output_path(f"prediction_example_{args.model}_plot.png")
+ summary_file = output_path(f"prediction_example_{args.model}_summary.json")
+
+ print(f"Loading {spec['name']} on device: {device}")
+ start_time = time.perf_counter()
+ tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"])
+ model = Kronos.from_pretrained(spec["model_id"])
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=spec["max_context"])
+
+ df = pd.read_csv(data_file)
+ df["timestamps"] = pd.to_datetime(df["timestamps"])
+
+ x_df = df.loc[: LOOKBACK - 1, ["open", "high", "low", "close", "volume", "amount"]]
+ x_timestamp = df.loc[: LOOKBACK - 1, "timestamps"]
+ y_timestamp = df.loc[LOOKBACK : LOOKBACK + PRED_LEN - 1, "timestamps"]
+
+ pred_df = predictor.predict(
+ df=x_df,
+ x_timestamp=x_timestamp,
+ y_timestamp=y_timestamp,
+ pred_len=PRED_LEN,
+ T=1.0,
+ top_p=0.9,
+ sample_count=1,
+ verbose=False,
+ )
+ elapsed_seconds = time.perf_counter() - start_time
+
+ print("Forecasted Data Head:")
+ print(pred_df.head())
+
+ pred_df.to_csv(csv_file, index_label="timestamps")
+ kline_df = df.loc[: LOOKBACK + PRED_LEN - 1]
+ plot_prediction(kline_df, pred_df, plot_file)
+
+ save_json(
+ summary_file,
+ {
+ "experiment": "prediction_example",
+ "description": "Single-series OHLCV prediction on one historical window.",
+ "model_key": args.model,
+ "model_name": spec["name"],
+ "model_id": spec["model_id"],
+ "tokenizer_id": spec["tokenizer_id"],
+ "device": device,
+ "elapsed_seconds": elapsed_seconds,
+ "lookback": LOOKBACK,
+ "pred_len": PRED_LEN,
+ "input_window_start": x_timestamp.iloc[0],
+ "input_window_end": x_timestamp.iloc[-1],
+ "prediction_window_start": y_timestamp.iloc[0],
+ "prediction_window_end": y_timestamp.iloc[-1],
+ "data_file": data_file,
+ "forecast_csv": csv_file,
+ "plot_file": plot_file,
+ "prediction_head": pred_df.head(),
+ },
+ )
+
+ print(f"Saved forecast CSV: {csv_file}")
+ print(f"Saved forecast plot: {plot_file}")
+ print(f"Saved summary JSON: {summary_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/examples/prediction_wo_vol_example.py b/examples/prediction_wo_vol_example.py
index 3e5624e1..2202261c 100644
--- a/examples/prediction_wo_vol_example.py
+++ b/examples/prediction_wo_vol_example.py
@@ -1,68 +1,113 @@
-import pandas as pd
-import matplotlib.pyplot as plt
-import sys
-sys.path.append("../")
-from model import Kronos, KronosTokenizer, KronosPredictor
-
-
-def plot_prediction(kline_df, pred_df):
- pred_df.index = kline_df.index[-pred_df.shape[0]:]
- sr_close = kline_df['close']
- sr_pred_close = pred_df['close']
- sr_close.name = 'Ground Truth'
- sr_pred_close.name = "Prediction"
-
- close_df = pd.concat([sr_close, sr_pred_close], axis=1)
-
- fig, ax = plt.subplots(1, 1, figsize=(8, 4))
-
- ax.plot(close_df['Ground Truth'], label='Ground Truth', color='blue', linewidth=1.5)
- ax.plot(close_df['Prediction'], label='Prediction', color='red', linewidth=1.5)
- ax.set_ylabel('Close Price', fontsize=14)
- ax.legend(loc='lower left', fontsize=12)
- ax.grid(True)
-
- plt.tight_layout()
- plt.show()
-
-
-# 1. Load Model and Tokenizer
-tokenizer = KronosTokenizer.from_pretrained("NeoQuasar/Kronos-Tokenizer-base")
-model = Kronos.from_pretrained("NeoQuasar/Kronos-small")
-
-# 2. Instantiate Predictor
-predictor = KronosPredictor(model, tokenizer, device="cuda:0", max_context=512)
-
-# 3. Prepare Data
-df = pd.read_csv("./data/XSHG_5min_600977.csv")
-df['timestamps'] = pd.to_datetime(df['timestamps'])
-
-lookback = 400
-pred_len = 120
-
-x_df = df.loc[:lookback-1, ['open', 'high', 'low', 'close']]
-x_timestamp = df.loc[:lookback-1, 'timestamps']
-y_timestamp = df.loc[lookback:lookback+pred_len-1, 'timestamps']
-
-# 4. Make Prediction
-pred_df = predictor.predict(
- df=x_df,
- x_timestamp=x_timestamp,
- y_timestamp=y_timestamp,
- pred_len=pred_len,
- T=1.0,
- top_p=0.9,
- sample_count=1,
- verbose=True
-)
-
-# 5. Visualize Results
-print("Forecasted Data Head:")
-print(pred_df.head())
-
-# Combine historical and forecasted data for plotting
-kline_df = df.loc[:lookback+pred_len-1]
-
-# visualize
-plot_prediction(kline_df, pred_df)
-
+import argparse
+import time
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+import pandas as pd
+
+from common import data_path, detect_device, get_model_spec, output_path, save_json
+from model import Kronos, KronosPredictor, KronosTokenizer
+
+LOOKBACK = 400
+PRED_LEN = 120
+
+
+def plot_prediction(kline_df: pd.DataFrame, pred_df: pd.DataFrame, plot_file: Path) -> None:
+ plot_df = pred_df.copy()
+ plot_df.index = kline_df.index[-plot_df.shape[0]:]
+
+ close_df = pd.concat(
+ [
+ kline_df["close"].rename("Ground Truth"),
+ plot_df["close"].rename("Prediction"),
+ ],
+ axis=1,
+ )
+
+ fig, ax = plt.subplots(1, 1, figsize=(8, 4))
+ ax.plot(close_df["Ground Truth"], label="Ground Truth", color="blue", linewidth=1.5)
+ ax.plot(close_df["Prediction"], label="Prediction", color="red", linewidth=1.5)
+ ax.set_ylabel("Close Price", fontsize=14)
+ ax.legend(loc="lower left", fontsize=12)
+ ax.grid(True)
+
+ plt.tight_layout()
+ plt.savefig(plot_file, dpi=150)
+ plt.close(fig)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run an OHLC-only Kronos prediction example.")
+ parser.add_argument("--model", choices=["small", "base"], default="small", help="Model size to run.")
+ args = parser.parse_args()
+
+ spec = get_model_spec(args.model)
+ device = detect_device()
+ data_file = data_path("XSHG_5min_600977.csv")
+ csv_file = output_path(f"prediction_wo_vol_example_{args.model}_forecast.csv")
+ plot_file = output_path(f"prediction_wo_vol_example_{args.model}_plot.png")
+ summary_file = output_path(f"prediction_wo_vol_example_{args.model}_summary.json")
+
+ print(f"Loading {spec['name']} on device: {device}")
+ start_time = time.perf_counter()
+ tokenizer = KronosTokenizer.from_pretrained(spec["tokenizer_id"])
+ model = Kronos.from_pretrained(spec["model_id"])
+ predictor = KronosPredictor(model, tokenizer, device=device, max_context=spec["max_context"])
+
+ df = pd.read_csv(data_file)
+ df["timestamps"] = pd.to_datetime(df["timestamps"])
+
+ x_df = df.loc[: LOOKBACK - 1, ["open", "high", "low", "close"]]
+ x_timestamp = df.loc[: LOOKBACK - 1, "timestamps"]
+ y_timestamp = df.loc[LOOKBACK : LOOKBACK + PRED_LEN - 1, "timestamps"]
+
+ pred_df = predictor.predict(
+ df=x_df,
+ x_timestamp=x_timestamp,
+ y_timestamp=y_timestamp,
+ pred_len=PRED_LEN,
+ T=1.0,
+ top_p=0.9,
+ sample_count=1,
+ verbose=False,
+ )
+ elapsed_seconds = time.perf_counter() - start_time
+
+ print("Forecasted Data Head:")
+ print(pred_df.head())
+
+ pred_df.to_csv(csv_file, index_label="timestamps")
+ kline_df = df.loc[: LOOKBACK + PRED_LEN - 1]
+ plot_prediction(kline_df, pred_df, plot_file)
+
+ save_json(
+ summary_file,
+ {
+ "experiment": "prediction_wo_vol_example",
+ "description": "Single-series OHLC-only prediction. Volume and amount are not provided as model inputs.",
+ "model_key": args.model,
+ "model_name": spec["name"],
+ "model_id": spec["model_id"],
+ "tokenizer_id": spec["tokenizer_id"],
+ "device": device,
+ "elapsed_seconds": elapsed_seconds,
+ "lookback": LOOKBACK,
+ "pred_len": PRED_LEN,
+ "input_window_start": x_timestamp.iloc[0],
+ "input_window_end": x_timestamp.iloc[-1],
+ "prediction_window_start": y_timestamp.iloc[0],
+ "prediction_window_end": y_timestamp.iloc[-1],
+ "data_file": data_file,
+ "forecast_csv": csv_file,
+ "plot_file": plot_file,
+ "prediction_head": pred_df.head(),
+ },
+ )
+
+ print(f"Saved forecast CSV: {csv_file}")
+ print(f"Saved forecast plot: {plot_file}")
+ print(f"Saved summary JSON: {summary_file}")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune/config.py b/finetune/config.py
index 04cc3ee3..009a4ba6 100644
--- a/finetune/config.py
+++ b/finetune/config.py
@@ -1,131 +1,127 @@
-import os
-
-class Config:
- """
- Configuration class for the entire project.
- """
-
- def __init__(self):
- # =================================================================
- # Data & Feature Parameters
- # =================================================================
- # TODO: Update this path to your Qlib data directory.
- self.qlib_data_path = "~/.qlib/qlib_data/cn_data"
- self.instrument = 'csi300'
-
- # Overall time range for data loading from Qlib.
- self.dataset_begin_time = "2011-01-01"
- self.dataset_end_time = '2025-06-05'
-
- # Sliding window parameters for creating samples.
- self.lookback_window = 90 # Number of past time steps for input.
- self.predict_window = 10 # Number of future time steps for prediction.
- self.max_context = 512 # Maximum context length for the model.
-
- # Features to be used from the raw data.
- self.feature_list = ['open', 'high', 'low', 'close', 'vol', 'amt']
- # Time-based features to be generated.
- self.time_feature_list = ['minute', 'hour', 'weekday', 'day', 'month']
-
- # =================================================================
- # Dataset Splitting & Paths
- # =================================================================
- # Note: The validation/test set starts earlier than the training/validation set ends
- # to account for the `lookback_window`.
- self.train_time_range = ["2011-01-01", "2022-12-31"]
- self.val_time_range = ["2022-09-01", "2024-06-30"]
- self.test_time_range = ["2024-04-01", "2025-06-05"]
- self.backtest_time_range = ["2024-07-01", "2025-06-05"]
-
- # TODO: Directory to save the processed, pickled datasets.
- self.dataset_path = "./data/processed_datasets"
-
- # =================================================================
- # Training Hyperparameters
- # =================================================================
- self.clip = 5.0 # Clipping value for normalized data to prevent outliers.
-
- self.epochs = 30
- self.log_interval = 100 # Log training status every N batches.
- self.batch_size = 50 # Batch size per GPU.
-
- # Number of samples to draw for one "epoch" of training/validation.
- # This is useful for large datasets where a true epoch is too long.
- self.n_train_iter = 2000 * self.batch_size
- self.n_val_iter = 400 * self.batch_size
-
- # Learning rates for different model components.
- self.tokenizer_learning_rate = 2e-4
- self.predictor_learning_rate = 4e-5
-
- # Gradient accumulation to simulate a larger batch size.
- self.accumulation_steps = 1
-
- # AdamW optimizer parameters.
- self.adam_beta1 = 0.9
- self.adam_beta2 = 0.95
- self.adam_weight_decay = 0.1
-
- # Miscellaneous
- self.seed = 100 # Global random seed for reproducibility.
-
- # =================================================================
- # Experiment Logging & Saving
- # =================================================================
- self.use_comet = True # Set to False if you don't want to use Comet ML
- self.comet_config = {
- # It is highly recommended to load secrets from environment variables
- # for security purposes. Example: os.getenv("COMET_API_KEY")
- "api_key": "YOUR_COMET_API_KEY",
- "project_name": "Kronos-Finetune-Demo",
- "workspace": "your_comet_workspace" # TODO: Change to your Comet ML workspace name
- }
- self.comet_tag = 'finetune_demo'
- self.comet_name = 'finetune_demo'
-
- # Base directory for saving model checkpoints and results.
- # Using a general 'outputs' directory is a common practice.
- self.save_path = "./outputs/models"
- self.tokenizer_save_folder_name = 'finetune_tokenizer_demo'
- self.predictor_save_folder_name = 'finetune_predictor_demo'
- self.backtest_save_folder_name = 'finetune_backtest_demo'
-
- # Path for backtesting results.
- self.backtest_result_path = "./outputs/backtest_results"
-
- # =================================================================
- # Model & Checkpoint Paths
- # =================================================================
- # TODO: Update these paths to your pretrained model locations.
- # These can be local paths or Hugging Face Hub model identifiers.
- self.pretrained_tokenizer_path = "path/to/your/Kronos-Tokenizer-base"
- self.pretrained_predictor_path = "path/to/your/Kronos-small"
-
- # Paths to the fine-tuned models, derived from the save_path.
- # These will be generated automatically during training.
- self.finetuned_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model"
- self.finetuned_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model"
-
- # =================================================================
- # Backtesting Parameters
- # =================================================================
- self.backtest_n_symbol_hold = 50 # Number of symbols to hold in the portfolio.
- self.backtest_n_symbol_drop = 5 # Number of symbols to drop from the pool.
- self.backtest_hold_thresh = 5 # Minimum holding period for a stock.
- self.inference_T = 0.6
- self.inference_top_p = 0.9
- self.inference_top_k = 0
- self.inference_sample_count = 5
- self.backtest_batch_size = 1000
- self.backtest_benchmark = self._set_benchmark(self.instrument)
-
- def _set_benchmark(self, instrument):
- dt_benchmark = {
- 'csi800': "SH000906",
- 'csi1000': "SH000852",
- 'csi300': "SH000300",
- }
- if instrument in dt_benchmark:
- return dt_benchmark[instrument]
- else:
- raise ValueError(f"Benchmark not defined for instrument: {instrument}")
+import copy
+import json
+from pathlib import Path
+
+
+class Config:
+ """
+ Configuration class for finetuning and Qlib experiments.
+
+ Defaults preserve the repository's original intended experiment setup.
+ Use a config file or explicit overrides for local smoke tests and other
+ environment-specific runs.
+ """
+
+ DEFAULTS = {
+ "qlib_data_path": "~/.qlib/qlib_data/cn_data",
+ "instrument": "csi300",
+ "dataset_begin_time": "2011-01-01",
+ "dataset_end_time": "2025-06-05",
+ "lookback_window": 90,
+ "predict_window": 10,
+ "max_context": 512,
+ "feature_list": ["open", "high", "low", "close", "vol", "amt"],
+ "time_feature_list": ["minute", "hour", "weekday", "day", "month"],
+ "train_time_range": ["2011-01-01", "2022-12-31"],
+ "val_time_range": ["2022-09-01", "2024-06-30"],
+ "test_time_range": ["2024-04-01", "2025-06-05"],
+ "backtest_time_range": ["2024-07-01", "2025-06-05"],
+ "dataset_path": "./data/processed_datasets",
+ "clip": 5.0,
+ "epochs": 30,
+ "log_interval": 100,
+ "batch_size": 50,
+ "n_train_iter": 2000 * 50,
+ "n_val_iter": 400 * 50,
+ "tokenizer_learning_rate": 2e-4,
+ "predictor_learning_rate": 4e-5,
+ "accumulation_steps": 1,
+ "adam_beta1": 0.9,
+ "adam_beta2": 0.95,
+ "adam_weight_decay": 0.1,
+ "seed": 100,
+ "use_comet": True,
+ "comet_config": {
+ "api_key": "YOUR_COMET_API_KEY",
+ "project_name": "Kronos-Finetune-Demo",
+ "workspace": "your_comet_workspace",
+ },
+ "comet_tag": "finetune_demo",
+ "comet_name": "finetune_demo",
+ "save_path": "./outputs/models",
+ "tokenizer_save_folder_name": "finetune_tokenizer_demo",
+ "predictor_save_folder_name": "finetune_predictor_demo",
+ "backtest_save_folder_name": "finetune_backtest_demo",
+ "backtest_result_path": "./outputs/backtest_results",
+ "pretrained_tokenizer_path": "path/to/your/Kronos-Tokenizer-base",
+ "pretrained_predictor_path": "path/to/your/Kronos-small",
+ "finetuned_tokenizer_path": "__AUTO__",
+ "finetuned_predictor_path": "__AUTO__",
+ "backtest_n_symbol_hold": 50,
+ "backtest_n_symbol_drop": 5,
+ "backtest_hold_thresh": 5,
+ "inference_T": 0.6,
+ "inference_top_p": 0.9,
+ "inference_top_k": 0,
+ "inference_sample_count": 5,
+ "backtest_batch_size": 1000,
+ }
+
+ def __init__(self, config_file: str | None = None, overrides: dict | None = None):
+ defaults = copy.deepcopy(self.DEFAULTS)
+ for key, value in defaults.items():
+ setattr(self, key, value)
+ self._explicit_finetuned_tokenizer_path = False
+ self._explicit_finetuned_predictor_path = False
+
+ if config_file is not None:
+ self.apply_overrides(self._load_config_file(config_file))
+ if overrides:
+ self.apply_overrides(overrides)
+
+ self._refresh_derived_fields()
+
+ @staticmethod
+ def _load_config_file(config_file: str) -> dict:
+ path = Path(config_file).expanduser()
+ with path.open("r", encoding="utf-8") as handle:
+ return json.load(handle)
+
+ def apply_overrides(self, overrides: dict) -> None:
+ for key, value in overrides.items():
+ if not hasattr(self, key):
+ raise AttributeError(f"Unknown config key: {key}")
+ if key == "finetuned_tokenizer_path":
+ self._explicit_finetuned_tokenizer_path = True
+ if key == "finetuned_predictor_path":
+ self._explicit_finetuned_predictor_path = True
+ setattr(self, key, copy.deepcopy(value))
+ self._refresh_derived_fields()
+
+ def _refresh_derived_fields(self) -> None:
+ default_tokenizer_path = f"{self.save_path}/{self.tokenizer_save_folder_name}/checkpoints/best_model"
+ default_predictor_path = f"{self.save_path}/{self.predictor_save_folder_name}/checkpoints/best_model"
+
+ if not self._explicit_finetuned_tokenizer_path:
+ self.finetuned_tokenizer_path = default_tokenizer_path
+ if not self._explicit_finetuned_predictor_path:
+ self.finetuned_predictor_path = default_predictor_path
+
+ self.backtest_benchmark = self._set_benchmark(self.instrument)
+
+ def to_dict(self) -> dict:
+ return {
+ key: copy.deepcopy(value)
+ for key, value in self.__dict__.items()
+ if not key.startswith("_")
+ }
+
+ def _set_benchmark(self, instrument):
+ dt_benchmark = {
+ "csi800": "SH000906",
+ "csi1000": "SH000852",
+ "csi300": "SH000300",
+ }
+ if instrument in dt_benchmark:
+ return dt_benchmark[instrument]
+ raise ValueError(f"Benchmark not defined for instrument: {instrument}")
diff --git a/finetune/dataset.py b/finetune/dataset.py
index f955ec12..641d4cd8 100644
--- a/finetune/dataset.py
+++ b/finetune/dataset.py
@@ -6,7 +6,7 @@
from config import Config
-class QlibDataset(Dataset):
+class QlibDataset(Dataset):
"""
A PyTorch Dataset for handling Qlib financial time series data.
@@ -20,8 +20,8 @@ class QlibDataset(Dataset):
ValueError: If `data_type` is not 'train' or 'val'.
"""
- def __init__(self, data_type: str = 'train'):
- self.config = Config()
+ def __init__(self, data_type: str = 'train', config: Config | None = None):
+ self.config = config or Config()
if data_type not in ['train', 'val']:
raise ValueError("data_type must be 'train' or 'val'")
self.data_type = data_type
diff --git a/finetune/local_configs/qlib_smoke_backtest_base.example.json b/finetune/local_configs/qlib_smoke_backtest_base.example.json
new file mode 100644
index 00000000..8ae26dcf
--- /dev/null
+++ b/finetune/local_configs/qlib_smoke_backtest_base.example.json
@@ -0,0 +1,25 @@
+{
+ "dataset_end_time": "2020-09-25",
+ "train_time_range": [
+ "2011-01-01",
+ "2018-12-31"
+ ],
+ "val_time_range": [
+ "2018-09-01",
+ "2020-03-31"
+ ],
+ "test_time_range": [
+ "2020-04-01",
+ "2020-09-25"
+ ],
+ "backtest_time_range": [
+ "2020-07-01",
+ "2020-09-24"
+ ],
+ "pretrained_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base",
+ "pretrained_predictor_path": "NeoQuasar/Kronos-base",
+ "finetuned_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base",
+ "finetuned_predictor_path": "NeoQuasar/Kronos-base",
+ "inference_sample_count": 1,
+ "backtest_save_folder_name": "qlib_smoke_backtest_base"
+}
diff --git a/finetune/local_configs/qlib_smoke_backtest_small.example.json b/finetune/local_configs/qlib_smoke_backtest_small.example.json
new file mode 100644
index 00000000..763f6242
--- /dev/null
+++ b/finetune/local_configs/qlib_smoke_backtest_small.example.json
@@ -0,0 +1,25 @@
+{
+ "dataset_end_time": "2020-09-25",
+ "train_time_range": [
+ "2011-01-01",
+ "2018-12-31"
+ ],
+ "val_time_range": [
+ "2018-09-01",
+ "2020-03-31"
+ ],
+ "test_time_range": [
+ "2020-04-01",
+ "2020-09-25"
+ ],
+ "backtest_time_range": [
+ "2020-07-01",
+ "2020-09-24"
+ ],
+ "pretrained_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base",
+ "pretrained_predictor_path": "NeoQuasar/Kronos-small",
+ "finetuned_tokenizer_path": "NeoQuasar/Kronos-Tokenizer-base",
+ "finetuned_predictor_path": "NeoQuasar/Kronos-small",
+ "inference_sample_count": 1,
+ "backtest_save_folder_name": "qlib_smoke_backtest_small"
+}
diff --git a/finetune/qlib_data_preprocess.py b/finetune/qlib_data_preprocess.py
index 7fe6147b..5335a566 100644
--- a/finetune/qlib_data_preprocess.py
+++ b/finetune/qlib_data_preprocess.py
@@ -1,7 +1,8 @@
-import os
-import pickle
-import numpy as np
-import pandas as pd
+import os
+import pickle
+import argparse
+import numpy as np
+import pandas as pd
import qlib
from qlib.config import REG_CN
from qlib.data import D
@@ -11,16 +12,16 @@
from config import Config
-class QlibDataPreprocessor:
+class QlibDataPreprocessor:
"""
A class to handle the loading, processing, and splitting of Qlib financial data.
"""
- def __init__(self):
- """Initializes the preprocessor with configuration and data fields."""
- self.config = Config()
- self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap']
- self.data = {} # A dictionary to store processed data for each symbol.
+ def __init__(self, config_file: str | None = None):
+ """Initializes the preprocessor with configuration and data fields."""
+ self.config = Config(config_file=config_file)
+ self.data_fields = ['open', 'close', 'high', 'low', 'volume', 'vwap']
+ self.data = {} # A dictionary to store processed data for each symbol.
def initialize_qlib(self):
"""Initializes the Qlib environment."""
@@ -121,10 +122,13 @@ def prepare_dataset(self):
print("Datasets prepared and saved successfully.")
-if __name__ == '__main__':
- # This block allows the script to be run directly to perform data preprocessing.
- preprocessor = QlibDataPreprocessor()
- preprocessor.initialize_qlib()
- preprocessor.load_qlib_data()
- preprocessor.prepare_dataset()
+if __name__ == '__main__':
+ parser = argparse.ArgumentParser(description="Prepare Qlib train/val/test datasets.")
+ parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.")
+ args = parser.parse_args()
+
+ preprocessor = QlibDataPreprocessor(config_file=args.config_file)
+ preprocessor.initialize_qlib()
+ preprocessor.load_qlib_data()
+ preprocessor.prepare_dataset()
diff --git a/finetune/qlib_test.py b/finetune/qlib_test.py
index 97aebe41..0702001e 100644
--- a/finetune/qlib_test.py
+++ b/finetune/qlib_test.py
@@ -1,362 +1,382 @@
-import os
-import sys
-import argparse
-import pickle
-from collections import defaultdict
-
-import numpy as np
-import pandas as pd
-import torch
-from torch.utils.data import Dataset, DataLoader
-from tqdm import trange, tqdm
-from matplotlib import pyplot as plt
-
-import qlib
-from qlib.config import REG_CN
-from qlib.backtest import backtest, executor, CommonInfrastructure
-from qlib.contrib.evaluate import risk_analysis
-from qlib.contrib.strategy import TopkDropoutStrategy
-from qlib.utils import flatten_dict
-from qlib.utils.time import Freq
-
-# Ensure project root is in the Python path
-sys.path.append("../")
-from config import Config
-from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference
-
-
-# =================================================================================
-# 1. Data Loading and Processing for Inference
-# =================================================================================
-
-class QlibTestDataset(Dataset):
- """
- PyTorch Dataset for handling Qlib test data, specifically for inference.
-
- This dataset iterates through all possible sliding windows sequentially. It also
- yields metadata like symbol and timestamp, which are crucial for mapping
- predictions back to the original time series.
- """
-
- def __init__(self, data: dict, config: Config):
- self.data = data
- self.config = config
- self.window_size = config.lookback_window + config.predict_window
- self.symbols = list(self.data.keys())
- self.feature_list = config.feature_list
- self.time_feature_list = config.time_feature_list
- self.indices = []
-
- print("Preprocessing and building indices for test dataset...")
- for symbol in self.symbols:
- df = self.data[symbol].reset_index()
- # Generate time features on-the-fly
- df['minute'] = df['datetime'].dt.minute
- df['hour'] = df['datetime'].dt.hour
- df['weekday'] = df['datetime'].dt.weekday
- df['day'] = df['datetime'].dt.day
- df['month'] = df['datetime'].dt.month
- self.data[symbol] = df # Store preprocessed dataframe
-
- num_samples = len(df) - self.window_size + 1
- if num_samples > 0:
- for i in range(num_samples):
- timestamp = df.iloc[i + self.config.lookback_window - 1]['datetime']
- self.indices.append((symbol, i, timestamp))
-
- def __len__(self) -> int:
- return len(self.indices)
-
- def __getitem__(self, idx: int):
- symbol, start_idx, timestamp = self.indices[idx]
- df = self.data[symbol]
-
- context_end = start_idx + self.config.lookback_window
- predict_end = context_end + self.config.predict_window
-
- context_df = df.iloc[start_idx:context_end]
- predict_df = df.iloc[context_end:predict_end]
-
- x = context_df[self.feature_list].values.astype(np.float32)
- x_stamp = context_df[self.time_feature_list].values.astype(np.float32)
- y_stamp = predict_df[self.time_feature_list].values.astype(np.float32)
-
- # Instance-level normalization, consistent with training
- x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
- x = (x - x_mean) / (x_std + 1e-5)
- x = np.clip(x, -self.config.clip, self.config.clip)
-
- return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp
-
-
-# =================================================================================
-# 2. Backtesting Logic
-# =================================================================================
-
-class QlibBacktest:
- """
- A wrapper class for conducting backtesting experiments using Qlib.
- """
-
- def __init__(self, config: Config):
- self.config = config
- self.initialize_qlib()
-
- def initialize_qlib(self):
- """Initializes the Qlib environment."""
- print("Initializing Qlib for backtesting...")
- qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
-
- def run_single_backtest(self, signal_series: pd.Series) -> pd.DataFrame:
- """
- Runs a single backtest for a given prediction signal.
-
- Args:
- signal_series (pd.Series): A pandas Series with a MultiIndex
- (instrument, datetime) and prediction scores.
- Returns:
- pd.DataFrame: A DataFrame containing the performance report.
- """
- strategy = TopkDropoutStrategy(
- topk=self.config.backtest_n_symbol_hold,
- n_drop=self.config.backtest_n_symbol_drop,
- hold_thresh=self.config.backtest_hold_thresh,
- signal=signal_series,
- )
- executor_config = {
- "time_per_step": "day",
- "generate_portfolio_metrics": True,
- "delay_execution": True,
- }
- backtest_config = {
- "start_time": self.config.backtest_time_range[0],
- "end_time": self.config.backtest_time_range[1],
- "account": 100_000_000,
- "benchmark": self.config.backtest_benchmark,
- "exchange_kwargs": {
- "freq": "day", "limit_threshold": 0.095, "deal_price": "open",
- "open_cost": 0.001, "close_cost": 0.0015, "min_cost": 5,
- },
- "executor": executor.SimulatorExecutor(**executor_config),
- }
-
- portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config)
- analysis_freq = "{0}{1}".format(*Freq.parse("day"))
- report, _ = portfolio_metric_dict.get(analysis_freq)
-
- # --- Analysis and Reporting ---
- analysis = {
- "excess_return_without_cost": risk_analysis(report["return"] - report["bench"], freq=analysis_freq),
- "excess_return_with_cost": risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq),
- }
- print("\n--- Backtest Analysis ---")
- print("Benchmark Return:", risk_analysis(report["bench"], freq=analysis_freq), sep='\n')
- print("\nExcess Return (w/o cost):", analysis["excess_return_without_cost"], sep='\n')
- print("\nExcess Return (w/ cost):", analysis["excess_return_with_cost"], sep='\n')
-
- report_df = pd.DataFrame({
- "cum_bench": report["bench"].cumsum(),
- "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(),
- "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(),
- })
- return report_df
-
- def run_and_plot_results(self, signals: dict[str, pd.DataFrame]):
- """
- Runs backtests for multiple signals and plots the cumulative return curves.
-
- Args:
- signals (dict[str, pd.DataFrame]): A dictionary where keys are signal names
- and values are prediction DataFrames.
- """
- return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
-
- for signal_name, pred_df in signals.items():
- print(f"\nBacktesting signal: {signal_name}...")
- pred_series = pred_df.stack()
- pred_series.index.names = ['datetime', 'instrument']
- pred_series = pred_series.swaplevel().sort_index()
- report_df = self.run_single_backtest(pred_series)
-
- return_df[signal_name] = report_df['cum_return_w_cost']
- ex_return_df[signal_name] = report_df['cum_ex_return_w_cost']
- if 'return' not in bench_df:
- bench_df['return'] = report_df['cum_bench']
-
- # Plotting results
- fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
- return_df.plot(ax=axes[0], title='Cumulative Return with Cost', grid=True)
- axes[0].plot(bench_df['return'], label=self.config.instrument.upper(), color='black', linestyle='--')
- axes[0].legend()
- axes[0].set_ylabel("Cumulative Return")
-
- ex_return_df.plot(ax=axes[1], title='Cumulative Excess Return with Cost', grid=True)
- axes[1].legend()
- axes[1].set_xlabel("Date")
- axes[1].set_ylabel("Cumulative Excess Return")
-
- plt.tight_layout()
- plt.savefig("../figures/backtest_result_example.png", dpi=200)
- plt.show()
-
-
-# =================================================================================
-# 3. Inference Logic
-# =================================================================================
-
-def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]:
- """Loads the fine-tuned tokenizer and predictor model."""
- device = torch.device(config['device'])
- print(f"Loading models onto device: {device}...")
- tokenizer = KronosTokenizer.from_pretrained(config['tokenizer_path']).to(device).eval()
- model = Kronos.from_pretrained(config['model_path']).to(device).eval()
- return tokenizer, model
-
-
-def collate_fn_for_inference(batch):
- """
- Custom collate function to handle batches containing Tensors, strings, and Timestamps.
-
- Args:
- batch (list): A list of samples, where each sample is the tuple returned by
- QlibTestDataset.__getitem__.
-
- Returns:
- A single tuple containing the batched data.
- """
- # Unzip the list of samples into separate lists for each data type
- x, x_stamp, y_stamp, symbols, timestamps = zip(*batch)
-
- # Stack the tensors to create a batch
- x_batch = torch.stack(x, dim=0)
- x_stamp_batch = torch.stack(x_stamp, dim=0)
- y_stamp_batch = torch.stack(y_stamp, dim=0)
-
- # Return the strings and timestamps as lists
- return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps)
-
-
-def generate_predictions(config: dict, test_data: dict) -> dict[str, pd.DataFrame]:
- """
- Runs inference on the test dataset to generate prediction signals.
-
- Args:
- config (dict): A dictionary containing inference parameters.
- test_data (dict): The raw test data loaded from a pickle file.
-
- Returns:
- A dictionary where keys are signal types (e.g., 'mean', 'last') and
- values are DataFrames of predictions (datetime index, symbol columns).
- """
- tokenizer, model = load_models(config)
- device = torch.device(config['device'])
-
- # Use the Dataset and DataLoader for efficient batching and processing
- dataset = QlibTestDataset(data=test_data, config=Config())
- loader = DataLoader(
- dataset,
- batch_size=config['batch_size'] // config['sample_count'],
- shuffle=False,
- num_workers=os.cpu_count() // 2,
- collate_fn=collate_fn_for_inference
- )
-
- results = defaultdict(list)
- with torch.no_grad():
- for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
- preds = auto_regressive_inference(
- tokenizer, model, x.to(device), x_stamp.to(device), y_stamp.to(device),
- max_context=config['max_context'], pred_len=config['pred_len'], clip=config['clip'],
- T=config['T'], top_k=config['top_k'], top_p=config['top_p'], sample_count=config['sample_count']
- )
- # You can try commenting on this line to keep the history data
- preds = preds[:, -config['pred_len']:, :]
-
- # The 'close' price is at index 3 in `feature_list`
- last_day_close = x[:, -1, 3].numpy()
- signals = {
- 'last': preds[:, -1, 3] - last_day_close,
- 'mean': np.mean(preds[:, :, 3], axis=1) - last_day_close,
- 'max': np.max(preds[:, :, 3], axis=1) - last_day_close,
- 'min': np.min(preds[:, :, 3], axis=1) - last_day_close,
- }
-
- for i in range(len(symbols)):
- for sig_type, sig_values in signals.items():
- results[sig_type].append((timestamps[i], symbols[i], sig_values[i]))
-
- print("Post-processing predictions into DataFrames...")
- prediction_dfs = {}
- for sig_type, records in results.items():
- df = pd.DataFrame(records, columns=['datetime', 'instrument', 'score'])
- pivot_df = df.pivot_table(index='datetime', columns='instrument', values='score')
- prediction_dfs[sig_type] = pivot_df.sort_index()
-
- return prediction_dfs
-
-
-# =================================================================================
-# 4. Main Execution
-# =================================================================================
-
-def main():
- """Main function to set up config, run inference, and execute backtesting."""
- parser = argparse.ArgumentParser(description="Run Kronos Inference and Backtesting")
- parser.add_argument("--device", type=str, default="cuda:1", help="Device for inference (e.g., 'cuda:0', 'cpu')")
- args = parser.parse_args()
-
- # --- 1. Configuration Setup ---
- base_config = Config()
-
- # Create a dedicated dictionary for this run's configuration
- run_config = {
- 'device': args.device,
- 'data_path': base_config.dataset_path,
- 'result_save_path': base_config.backtest_result_path,
- 'result_name': base_config.backtest_save_folder_name,
- 'tokenizer_path': base_config.finetuned_tokenizer_path,
- 'model_path': base_config.finetuned_predictor_path,
- 'max_context': base_config.max_context,
- 'pred_len': base_config.predict_window,
- 'clip': base_config.clip,
- 'T': base_config.inference_T,
- 'top_k': base_config.inference_top_k,
- 'top_p': base_config.inference_top_p,
- 'sample_count': base_config.inference_sample_count,
- 'batch_size': base_config.backtest_batch_size,
- }
-
- print("--- Running with Configuration ---")
- for key, val in run_config.items():
- print(f"{key:>20}: {val}")
- print("-" * 35)
-
- # --- 2. Load Data ---
- test_data_path = os.path.join(run_config['data_path'], "test_data.pkl")
- print(f"Loading test data from {test_data_path}...")
- with open(test_data_path, 'rb') as f:
- test_data = pickle.load(f)
- print(test_data)
- # --- 3. Generate Predictions ---
- model_preds = generate_predictions(run_config, test_data)
-
- # --- 4. Save Predictions ---
- save_dir = os.path.join(run_config['result_save_path'], run_config['result_name'])
- os.makedirs(save_dir, exist_ok=True)
- predictions_file = os.path.join(save_dir, "predictions.pkl")
- print(f"Saving prediction signals to {predictions_file}...")
- with open(predictions_file, 'wb') as f:
- pickle.dump(model_preds, f)
-
- # --- 5. Run Backtesting ---
- with open(predictions_file, 'rb') as f:
- model_preds = pickle.load(f)
-
- backtester = QlibBacktest(base_config)
- backtester.run_and_plot_results(model_preds)
-
-
-if __name__ == '__main__':
- main()
-
-
+import argparse
+import json
+import os
+import pickle
+import sys
+from collections import defaultdict
+from pathlib import Path
+
+import numpy as np
+import pandas as pd
+import torch
+from matplotlib import pyplot as plt
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+
+import qlib
+from qlib.backtest import backtest, executor
+from qlib.config import REG_CN
+from qlib.contrib.evaluate import risk_analysis
+from qlib.contrib.strategy import TopkDropoutStrategy
+from qlib.utils.time import Freq
+
+# Ensure project root is in the Python path
+sys.path.append("../")
+from config import Config
+from model.kronos import Kronos, KronosTokenizer, auto_regressive_inference
+
+
+def save_json(path: Path, payload: dict) -> None:
+ with path.open("w", encoding="utf-8") as handle:
+ json.dump(payload, handle, indent=2, ensure_ascii=False)
+
+
+def analysis_frame_to_dict(frame: pd.DataFrame) -> dict[str, float]:
+ if "risk" in frame.columns:
+ series = frame["risk"]
+ else:
+ series = frame.iloc[:, 0]
+ return {str(key): float(value) for key, value in series.items()}
+
+
+class QlibTestDataset(Dataset):
+ """
+ PyTorch Dataset for handling Qlib test data, specifically for inference.
+ """
+
+ def __init__(self, data: dict, config: Config):
+ self.data = data
+ self.config = config
+ self.window_size = config.lookback_window + config.predict_window
+ self.symbols = list(self.data.keys())
+ self.feature_list = config.feature_list
+ self.time_feature_list = config.time_feature_list
+ self.indices = []
+
+ print("Preprocessing and building indices for test dataset...")
+ for symbol in self.symbols:
+ df = self.data[symbol].reset_index()
+ df["minute"] = df["datetime"].dt.minute
+ df["hour"] = df["datetime"].dt.hour
+ df["weekday"] = df["datetime"].dt.weekday
+ df["day"] = df["datetime"].dt.day
+ df["month"] = df["datetime"].dt.month
+ self.data[symbol] = df
+
+ num_samples = len(df) - self.window_size + 1
+ if num_samples > 0:
+ for idx in range(num_samples):
+ timestamp = df.iloc[idx + self.config.lookback_window - 1]["datetime"]
+ self.indices.append((symbol, idx, timestamp))
+
+ def __len__(self) -> int:
+ return len(self.indices)
+
+ def __getitem__(self, idx: int):
+ symbol, start_idx, timestamp = self.indices[idx]
+ df = self.data[symbol]
+
+ context_end = start_idx + self.config.lookback_window
+ predict_end = context_end + self.config.predict_window
+
+ context_df = df.iloc[start_idx:context_end]
+ predict_df = df.iloc[context_end:predict_end]
+
+ x = context_df[self.feature_list].values.astype(np.float32)
+ x_stamp = context_df[self.time_feature_list].values.astype(np.float32)
+ y_stamp = predict_df[self.time_feature_list].values.astype(np.float32)
+
+ x_mean, x_std = np.mean(x, axis=0), np.std(x, axis=0)
+ x = (x - x_mean) / (x_std + 1e-5)
+ x = np.clip(x, -self.config.clip, self.config.clip)
+
+ return torch.from_numpy(x), torch.from_numpy(x_stamp), torch.from_numpy(y_stamp), symbol, timestamp
+
+
+class QlibBacktest:
+ """
+ A wrapper class for conducting Qlib backtests and saving structured results.
+ """
+
+ def __init__(self, config: Config, save_dir: Path, show_plot: bool):
+ self.config = config
+ self.save_dir = save_dir
+ self.show_plot = show_plot
+ self.initialize_qlib()
+
+ def initialize_qlib(self):
+ print("Initializing Qlib for backtesting...")
+ qlib.init(provider_uri=self.config.qlib_data_path, region=REG_CN)
+
+ def run_single_backtest(self, signal_series: pd.Series) -> tuple[pd.DataFrame, dict[str, dict[str, float]]]:
+ strategy = TopkDropoutStrategy(
+ topk=self.config.backtest_n_symbol_hold,
+ n_drop=self.config.backtest_n_symbol_drop,
+ hold_thresh=self.config.backtest_hold_thresh,
+ signal=signal_series,
+ )
+ executor_config = {
+ "time_per_step": "day",
+ "generate_portfolio_metrics": True,
+ "delay_execution": True,
+ }
+ backtest_config = {
+ "start_time": self.config.backtest_time_range[0],
+ "end_time": self.config.backtest_time_range[1],
+ "account": 100_000_000,
+ "benchmark": self.config.backtest_benchmark,
+ "exchange_kwargs": {
+ "freq": "day",
+ "limit_threshold": 0.095,
+ "deal_price": "open",
+ "open_cost": 0.001,
+ "close_cost": 0.0015,
+ "min_cost": 5,
+ },
+ "executor": executor.SimulatorExecutor(**executor_config),
+ }
+
+ portfolio_metric_dict, _ = backtest(strategy=strategy, **backtest_config)
+ analysis_freq = "{0}{1}".format(*Freq.parse("day"))
+ report, _ = portfolio_metric_dict.get(analysis_freq)
+
+ benchmark_metrics = risk_analysis(report["bench"], freq=analysis_freq)
+ excess_without_cost = risk_analysis(report["return"] - report["bench"], freq=analysis_freq)
+ excess_with_cost = risk_analysis(report["return"] - report["bench"] - report["cost"], freq=analysis_freq)
+
+ print("\n--- Backtest Analysis ---")
+ print("Benchmark Return:", benchmark_metrics, sep="\n")
+ print("\nExcess Return (w/o cost):", excess_without_cost, sep="\n")
+ print("\nExcess Return (w/ cost):", excess_with_cost, sep="\n")
+
+ report_df = pd.DataFrame(
+ {
+ "cum_bench": report["bench"].cumsum(),
+ "cum_return_w_cost": (report["return"] - report["cost"]).cumsum(),
+ "cum_ex_return_w_cost": (report["return"] - report["bench"] - report["cost"]).cumsum(),
+ }
+ )
+ analysis = {
+ "benchmark": analysis_frame_to_dict(benchmark_metrics),
+ "excess_return_without_cost": analysis_frame_to_dict(excess_without_cost),
+ "excess_return_with_cost": analysis_frame_to_dict(excess_with_cost),
+ }
+ return report_df, analysis
+
+ def run_and_plot_results(self, signals: dict[str, pd.DataFrame]):
+ return_df, ex_return_df, bench_df = pd.DataFrame(), pd.DataFrame(), pd.DataFrame()
+ summary_rows = []
+
+ for signal_name, pred_df in signals.items():
+ print(f"\nBacktesting signal: {signal_name}...")
+ pred_series = pred_df.stack()
+ pred_series.index.names = ["datetime", "instrument"]
+ pred_series = pred_series.swaplevel().sort_index()
+ report_df, analysis = self.run_single_backtest(pred_series)
+
+ report_df.to_csv(self.save_dir / f"{signal_name}_curve.csv", index_label="datetime")
+ save_json(self.save_dir / f"{signal_name}_metrics.json", analysis)
+
+ return_df[signal_name] = report_df["cum_return_w_cost"]
+ ex_return_df[signal_name] = report_df["cum_ex_return_w_cost"]
+ if "return" not in bench_df:
+ bench_df["return"] = report_df["cum_bench"]
+
+ summary_rows.append(
+ {
+ "signal": signal_name,
+ "benchmark_annualized_return": analysis["benchmark"]["annualized_return"],
+ "benchmark_information_ratio": analysis["benchmark"]["information_ratio"],
+ "benchmark_max_drawdown": analysis["benchmark"]["max_drawdown"],
+ "excess_annualized_return_without_cost": analysis["excess_return_without_cost"]["annualized_return"],
+ "excess_information_ratio_without_cost": analysis["excess_return_without_cost"]["information_ratio"],
+ "excess_max_drawdown_without_cost": analysis["excess_return_without_cost"]["max_drawdown"],
+ "excess_annualized_return_with_cost": analysis["excess_return_with_cost"]["annualized_return"],
+ "excess_information_ratio_with_cost": analysis["excess_return_with_cost"]["information_ratio"],
+ "excess_max_drawdown_with_cost": analysis["excess_return_with_cost"]["max_drawdown"],
+ }
+ )
+
+ summary_df = pd.DataFrame(summary_rows)
+ summary_df.to_csv(self.save_dir / "backtest_summary.csv", index=False)
+ save_json(self.save_dir / "backtest_summary.json", {"signals": summary_rows})
+
+ fig, axes = plt.subplots(2, 1, figsize=(12, 8), sharex=True)
+ return_df.plot(ax=axes[0], title="Cumulative Return with Cost", grid=True)
+ axes[0].plot(bench_df["return"], label=self.config.instrument.upper(), color="black", linestyle="--")
+ axes[0].legend()
+ axes[0].set_ylabel("Cumulative Return")
+
+ ex_return_df.plot(ax=axes[1], title="Cumulative Excess Return with Cost", grid=True)
+ axes[1].legend()
+ axes[1].set_xlabel("Date")
+ axes[1].set_ylabel("Cumulative Excess Return")
+
+ plt.tight_layout()
+ plot_path = self.save_dir / "backtest_curves.png"
+ plt.savefig(plot_path, dpi=200)
+ if self.show_plot:
+ plt.show()
+ else:
+ plt.close(fig)
+
+ print(f"Saved backtest summary CSV: {self.save_dir / 'backtest_summary.csv'}")
+ print(f"Saved backtest plot: {plot_path}")
+
+
+def load_models(config: dict) -> tuple[KronosTokenizer, Kronos]:
+ device = torch.device(config["device"])
+ print(f"Loading models onto device: {device}...")
+ tokenizer = KronosTokenizer.from_pretrained(config["tokenizer_path"]).to(device).eval()
+ model = Kronos.from_pretrained(config["model_path"]).to(device).eval()
+ return tokenizer, model
+
+
+def collate_fn_for_inference(batch):
+ x, x_stamp, y_stamp, symbols, timestamps = zip(*batch)
+ x_batch = torch.stack(x, dim=0)
+ x_stamp_batch = torch.stack(x_stamp, dim=0)
+ y_stamp_batch = torch.stack(y_stamp, dim=0)
+ return x_batch, x_stamp_batch, y_stamp_batch, list(symbols), list(timestamps)
+
+
+def generate_predictions(run_config: dict, test_data: dict) -> dict[str, pd.DataFrame]:
+ tokenizer, model = load_models(run_config)
+ device = torch.device(run_config["device"])
+
+ dataset = QlibTestDataset(data=test_data, config=run_config["config_obj"])
+ loader = DataLoader(
+ dataset,
+ batch_size=max(1, run_config["batch_size"] // run_config["sample_count"]),
+ shuffle=False,
+ num_workers=run_config["num_workers"],
+ collate_fn=collate_fn_for_inference,
+ )
+
+ results = defaultdict(list)
+ with torch.no_grad():
+ for x, x_stamp, y_stamp, symbols, timestamps in tqdm(loader, desc="Inference"):
+ preds = auto_regressive_inference(
+ tokenizer,
+ model,
+ x.to(device),
+ x_stamp.to(device),
+ y_stamp.to(device),
+ max_context=run_config["max_context"],
+ pred_len=run_config["pred_len"],
+ clip=run_config["clip"],
+ T=run_config["T"],
+ top_k=run_config["top_k"],
+ top_p=run_config["top_p"],
+ sample_count=run_config["sample_count"],
+ )
+ preds = preds[:, -run_config["pred_len"] :, :]
+ last_day_close = x[:, -1, 3].numpy()
+ signals = {
+ "last": preds[:, -1, 3] - last_day_close,
+ "mean": np.mean(preds[:, :, 3], axis=1) - last_day_close,
+ "max": np.max(preds[:, :, 3], axis=1) - last_day_close,
+ "min": np.min(preds[:, :, 3], axis=1) - last_day_close,
+ }
+
+ for idx in range(len(symbols)):
+ for sig_type, sig_values in signals.items():
+ results[sig_type].append((timestamps[idx], symbols[idx], sig_values[idx]))
+
+ print("Post-processing predictions into DataFrames...")
+ prediction_dfs = {}
+ for sig_type, records in results.items():
+ df = pd.DataFrame(records, columns=["datetime", "instrument", "score"])
+ pivot_df = df.pivot_table(index="datetime", columns="instrument", values="score")
+ prediction_dfs[sig_type] = pivot_df.sort_index()
+
+ return prediction_dfs
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Run Kronos inference and Qlib backtesting.")
+ parser.add_argument("--device", type=str, default="cuda:0", help="Device for inference (e.g. cuda:0, mps, cpu).")
+ parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.")
+ parser.add_argument("--tokenizer-path", type=str, default=None, help="Explicit tokenizer path or Hugging Face model ID.")
+ parser.add_argument("--model-path", type=str, default=None, help="Explicit predictor path or Hugging Face model ID.")
+ parser.add_argument("--run-name", type=str, default=None, help="Optional override for the backtest result folder.")
+ parser.add_argument("--save-only", action="store_true", help="Save plots to disk without opening a GUI window.")
+ parser.add_argument("--num-workers", type=int, default=0, help="DataLoader worker count for inference.")
+ args = parser.parse_args()
+
+ overrides = {}
+ if args.run_name:
+ overrides["backtest_save_folder_name"] = args.run_name
+ if args.tokenizer_path:
+ overrides["finetuned_tokenizer_path"] = args.tokenizer_path
+ if args.model_path:
+ overrides["finetuned_predictor_path"] = args.model_path
+
+ base_config = Config(config_file=args.config_file, overrides=overrides if overrides else None)
+ run_config = {
+ "config_obj": base_config,
+ "device": args.device,
+ "data_path": base_config.dataset_path,
+ "result_save_path": base_config.backtest_result_path,
+ "result_name": base_config.backtest_save_folder_name,
+ "tokenizer_path": base_config.finetuned_tokenizer_path,
+ "model_path": base_config.finetuned_predictor_path,
+ "max_context": base_config.max_context,
+ "pred_len": base_config.predict_window,
+ "clip": base_config.clip,
+ "T": base_config.inference_T,
+ "top_k": base_config.inference_top_k,
+ "top_p": base_config.inference_top_p,
+ "sample_count": base_config.inference_sample_count,
+ "batch_size": base_config.backtest_batch_size,
+ "num_workers": args.num_workers,
+ }
+
+ print("--- Running with Configuration ---")
+ for key, val in run_config.items():
+ if key == "config_obj":
+ continue
+ print(f"{key:>20}: {val}")
+ print("-" * 35)
+
+ test_data_path = Path(run_config["data_path"]) / "test_data.pkl"
+ print(f"Loading test data from {test_data_path}...")
+ with test_data_path.open("rb") as handle:
+ test_data = pickle.load(handle)
+ non_empty_symbols = sum(1 for df in test_data.values() if not df.empty)
+ print(f"Loaded {len(test_data)} symbols from test data, {non_empty_symbols} non-empty.")
+ if non_empty_symbols == 0:
+ raise RuntimeError(
+ "test_data.pkl contains no non-empty symbols. Adjust the config ranges or local Qlib dataset coverage."
+ )
+
+ model_preds = generate_predictions(run_config, test_data)
+
+ save_dir = Path(run_config["result_save_path"]) / run_config["result_name"]
+ save_dir.mkdir(parents=True, exist_ok=True)
+ predictions_file = save_dir / "predictions.pkl"
+ print(f"Saving prediction signals to {predictions_file}...")
+ with predictions_file.open("wb") as handle:
+ pickle.dump(model_preds, handle)
+
+ save_json(
+ save_dir / "run_config.json",
+ {
+ "device": args.device,
+ "config_file": args.config_file,
+ "tokenizer_path": run_config["tokenizer_path"],
+ "model_path": run_config["model_path"],
+ "result_save_path": str(save_dir),
+ "save_only": args.save_only,
+ "num_workers": args.num_workers,
+ "config": base_config.to_dict(),
+ "non_empty_symbols": non_empty_symbols,
+ },
+ )
+
+ backtester = QlibBacktest(base_config, save_dir=save_dir, show_plot=not args.save_only)
+ backtester.run_and_plot_results(model_preds)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/finetune/train_predictor.py b/finetune/train_predictor.py
index 1e425872..6eaf96ec 100644
--- a/finetune/train_predictor.py
+++ b/finetune/train_predictor.py
@@ -1,15 +1,19 @@
-import os
-import sys
-import json
-import time
-from time import gmtime, strftime
-import torch.distributed as dist
+import os
+import sys
+import json
+import time
+from time import gmtime, strftime
+import argparse
+import torch.distributed as dist
import torch
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
-import comet_ml
+try:
+ import comet_ml
+except ImportError:
+ comet_ml = None
# Ensure project root is in path
sys.path.append('../')
@@ -26,7 +30,7 @@
)
-def create_dataloaders(config: dict, rank: int, world_size: int):
+def create_dataloaders(config: dict, rank: int, world_size: int):
"""
Creates and returns distributed dataloaders for training and validation.
@@ -39,8 +43,9 @@ def create_dataloaders(config: dict, rank: int, world_size: int):
tuple: (train_loader, val_loader, train_dataset, valid_dataset).
"""
print(f"[Rank {rank}] Creating distributed dataloaders...")
- train_dataset = QlibDataset('train')
- valid_dataset = QlibDataset('val')
+ dataset_config = Config(overrides=config)
+ train_dataset = QlibDataset('train', config=dataset_config)
+ valid_dataset = QlibDataset('val', config=dataset_config)
print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
@@ -196,16 +201,18 @@ def main(config: dict):
'save_directory': save_dir,
'world_size': world_size,
}
- if config['use_comet']:
- comet_logger = comet_ml.Experiment(
- api_key=config['comet_config']['api_key'],
- project_name=config['comet_config']['project_name'],
- workspace=config['comet_config']['workspace'],
- )
- comet_logger.add_tag(config['comet_tag'])
- comet_logger.set_name(config['comet_name'])
- comet_logger.log_parameters(config)
- print("Comet Logger Initialized.")
+ if config['use_comet'] and comet_ml is not None:
+ comet_logger = comet_ml.Experiment(
+ api_key=config['comet_config']['api_key'],
+ project_name=config['comet_config']['project_name'],
+ workspace=config['comet_config']['workspace'],
+ )
+ comet_logger.add_tag(config['comet_tag'])
+ comet_logger.set_name(config['comet_name'])
+ comet_logger.log_parameters(config)
+ print("Comet Logger Initialized.")
+ elif config['use_comet']:
+ print("comet_ml is not installed; continuing without experiment logging.")
dist.barrier()
@@ -235,10 +242,14 @@ def main(config: dict):
cleanup_ddp()
-if __name__ == '__main__':
- # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py
- if "WORLD_SIZE" not in os.environ:
- raise RuntimeError("This script must be launched with `torchrun`.")
-
- config_instance = Config()
- main(config_instance.__dict__)
+if __name__ == '__main__':
+ # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_predictor.py
+ if "WORLD_SIZE" not in os.environ:
+ raise RuntimeError("This script must be launched with `torchrun`.")
+
+ parser = argparse.ArgumentParser(description="Train the Kronos predictor with torchrun.")
+ parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.")
+ args = parser.parse_args()
+
+ config_instance = Config(config_file=args.config_file)
+ main(config_instance.to_dict())
diff --git a/finetune/train_tokenizer.py b/finetune/train_tokenizer.py
index 2fe28cf1..2b22cff3 100644
--- a/finetune/train_tokenizer.py
+++ b/finetune/train_tokenizer.py
@@ -12,7 +12,10 @@
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
-import comet_ml
+try:
+ import comet_ml
+except ImportError:
+ comet_ml = None
# Ensure project root is in path
sys.path.append("../")
@@ -29,7 +32,7 @@
)
-def create_dataloaders(config: dict, rank: int, world_size: int):
+def create_dataloaders(config: dict, rank: int, world_size: int):
"""
Creates and returns distributed dataloaders for training and validation.
@@ -42,8 +45,9 @@ def create_dataloaders(config: dict, rank: int, world_size: int):
tuple: A tuple containing (train_loader, val_loader, train_dataset, valid_dataset).
"""
print(f"[Rank {rank}] Creating distributed dataloaders...")
- train_dataset = QlibDataset('train')
- valid_dataset = QlibDataset('val')
+ dataset_config = Config(overrides=config)
+ train_dataset = QlibDataset('train', config=dataset_config)
+ valid_dataset = QlibDataset('val', config=dataset_config)
print(f"[Rank {rank}] Train dataset size: {len(train_dataset)}, Validation dataset size: {len(valid_dataset)}")
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank, shuffle=True)
@@ -234,16 +238,18 @@ def main(config: dict):
'save_directory': save_dir,
'world_size': world_size,
}
- if config['use_comet']:
- comet_logger = comet_ml.Experiment(
- api_key=config['comet_config']['api_key'],
- project_name=config['comet_config']['project_name'],
- workspace=config['comet_config']['workspace'],
- )
- comet_logger.add_tag(config['comet_tag'])
- comet_logger.set_name(config['comet_name'])
- comet_logger.log_parameters(config)
- print("Comet Logger Initialized.")
+ if config['use_comet'] and comet_ml is not None:
+ comet_logger = comet_ml.Experiment(
+ api_key=config['comet_config']['api_key'],
+ project_name=config['comet_config']['project_name'],
+ workspace=config['comet_config']['workspace'],
+ )
+ comet_logger.add_tag(config['comet_tag'])
+ comet_logger.set_name(config['comet_name'])
+ comet_logger.log_parameters(config)
+ print("Comet Logger Initialized.")
+ elif config['use_comet']:
+ print("comet_ml is not installed; continuing without experiment logging.")
dist.barrier() # Ensure save directory is created before proceeding
@@ -272,10 +278,14 @@ def main(config: dict):
cleanup_ddp()
-if __name__ == '__main__':
- # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py
- if "WORLD_SIZE" not in os.environ:
- raise RuntimeError("This script must be launched with `torchrun`.")
-
- config_instance = Config()
- main(config_instance.__dict__)
+if __name__ == '__main__':
+ # Usage: torchrun --standalone --nproc_per_node=NUM_GPUS train_tokenizer.py
+ if "WORLD_SIZE" not in os.environ:
+ raise RuntimeError("This script must be launched with `torchrun`.")
+
+ parser = argparse.ArgumentParser(description="Train the Kronos tokenizer with torchrun.")
+ parser.add_argument("--config-file", type=str, default=None, help="Optional JSON config override file.")
+ args = parser.parse_args()
+
+ config_instance = Config(config_file=args.config_file)
+ main(config_instance.to_dict())
diff --git a/webui/README.md b/webui/README.md
index 571d93eb..a3ac2243 100644
--- a/webui/README.md
+++ b/webui/README.md
@@ -14,6 +14,19 @@ Web user interface for Kronos financial prediction model, providing intuitive gr
## ๐ Quick Start
+Install the Web UI dependencies from the `webui` directory before starting:
+
+```bash
+pip install -r requirements.txt
+```
+
+If you use a shared environment from the repository root, install the root requirements first and then install the Web UI extras:
+
+```bash
+pip install -r ../requirements.txt
+pip install -r requirements.txt
+```
+
### Method 1: Start with Python script
```bash
cd webui
@@ -112,7 +125,7 @@ The system automatically provides comparison analysis between prediction results
### Common Issues
1. **Port occupied**: Modify port number in app.py
-2. **Missing dependencies**: Run `pip install -r requirements.txt`
+2. **Missing dependencies**: Run `pip install -r requirements.txt` from the `webui` directory after installing the root project requirements
3. **Model loading failed**: Check network connection and model ID
4. **Data format error**: Ensure data column names and format are correct
diff --git a/webui/app.py b/webui/app.py
index d240a372..886e02d9 100644
--- a/webui/app.py
+++ b/webui/app.py
@@ -59,10 +59,17 @@
def load_data_files():
"""Scan data directory and return available data files"""
- data_dir = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), 'data')
data_files = []
-
- if os.path.exists(data_dir):
+
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
+ candidate_dirs = [
+ os.path.join(project_root, 'data'),
+ os.path.join(project_root, 'examples', 'data'),
+ ]
+
+ for data_dir in candidate_dirs:
+ if not os.path.exists(data_dir):
+ continue
for file in os.listdir(data_dir):
if file.endswith(('.csv', '.feather')):
file_path = os.path.join(data_dir, file)
@@ -70,6 +77,7 @@ def load_data_files():
data_files.append({
'name': file,
'path': file_path,
+ 'source_dir': os.path.relpath(data_dir, project_root),
'size': f"{file_size / 1024:.1f} KB" if file_size < 1024*1024 else f"{file_size / (1024*1024):.1f} MB"
})