在 TPU 上使用 PyTorch:深入了解 TorchTPU、XLA 和 Cloud TPU

最后更新: 05/03/2026
作者: C 源跟踪
  • Google 的 TorchTPU 和 PyTorch/XLA 使 TPU 成为 PyTorch 的原生高性能后端,而无需强制使用 JAX 式的思维模型。
  • TPU 架构、XLA 编译和 StableHLO 能够实现大规模的高效密集计算和集体计算,尤其适用于分布式训练。
  • 新的 eager 模式、有界动态性以及 easy-torch-tpu 等生态系统工具减少了将以 GPU 为中心的 PyTorch 代码迁移到 TPU 集群时的摩擦。
  • Cloud TPU、GKE 和 Vertex AI 提供了在 TPU 上运行从研究规模到 pod 规模的 PyTorch 工作负载所需的基础设施。

在 TPU 基础设施上运行 PyTorch

在 Google TPU 上运行 PyTorch 不再是少数专家才能涉足的小众实验领域。在谷歌的新…… TorchTPU堆栈借助久经考验的 PyTorch/XLA 项目以及日益壮大的工具和框架生态系统,在 TPU 上训练和部署模型正迅速变得像在 NVIDIA GPU 上工作一样自然。最大的转变在于,现在您可以同时追求高性能、大规模和更加流畅的开发体验。

本文深入探讨了 PyTorch 目前如何利用 TPU 以及该技术栈的未来发展方向。我们将深入剖析 TorchTPU 的架构,它与传统的 PyTorch/XLA 有何不同?分布式训练、编译和硬件细节是如何运作的?如果您正在迁移以 GPU 为中心的 PyTorch 工作流程,这些在实践中意味着什么?如果您从事 LLM、扩散或大规模推荐系统等领域,那么以下细节正是决定您的 TPU 运行速度是飞快还是缓慢的关键所在。

运行时 ia pytorch javascript c++ cuda
相关文章:
AI运行时内部:PyTorch、C++、CUDA及其他

为什么现在在 TPU 上运行 PyTorch 如此重要

现代人工智能工作负载已经超越了简单的“一台机器,少量GPU”时代。最先进的模型现在涵盖了包含数万个加速器的集群,推动软件处理极端规模、可靠的分布式执行以及跨不同芯片和供应商的可移植性能。 人工智能基础设施.

谷歌的张量处理单元(TPU)是这一前沿技术的核心。它们为 Gemini 和 Veo 等内部系统以及 Google Cloud 客户的大部分训练和推理工作负载提供支持。历史上,TPU 与 JAX 和 TensorFlow 紧密结合,但更广泛的生态系统已高度标准化为 PyTorch,这造成了一种痛苦的分裂:GPU 意味着“PyTorch + CUDA”,而 TPU 则意味着“JAX + XLA”。

谷歌的应对之策是全力以赴,力求使 TPU 成为一流的 PyTorch 目标平台。TorchTPU 旨在提供原生、即时响应的 PyTorch 语义,并拥有顶级的性能;而 PyTorch/XLA 仍然是一个功能强大、延迟编译的方案,已被广泛应用于生产环境。围绕这些技术栈,Cloud TPU、GKE、Vertex AI 以及 easy-torch-tpu 等社区框架正在将 TPU 集群转变为简单易用、可脚本化的基础设施,能够处理从 1 亿到 70 亿+ 参数的各种模型。

在 TPU 上训练 PyTorch 模型

TPU硬件内部:不仅仅是更快的芯片

TPU系统本质上是由芯片、主机和互连紧密集成而成的结构。不仅仅是一张加速卡。理解这种布局对于理解 TorchTPU 的设计以及其编译器选择为何与纯 GPU 堆栈不同至关重要。

每个 TPU 主机通过芯片间互连 (ICI) 连接到多个 TPU 芯片。ICI 构建了一个高带宽的二维或三维环面拓扑结构,使得大规模 pod 能够像单个逻辑加速器一样运行。集群无需通过传统的网络协议栈传输梯度,而是直接运行在这个环面上,一旦软件能够正确表达这些集群,横向扩展的效率就会大大提高。

在TPU芯片内部,计算任务被划分到TensorCore和SparseCore之间。TensorCores 是专门的单线程引擎,擅长处理稠密矩阵运算——这正是 Transformer 模型、卷积神经网络 (CNN) 和大多数标准深度学习层的核心所在。SparseCores 则专为内存访问模式不规则的工作负载而设计,例如嵌入、聚集/分散操作以及卸载的集体操作。

