Skip to content

kubernetes-bad/BERTrainer

Repository files navigation

BERTrainer

BERTrainer is designed to make your life easier when training text classification models.

If you could handle Axolotl, you can handle BERTrainer too.

Features

  • Supports BERT, DeBERTa, RoBERTa (probably more)
  • Yaml configs, yay!
  • CUDA, MPS, and CPU are supported
  • Training and inference!
  • Weights & Biases Sweeps 🙌
  • Multiple datasets in one training, shuffled

Installation

To get started, install it using pip:

git clone https://github.com/kubernetes-bad/BERTrainer
cd BERTrainer
pip3 install -e .

Or use Docker:

docker run -it \
  -e WANDB_API_KEY=abcdef00008888 \
  -v /path/to/config.yaml:/config.yaml \
  -v /path/to/output/:/output \
  -v ~/.cache/huggingface/:/root/.cache/huggingface/ \
  ghcr.io/kubernetes-bad/bertrainer /config.yaml

Usage

Using BERTrainer is easy, the design is very human. Just follow these steps:

  1. Create a configuration file (e.g., config.yml) specifying your model, dataset, and training settings. Check out the example configurations for inspiration.

  2. Run the trainer with your configuration file:

    python3 -m bertrainer.train config.yml
  3. Sit back, watch the graphs, and let the trainer do its magic! ✨

  4. Once the training is complete, you'll find your trained model in the specified output directory.

  5. For running your model, run python3 -m bertrainer.serve config.yml - it will load the model from your output_directory and serve on port 8000. Here's an example of a request to that inference endpoint:

curl --location 'http://localhost:8000/predict' \
--header 'Content-Type: application/json' \
--data '{
    "text": "Quick brown fox jumps over the lazy dog."
}'

And here is an example response:

{
    "class_0": 0.0005910243489779532,
    "class_1": 0.9994089603424072
}

Happy training! 🎓✨

About

Train classifier models! Ez!

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages