介紹 JobSet

作者:Daniel Vega-Myhre (Google),Abdullah Gharaibeh (Google),Kevin Hannon (Red Hat)

在本文中,我們介紹 JobSet,這是一個用於表示分散式作業的開源 API。JobSet 的目標是為 Kubernetes 上的分散式機器學習訓練和 HPC 負載提供統一的 API。

為什麼需要 JobSet?

Kubernetes 社群最近對 Kubernetes 批處理生態系統的增強吸引了機器學習工程師,他們發現這非常適合執行分散式訓練負載的需求。

大型機器學習模型(特別是 LLM)無法裝入單個主機上的 GPU 或 TPU 晶片記憶體中,通常會分佈在數萬個加速器晶片上,而這些晶片又可能跨越數千個主機。

因此,模型訓練程式碼通常被容器化並在所有這些主機上同時執行,執行分散式計算,通常將模型引數和/或訓練資料集分片到目標加速器晶片上,使用 all-gather 和 all-reduce 等通訊集合原語來執行分散式計算並在主機之間同步梯度。

這些工作負載特性使 Kubernetes 非常適合此類工作負載,因為它擅長在計算資源叢集中高效地排程和管理容器化應用程式的生命週期。

它還具有很強的可擴充套件性,允許開發人員定義自己的 Kubernetes API、物件和控制器來管理這些物件的行為和生命週期,從而使工程師能夠開發定製的分散式訓練編排解決方案以滿足其需求。

然而,隨著分散式機器學習訓練技術的不斷發展,現有的 Kubernetes 原語已不足以單獨對其進行建模。

此外,Kubernetes 分散式訓練編排 API 的格局變得分散,而且這個分散格局中的每個現有解決方案都存在某些侷限性,使其不適用於分散式機器學習訓練。

例如,KubeFlow 訓練 Operator 為不同的框架(例如 PyTorchJob、TFJob、MPIJob 等)定義了自定義 API;然而,這些作業型別實際上是專門針對目標框架的解決方案,每種都有不同的語義和行為。

另一方面,Job API 修復了執行批處理工作負載的許多差距,包括 Indexed 完成模式、更高的可擴充套件性、Pod 失敗策略和 Pod 退避策略等等,這只是最近的一些增強功能。然而,使用上游 Job API 執行機器學習訓練和 HPC 工作負載需要額外的編排來填補以下差距

多模板 Pod:大多數 HPC 或機器學習訓練作業包含不止一種型別的 Pod。不同的 Pod 是同一工作負載的一部分,但它們需要執行不同的容器、請求不同的資源或具有不同的失敗策略。一個常見的例子是驅動程式-工作程式模式。

作業組:大規模訓練工作負載跨越多個網路拓撲,例如跨多個機架執行。此類工作負載對網路延遲敏感,旨在將通訊本地化並最小化跨越高延遲網路鏈路的流量。為了實現這一點,需要將工作負載拆分為多個 Pod 組,每個組分配給一個網路拓撲。

Pod 間通訊:建立和管理建立作業 Pod 之間通訊所需的資源(例如 無頭服務)。

啟動順序:一些作業需要特定的 Pod 啟動順序;有時驅動程式需要先啟動(如 Ray 或 Spark),在其他情況下,工作程式需要在驅動程式啟動前準備好(如 MPI)。

JobSet 旨在利用 Job API 作為構建塊,為大規模分散式 HPC 和機器學習用例構建更豐富的 API 來解決這些差距。

JobSet 的工作原理

JobSet 將分散式批處理工作負載建模為一組 Kubernetes Job。這使使用者可以輕鬆地為不同的 Pod 組(例如 leader、worker、引數伺服器等)指定不同的 Pod 模板。

它使用 ReplicatedJob 的抽象來管理子 Job,其中 ReplicatedJob 本質上是一個指定了所需 Job 副本數量的 Job 模板。這提供了一種宣告性的方式來輕鬆建立相同的子 Job 以在不同的加速器孤島上執行,而無需藉助指令碼或 Helm chart 來生成具有不同名稱的同一 Job 的多個版本。

JobSet Architecture

JobSet 的其他一些解決上述問題的關鍵特性包括

Replicated Jobs:在現代資料中心中,像 GPU 和 TPU 這樣的硬體加速器分配在同構加速器的孤島中,透過專門的高頻寬網路連結連線。例如,使用者可能會配置包含一組位於同一機架上的主機的節點,每個主機都配備 H100 GPU,其中每個主機內的 GPU 晶片透過 NVLink 連線,並由 NVLink Switch 連線多個 NVLink。TPU Pod 是另一個例子:TPU ViperLitePods 由 64 個主機組成,每個主機連線 4 個 TPU v5e 晶片,所有晶片透過 ICI 網格連線。當在多個這樣的孤島上執行分散式訓練作業時,我們通常希望將工作負載劃分為一組較小的相同作業,每個孤島一個,其中每個 Pod 主要與同一孤島內的 Pod 通訊以進行分散式計算的分段,並將 DCN(資料中心網路,其頻寬低於 ICI)上的梯度同步保持在最低限度。