这种架构非常适合深度学习,但它对输入数据的要求很高。例如,许多Transformer实现将注意力头维度硬编码为64。而当前几代TPU的最佳性能点通常在128-256之间,这意味着简单地将注意力头维度翻倍就能显著提高矩阵乘法效率和TensorCore利用率。可移植性并不会消除这些硬件限制,它只是让实现这些限制变得更容易。

从 PyTorch/XLA 到 TorchTPU:在 TPU 上运行 PyTorch 的两种互补方式

目前,PyTorch 已经可以通过 PyTorch/XLA (torch_xla) 在 TPU 上运行。它将 TPU 作为标准的 PyTorch 设备呈现,并在底层编译惰性 XLA 图。然而,许多研究人员发现,尽管从理论上讲,代码的改动很小,但与 GPU 即时执行相比,其行为差异可能会让人感到不适。

TorchTPU 是谷歌新推出的原生 PyTorch 后端,旨在提供“真正”的 PyTorch 使用体验,而不是一个封装器。TorchTPU并没有像JAX那样强制PyTorch采用到处使用惰性张量的模型,而是充分利用了PyTorch的即时执行特性和现代编译API,例如 torch.编译。 它使用 私用1 PyTorch 中的设备机制,所以从你的角度来看,你只是在操作常规的设备机制。 torch.Tensor 恰好存在于 TPU 上的对象。

两种方法的主要区别在于执行方式。PyTorch/XLA 默认采用惰性执行:操作会构建一个计算图,当遇到同步障碍(例如训练循环中的某个步骤)时,才会触发 XLA 编译。相比之下,TorchTPU 的架构是“积极优先”,它提供了额外的模式,可以逐步融合操作并将优化后的子图交给 XLA,而无需用户放弃标准的 PyTorch 思维模型。

Cloud TPU、GKE 和 Vertex AI:基础设施骨干

无论你选择哪个基于 TPU 的 PyTorch 技术栈,其底层都是 Cloud TPU 平台。它将定制的 ASIC 作为可扩展的云资源公开,这些资源针对训练和推理进行了优化。这些加速器可用于各种工作负载: 会话代理代码生成、图像和媒体模型、语音、推荐系统和个性化引擎。

云端 TPU 与 Google Kubernetes Engine (GKE) 紧密集成。因此,您可以使用标准的 Kubernetes 原语来调度大规模 PyTorch 作业。动态工作负载调度器允许您一次性请求所需的所有加速器,确保数千个 TPU 芯片同时上线,无需手动编排即可训练或运行模型。

对于希望简化入门流程的团队来说,Vertex AI 可以抽象化大部分集群管理工作。您可以从托管培训和服务工作流中定位 TPU,包括当您使用以下情况时: 基于 PyTorch 的模型Google Cloud 将这种灵活性(TPU 或 GPU,托管或 DIY Kubernetes)定位为直接回应企业和研究实验室对 AI 基础设施日益增长的需求。

TorchTPU 的核心理念:“PyTorch 公民意识”

TorchTPU 的核心设计目标非常明确:它应该感觉像 PyTorch,而不是一个陌生的框架。如果您已经知道如何在 CUDA GPU 上训练模型,那么您应该能够将相同的训练脚本移植到 TPU 上,只需进行最少的代码修改,而无需重写您的心理模型。

实际上,理想的迁移方案看起来几乎简单得有些滑稽。在你通常会写的地方 device = torch.device('cuda')相反,您可以从 TorchTPU 模块获取 TPU 设备——概念上类似于 设备 = tpu.get_device()—并致电 模型.to(设备) 就像在 GPU 上一样。你的前向传播、优化器逻辑以及调用 Hugging Face 模型的方式都可以保持不变。

之前的 TPU 集成经常迫使 PyTorch 模仿 JAX。它们严重依赖惰性张量,迫使你采用静态图的思维模式。这破坏了 PyTorch 的最大优势之一:你无法在正向传播过程中插入打印语句来检查形状或值。TorchTPU 拒绝这种权衡。它以即时加载行为为基础,并围绕其构建性能,而不是要求你放弃它。

这种“PyTorch公民意识”原则也延伸到了错误处理方面。与其查看深藏在 XLA 堆栈中的晦涩难懂的 500 行 C++ 堆栈跟踪信息,不如直接获取清晰的 Py​​thon 回溯信息,直指训练循环或模型定义中出错的行。当你需要处理数十亿参数的模型和数千个 TPU 时,这种效率的提升意味着只需一个下午就能解决问题,而无需耗费数天时间进行漫无目的的调试。

