Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions EXPERIMENTS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Experiments Overview

This repository currently contains four experiment categories.

## 1. Prediction Demos

Location: `examples/`

Purpose:
- Verify that pretrained Kronos models can run end-to-end inference.
- Compare model sizes under the same inputs.

Typical inputs:
- Local CSV K-line data
- Optional online A-share data from `akshare`

Typical outputs:
- Forecast CSV
- Plot PNG
- Summary JSON

Ground-truth availability:
- `prediction_example.py`: no future truth, visualization only
- `prediction_wo_vol_example.py`: no future truth, visualization only
- `prediction_batch_example.py`: no future truth, visualization only
- `prediction_cn_markets_day.py --mode eval`: includes future truth and error metrics

## 2. Regression Tests

Location: `tests/test_kronos_regression.py`

Purpose:
- Ensure code changes do not silently drift model outputs.
- Check exact regression fixtures and MSE thresholds against fixed model revisions.

Typical outputs:
- `pytest` pass/fail result
- Printed absolute/relative differences and MSE summary in the terminal

## 3. Fine-Tuning Pipelines

Locations:
- `finetune/`
- `finetune_csv/`

Purpose:
- Adapt Kronos to domain-specific datasets.
- Support either Qlib-based A-share workflows or direct CSV-based training.

Notes:
- The `finetune/` training scripts are currently written for CUDA + NCCL + DDP.
- The training scripts now treat `comet_ml` as optional.

## 4. Qlib Backtesting

Location: `finetune/qlib_test.py`

Purpose:
- Convert model forecasts into cross-sectional stock-selection signals.
- Evaluate those signals with Qlib's `TopkDropoutStrategy`.

Typical inputs:
- Processed `train/val/test` pickle datasets from `qlib_data_preprocess.py`
- A tokenizer/model pair, which can be local checkpoints or explicit Hugging Face IDs

Typical outputs:
- `predictions.pkl`
- Per-signal curve CSV
- Per-signal metrics JSON
- Backtest summary CSV/JSON
- Backtest plot PNG

Ground-truth availability:
- Yes, via historical replay using the test split

## Recommended Validation Sequence

1. Run `examples/` to verify inference and output formats.
2. Run `tests/test_kronos_regression.py` to detect output drift.
3. Start `webui` to verify the interactive inference path.
4. Run `finetune/qlib_test.py` with an explicit config file for smoke backtesting.
5. Compare `Kronos-small` vs `Kronos-base` on the same prediction and backtest setups.
12 changes: 8 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -211,10 +211,14 @@ Running this script will generate a plot comparing the ground truth data against
<img src="figures/prediction_example.png" alt="Forecast Example" align="center" width="600px" />
</p>

Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py).


## 🔧 Finetuning on Your Own Data (A-Share Market Example)
Additionally, we provide a script that makes predictions without Volume and Amount data, which can be found in [`examples/prediction_wo_vol_example.py`](examples/prediction_wo_vol_example.py).

For a concise guide to all example scripts, recommended run order, Apple Silicon device selection, and optional `akshare` setup, see [`examples/README.md`](examples/README.md).

For an overview of the repository's prediction demos, regression tests, finetuning pipelines, and Qlib backtesting workflows, see [`EXPERIMENTS.md`](EXPERIMENTS.md).


## 🔧 Finetuning on Your Own Data (A-Share Market Example)

