PyTorch-Code auf TPU-Pod-Slices ausführen

PyTorch/XLA erfordert, dass alle TPU-VMs auf den Modellcode und die -Daten zugreifen können. Sie können ein Startskript verwenden, um die erforderliche Software um die Modelldaten auf alle TPU-VMs zu verteilen.

Wenn Sie Ihre TPU-VMs mit einer Virtual Private Cloud verbinden (VPC) müssen Sie in Ihrem Projekt eine Firewallregel hinzufügen, um eingehenden Traffic von Ports zuzulassen 8470–8479. Weitere Informationen zum Hinzufügen von Firewallregeln finden Sie unter Firewallregeln verwenden

Umgebung einrichten

  1. Führen Sie in Cloud Shell den folgenden Befehl aus, um sicherzustellen, die aktuelle Version von gcloud ausführen:

    $ gcloud components update
    

    Wenn Sie gcloud installieren müssen, verwenden Sie den folgenden Befehl:

    $ sudo apt install -y google-cloud-sdk
  2. Erstellen Sie einige Umgebungsvariablen:

    $ export PROJECT_ID=project-id
    $ export TPU_NAME=tpu-name
    $ export ZONE=us-central2-b
    $ export RUNTIME_VERSION=tpu-ubuntu2204-base
    $ export ACCELERATOR_TYPE=v4-32
    

TPU-VM erstellen

$ gcloud compute tpus tpu-vm create ${TPU_NAME} \
--zone=${ZONE} \
--project=${PROJECT_ID} \
--accelerator-type=${ACCELERATOR_TYPE} \
--version ${RUNTIME_VERSION}

Trainingsskript konfigurieren und ausführen

  1. Fügen Sie Ihrem Projekt Ihr SSH-Zertifikat hinzu:

    ssh-add ~/.ssh/google_compute_engine
    
  2. PyTorch/XLA auf allen TPU-VM-Workern installieren

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="
      pip install torch~=2.3.0 torch_xla[tpu]~=2.3.0 torchvision -f https://storage.googleapis.com/libtpu-releases/index.html"
    
  3. XLA auf allen TPU-VM-Workern klonen

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all --command="git clone -b r2.3 https://github.com/pytorch/xla.git"
    
  4. Trainingsskript auf allen Workern ausführen

    gcloud compute tpus tpu-vm ssh ${TPU_NAME} \
      --zone=${ZONE} \
      --project=${PROJECT_ID} \
      --worker=all \
      --command="PJRT_DEVICE=TPU python3 ~/xla/test/test_train_mp_imagenet.py  \
      --fake_data \
      --model=resnet50  \
      --num_epochs=1 2>&1 | tee ~/logs.txt"
      

    Das Training dauert etwa 5 Minuten. Nach Abschluss des Vorgangs sollte eine Meldung wie die folgende angezeigt werden:

    Epoch 1 test end 23:49:15, Accuracy=100.00
    10.164.0.11 [0] Max Accuracy: 100.00%
    

Bereinigen

Wenn Sie mit Ihrer TPU-VM fertig sind, führen Sie die folgenden Schritte aus, um Ihre Ressourcen zu bereinigen.

  1. Trennen Sie die Verbindung zur Compute Engine:

    (vm)$ exit
    
  2. Überprüfen Sie mit dem folgenden Befehl, ob die Ressourcen gelöscht wurden. Achten Sie darauf, dass Ihre TPU nicht mehr aufgeführt wird. Der Löschvorgang kann einige Minuten dauern.

    $ gcloud compute tpus tpu-vm list \
      --zone europe-west4-a