TorchTPU 中的 Eager 模式:调试模式、严格模式和融合模式

在专为大型融合图构建的硬件上提供原生的即时体验并非易事。TorchTPU 通过提供多个由共享编译和执行管道支持的 eager 模式来解决这个问题,因此您可以平滑地从“使其工作”过渡到“使其快速”。

调试渴望 这是速度最慢但最透明的模式。它会分发 一次只做一项手术 每次操作后,数据都会同步到 TPU 并与 CPU 同步。性能方面有意有所牺牲,以便您可以轻松追踪 NaN 值、形状不匹配或内存不足错误,并获得即时反馈和清晰的堆栈跟踪。

严格渴望 保持这种单操作分发语义,但执行 异步TPU 和 CPU 可以并行运行,直到用户代码达到同步点,从而提供更接近标准 GPU 支持的 eager PyTorch 的体验,但仍然不需要大量的图编译。

从性能角度来看,Fused Eager 才是真正有趣的地方。TorchTPU 会观察您执行的操作流,并在通过 XLA 将其发送到 TPU 之前,自动将它们融合为更大、更密集的计算块。这种动态融合步骤显著提高了 TensorCore 的利用率,并降低了内存带宽开销,通常可带来显著的性能提升。 比严格渴望模式快 50-100% 以上 无需对模型代码进行任何更改。

这三种即时模式共享一个公共编译缓存。 它可以运行在单个主机上,也可以在分布式环境中跨多个主机持久运行。随着时间的推移,当训练循环趋于稳定,系统识别出相同的模式时,编译成本会降低,您将把更多的实际运行时间用于处理张量,而不是构建可执行文件。

静态编译:torch.compile、XLA 和 StableHLO

当您需要在 TPU 上获得绝对峰值性能时,TorchTPU 可以直接集成到现代 PyTorch 编译流程中。你可以用以下方式包装模型或函数: torch.compile()它使用 Torch Dynamo 捕获 FX 图形,然后绕过通常的 TorchInductor 后端,并将控制权交给 XLA。

选择 XLA 作为主要后端是基于 TPU 实际情况的深思熟虑的决定。XLA 经过多年在 TPU pod 上的部署,性能已得到充分验证,并且它深刻理解了密集数学运算和 ICI 环面上的集体通信的交集。TorchTPU 将 PyTorch 算子直接映射到 StableHLOOpenXLA 理解的张量 IR,然后让 XLA 的降阶过程生成优化的 TPU 二进制文件,尽可能地重用与 eager 模式相同的运行时路径。

对自定义运算符的扩展性并非事后考虑的因素TorchTPU 支持在 Pallas 和 JAX 中定义的自定义内核:通过类似以下方式装饰 JAX 函数。 @torch_tpu.pallas.custom_jax_kernel这样,您就可以将底层硬件优化代码注入到编译路径中,而不会损失全局优化器的优势。此外,我们还在努力支持更多领域特定语言(DSL),例如 Helion,以实现更灵活的内核编写。

在 TPU 上实现分布式 PyTorch:DDP、FSDP、DTensor 和 MPMD

大规模模型并非在单一加速器上进行训练,而 TorchTPU 的设计正是基于这一现实。它直接与 PyTorch 的标准分布式 API 集成,包括 分布式数据并行(DDP), FSDPv2DTensor并且已经通过基于这些抽象概念构建的第三方库进行了验证。

PyTorch/XLA 的一个主要历史痛点是其严格的 SPMD(单程序多数据)偏好。许多实际的 PyTorch 训练脚本在不同层级之间都存在细微差异——例如,层级 0 可能负责日志记录、检查点维护或指标管理,而其他层级则只进行纯粹的计算。对于 XLA 的全局图视图而言,这种行为显得十分繁琐,常常迫使开发者重写代码以避免层级差异。

TorchTPU明确支持MPMD(多程序、多数据)场景。它精心隔离并限定了通信原语的范围,以避免行为差异破坏正确性或降低性能。在可能的情况下,它仍然允许 XLA 了解分布式计算的全局情况,从而将通信与计算重叠,但它不再强制您采用不切实际的纯粹 SPMD 风格。

