电子书《How to Scale Your Model》 如何扩展模型:TPU 上 LLM 的系统视角
jax-ml.github.io/scaling-book/
“
训练 LLM 常常感觉像炼金术,但理解和优化模型的性能并非必须如此神秘。本书旨在揭开在 TPU 上扩展语言模型的科学性:TPU 的工作原理以及它们如何相互通信,LLM 如何在真实硬件上运行,以及如何在训练和推理期间并行化模型,以便它们能够在大规模下高效运行。如果您曾想知道“训练这个 LLM 应该花费多少钱?”,“我需要多少内存才能自己部署这个模型?”,或者“什么是 AllGather?”,我们希望本书对您有所帮助。
预期背景: 我们假设您对 LLM 和 Transformer 架构有基本的了解,但不一定了解它们如何大规模运行。您应该了解 LLM 训练的基础知识,并且最好对 JAX 有一些基本的熟悉程度。一些有用的背景阅读材料可能包括这篇关于 Transformer 架构的博客文章,以及这些关于 JAX 中 LLM 扩展的出色幻灯片。
目标与反馈: 到最后,您应该能够轻松地为给定硬件平台上的 Transformer 模型估算最佳并行方案,并大致了解训练和推理所需的时间。 如果您仍然感到困惑,请给我们留言!我们很想知道如何才能让内容更清晰。
本书的目标是解释 TPU(和 GPU)硬件的工作原理,以及 Transformer 架构如何演变为在当前硬件上表现良好。 我们希望这对设计新架构的研究人员和致力于加速当前一代 LLM 运行速度的工程师都有所帮助。
”
章节目录:
第一部分:预备知识
第一章:屋顶线分析简介。 算法受三个因素的限制:计算、通信和内存。 我们可以使用这些来近似估算算法的运行速度。
第二章:如何思考 TPU。 TPU 如何工作? 这如何影响我们可以训练和部署的模型?
第三章:分片矩阵以及如何相乘。 在这里,我们通过我们最喜欢的操作:(分片)矩阵乘法来解释模型分片和多 TPU 并行性。
第二部分:Transformer
第四章:您需要了解的所有 Transformer 数学知识。 Transformer 在前向和后向传播中使用多少 FLOPs? 您可以计算参数数量吗? KV 缓存的大小? 我们在这里研究这些数学知识。
第五章:如何并行化 Transformer 以进行训练。 FSDP。 Megatron 分片。 流水线并行。 给定一定数量的芯片,我如何尽可能高效地训练给定大小和批量大小的模型?
第六章:在 TPU 上训练 LLaMA 3。 我们如何在 TPU 上训练 LLaMA 3? 需要多长时间? 花费多少钱?
第七章:关于 Transformer 推理的一切。 一旦我们训练好模型,我们就必须部署它。 推理增加了一个新的考虑因素 —— 延迟 —— 并改变了内存格局。 我们将讨论解耦服务的工作原理以及如何思考 KV 缓存。
第八章:在 TPU 上部署 LLaMA 3。 在 TPU v5e 上部署 LLaMA 3 需要多少钱? 延迟/吞吐量权衡是什么?
第三部分:实践教程
第九章:如何分析 TPU 代码。 真实的 LLM 永远不像上面的理论那么简单。 在这里,我们解释 JAX + XLA 堆栈,以及如何使用 JAX/TensorBoard 分析器来调试和修复实际问题。
第十章:在 JAX 中编写 TPU 程序。 JAX 提供了一堆用于并行化计算的神奇 API,但您需要知道如何使用它们。 有趣的示例和已解决的问题。
第十一章:结论与进一步阅读。 关于 TPU 和 LLM 的总结性思考和进一步阅读材料。