Deploying a Google TPU instance on Google Cloud Platform (GCP)
Context
We assume the reader has already created a Google Cloud Platform (GCP) user or organisation account and an associated project.
We also assume the reader to have the Google Cloud CLI installed. If not please follow the links right after to install and setup.
Creating the initial TPU VM on GCP
In order to create your initial TPU instance, you will need to provide some information:
- The GCP zone you would like to see the instance being deployed (close to the reader for development purpose, close to the end user for production for instance)
- Which kind of TPU you would like to target
- Which version of the TPU runtime you want to leverage on the instance
- Custom instance name to quickly skim and refer back to the instance
Overall the end command looks like this:
gcloud compute tpus tpu-vm create <ref_instance_name> \ --zone=<deploiment_zone> \ --accelerator-type=<target_tpu_generation> \ --version=<runtime_version>
Deploying a TPU v5litepod-8 instance
In our case we will be deploying a v5litepod-8
instance name optimum-tpu-get-started
in the GCP region us-west4-a
using the latest v2-alpha-tpuv5-lite
runtime version.
Of course, feel free to adjust all these parameters to the one that match with your usage and quotas.
Before creating the instance, please make sure to install gcloud alpha component
as it is required to be able to
target TPUv5 VMs: gcloud components install alpha
gcloud alpha compute tpus tpu-vm create optimum-tpu-get-started \ --zone=us-west4-a \ --accelerator-type=v5litepod-8 \ --version=v2-alpha-tpuv5-lite
Connecting to the instance
gcloud compute tpus tpu-vm ssh <ref_instance_name> --zone=<deployment_zone> $ >
In the example above deploying v5litepod-8 it would be something like:
gcloud compute tpus tpu-vm ssh optimum-tpu-get-started --zone=us-west4-a $ >
Setting up the instance to run AI workloads on TPUs
Optimum-TPU with PyTorch/XLA
If you want to leverage PyTorch/XLA through Optimum-TPU, it should be as simple as
$ python3 -m pip install optimum-tpu -f https://storage.googleapis.com/libtpu-releases/index.html
$ export PJRT_DEVICE=TPU
Now you can validate the installation with the following command which should print xla:0
as we do have a single
TPU device bound to this instance.
$ python -c "import torch_xla.core.xla_model as xm; print(xm.xla_device())"
xla:0
Optimum-TPU with JAX
JAX is coming soon - stay tuned!