它与现有 PyTorch 分布式范式的契合方式尤为重要。诸如 FSDP、DTensor 之类的框架以及 TorchTitan 等生态系统工具都依赖于此。 进程组 用于诸如 all-reduce、all-gather 和 broadcast 等集体操作的 API。在 GPU 上,这些调用通常会解析为 NCCL。TorchTPU 会在 ProcessGroup 层拦截这些集体操作,并将它们转换为 StableHLO 集体操作,TPU 硬件和 ICI 环面可以原生执行这些操作。从 FSDP 或 DTensor 的角度来看,没有任何变化——它们只是看到了不同的后端。

PyTorch/XLA:惰性执行、同步点和实用技巧

虽然 TorchTPU 是长远的、完全原生化的实现路径,但 PyTorch/XLA 目前仍然是 TPU 上运行 PyTorch 的关键工具。如果您习惯了 CUDA 的立即执行,那么 PyTorch/XLA 最大的概念转变在于张量是…… 懒惰操作会记录一个图表;实际的执行和编译会在显式或隐式同步时发生。

同步点是指 PyTorch/XLA 将构建好的计算图交给 XLA 进行编译和执行的地方。典型的障碍包括诸如此类的电话 torch_xla.sync() 或更高级别的实用程序,例如 xm.optimizer_step(optimizer),这既可以逐步优化你的优化器,也可以在分布式设置中跨设备同步梯度。

这种惰性模型会对性能产生重大影响。首次执行给定图(或具有新输入形状的图)时,需要付出编译成本,但只要结构保持稳定,后续迭代的运行速度就会快得多。这就是为什么形状稳定性(固定的序列长度、一致的批处理大小)对于 PyTorch/XLA 工作负载如此重要,以及为什么 将输入框填充到固定大小 这是一种非常常见的模式。

PyTorch/XLA 上的多进程训练使用其自身的便捷工具。通常情况下,你会将核心训练功能封装起来(例如, _mp_mnist_fn)并在所有设备上启动它 torch_xla.launch数据加载通过以下方式管理 torch_xla.distributed.parallel_loader.MpDeviceLoader它采用标准的 PyTorch DataLoader,并确保每个进程都能看到唯一的数据分片,同时将批次预取到相应的 TPU 设备。

数据加载、分布式执行和TPU上的AMP

高效的输入流水线对于 TPU 来说与对于 GPU 来说同样重要。在 PyTorch/XLA 上, MpDeviceLoader 它将主机端数据加载和设备端执行重叠起来,直接将批次数据提供给 TPU,从而帮助您避免在加速器等待新数据时出现长时间的空闲期。

对于分布式训练,`xm.optimizer_step(optimizer)` 执行的操作比普通的优化器步骤更多。它会在所有设备上执行梯度全归约,计算平均值,应用权重更新,并处理必要的同步,因此通常不需要在每次迭代中单独显式调用同步函数。日志辅助函数 xm.is_master_ordinal(local=False) 确保只有一个进程负责处理指标和检查点,以避免重复工作。

自动混合精度(AMP)在TPU上的表现与在GPU上的表现略有不同。TPU 原生支持 bfloat16 (BF16)它提供的指数范围比 float16 大得多,而且通常不需要显式地调整损失函数就能保证稳定性。PyTorch/XLA 扩展了 PyTorch AMP,使其能够在需要时自动在 BF16 和 FP32 之间进行映射,从而使得在 TPU 上进行混合精度训练既简单又稳健。

保存模型也有针对 TPU 的最佳实践虽然你可以打电话 手电筒保存 对于设备张量,通常建议 在序列化之前将状态字典移至 CPU 使用 PyTorch/XLA 时,它们更容易在非 TPU 硬件(例如标准 GPU 机器)上重新加载。

简易的Torch-TPU和真实世界的TPU训练框架

除了官方技术栈之外,社区还在构建更高级别的框架,以使 TPU 更容易被采用。。 一个例子是 aklein4/easy-torch-tpu,一个轻量级的训练框架,专门用于简化 Google Cloud TPU 集群上的 PyTorch/XLA 工作流程。

Easy-torch-tpu 将自身定位为 Hypercomputer/torchprime 等大型、僵化代码库的更简单、更灵活的替代方案。它的设计重点很明确:易于设置、可直接定制以及与系统的无缝集成。 gcloud ssh驱动集群工作流程。它专门针对“学术规模”的实验——参数范围在 1-10B 到 10B 之间的模型,大约需要 32-64 个 TPU 芯片。

可扩展性是通过子类化和配置文件来实现的。通过添加新的子类,您可以插入自己的架构、训练循环、优化器、数据加载器,甚至自定义分片和重存储策略。这使您可以在重用框架的分布式和日志脚手架的同时,自由地进行实验。

