Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ We encourage contributions in the following areas:
- 🔍 **Testing**: Adding realistic traces and increasing coverage for components that are not thoroughly tested.
- 🚧 **Engineering Improvements**: Enhancing log formatting, improving CLI usability, and performing code cleanup.

**For specific tasks and upcoming features where we need assistance, please see our [ROADMAP (TBD)](./ROADMAP.md) for planned directions and priorities.**
**For specific tasks and upcoming features where we need assistance, please see our [ROADMAP](./ROADMAP.md) for planned directions and priorities.**

## ⚠️ Important Information for Contributors

Expand Down
61 changes: 41 additions & 20 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,53 +3,74 @@
<picture>
<img alt="TrainCheck logo" width="55%" src="./docs/assets/images/traincheck_logo.png">
</picture>
<h1>Silent Error Detection for Deep Learning Training</h1>
<h1>TrainCheck: Training with Confidence</h1>

[![format and types](https://github.com/OrderLab/traincheck/actions/workflows/pre-commit-checks.yml/badge.svg)](https://github.com/OrderLab/traincheck/actions/workflows/pre-commit-checks.yml)
[![Chat on Discord](https://img.shields.io/discord/1362661016760090736?label=Discord&logo=discord&style=flat)](https://discord.gg/ZvYewjsQ9D)

</div>

> ***Training with Confidence***

TrainCheck is a lightweight, extensible tool for runtime monitoring of “silent” bugs in deep‑learning training pipelines. Instead of waiting for a crash or a bad model, TrainCheck:
1. **Automatically instruments** your existing training scripts (e.g., from [pytorch/examples](https://github.com/pytorch/examples) or [huggingface/transformers/examples](https://github.com/huggingface/transformers/tree/main/examples)), inserting tracing hooks with minimal code changes.
2. **Learns precise invariants**–precise properties that should hold during training across API calls and model updates-by analyzing executions of known-good runs.
3. **Catches silent issues early**–by checking invariants on new or modified training jobs, alerting you immediately if something didn't happen as expected (e.g., model weight inconsistency, mixed precision not applied successfully, unexpected tensor shapes). On violation, TrainCheck flags the point of divergence—so users can diagnose silent issues before they derail your model.
**TrainCheck** is a lightweight tool for proactively catching **silent errors** in deep learning training runs. It detects correctness issues, such as code bugs and faulty hardware, early and pinpoints their root cause.

TrainCheck has detected silent errors in a wide range of real-world training scenarios, from large-scale LLM pretraining (such as BLOOM-176B) to small-scale tutorial runs by deep learning beginners.

📌 For a list of successful cases, see: TODO

## What It Does

TrainCheck uses **training invariants**, which are semantic rules that describe expected behavior during training, to detect bugs as they happen. These invariants can be extracted from any correct run, including those produced by official examples and tutorials. There is no need to curate inputs or write manual assertions.

TrainCheck performs three core functions:

1. **Instruments your training code**
Inserts lightweight tracing into existing scripts (such as [pytorch/examples](https://github.com/pytorch/examples) or [transformers](https://github.com/huggingface/transformers/tree/main/examples)) with minimal code changes.

2. **Learns invariants from correct runs**
Discovers expected relationships across APIs, tensors, and training steps to build a model of normal behavior.

3. **Checks new or modified runs**
Validates behavior against the learned invariants and flags silent errors, such as missing gradient clipping, weight desynchronization, or broken mixed precision, right when they occur.

This picture illustrates the TrainCheck workflow:

![Workflow](docs/assets/images/workflow.png)

Under the hood, TrainCheck decomposes into three CLI tools:
- **Instrumentor** (`traincheck-collect`)
Wraps target training programs with lightweight tracing logic. It produces an instrumented version of the target program that logs API calls and model states without altering training semantics.
- **Inference Engine** (`traincheck-infer`)
Consumes one or more trace logs from successful runs to infer low‑level invariants.
Consumes one or more trace logs from successful runs to infer training invariants.
- **Checker** (`traincheck-check`)
Runs alongside or after new training jobs to verify that each recorded event satisfies the inferred invariants.

## Status

TrainCheck is under active development. Features may be incomplete and the documentation is evolving—if you give it a try, please join our 💬 [Discord server](https://discord.gg/VwxpJDvB) or file a GitHub issue for support. Currently, the **Checker** operates in a semi‑online mode: you invoke it against the live, growing trace output to catch silent bugs as they appear. Fully automatic monitoring is on the roadmap, and we welcome feedback and contributions from early adopters.

## Try TrainCheck
## 🔥 Try TrainCheck

1. **Install**
Follow the [Installation Guide](./docs/installation-guide.md) to get TrainCheck set up on your machine.

2. **Explore**
Work through our "[5‑Minute Experience with TrainCheck](./docs/5-min-tutorial.md)" tutorial. You’ll learn how to:
Work through [5‑Minute Experience with TrainCheck](./docs/5-min-tutorial.md). You’ll learn how to:
- Instrument a training script and collect a trace
- Automatically infer low‑level invariants
- Run the Checker in semi‑online mode to uncover silent bugs
- Automatically infer invariants
- Uncover silent bugs in the training script

## Documentation

Please visit [TrainCheck Technical Doc](./docs/technical-doc.md).
- **[Installation Guide](./docs/installation-guide.md)**
- **[Usage Guide: Scenarios and Limitations](./docs/usage-guide.md)**
- **[TrainCheck Technical Doc](./docs/technical-doc.md)**
- **[TrainCheck Dev RoadMap](./ROADMAP.md)**

## Status

TrainCheck is under active development. Please join our 💬 [Discord server](https://discord.gg/VwxpJDvB) or file a GitHub issue for support.
We welcome feedback and contributions from early adopters.

## Contributing

We welcome and value any contributions and collaborations. Please check out [Contributing to TrainCheck](./CONTRIBUTING.md) for how to get involved.

## License

TrainCheck is licensed under the [Apache License 2.0](./LICENSE).

## Citation

If TrainCheck is relevant to your work, please cite our paper:
Expand Down
28 changes: 28 additions & 0 deletions ROADMAP.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# TrainCheck Roadmap

This document outlines planned directions for the TrainCheck project. The roadmap is aspirational and subject to change as we gather feedback from the community.

## Short Term

- **Online monitoring** – integrate the checker directly into the collection process so violations are reported immediately during training.
- **Pre-inferred invariant library** – ship a curated set of invariants for common PyTorch and HuggingFace workflows to reduce the need for manual inference.
- **Improved distributed support** – better handling of multi-GPU and multi-node runs, including tracing of distributed backends.
- **High-quality invariants** – publish well-tested invariants for PyTorch, DeepSpeed, and Transformers out of the box.
- **Demo assets** – publish a short demo video and GIFs illustrating the TrainCheck workflow.
- **Expanded documentation** – add guidance on choosing reference runs and diagnosing issues, plus deeper technical docs.
- **Stability fixes and tests** – resolve proxy dump bugs and add end-to-end tests for the full instrumentation→inference→checking pipeline.
- **Call graph updates** – document the call-graph generation process and keep graphs in sync with recent PyTorch versions.
- **Repository cleanup** – remove obsolete files and artifacts.

## Medium Term

- **Extensible instrumentation** – allow plugins for third-party libraries and custom frameworks.
- **Smarter invariant filtering** – tooling to help users manage large numbers of invariants and suppress benign ones.
- **Performance improvements** – explore parallel inference and more efficient trace storage formats.

## Long Term

- **Cross-framework support** – expand beyond PyTorch to additional deep learning frameworks.
- **Automated root-cause analysis** – provide hints or suggested fixes when a violation is detected.

We welcome contributions in any of these areas. If you have ideas or want to help, please check the [CONTRIBUTING guide](./CONTRIBUTING.md) and open an issue to discuss!
Binary file modified docs/assets/images/workflow.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
40 changes: 40 additions & 0 deletions docs/usage-guide.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# 🧪 TrainCheck: Usage Guide

TrainCheck helps detect and diagnose silent errors in deep learning training runs—issues that don't crash your code but silently break correctness.

## 🚀 Quick Start

Check out the [5-minute guide](./docs/5-min.md) for a minimal working example.

## ✅ Common Use Cases

TrainCheck is useful when your training process doesn’t converge, behaves inconsistently, or silently fails. It can help you:

- **Monitor** long-running training jobs and catch issues early
- **Debug** finished runs and pinpoint where things went wrong
- **Sanity-check** new pipelines, code changes, or infrastructure upgrades

TrainCheck detects a range of correctness issues—like misused APIs, incorrect training logic, or hardware faults—without requiring labels or modifications to your training code.

**While TrainCheck focuses on correctness, it’s also useful for *ruling out bugs* so you can focus on algorithm design with confidence.**

## 🧠 Tips for Effective Use

1. **Use short runs to reduce overhead.**
If your hardware is stable, you can validate just the beginning of training. Use smaller models and fewer iterations to speed up turnaround time.

2. **Choose good reference runs for inference.**
- If you have a past run of the same code that worked well, just use that.
- You can also use small-scale example pipelines that cover different features of the framework (e.g., various optimizers, mixed precision, optional flags).
- If you're debugging a new or niche feature with limited history, try using the official example as a reference. Even if the example is not bug-free, invariant violations can still highlight behavioral differences between your run and the example, helping you debug faster.

3. **Minimize scale when collecting traces.**
- Shrink the pipeline by using a smaller model, running for only ~10 iterations, and using the minimal necessary compute setup (e.g., 2 nodes for distributed training).


## 🚧 Current Limitations

- **Eager mode only.** TrainCheck instrumentor currently works only in PyTorch eager mode. Features like `torch.compile` are disabled during instrumentation.

- **Not fully real-time (yet).** Invariant checking is semi-online. Full real-time support is planned but not yet available.

Empty file added tests/__init__.py
Empty file.
23 changes: 23 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import importlib.util
from pathlib import Path

import torch

# Load utils module without triggering traincheck package __init__
UTILS_PATH = Path(__file__).resolve().parents[1] / "traincheck" / "utils.py"
_spec = importlib.util.spec_from_file_location("tc_utils", UTILS_PATH)
_utils = importlib.util.module_from_spec(_spec)
_spec.loader.exec_module(_utils)

typename = _utils.typename


def test_typename_builtin_function():
assert typename(len) == "len"
assert typename(print) == "print"

def test_typename_tensor_and_parameter():
t = torch.tensor([1.0])
assert typename(t) == t.type()
p = torch.nn.Parameter(torch.zeros(1))
assert typename(p) == "torch.nn.Parameter"
4 changes: 2 additions & 2 deletions traincheck/instrumentor/tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ def find_proxy_in_args(args):
dump_trace_API(pre_record)

if handle_proxy and trigger_proxy_state_dump:
"""Mimiking the behavior the observer wrapper: pre-observe"""
"""Mimicking the behavior the observer wrapper: pre-observe"""
get_global_registry().dump_only_modified(
dump_loc=original_function_name, dump_config=proxy_state_dump_config
)
Expand Down Expand Up @@ -275,7 +275,7 @@ def find_proxy_in_args(args):
ORIG_EXIT_PERF_TIME = time.perf_counter()

if handle_proxy and trigger_proxy_state_dump:
"""Mimiking the behavior the observer wrapper: post-observe"""
"""Mimicking the behavior the observer wrapper: post-observe"""
get_global_registry().dump_only_modified(
dump_loc=original_function_name, dump_config=proxy_state_dump_config
)
Expand Down
6 changes: 3 additions & 3 deletions traincheck/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def typename(o, is_runtime=False):
if isinstance(prefix, ModuleSpec):
# handle the case when prefix is a ModuleSpec object
prefix = prefix.name
if prefix in ["buitins", "__builtin__", None]:
if prefix in ["builtins", "__builtin__", None]:
prefix = ""
is_class_name_qualname = True
last_name = safe_getattr(o, "__qualname__", "")
Expand Down Expand Up @@ -86,7 +86,7 @@ def handle_excepthook(typ, message, stack):
stack_info = traceback.StackSummary.extract(
traceback.walk_tb(stack), capture_locals=True
).format()
logger.critical("An exception occured: %s: %s.", typ, message)
logger.critical("An exception occurred: %s: %s.", typ, message)
for i in stack_info:
logger.critical(i.encode().decode("unicode-escape"))

Expand All @@ -109,7 +109,7 @@ def thread_excepthook(args):
stack_info = traceback.StackSummary.extract(
traceback.walk_tb(exc_traceback), capture_locals=True
).format()
logger.critical("An exception occured: %s: %s.", exc_type, exc_value)
logger.critical("An exception occurred: %s: %s.", exc_type, exc_value)
for i in stack_info:
logger.critical(i.encode().decode("unicode-escape"))

Expand Down
Loading