自動建立、配置和生命週期管理無頭服務:預設啟用透過 Pod 主機名進行 Pod 間通訊,並自動配置和管理實現此功能的無頭服務的生命週期。

可配置的成功策略:JobSet 具有可配置的成功策略,這些策略針對特定的 ReplicatedJob,並帶有運算子以針對其子作業的“Any”或“All”。例如,您可以配置 JobSet,使其僅在“worker” ReplicatedJob 的所有 Pod 都完成時才標記為完成。

可配置的失敗策略:JobSet 具有可配置的失敗策略,允許使用者指定在發生故障時 JobSet 應重新啟動的最大次數。如果任何作業被標記為失敗,整個 JobSet 將被重新建立,從而允許工作負載從最後一個檢查點恢復。當未指定失敗策略時,如果任何作業失敗,JobSet 將直接失敗。

每個拓撲域的獨佔放置:JobSet 允許使用者表示子作業與拓撲域(通常是像機架一樣的加速器孤島)具有 1:1 的獨佔分配。例如,如果 JobSet 建立兩個子作業,那麼此功能將強制每個子作業的 Pod 位於同一孤島上,並且每個孤島只允許排程一個子作業。這對於我們希望使用分散式資料並行(DDP)訓練策略來使用多個計算資源孤島(GPU 機架或 TPU 切片)訓練模型的場景很有用,在每個加速器孤島中執行 1 個模型副本,確保前向和後向傳播本身在單個模型副本內透過連線孤島內加速器晶片的高頻寬互連進行,並且只有模型副本之間的梯度同步透過較低頻寬的資料中心網路跨加速器孤島進行。

與 Kueue 整合:使用者可以透過 Kueue 提交 JobSet 來超額訂閱其叢集、將工作負載排隊以在容量可用時執行、防止部分排程和死鎖、啟用多租戶等。

用例示例

使用 Jax 在多個 TPU 切片上進行分散式機器學習訓練

以下示例是一個 JobSet 規範,用於在 4 個 TPU v5e 切片上執行 TPU 多切片工作負載。要了解有關 TPU 概念和術語的更多資訊,請參閱這些文件

本示例使用 Jax,這是一個透過 OpenXLA 原生支援針對 TPU 晶片進行即時(JIT)編譯的機器學習框架。不過,您也可以使用 PyTorch/XLA 在 TPU 上進行機器學習訓練。

此示例利用了多個 JobSet 特性(顯式和隱式)來支援 TPU 多切片訓練的獨特排程要求,開箱即用,使用者只需很少的配置。

# Run a simple Jax workload on 
apiVersion: jobset.x-k8s.io/v1alpha2
kind: JobSet
metadata:
  name: multislice
  annotations:
    # Give each child Job exclusive usage of a TPU slice 
    alpha.jobset.sigs.k8s.io/exclusive-topology: cloud.google.com/gke-nodepool
spec:
  failurePolicy:
    maxRestarts: 3
  replicatedJobs:
  - name: workers
    replicas: 4 # Set to number of TPU slices
    template:
      spec:
        parallelism: 2 # Set to number of VMs per TPU slice
        completions: 2 # Set to number of VMs per TPU slice
        backoffLimit: 0
        template:
          spec:
            hostNetwork: true
            dnsPolicy: ClusterFirstWithHostNet
            nodeSelector:
              cloud.google.com/gke-tpu-accelerator: tpu-v5-lite-podslice
              cloud.google.com/gke-tpu-topology: 2x4
            containers:
            - name: jax-tpu
              image: python:3.8
              ports:
              - containerPort: 8471
              - containerPort: 8080
              securityContext:
                privileged: true
              command:
              - bash
              - -c
              - |
                pip install "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
                python -c 'import jax; print("Global device count:", jax.device_count())'
                sleep 60                
              resources:
                limits:
                  google.com/tpu: 4

未來工作和參與

我們在 JobSet 路線圖上計劃了今年要開發的許多功能,可以在 JobSet 路線圖中找到。

歡迎隨時提供任何形式的反饋。我們也歡迎更多的貢獻者,無論是修復或報告錯誤,還是幫助新增新功能或編寫文件。

您可以透過我們的 倉庫郵件列表或在 Slack 上與我們聯絡。

最後但同樣重要的是,感謝所有使這個專案成為可能的貢獻者