该框架与关键生态系统工具紧密集成。权重和偏差支持使实验跟踪变得轻而易举,而 Hugging Face 集成则简化了数据集加载、预训练检查点提取以及模型保存,这些模型随后可在标准的基于 GPU 的 PyTorch 上运行。该代码库包含安装文档、入门示例,并会根据社区反馈积极更新。

局限性、调试和性能缺陷

即使有了这些改进,在 TPU 上运行 PyTorch 仍然不够流畅。了解哪些环节可能出错,可以为您在部署大型模型或动态工作负载时节省大量时间。

图重新编译仍然是最大的隐形性能杀手之一。每当计算图或输入形状在同步点之间发生变化时,XLA 可能需要重新编译,这会导致明显的停顿。这种情况在处理可变长度序列或自适应批处理大小时尤为常见,而这些情况在语言建模和生成工作负载中十分普遍。

不受支持或部分受支持的运算符可能会悄无声息地降低性能。虽然 PyTorch/XLA 和 TorchTPU 的目标是覆盖广泛的运算符,但某些 ATen 运算符可能尚未有原生的 XLA 降级机制。在这种情况下,执行可能会回退到 CPU,这在技术上是正确的,但速度可能会慢几个数量级。内置的调试工具和指标(例如 torch_xla.debug.metrics)帮助您发现 CPU 回退或意外重新编译发生的位置。

像Nsight和nvprof这样的传统GPU性能分析工具无法查看TPU内核内部结构。相反,你需要依赖 XLA 特有的性能分析钩子、TPU 运行时指标和更高级别的日志记录来了解性能瓶颈。许多团队发现,一旦他们采用最佳实践(例如,采用相对静态的数据结构、谨慎的数据加载、监控重新编译过程),就能迅速获得可预测的性能表现。

谷歌的编译器路线图明确针对这些痛点。XLA 中高级有界动态性的研究旨在使模型能够处理不同的序列长度和批处理大小,而无需触发新的编译。不断增长的预编译 TPU 内核库旨在大幅降低新图首次迭代的冷启动延迟。

路线图和生态系统:迈向在 TPU 上实现无摩擦 PyTorch

展望未来,谷歌的 TorchTPU 路线图雄心勃勃,并且与更广泛的 PyTorch 生态系统紧密契合。我们计划建立一个公开的 GitHub 存储库,其中包含详尽的文档、架构教程以及涵盖训练和服务场景的可复现示例。

与 PyTorch 的 Helion DSL 集成指日可待这将扩展开发者编写自定义 TPU 内核的选择范围,而无需深入 XLA 或硬件特定代码的底层。通过原生、一流的方式支持动态形状。 torch.编译 这也是一项优先事项,反映了现代基于序列的模型的现实情况。

多队列支持是另一个重点关注领域。许多生产环境中的 PyTorch 代码库都大量依赖异步执行模式和解耦的内存/计算流。如果能够让这些模式无需进行重大重构就能无缝映射到 TPU,将显著降低大型成熟项目的迁移阻力。

深度生态系统整合已经开始。目前正在努力验证其能否扩展到完整的 TPU Pod 规模,并与 vLLM 和 TorchTitan 等主流的基于 PyTorch 的系统进行集成。与此同时,谷歌正与 Meta 和 PyTorch 社区紧密合作,并探索将 TorchTPU 的关键部分开源,以加速其普及和提高透明度。

所有这一切都发生在一个更大的商业背景下,即TPU产能正在急剧扩张。谷歌云正在签署更多价值数十亿美元的人工智能基础设施协议,Anthropic 计划接入多达一百万个 TPU(容量约为 1 吉瓦),谷歌甚至直接向其内部数据中心出售 TPU。TPU 曾经是谷歌内部专属的小众资源,如今早已成为历史。

综上所述,PyTorch on TPU 的发展历程正以惊人的速度从“另类路径”转变为“标准选项”。凭借 TorchTPU 原生的即时执行体验、PyTorch/XLA 久经考验的惰性执行、easy-torch-tpu 等框架以及围绕它们构建的丰富的 Cloud TPU 基础设施,您现在可以将主流的 PyTorch 模型(通常只需更改设备字符串)高效地运行在一些最大的 AI 超级计算机上。随着技术栈越来越趋向于熟悉的 PyTorch 惯用法,而不是强行引入新的思维模式,将硬件选择视为实现细节而非根本设计约束就变得更加现实。

相关文章: