Google TPUs documentation

Deploying a Text-Generation Inference server on a Google Cloud TPU instance

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Deploying a Text-Generation Inference server on a Google Cloud TPU instance

Context

Text-Generation-Inference (TGI) is a highly optimized serving engine enabling serving Large Language Models (LLMs) in a way that better leverages the underlying hardware, Cloud TPU in this case.

Deploy TGI on Cloud TPU instance

We assume the reader already has a Cloud TPU instance up and running. If this is not the case, please see our guide to deploy one here

Docker Container Build

Optimum-TPU provides a make tpu-tgi command at the root level to help you create local docker image.

Docker Container Run

HF_TOKEN=<your_hf_token_here>
MODEL_ID=google/gemma-2b

sudo docker run --net=host \
                --privileged \
                -v $(pwd)/data:/data \
                -e HF_TOKEN=${HF_TOKEN} \
                huggingface/optimum-tpu:latest \
                --model-id ${MODEL_ID} \
                --max-concurrent-requests 4 \
                --max-input-length 32 \
                --max-total-tokens 64 \
                --max-batch-size 1

Executing requests against the service

You can query the model using either the /generate or /generate_stream routes:

curl localhost/generate \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json'
curl localhost/generate_stream \
    -X POST \
    -d '{"inputs":"What is Deep Learning?","parameters":{"max_new_tokens":20}}' \
    -H 'Content-Type: application/json'

Jetstream Pytorch and Pytorch XLA backends

Jetstream Pytorch is a highly optimized Pytorch engine for serving LLMs on Cloud TPU. This engine is selected by default if the dependency is available. If for some reason you want to use the Pytorch/XLA backend instead, you can set the JETSTREAM_PT_DISABLE=1 environment variable.

When using Jetstream Pytorch engine, it is possible to enable quantization to reduce the memory footprint and increase the throughput. To enable quantization, set the QUANTIZATION=1 environment variable. For instance, on a 2x4 TPU v5e, you can serve models up to 70B parameters such as Llama 3.3-70B.

Note: Quantization is still experimental and may produce lower quality results compared to the non-quantized version.