We provide a complete pipeline for finetuning Kronos on your own datasets. As an example, we demonstrate how to use [Qlib](https://github.com/microsoft/qlib) to prepare data from the Chinese A-share market and conduct a simple backtest.

Expand Down
62 changes: 62 additions & 0 deletions examples/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Example Scripts

All scripts in this directory can be run from the repository root and save their outputs to `examples/outputs/`.

## Recommended Order

1. `python examples/prediction_example.py --model small`
2. `python examples/prediction_wo_vol_example.py --model small`
3. `python examples/prediction_batch_example.py --model small`
4. `python examples/prediction_cn_markets_day.py --model small --mode forward --symbol 000001`
5. `python examples/prediction_cn_markets_day.py --model small --mode eval --symbol 000001`

## Model Selection

The examples support:

- `--model small`
- `--model base`

This keeps the input data and evaluation mode fixed while allowing direct `Kronos-small` vs `Kronos-base` comparisons.

## Script Semantics

- `prediction_example.py`
Runs one OHLCV forecast on a single historical window and saves CSV, PNG, and summary JSON.
- `prediction_wo_vol_example.py`
Runs one forecast without `volume`/`amount` inputs and saves CSV, PNG, and summary JSON.
- `prediction_batch_example.py`
Demonstrates batch inference on multiple sequential time windows from the same instrument.
It is not a multi-asset cross-sectional example.
By default it saves CSV for every window and one representative PNG.
- `prediction_cn_markets_day.py`
Supports two explicit modes:
- `forward`: future prediction only, so there is no ground-truth window yet.
- `eval`: historical replay mode with known future truth and saved error metrics.

## Outputs

Each run saves structured outputs in `examples/outputs/`:

- Forecast CSV files
- Plot PNG files
- Summary JSON files describing the model, device, time ranges, and output paths

## Apple Silicon Devices

The scripts auto-select a compute device in this order:

1. `cuda:0`
2. `mps`
3. `cpu`

On Apple Silicon Macs, `mps` is used when PyTorch reports that Metal is available.

## Dependencies

- Local CSV examples only need the root `requirements.txt`.
- `prediction_cn_markets_day.py` also requires `akshare` and network access:

```bash
uv pip install --python .venv/bin/python akshare
```
82 changes: 82 additions & 0 deletions examples/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import json
from pathlib import Path
import sys
from typing import Any

import matplotlib
import pandas as pd
import torch

matplotlib.use("Agg")

EXAMPLES_ROOT = Path(__file__).resolve().parent
PROJECT_ROOT = EXAMPLES_ROOT.parent
DATA_ROOT = EXAMPLES_ROOT / "data"
OUTPUT_ROOT = EXAMPLES_ROOT / "outputs"
OUTPUT_ROOT.mkdir(exist_ok=True)

MODEL_SPECS = {
"small": {
"name": "Kronos-small",
"model_id": "NeoQuasar/Kronos-small",
"tokenizer_id": "NeoQuasar/Kronos-Tokenizer-base",
"max_context": 512,
},
"base": {
"name": "Kronos-base",
"model_id": "NeoQuasar/Kronos-base",
"tokenizer_id": "NeoQuasar/Kronos-Tokenizer-base",
"max_context": 512,
},
}

if str(PROJECT_ROOT) not in sys.path:
sys.path.insert(0, str(PROJECT_ROOT))


def data_path(filename: str) -> Path:
return DATA_ROOT / filename


def output_path(filename: str) -> Path:
return OUTPUT_ROOT / filename


def detect_device() -> str:
if torch.cuda.is_available():
return "cuda:0"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
return "mps"
return "cpu"


def get_model_spec(model_key: str) -> dict[str, Any]:
if model_key not in MODEL_SPECS:
raise ValueError(f"Unsupported model '{model_key}'. Expected one of: {sorted(MODEL_SPECS)}")
return MODEL_SPECS[model_key]


def save_json(path: Path, payload: dict[str, Any]) -> None:
with path.open("w", encoding="utf-8") as handle:
json.dump(to_jsonable(payload), handle, indent=2, ensure_ascii=False)


def to_jsonable(value: Any) -> Any:
if isinstance(value, Path):
try:
return str(value.relative_to(PROJECT_ROOT))
except ValueError:
return str(value)
if isinstance(value, pd.Timestamp):
return value.isoformat()
if isinstance(value, pd.Series):
return [to_jsonable(item) for item in value.tolist()]
if isinstance(value, pd.DataFrame):
return [{key: to_jsonable(item) for key, item in row.items()} for row in value.to_dict(orient="records")]
if isinstance(value, dict):
return {str(key): to_jsonable(item) for key, item in value.items()}
if isinstance(value, (list, tuple)):
return [to_jsonable(item) for item in value]
if hasattr(value, "item"):
return value.item()
return value
Loading
Loading