<?xml version="1.0" encoding="UTF-8"?>
<rss version="2.0"
  xmlns:atom="http://www.w3.org/2005/Atom">
  <channel>
    <title>harleyszhang Blog</title>
    <description>一窍不通大混子,想要成为AI infra糕手。</description>
    <image>
      <url>https://gw.alipayobjects.com/zos/k/hj/sLnZnY.jpg</url>
    </image>
    <follow_challenge>
      <feedId>41147805272531976</feedId>
      <userId>42909600318350336</userId>
    </follow_challenge>
    <link>https://c1lovez1.github.io</link>
    <atom:link href="https://c1lovez1.github.io/feed.xml" rel="self" type="application/rss+xml" />
		
    
    <item>
      <title>Pytorch显存管理机制与显存占用分析方法</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#vgg&quot;&gt;VGG&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#resnet&quot;&gt;ResNet&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#inceptionv3&quot;&gt;Inceptionv3&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#resnetv2&quot;&gt;Resnetv2&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#resnext&quot;&gt;ResNeXt&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#darknet53&quot;&gt;Darknet53&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#densenet&quot;&gt;DenseNet&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#cspnet&quot;&gt;CSPNet&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#vovnet&quot;&gt;VoVNet&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#一些结论&quot;&gt;一些结论&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;vgg&quot;&gt;VGG&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VGG&lt;/code&gt;网络结构参数表如下图所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/VGG.png&quot; width=&quot;60%&quot; alt=&quot;VGG&quot; /&gt;
&lt;/div&gt;

&lt;h2 id=&quot;resnet&quot;&gt;ResNet&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNet&lt;/code&gt; 模型比 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VGG&lt;/code&gt; 网络具有更少的滤波器数量和更低的复杂性。 比如 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Resnet34&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FLOPs&lt;/code&gt; 为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;3.6G&lt;/code&gt;，仅为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VGG-19&lt;/code&gt; &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;19.6G&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;18%&lt;/code&gt;。&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;注意，论文中算的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FLOPs&lt;/code&gt;，把乘加当作 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;1&lt;/code&gt; 次计算。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNet&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VGG&lt;/code&gt; 的网络结构连接对比图，如下图所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/resnet.png&quot; width=&quot;60%&quot; alt=&quot;resnet&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;不同层数的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Resnet&lt;/code&gt; 网络参数表如下图所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/resnet_network_parameter_table.png&quot; width=&quot;60%&quot; alt=&quot;resnet网络参数表&quot; /&gt;
&lt;/div&gt;

&lt;blockquote&gt;
  &lt;p&gt;看了后续的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNeXt&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNetv2&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Densenet&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CSPNet&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VOVNet&lt;/code&gt; 等论文，越发觉得 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNet&lt;/code&gt; 真的算是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Backone&lt;/code&gt; 领域划时代的工作了，因为它让&lt;strong&gt;深层&lt;/strong&gt;神经网络可以训练，基本解决了深层神经网络训练过程中的梯度消失问题，并给出了系统性的解决方案（两种残差结构），即系统性的让网络变得更“深”了。而让网络变得更“宽”的工作，至今也没有一个公认的最佳方案（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Inception&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNeXt&lt;/code&gt; 等后续没有广泛应用），难道是因为网络变得“宽”不如“深”更重要，亦或是我们还没有找到一个更有效的方案。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h2 id=&quot;inceptionv3&quot;&gt;Inceptionv3&lt;/h2&gt;

&lt;p&gt;Inception v3 是一种图像识别模型，经证实可以对 ImageNet 数据集实现 78.1% 以上的准确率。该模型是数年来多位研究人员提出的诸多想法积淀的成果。它以 Szegedy 等人发表的《Rethinking the Inception Architecture for Computer Vision》原创性论文为理论依据。&lt;/p&gt;

&lt;p&gt;模型本身由对称和非对称构建块组成，包括卷积、平均池化、最大池化、串联、丢弃、全连接层。批量归一化也在模型中广泛应用，同时用于激活输入。损失是通过 Softmax 计算的。&lt;/p&gt;

&lt;p&gt;以下是该模型的简要图示：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/inceptionv3onc--oview.png&quot; width=&quot;60%&quot; alt=&quot;inceptionv3 model&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;常见的一种 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Inception Modules&lt;/code&gt; 结构如下：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/inception_module.jpg&quot; width=&quot;60%&quot; alt=&quot;Inception模块&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;论文地址 &lt;a href=&quot;https://arxiv.org/pdf/1512.00567.pdf&quot;&gt;https://arxiv.org/pdf/1512.00567.pdf&lt;/a&gt;。&lt;/p&gt;
&lt;h2 id=&quot;resnetv2&quot;&gt;Resnetv2&lt;/h2&gt;

&lt;p&gt;作者总结出&lt;strong&gt;恒等映射形式的快捷连接和预激活对于信号在网络中的顺畅传播至关重要&lt;/strong&gt;的结论。&lt;/p&gt;

&lt;h2 id=&quot;resnext&quot;&gt;ResNeXt&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNeXt&lt;/code&gt; 的卷积block 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Resnet&lt;/code&gt; 对比图如下所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/comparison_of_resnext_s_convolutional_block_and_resnet.png&quot; width=&quot;60%&quot; alt=&quot;resnext的卷积block和resnet的对比图&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;ResNeXt 和 Resnet 的模型结构参数对比图如下图所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/comparison_of_resnext_s_structural_parameters_and_resnet.png&quot; width=&quot;60%&quot; alt=&quot;resnext的结构参数和resnet的对比图&quot; /&gt;
&lt;/div&gt;

&lt;h2 id=&quot;darknet53&quot;&gt;Darknet53&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Darknet53&lt;/code&gt; 模型结构连接图，如下图所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/darknet53.png&quot; width=&quot;60%&quot; alt=&quot;darknet53&quot; /&gt;
&lt;/div&gt;

&lt;h2 id=&quot;densenet&quot;&gt;DenseNet&lt;/h2&gt;

&lt;blockquote&gt;
  &lt;p&gt;作者 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Gao Huang&lt;/code&gt; 于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;2018&lt;/code&gt; 年发表的论文 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Densely Connected Convolutional Networks&lt;/code&gt;。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;&lt;strong&gt;在密集块（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseBlock&lt;/code&gt;）结构中，每一层都会将前面所有层 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;concate&lt;/code&gt; 后作为输入&lt;/strong&gt;。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseBlock&lt;/code&gt;（类似于残差块的密集块结构）结构的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;3&lt;/code&gt; 画法图如下所示：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/3_densenet_structure_drawing_methods.png&quot; width=&quot;60%&quot; alt=&quot;3种DenseNet结构画法&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;可以看出 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseNet&lt;/code&gt; 论文更侧重的是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseBlock&lt;/code&gt; 内各个卷积层之间的密集连接（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;dense connection&lt;/code&gt;）关系，另外两个则是强调每层的输入是前面所有层 feature map 的叠加，反映了 feature map 数量的变化。&lt;/p&gt;

&lt;h2 id=&quot;cspnet&quot;&gt;CSPNet&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CSPDenseNet&lt;/code&gt; 的一个阶段是由局部密集块和局部过渡层组成（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;a partial dense block and a partial transition layer&lt;/code&gt;）&lt;/strong&gt;。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/figure_3_several_different_forms_of_csp.png&quot; width=&quot;60%&quot; alt=&quot;Figure3几种不同形式的CSP&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CSP&lt;/code&gt; 方法可以减少模型计算量和提高运行速度的同时，还不降低模型的精度，是一种更高效的网络设计方法，同时还能和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Resnet&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Densenet&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Darknet&lt;/code&gt; 等 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;backbone&lt;/code&gt; 结合在一起。&lt;/p&gt;

&lt;h2 id=&quot;vovnet&quot;&gt;VoVNet&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;One-Shot Aggregation（只聚集一次）是指 OSA 模块的 concat 操作只进行一次，即只有最后一层($1\times 1$ 卷积)的输入是前面所有层 feature map 的 concat（叠加）&lt;/strong&gt;。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;OSA&lt;/code&gt; 模块的结构图如图 1(b) 所示。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/VoVNet.png&quot; width=&quot;60%&quot; alt=&quot;VoVNet&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;在 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;OSA module&lt;/code&gt; 中，每一层产生两种连接，一种是通过 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;conv&lt;/code&gt; 和下一层连接，产生 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;receptive field&lt;/code&gt; 更大的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;feature map&lt;/code&gt;，另一种是和最后的输出层相连，以聚合足够好的特征。通过使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;OSA module&lt;/code&gt;，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;5&lt;/code&gt; 层 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;43&lt;/code&gt; &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;channels&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseNet-40&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MAC&lt;/code&gt; 可以被减少 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;30%&lt;/code&gt;（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;3.7M -&amp;gt; 2.5M&lt;/code&gt;）。&lt;/p&gt;

&lt;p&gt;基于 OSA 模块构建的各种 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VoVNet&lt;/code&gt; 结构参数表如下。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/various_vovnet_structures.png&quot; width=&quot;60%&quot; alt=&quot;各种VoVNet结构&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;作者认为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseNet&lt;/code&gt; 用更少的参数与 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Flops&lt;/code&gt; 而性能却比 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNet&lt;/code&gt; 更好，主要是因为&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;concat&lt;/code&gt; 比 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;add&lt;/code&gt; 能保留更多的信息。但是，实际上 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DenseNet&lt;/code&gt; 却比 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ResNet&lt;/code&gt;要慢且消耗更多资源。&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;GPU&lt;/code&gt; 的计算效率：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;GPU&lt;/code&gt; 特性是擅长 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;parallel computation&lt;/code&gt;，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tensor&lt;/code&gt; 越大，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;GPU&lt;/code&gt; 使用效率越高。&lt;/li&gt;
  &lt;li&gt;把大的卷积操作拆分成碎片的小操作将不利于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;GPU&lt;/code&gt; 计算。&lt;/li&gt;
  &lt;li&gt;设计 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;layer&lt;/code&gt; 数量少的网络是更好的选择。&lt;/li&gt;
  &lt;li&gt;1x1 卷积可以减少计算量，但不利于 GPU 计算。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;在 CenterMask 论文提出了 VoVNetv2，其卷积模块结构图如下：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/backbone/VoVNetv2.png&quot; width=&quot;60%&quot; alt=&quot;VoVNetv2&quot; /&gt;
&lt;/div&gt;

&lt;h2 id=&quot;一些结论&quot;&gt;一些结论&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;当卷积层的输入输出通道数相等时，内存访问代价（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MAC&lt;/code&gt;）最小。&lt;/li&gt;
  &lt;li&gt;影响 CNN 功耗的主要因素在于内存访问代价 MAC，而不是计算量 FLOPs。&lt;/li&gt;
  &lt;li&gt;GPU 擅长并行计算，Tensor 越大，GPU 使用效率越高，把大的卷积操作拆分成碎片的小操作不利于 GPU 计算。&lt;/li&gt;
  &lt;li&gt;1x1 卷积可以减少计算量，但不利于 GPU 计算。&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;VGG/ResNet/Inception/ResNeXt/CSPNet&lt;/code&gt; 论文&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://blog.csdn.net/shanglianlm/article/details/106482678&quot;&gt;深度学习论文: An Energy and GPU-Computation Efficient Backbone Network for Object Detection及其PyTorch&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2025-05-20</pubDate>
      <link>https://c1lovez1.github.io/2025-05-20/Pytorch%E6%98%BE%E5%AD%98%E7%AE%A1%E7%90%86%E6%9C%BA%E5%88%B6%E4%B8%8E%E6%98%BE%E5%AD%98%E5%8D%A0%E7%94%A8%E5%88%86%E6%9E%90%E6%96%B9%E6%B3%95.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2025-05-20/Pytorch%E6%98%BE%E5%AD%98%E7%AE%A1%E7%90%86%E6%9C%BA%E5%88%B6%E4%B8%8E%E6%98%BE%E5%AD%98%E5%8D%A0%E7%94%A8%E5%88%86%E6%9E%90%E6%96%B9%E6%B3%95.html</guid>
    </item>
    
		
    
    <item>
      <title>DeepseekMoE 结构详解和代码实现</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#1-基础-moe-结构介绍&quot;&gt;1. 基础 MOE 结构介绍&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#2-deepseekmoe-结构介绍&quot;&gt;2. DeepseekMOE 结构介绍&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#21-gate-网络与-deepseekmoe-计算流程&quot;&gt;2.1 Gate 网络与 DeepseekMOE 计算流程&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#3-deepseekmoe-结构代码实现&quot;&gt;3. DeepseekMOE 结构代码实现&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#31-deepseekv2mlp-实现&quot;&gt;3.1 DeepseekV2MLP 实现&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#32-门控路由网络实现&quot;&gt;3.2 门控/路由网络实现&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#33-deepseekmoe-实现&quot;&gt;3.3 DeepseekMOE 实现&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;1-基础-moe-结构介绍&quot;&gt;1. 基础 MOE 结构介绍&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Mixtral&lt;/code&gt; 8x7B (announcement, model card) 是高质量的混合专家模型 (Mixed Expert Models，简称 MoEs) 的 Transformer 模型，或者说是一种稀疏的 mixture-of-experts 模型，采用纯解码器结构，并使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MOE&lt;/code&gt; 结构，替换原始的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FFN&lt;/code&gt; 结构。在每一层，对每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;token&lt;/code&gt;，存在一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;router network&lt;/code&gt; 会挑选两组 “experts”(即参数量更小的 FFN）来分别处理该 token，并通过&lt;strong&gt;加法方式&lt;/strong&gt;融合两组 “experts” 的输出。&lt;/p&gt;

&lt;p&gt;基础版的（稀疏）MOE 结构图如下图所示:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/moe/basic_moe.png&quot; alt=&quot;basic_moe&quot; /&gt;&lt;/p&gt;

&lt;p&gt;MOE 通常由两部分组成：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;门控或 Router 网络&lt;/strong&gt;：模块负责根据输入 token 的特征动态选择激活哪些专家，路由器是由带学习的参数组成的网络。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;“experts” 网络（小型 FFN）&lt;/strong&gt;：每层 MOE 都包含若干个（稀疏）专家网络，其通常是小型的 FFN，&lt;strong&gt;在实际推理中只有部分专家(通常 8 个)会被激活参与计算&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;2-deepseekmoe-结构介绍&quot;&gt;2. DeepseekMOE 结构介绍&lt;/h2&gt;

&lt;p&gt;和基础 MOE 结构的区别是：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;更精细地划分专家网络&lt;/strong&gt;，提升每个专家的专业性，提高知识表达的准确度。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;引入部分共享专家&lt;/strong&gt;，减少不同专家间的知识冗余，提升计算效率；所有 tokens 都会经过的共享专家，每个 token 会用计算的 Router 权重，来选择 topK 个专家，然后和共享的专家的输出一起加权求和。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;DeepseekMOE 其实是有两类专家的：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;共享专家（Shared Expert）：1 个共享专家，用于捕捉&lt;strong&gt;通用&lt;/strong&gt;、全局的特征信息。&lt;/li&gt;
  &lt;li&gt;路由专家（Routed Experts）：每个 MoE 层都包含 256 个路由专家，负责精细化处理输入 tokens 的专业特征。&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;21-gate-网络与-deepseekmoe-计算流程&quot;&gt;2.1 Gate 网络与 DeepseekMOE 计算流程&lt;/h3&gt;

&lt;p&gt;当一个 token 的向量传入 MoE 层时，首先会经过一个专门的 Gate 网络，该网络负责计算 token 与各个路由专家之间的匹配得分。具体流程如下：&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;计算 tokens 和专家的匹配得分&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;Gate 网络通过&lt;strong&gt;线性变换&lt;/strong&gt;计算每个 token 与所有路由专家的兼容性得分。得分可以反映 token 和各专家“契合”的程度。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;选择 Top-K 专家&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;基于得分，Gate 网络为每个 token 选择 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top-K&lt;/code&gt; 个最合适的路由专家。在 DeepSeek‐V3 中，每个 token 通常选择 8 个路由专家（在一些实现中还可能对跨节点路由做限制，如最多路由到 4 个不同节点），从而只激活极少数专家进行计算。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;专家处理与加权聚合&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;被选中的专家各自对 token 进行独立处理（专家实际上是小型 FFN 模块，类似于 Transformer 中的 FFN 模块），并产生各自的输出。然后，&lt;strong&gt;这些专家的输出会根据 Gate 网络给出的得分权重进行加权聚合&lt;/strong&gt;，最后再和共享专家的输出进行融合，形成当前 MoE 层的最终输出表示。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;DeepseekV2 模型的 MOE 参数如下：&lt;/p&gt;

&lt;div class=&quot;language-json highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;err&quot;&gt;//&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;err&quot;&gt;部分参数省略&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;hidden_act&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;silu&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;hidden_size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5120&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;initializer_range&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.02&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;intermediate_size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;12288&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;model_type&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;deepseek_v2&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;moe_intermediate_size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1536&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;moe_layer_freq&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;n_group&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;n_routed_experts&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;160&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;n_shared_experts&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;norm_topk_prob&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;kc&quot;&gt;false&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;num_experts_per_tok&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;num_hidden_layers&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;60&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;num_key_value_heads&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;topk_group&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
  &lt;/span&gt;&lt;span class=&quot;nl&quot;&gt;&quot;topk_method&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;w&quot;&gt; &lt;/span&gt;&lt;span class=&quot;s2&quot;&gt;&quot;group_limited_greedy&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;w&quot;&gt;
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;混合专家（MoE）参数说明：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/moe/moe_params.png&quot; width=&quot;60%&quot; alt=&quot;moe_params&quot; /&gt;
&lt;/div&gt;

&lt;h2 id=&quot;3-deepseekmoe-结构代码实现&quot;&gt;3. DeepseekMOE 结构代码实现&lt;/h2&gt;

&lt;p&gt;这里只考虑推理模式下的 DeepseekMOE 结构实现，且分步实现。&lt;/p&gt;

&lt;h3 id=&quot;31-deepseekv2mlp-实现&quot;&gt;3.1 DeepseekV2MLP 实现&lt;/h3&gt;

&lt;p&gt;专家其实就是参数量更少的 FFN/MLP 结构，和 llama 中结构一样，只是参数量和计算量更少了，DeepseekV2MLP 代码如下所示。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;DeepseekV2MLP&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gate_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;up_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;down_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;act_fn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ACT2FN&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_act&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# silu 激活函数
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mlp_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;down_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;act_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gate_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;up_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mlp_out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;32-门控路由网络实现&quot;&gt;3.2 门控/路由网络实现&lt;/h3&gt;

&lt;p&gt;门控网络的作用是，根据输入 tokens 动态的选择 Top-K 个专家，并为每个 Token 分配权重。关键流程如下：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;门控分数计算&lt;/strong&gt;：通过线性层 + Softmax 生成专家选择概率分布。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Top-K 专家选择&lt;/strong&gt;：支持两种模式（贪婪选择 vs 分组限制贪婪选择），贪婪模式直接使用 torch.topk 函数&lt;strong&gt;选取分数张量中的前 k 个分数&lt;/strong&gt;。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;权重归一化&lt;/strong&gt;：对 Top-K 权重进行归一化或缩放。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;代码实现如下所示:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch.nn.functional&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;math&lt;/span&gt; 
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;dataclasses&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dataclass&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;MoEGate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top_k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_experts_per_tok&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_routed_experts&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_routed_experts&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;routed_scaling_factor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;routed_scaling_factor&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scoring_func&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scoring_func&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_method&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_method&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_group&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_group&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_group&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_group&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm_topk_prob&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm_topk_prob&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;# 静态化推理配置（假设配置固定）
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inference_norm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;norm_topk_prob&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top_k&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_group_limited&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_method&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;group_limited_greedy&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 门控权重
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gating_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weight&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Parameter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;empty&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_routed_experts&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gating_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reset_parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;reset_parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kaiming_uniform_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;math&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inference_mode&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 禁用梯度与训练逻辑
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;bsz&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;# 门控分数计算（保持原始数据类型）
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# [n_tokens, n_experts]
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 自动推断 dtype
&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Top-K 选择（静态分支）
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;use_group_limited&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# 分组限制逻辑优化
&lt;/span&gt;            &lt;span class=&quot;n&quot;&gt;group_scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bsz&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;values&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;group_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;group_scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sorted&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;group_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;group_scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scatter_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;group_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;score_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;group_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;unsqueeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_routed_experts&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bsz&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;masked_fill&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;~&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;score_mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;top_k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sorted&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 权重归一化（静态分支）
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inference_norm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keepdim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1e-20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;routed_scaling_factor&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# aux_loss 始终为 None
&lt;/span&gt;
&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataclass&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;DeepseekV2Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 1, Position Config
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;max_position_embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;163840&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;vocab_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;102400&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 2, MLA Config
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# down_linear config
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;q_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1536&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# head_dim、heads and hidden_size config
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5120&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_attention_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_key_value_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
    
    &lt;span class=&quot;n&quot;&gt;attention_bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;attention_dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# rope config
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;rope_theta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10000&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3, MOE Config
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;n_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;n_routed_experts&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;160&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_experts_per_tok&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;topk_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;routed_scaling_factor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;scoring_func&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;softmax&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;topk_method&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;greedy&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;norm_topk_prob&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 初始化配置
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeepseekV2Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 模拟输入，CPU 电脑可直接跑，去除了 cuda 设备限制代码
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;cuda&quot;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_available&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;cpu&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5120&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 创建模块
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;moe_gate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;MoEGate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 半精度推理
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# gate 网络推理
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;moe_gate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;topk_idx shape &quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# 32 * 64 = 2048 个 tokens
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;topk_weight shape&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
# 输出如下，表示每个 token 会激活 6 个专家参与计算
topk_idx shape  torch.Size([2048, 6]) 
topk_weight shape torch.Size([2048, 6])
&quot;&quot;&quot;&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;33-deepseekmoe-实现&quot;&gt;3.3 DeepseekMOE 实现&lt;/h3&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;门控计算&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;调用门控网络（self.gate），对输入 hidden_states 计算得到 top‑k 专家索引（topk_idx）、对应权重（topk_weight）以及辅助损失（aux_loss，推理时不参与梯度计算）。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;数据重排&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;将输入 hidden_states 展平为二维张量（形状 [B * T, d]），并将 topk_idx 也展平。&lt;/li&gt;
      &lt;li&gt;在推理模式下，通常不需要像训练时那样对每个 token 进行 repeat_interleave，因为每个 token 只会由对应专家处理一次。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;专家计算&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;根据展平后的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;topk_idx&lt;/code&gt;，依次对每个专家负责的 token 子集进行计算。&lt;/li&gt;
      &lt;li&gt;由于这里可能存在多个 token 被分配给不同专家，实际实现中需要将每个专家的输出按顺序记录下来。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;输出重构与加权融合&lt;/strong&gt;
    &lt;ul&gt;
      &lt;li&gt;将所有专家计算的输出进行合并。通过将输出重新整理（排序）回原始 token 顺序，并按照 topk_weight 对各个专家输出进行加权求和，从而获得最终输出。&lt;/li&gt;
      &lt;li&gt;整个过程保证最终输出形状与原始输入保持一致，即 [B, T, d]。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;代码实现如下所示：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# 为了单元测试，模拟不使用分布式（ep_size默认为1）
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;DeepseekV2MoE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    A mixed expert module containing shared experts.
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_experts_per_tok&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_experts_per_tok&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;experts&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ModuleList&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;DeepseekV2MLP&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;moe_intermediate_size&lt;/span&gt;
                &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_routed_experts&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;MoEGate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_shared_experts&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;moe_intermediate_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_shared_experts&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shared_experts&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DummyMLP&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 此处为简化实现，仅做推理示例，不涉及分布式通信
&lt;/span&gt;    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;no_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;moe_infer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# x: [batch * seq_len, hidden_size]
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;# 对每个 token 依然采用与训练类似的方式进行专家计算
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;flat_topk_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;expert&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;experts&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flat_topk_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;continue&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expert&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# 简单拼接，不做复杂排序和 all-to-all 操作
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;outs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;new_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;empty_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# 这里直接返回加权求和的结果（实际实现更复杂）
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;final_out&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;topk_weight&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;unsqueeze&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;final_out&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/pdf/2401.04088&quot;&gt;Mixtral of Experts&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/pdf/2101.03961&quot;&gt;Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://huggingface.co/blog/zh/moe&quot;&gt;混合专家模型 (MoE) 详解&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/22570639120&quot;&gt;MOE 大模型架构与机制详解 —— 以 DeepSeek‑v3 为例&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2025-02-12</pubDate>
      <link>https://c1lovez1.github.io/2025-02-12/deepseek-moe-code.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2025-02-12/deepseek-moe-code.html</guid>
    </item>
    
		
    
    <item>
      <title>MLA 结构代码实现及优化</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#1-mla-实现拆解&quot;&gt;1. MLA 实现拆解&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#11-q-向量计算&quot;&gt;1.1 Q 向量计算&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#12-kv-向量计算&quot;&gt;1.2 KV 向量计算&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#13-self-attention-计算&quot;&gt;1.3 Self-Attention 计算&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#2-标准-mla-模块的代码实现&quot;&gt;2. 标准 MLA 模块的代码实现&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#3-mla-模块的代码优化-projection-absorption&quot;&gt;3 MLA 模块的代码优化-Projection Absorption&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#31-cc-cachecompressed&quot;&gt;3.1 CC (CacheCompressed）&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#32-a_ccabsorbcachecompressed&quot;&gt;3.2 A_CC（AbsorbCacheCompressed）&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;1-mla-实现拆解&quot;&gt;1. MLA 实现拆解&lt;/h2&gt;

&lt;p&gt;DeepDeekv2 的模型配置如下所示:&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/deepseekv2/deepseekv2_config.png&quot; width=&quot;40%&quot; alt=&quot;deepseekv2_config&quot; /&gt;
&lt;/div&gt;

&lt;h3 id=&quot;11-q-向量计算&quot;&gt;1.1 Q 向量计算&lt;/h3&gt;
&lt;blockquote&gt;
  &lt;p&gt;大部分参考 &lt;a href=&quot;https://github.com/madsys-dev/deepseekv2-profile/blob/main/workspace/blog/optimizing-mla.md&quot;&gt;DeepSeek-V2高性能推理优化笔记：MLA优化&lt;/a&gt;，部分细节做了修改和优化， MLA 结构图以及这章节的公式更多的是给出 MLA 过程和细节，实际的代码实现没有一一对应。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;1，在 DeepSeek-V2 中，Q 向量也采用了低秩压缩的方式。首先，将输入向量投影到一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;1536&lt;/code&gt;（&lt;strong&gt;对应模型配置文件中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;q_lora_rank&lt;/code&gt; 参数&lt;/strong&gt;）维的低维空间，得到 Latent $c_t^Q$。&lt;/p&gt;

\[c_t^Q = W^{DQ} h_t \in \mathbb{R}^{B \times L \times 1536}\]

&lt;p&gt;2，然后，再将其投影到 $\mathbb{R}^{H \times 128}$ 的多头向量空间上（其中 $H=128$ 是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;heads&lt;/code&gt; 数，对应配置文件中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;qk_nope_head_dim&lt;/code&gt; 参数），得到了 Q 向量的第一部分: $q_t^C$。&lt;/p&gt;

\[q_t^C = W^{UQ} c_t^Q \in \mathbb{R}^{B \times L \times H \times 128}\]

&lt;p&gt;3，再将其投影到 $\mathbb{R}^{H \times 64}$（对应模型配置文件中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;qk_rope_head_dim&lt;/code&gt; 参数）上，并使用 RoPE 嵌入位置信息，得到 Q 向量的第二部分: $q_t^R$。&lt;/p&gt;

\[q_t^R = \mathrm{RoPE}(W^{QR} h_t) \in \mathbb{R}^{B \times L \times H \times 64}\]

&lt;p&gt;4，最后，将这两部分进行 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;concat&lt;/code&gt; 拼接得到最终的 $Q$ 向量：$q_t$。&lt;/p&gt;

\[q_t = [q_t^C, q_t^R] \in \mathbb{R}^{B \times L \times H \times 192}\]

&lt;p&gt;其中：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;$B$: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch_size&lt;/code&gt; 批量大小；&lt;/li&gt;
  &lt;li&gt;$L$: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq_len&lt;/code&gt; 序列长度；&lt;/li&gt;
  &lt;li&gt;$H$: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;heads&lt;/code&gt; 注意力头数；&lt;/li&gt;
  &lt;li&gt;$\mathbb{R}$ 的最后一维是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;head_dim&lt;/code&gt;。&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;12-kv-向量计算&quot;&gt;1.2 KV 向量计算&lt;/h3&gt;

&lt;p&gt;1，计算 $KV$ 向量时，首先，将输入向量投影到一个 $512$（&lt;strong&gt;对应模型配置文件中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kv_lora_rank&lt;/code&gt; 参数&lt;/strong&gt;）维的低维空间，得到 Latent $c_t^{KV}$。&lt;/p&gt;

\[c_t^{KV} = W^{DKV} h_t \in \mathbb{R}^{B \times L \times 512}\]

&lt;p&gt;2，然后，和 $Q$ 向量的计算过程类似，再将其投影到 $\mathbb{R}^{H \times 128}$ 的多头向量空间上（其中 $H=128$ 是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;heads&lt;/code&gt; 数，$128$ 对应模型配置文件中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;qk_rope_head_dim&lt;/code&gt; 参数，得到了 $K$ 向量的第一部分 $k_t^C$。&lt;/p&gt;

\[k_t^C = W^{UK}c_t^{K} \in \mathbb{R}^{B\times L\times H\times 128}\]

&lt;p&gt;3，将输入向量投影到 $64$（对应模型配置文件中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;qk_rope_head_dim&lt;/code&gt; 参数）维向量空间，并应用 RoPE 嵌入位置信息得到 $K$ 向量的第二部分： $k_t^R$。&lt;/p&gt;

\[k_t^R = \mathrm{RoPE}(W^{KR} h_t) \in \mathbb{R}^{B \times L \times 1 \times 64}\]

&lt;p&gt;4，最后，和 $Q$ 不同的是，完整的 $K$ 是将 $k_t^R$ &lt;strong&gt;广播到每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;head&lt;/code&gt; 后与 $k_t^C$ &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;concate&lt;/code&gt; 拼接得到&lt;/strong&gt;：&lt;/p&gt;

\[k_t = \begin{bmatrix}
    k_{t,1}^C &amp;amp; k_t^R \\ 
    k_{t,2}^C &amp;amp; k_t^R \\
    \vdots &amp;amp; \vdots \\
    \end{bmatrix} \in \mathbb{R}^{B \times L \times H \times 192}\]

&lt;p&gt;上述广播后拼接的方式意味着，&lt;strong&gt;每个 head 的 RoPE 部分是完全相同的&lt;/strong&gt;。&lt;/p&gt;

&lt;p&gt;$V$ 向量因为不需要执行 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ROPE&lt;/code&gt; 操作，所以它的的计算较为简单，直接将 $c_t^{KV}$ 解压缩（升维）到 $\mathbb{R}^{H \times 128}$ 即可：&lt;/p&gt;

\[\mathbf{v}_t = W^{UV} c_t^{KV} \in \mathbb{R}^{B \times L \times H \times 128}\]

&lt;p&gt;注意: $k_t^R$ 和 $c_t^{KV}$ 是需要缓冲的向量。前面计算得到 $q_t$、$k_t$ 和 $\mathbf{v}_t$ 用来执行 self-attention 计算。&lt;/p&gt;

&lt;h3 id=&quot;13-self-attention-计算&quot;&gt;1.3 Self-Attention 计算&lt;/h3&gt;

&lt;p&gt;Self-Attention 的计算过程和传统的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MHA&lt;/code&gt; 一模一样。同样也是首先计算 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;attention score&lt;/code&gt;：&lt;/p&gt;

&lt;!--  --&gt;
&lt;p&gt;\(p = \mathrm{softmax}\left(\frac{q_t^\top k_t + \mathrm{Mask}}{\sqrt{192}}\right) = 
\mathrm{softmax}\left(\frac{{q_t^C}^\top k_t^C + {q_t^R}^\top k_t^R + \mathrm{Mask}}{\sqrt{128 + 64}} \right)
\mathrm{softmax}\left(\frac{{q_t^C}^\top k_t^C + {q_t^R}^\top k_t^R + \mathrm{Mask}} {\sqrt{128 + 64}} \right)
\in \mathbb{R}^{B \times L \times H \times L}\)
&lt;!--  --&gt;&lt;/p&gt;

&lt;p&gt;计算对 $V$的加权和，并将所有 heads 压平（即 heads * head_dim），得到 Attention 输出：&lt;/p&gt;

\[o = p \cdot \mathbf{v}_t \in \mathbb{R}^{B \times L \times H \times 128} \cong \mathbb{R}^{B \times L \times 16384}\]

&lt;p&gt;其中，$16384 = 128 \times 128 = \text{num\;attention\;heads * v\;head\;dim}$。最后，经过另一个注意力输出矩阵的投影（5120 是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;hidden_size&lt;/code&gt;），就能得到 MLA 的最终输出：&lt;/p&gt;

\[u = W^O o \in \mathbb{R}^{B \times L \times 5120}\]

&lt;h2 id=&quot;2-标准-mla-模块的代码实现&quot;&gt;2. 标准 MLA 模块的代码实现&lt;/h2&gt;

&lt;p&gt;transformers 库中的 &lt;a href=&quot;https://huggingface.co/deepseek-ai/DeepSeek-V2-Lite/blob/main/modeling_deepseek.py&quot;&gt;modeling_deepseek.py&lt;/a&gt; 是没有经过推理加速优化的原始实现，我参考其实现给出了一个更为精简和更易看懂的版本，完整代码在&lt;a href=&quot;https://github.com/harleyszhang/llm_note/blob/main/1-transformer_model/src/deepseekv_mla.py&quot;&gt;这里&lt;/a&gt;。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# 从 LlamaAttention 修改而来，适配 DeepseekV2 模型的注意力模块，简单版本不带 kv cache
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;DeepseekV2MLA&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeepseekV2Config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# MHA 初始化相关
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_attention_heads&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;o_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; 
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attention_bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attention_dropout&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attention_dropout&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;training&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# MLA 相关 part1: 压缩
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_lora_rank&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_lora_rank&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_down_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_down_rmsnorm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeepseekV2RMSNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_down_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; 
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_down_rmsnorm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeepseekV2RMSNorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;# MLA 相关 part2: 解压缩
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_up_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; 
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# qk_nope_head_dim = q_head_dim - qk_rope_head_dim
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_up_proj&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; 
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;# MLA 相关 part3: 切片 q k 张量，以及 rope 旋转位置编码
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rotary_emb&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeepseekV2RotaryEmbedding&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;max_position_embeddings&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rope_theta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; 

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;position_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;casual_mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 1，q 压缩和解压缩，以及 split to q_nope, q_rope
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_up_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_down_rmsnorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_down_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q_nope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_rope&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 2, kv 压缩和解压缩
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;kv_down&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_down_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hidden_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;# compressed_kv 压缩后的 kv 张量
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;compressed_kv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_rope&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;kv_down&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_lora_rank&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# num_heads = 1 后续广播其它 heads 上
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;k_rope&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_rope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 对 compressed_kv 解压缩
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;kv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_up_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_down_rmsnorm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;compressed_kv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;k_nope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;value_states&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;split&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;kv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_nope_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 3, 计算 cos 和 sin，并应用 rope 旋转位置编码
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;kv_seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;value_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# shape (b, nums_head, seq_len, v_head_dim)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;rotary_emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;value_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        
        &lt;span class=&quot;n&quot;&gt;q_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_rope&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;apply_rotary_pos_emb&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cos&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;position_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 4, 执行 self-attention 计算
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;query_states&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_nope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;key_states&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;concat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k_nope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_rope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;expand&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)],&lt;/span&gt; 
            &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# qk^t
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;query_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;math&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;casual_mask&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;masked_fill&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;casual_mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;-inf&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        
        &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scores&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;query_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attention_dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;training&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;training&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# attn_weights shape: [bs, num_heads, seq_len, seq_len]
&lt;/span&gt;        
        &lt;span class=&quot;n&quot;&gt;attn_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;value_states&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# shape: [bs, num_heads, seq_len, head_dim]
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;attn_output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;transpose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;contiguous&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v_head_dim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 5, MLA 输出映射
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;o_proj&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attn_output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;3-mla-模块的代码优化-projection-absorption&quot;&gt;3 MLA 模块的代码优化-Projection Absorption&lt;/h2&gt;

&lt;h3 id=&quot;31-cc-cachecompressed&quot;&gt;3.1 CC (CacheCompressed）&lt;/h3&gt;

&lt;p&gt;在 transformers 的最新开源版本中， MLA 算子改为缓存&lt;strong&gt;压缩后的 KV Cache&lt;/strong&gt;，并将 RoPE 后的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;k_pe&lt;/code&gt; 一并缓存入 KV Cache 中，与缓存完整的 KV Cache 相比，这将大大减少每个 token 的每层Cache 大小。&lt;/p&gt;

&lt;h3 id=&quot;32-a_ccabsorbcachecompressed&quot;&gt;3.2 A_CC（AbsorbCacheCompressed）&lt;/h3&gt;

&lt;p&gt;上述 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CacheCompressed&lt;/code&gt; 的实现代码其实并不能实质减少 KV Cache 过大的问题，因为在计算 MLA 的时候，仍然需要存储解压后的完整的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;KV Cache&lt;/code&gt;（中间激活），这很可能引起 OOM 崩溃。&lt;/p&gt;

&lt;p&gt;DeepSeek-V2 论文中提出，可以将 KV 的解压缩矩阵吸收到Q-projection 和 Out-projection 中，从而可以在不解压缩 KV Cache的 情况下直接计算最终的 Attention 结果。&lt;/p&gt;

&lt;p&gt;1，&lt;strong&gt;对于 K 的吸收&lt;/strong&gt;（吸收进 self-attention 算子中， 相当于算子合并），在 Attention Score 的计算公式中，K 向量的非 RoPE 部分的可以做如下展开：&lt;/p&gt;

\[{q_t^C}^\top k_t^C = (W^{UQ} c_t^Q)^{\top} W^{UK} c_t^{KV} = {c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK} c_t^{KV} = ({c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK}) c_t^{KV}\]

&lt;p&gt;即通过矩阵乘法结合律，可以改为计算 $({c_t^Q}^{\top}{W^{UQ}}^{\top} W^{UK})$，避免了解压缩出完整的 $K$ 矩阵。另外，在原始版本的解压缩的过程中，由于每个 token 的 key 都需要与 $W^{UK}$ 相乘才能得到，因此计算量较大；矩阵吸收后，$W^{UK}$ 只需要对 $q_t^C$ 这一个向量相乘，也大大减少了浮点计算量。&lt;/p&gt;

&lt;p&gt;总结：&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;A_CC&lt;/code&gt; 相比于 CC，把原来属于单 kv 的计算量转移到 q 上了，而 q 的 seq_len=1，可减少计算量。&lt;/p&gt;

&lt;p&gt;其中，$c_t^{KV}$ 是我们实际保存的 KV cache。&lt;/p&gt;

&lt;p&gt;2，&lt;strong&gt;$V$ 的吸收&lt;/strong&gt;，其实现更为复杂。为了更方便表述，采用 Einstein 求和约定描述该过程：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;v_t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdc,blc-&amp;gt;blhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# (1) 生成值向量 v_t
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;   &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bqhl,blhd-&amp;gt;bqhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;     &lt;span class=&quot;c1&quot;&gt;# (2) 加权求和得到 o
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;u&lt;/span&gt;   &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;       &lt;span class=&quot;c1&quot;&gt;# (3) 投影到最终输出 u
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# 将上述三式合并，得到总的计算过程
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;u&lt;/span&gt;   &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdc,blc,bqhl,hdD-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 利用结合律改变计算顺序
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;o_&lt;/span&gt;  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bhql,blc-&amp;gt;bhqc&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# (4) 避免显式生成 v_t，减少存储 (b, l, h, d) 的开销。
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;   &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bhqc,hdc-&amp;gt;bhqd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# (5) 延迟投影操作，减少计算量。
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;u&lt;/span&gt;   &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;     &lt;span class=&quot;c1&quot;&gt;# (6)
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;其中生成值向量 v_t：
输入：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;W_UV：权重矩阵，形状为 (h, d, c)，其中 h 是注意力头数，d 是值向量维度，c 是输入特征维度。&lt;/li&gt;
  &lt;li&gt;c_t_KV：键值上下文向量，形状为 (b, l, c)，其中 b 是批次大小，l 是序列长度。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;操作：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;将每个位置的 c 维特征通过 W_UV 投影到 d 维，生成多头值向量 v_t，形状为 (b, l, h, d)。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;改变计算顺序的优化: &lt;strong&gt;通过结合律调整计算顺序，减少中间张量的内存占用&lt;/strong&gt;：
先计算加权上下文 o_:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;操作：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;将注意力权重 attn_weights 直接作用于原始上下文 c_t_KV，生成中间结果 o_，形状为 (b, h, q, c)。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;意义：&lt;/strong&gt;&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;避免显式生成 v_t，减少存储 (b, l, h, d) 的开销。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;上述优化方法的实现和对比测试代码如下所示:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;time&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 配置参数
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 将 h 调整为 64
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_warmup&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;   &lt;span class=&quot;c1&quot;&gt;# 预热次数
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_trials&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 正式测试次数
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# 初始化张量（GPU）
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;cuda&quot;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_available&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;cpu&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;l&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# h 维度为 64
&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# 预热 GPU
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_warmup&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdc,blc-&amp;gt;blhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bqhl,blhd-&amp;gt;bqhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 原始分步实现
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;original_method&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v_t&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdc,blc-&amp;gt;blhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bqhl,blhd-&amp;gt;bqhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;u&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;permute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 优化后实现
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;optimized_method&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;o_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bhql,blc-&amp;gt;bhqc&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;permute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bhqc,hdc-&amp;gt;bhqd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;u&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 测量时间
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;benchmark&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;func&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;times&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_trials&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;start&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;func&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;end&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;times&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;end&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;start&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;times&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_trials&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 执行测试
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time_original&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;benchmark&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;original_method&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1000&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 转换为毫秒
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time_optimized&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;benchmark&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimized_method&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1000&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 打印结果
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;原始方法平均时间: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time_original&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt; ms&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;优化方法平均时间: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time_optimized&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt; ms&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;速度提升: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;time_original&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;time_optimized&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 验证等价性
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;validate_equivalence&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v_t_orig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdc,blc-&amp;gt;blhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;o_orig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bqhl,blhd-&amp;gt;bqhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_t_orig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;u_orig&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o_orig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;permute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    
    &lt;span class=&quot;n&quot;&gt;v_t_opt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdc,blc-&amp;gt;blhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_UV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_t_KV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;o_opt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;bqhl,blhd-&amp;gt;bqhd&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attn_weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_t_opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;u_opt&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;einsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;hdD,bhqd-&amp;gt;bhD&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_o&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;o_opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;permute&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# 检查是否等价
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allclose&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;u_orig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;u_opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;atol&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1e-4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;两种方法结果不一致！&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;两种方法结果一致，验证通过。&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 调用验证函数
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;validate_equivalence&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
原始方法平均时间: 28.649 ms
优化方法平均时间: 20.378 ms
速度提升: 40.6%
两种方法结果一致，验证通过。
&quot;&quot;&quot;&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/pdf/2405.04434&quot;&gt;DeepSeek-V2 论文&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://github.com/madsys-dev/deepseekv2-profile/blob/main/workspace/blog/optimizing-mla.md&quot;&gt;DeepSeek-V2高性能推理优化笔记：MLA优化&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://mp.weixin.qq.com/s/E7NwwMYw14FRT6OKzuVXFA&quot;&gt;再读MLA，还有多少细节是你不知道的&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://bruceyuan.com/llms-zero-to-hero/the-way-of-moe-model-evolution.html#_2-%E7%89%88%E6%9C%AC2-sparsemoe-%E5%A4%A7%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83%E4%BD%BF%E7%94%A8&quot;&gt;LLMs-Zero-to-Hero&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2025-02-10</pubDate>
      <link>https://c1lovez1.github.io/2025-02-10/mla-code.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2025-02-10/mla-code.html</guid>
    </item>
    
		
    
    <item>
      <title>DeepSeekV2 论文解读</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#1-介绍&quot;&gt;1. 介绍&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#2-架构&quot;&gt;2. 架构&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#21-多头潜变量注意力mla提升推理效率&quot;&gt;2.1 多头潜变量注意力（MLA）：提升推理效率&lt;/a&gt;
        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#211-standard-multi-head-attention&quot;&gt;2.1.1 Standard Multi-Head Attention&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#212-low-rank-key-value-joint-compression&quot;&gt;2.1.2 Low-Rank Key-Value Joint Compression&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#213-decoupled-rotary-position-embedding&quot;&gt;2.1.3 Decoupled Rotary Position Embedding&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#214-kv-cache-大小的比较&quot;&gt;2.1.4 kv cache 大小的比较&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#215-总结&quot;&gt;2.1.5 总结&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#22-deepseekmoe以经济成本训练强大的模型&quot;&gt;2.2 DeepSeekMoE：以经济成本训练强大的模型&lt;/a&gt;
        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#222-设备受限路由device-limited-routing&quot;&gt;2.2.2. 设备受限路由（Device-Limited Routing）&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#223-负载均衡的辅助损失auxiliary-loss-for-load-balance&quot;&gt;2.2.3. 负载均衡的辅助损失（Auxiliary Loss for Load Balance）&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#224-token-dropping-策略&quot;&gt;2.2.4. Token-Dropping 策略&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;1-介绍&quot;&gt;1. 介绍&lt;/h2&gt;

&lt;p&gt;DeepSeek-V2 是一种高效的开源混合专家（MoE）语言模型，基于创新的 Transformer 架构，实现了经济的训练和高效的推理。DeepSeek-V2 具有 2360 亿个参数(&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;236B&lt;/code&gt;)，每个 token 激活 21 亿个参数，支持 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;128K&lt;/code&gt; tokens 的上下文长度。&lt;/p&gt;

&lt;p&gt;和 DeepSeekV1 模型结构沿用 llama 结构不同，DeepSeekV2 提出了多头潜在注意力（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLA&lt;/code&gt;）和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DeepSeekMoE&lt;/code&gt;，旨在优化 Transformer 框架中的注意力模块和前馈网络（FFNs）。&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLA&lt;/code&gt;: Multi-head Latent Attention 结构，通过&lt;strong&gt;低秩键值联合压缩&lt;/strong&gt;，&lt;strong&gt;减少了推理时的 KV 缓存&lt;/strong&gt;，从而提高了推理效率。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DeepSeekMoE&lt;/code&gt;: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FFN&lt;/code&gt;（标准 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MOE&lt;/code&gt;） 的优化版。
    &lt;ul&gt;
      &lt;li&gt;&lt;strong&gt;细粒度专家划分(Routed Expert)&lt;/strong&gt;：相比标准 MOE，DeepSeekMoE 在保持参数量不变的前提下，通过减小每个 Expert 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FFN&lt;/code&gt; 维度，来增加 Expert 数量，进行更细粒度专家划分。&lt;/li&gt;
      &lt;li&gt;&lt;strong&gt;共享专家隔离(Shared Expert)&lt;/strong&gt;: 用于表示 Routed Expert 中的共用知识信息，减少 Routed Expert 的知识冗余问题。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;DeepSeek-V2 架构图如下所示：&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/deepseekv2/architecture_of_DeepSeek-V2.png&quot; alt=&quot;architecture_of_DeepSeekv2&quot; /&gt;&lt;/p&gt;

&lt;h2 id=&quot;2-架构&quot;&gt;2. 架构&lt;/h2&gt;

&lt;p&gt;本节介绍 MLA 和 DeepSeekMoE 的详细设计。&lt;/p&gt;

&lt;h3 id=&quot;21-多头潜变量注意力mla提升推理效率&quot;&gt;2.1 多头潜变量注意力（MLA）：提升推理效率&lt;/h3&gt;

&lt;p&gt;传统的 Transformer 模型通常采用多头注意力（MHA），但在生成（generation）过程中，其庞大的 Key-Value（KV）缓存会成为限制推理效率的瓶颈。为减少 KV 缓存占用，研究者提出了&lt;strong&gt;多查询注意力&lt;/strong&gt;（MQA）（Shazeer, 2019）和&lt;strong&gt;分组查询注意力&lt;/strong&gt;（GQA）（Ainslie et al., 2023）。这两种方法虽然减少了 KV 缓存需求，但在性能上仍无法与 MHA 相媲美（关于 MHA、GQA 和 MQA 的消融实验见附录 D.1）。&lt;/p&gt;

&lt;p&gt;DeepSeek-V2 引入了一种全新的注意力机制&lt;strong&gt;多头潜变量注意力&lt;/strong&gt;（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLA&lt;/code&gt;）。MLA 结合了&lt;strong&gt;低秩键值联合压缩&lt;/strong&gt;（low-rank key-value joint compression,），在推理时大幅降低 KV 缓存需求，同时在性能上超越 MHA。&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;MLA 本质上是通过低秩转换的思路减少 head 的维度，即换为一个压缩的 QKV，存储的KV 的维度显著减小，而不是 GQA 方法减少 kv heads 的数量。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;h4 id=&quot;211-standard-multi-head-attention&quot;&gt;2.1.1 Standard Multi-Head Attention&lt;/h4&gt;

&lt;p&gt;先回顾下标准的&lt;strong&gt;多头注意力（MHA）机制&lt;/strong&gt;。设 $d$ 为嵌入维度，$n_h$ 为注意力头数，$d_h$ 为单个注意力头的维度，$h_t \in R_d$ 表示第 $t$ 个 token 进入注意力层的输入向量。&lt;/p&gt;

&lt;p&gt;在 MHA 机制中，我们通过三个投影矩阵 $W_Q、W_K、W_V \in \mathbb{R}^{n_h d_h\times d}$ 分别计算得到查询向量、键向量和值向量（$q_t、k_t、v_t \in \mathbb{R}^{n_h d_h}$），QKV 向量的线性变换公式如下所示：&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;QKV 的线性变换的权重矩阵的第二个维度大小一定为嵌入维度 $d$。&lt;/p&gt;
&lt;/blockquote&gt;

\[\mathbf{q}_t = W^Q \mathbf{h}_t, \tag{1}\]

\[\mathbf{k}_t = W^K \mathbf{h}_t, \tag{2}\]

\[\mathbf{v}_t = W^V \mathbf{h}_t, \tag{3}\]

&lt;p&gt;然后，$q_t$, $k_t$, $v_t$ 将被切分为 $n_h$ 个头（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;heads&lt;/code&gt;），用于多头注意力计算：&lt;/p&gt;

\[[\mathbf{q}_{t,1}; \mathbf{q}_{t,2}; \dots; \mathbf{q}_{t,n_h}] = \mathbf{q}_t
\tag{4}\]

\[[\mathbf{k}_{t,1}; \mathbf{k}_{t,2}; \dots; \mathbf{k}_{t,n_h}] = \mathbf{k}_t
\tag{5}\]

\[[\mathbf{v}_{t,1}; \mathbf{v}_{t,2}; \dots; \mathbf{v}_{t,n_h}] = \mathbf{v}_t
\tag{6}\]

\[\mathbf{o}_{t,i} = \sum_{j=1}^{t} \text{Softmax}_j \left( \frac{\mathbf{q}_{t,i}^T \mathbf{k}_{j,i}}{\sqrt{d_h}} \right) \mathbf{v}_{j,i}
\tag{7}\]

\[\mathbf{u}_t = W^O [\mathbf{o}_{t,1}; \mathbf{o}_{t,2}; \dots; \mathbf{o}_{t,n_h}]
\tag{8}\]

&lt;p&gt;其中，$q_{t,i}$, $k_{t,i}$, $v_{t,i} \in \mathbb{R}^{d_h}$ 分别表示第 $i$ 个注意力头的查询（query）、键（key）和值（value）；$W_O \in \mathbb{R}^{d \times d_h n_h}$ 表示输出投影矩阵。在推理过程中，key 和 value 需要被缓存，以加速计算，避免重复计算。&lt;/p&gt;

&lt;p&gt;标准 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MHA&lt;/code&gt; 每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;token&lt;/code&gt; 的 kv 缓冲大小 = $2n_hd_h l$，单位为字节 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;byte&lt;/code&gt;；如果使用了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;GQA&lt;/code&gt; 优化技术，每个 token 的 kv 缓冲大小变为 $2n_{kv}d_h l = 2n_hd_h l/\text{groups}$ 个元素。下标 $t$ 表示第几个 token，下标 $[1, n_h]$ 表示注意力头数，$l$ 表示 decoder layers 数目。&lt;/p&gt;

&lt;p&gt;在模型部署时，这种庞大的 KV 缓存 成为了一个主要的瓶颈，限制了&lt;strong&gt;最大批量大小&lt;/strong&gt;（batch size）和&lt;strong&gt;序列长度&lt;/strong&gt;（sequence length）。&lt;/p&gt;

&lt;h4 id=&quot;212-low-rank-key-value-joint-compression&quot;&gt;2.1.2 Low-Rank Key-Value Joint Compression&lt;/h4&gt;

&lt;p&gt;MLA 的核心是对&lt;strong&gt;键&lt;/strong&gt;（keys）和&lt;strong&gt;值&lt;/strong&gt;（values）进行&lt;strong&gt;低秩联合压缩&lt;/strong&gt;（low-rank joint compression），以减少 KV 缓存（KV cache）的占用：&lt;/p&gt;

\[\mathbf{c}_t^{KV} = W^{DKV} \mathbf{h}_t,
\tag{9}\]

\[\mathbf{k}_t^{C} = W^{UK} \mathbf{c}_t^{KV},
\tag{10}\]

\[\mathbf{v}_t^{C} = W^{UV} \mathbf{c}_t^{KV},
\tag{11}\]

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;KV&lt;/code&gt; 向量的生成是先投影到一个&lt;strong&gt;低维&lt;/strong&gt;（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;5120 -&amp;gt; 512&lt;/code&gt;）的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compressed_kv&lt;/code&gt; 向量（$\mathbf{c}_t^{KV}$）再升维展开得到 $\mathbf{k}_t^{C}$ 和 $\mathbf{v}_t^{C}$。上述公式的各个变量定义：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;$\mathbf{c}_t^{KV}$ 是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;keys&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;values&lt;/code&gt; 的&lt;strong&gt;压缩后的潜在向量&lt;/strong&gt;（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;latent vector&lt;/code&gt;）；&lt;/li&gt;
  &lt;li&gt;$d_c (\ll d_h n_h)$ 代表 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;KV&lt;/code&gt; 压缩维度（KV compression dimension）&lt;/li&gt;
  &lt;li&gt;$W^{DKV} \in \mathbb{R}^{d_c \times d}$ 是&lt;strong&gt;降维投影矩阵&lt;/strong&gt;（down-projection matrix）；&lt;/li&gt;
  &lt;li&gt;$W^{UK}, W^{UV} \in \mathbb{R}^{d_h n_h \times d_c}$ 分别是 keys 和 values 的&lt;strong&gt;升维投影矩阵&lt;/strong&gt;（up-projection matrices）。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;另外，虽然不能减少 KV Cache 的占用，但是为了&lt;strong&gt;减少训练时的激活内存&lt;/strong&gt;（activation memory），同样也对查询（queries）也进行了&lt;strong&gt;低秩压缩&lt;/strong&gt;（low-rank compression）。同样也是先投影到一个&lt;strong&gt;低维&lt;/strong&gt;（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;5120 -&amp;gt; 1536&lt;/code&gt;）的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compressed_kv&lt;/code&gt; 向量（$\mathbf{c}_t^{Q}$）再升维展开得到 $\mathbf{q}_t^{C}$:&lt;/p&gt;

\[\mathbf{c}_t^{Q} = W^{DQ} \mathbf{h}_t,
\tag{12}\]

\[\mathbf{q}_t^{C} = W^{UQ} \mathbf{c}_t^{Q},
\tag{13}\]

&lt;p&gt;类比前面的公式可知:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;$\mathbf{c}_t^{Q} \in \mathbb{R}^{d’_c}$ 是查询的压缩潜在向量（compressed latent vector for queries）；&lt;/li&gt;
  &lt;li&gt;$d’_c (\ll d_h n_h)$ 表示查询的压缩维度（query compression dimension）；&lt;/li&gt;
  &lt;li&gt;$W^{DQ} \in \mathbb{R}^{d’_c \times d}$ 是查询的降维投影矩阵；&lt;/li&gt;
  &lt;li&gt;$W^{UQ} \in \mathbb{R}^{d_h n_h \times d’_c}$ 是查询的升维投影矩阵（up-projection matrix）。&lt;/li&gt;
&lt;/ul&gt;

&lt;h4 id=&quot;213-decoupled-rotary-position-embedding&quot;&gt;2.1.3 Decoupled Rotary Position Embedding&lt;/h4&gt;

&lt;p&gt;和 DeepSeek 67B（DeepSeek-AI, 2024）类似，作者也计划在 DeepSeek-V2 中使用 旋转位置编码（RoPE, Rotary Position Embedding）（Su et al., 2024）。但是，RoPE 与&lt;strong&gt;低秩 KV 压缩（low-rank KV compression）并不兼容&lt;/strong&gt;。&lt;/p&gt;

&lt;p&gt;具体来说，RoPE 使键（Key）和查询（Query）都具备&lt;strong&gt;位置敏感性&lt;/strong&gt;（position sensitivity）。如果我们在&lt;strong&gt;压缩后的键&lt;/strong&gt; $\mathbf{k}_t^{C}$ 上应用 ROPE，那么实际上我们得到的键表示会是这样的形式：&lt;/p&gt;

\[k_t^R = \text{ROPE}(W^{UK} \mathbf{c}_t^{KV})\]

&lt;p&gt;很明显式（10）中的 $W^{UK}$ 和 RoPE 旋转矩阵在计算过程中“耦合”在一起—这意味着 $W^{UK}$ 输出的结果会始终被那个依赖于具体位置的旋转矩阵所“修正”或“调制”。&lt;/p&gt;

&lt;p&gt;这样会导致在执行 atten weight（$QK^T$）的计算优化中，无法像原本设想的那样，把 $W^{UK}$ 吸收到 $W^Q$  中，因为 当前生成 token 相关的 RoPE 矩阵位于 $W^Q$  和 $W^{UK}$ 之间，而矩阵乘法不满足交换律（commutative law）。这直接导致在推理过程中，我们必须&lt;strong&gt;重新计算&lt;/strong&gt;所有 prefix token 的键（keys），这将显著降低推理效率。&lt;/p&gt;

&lt;p&gt;为了解决这个问题，作者提出了一种&lt;strong&gt;解耦 RoPE&lt;/strong&gt;（decoupled RoPE）的策略，通过&lt;strong&gt;额外引入多头查询&lt;/strong&gt;（multi-head queries）$q_{t,i}^R \in \mathbb{R}^{d^R_h}$和&lt;strong&gt;采用一个共享键&lt;/strong&gt;（shared key) $k_t^R \in \mathbb{R}^{d^R_h}$ 来&lt;strong&gt;承载 RoPE 信息&lt;/strong&gt;。其中 $d^R_h$  代表&lt;strong&gt;解耦查询和键的每头维度&lt;/strong&gt;（per-head dimension of the decoupled queries and key）。&lt;/p&gt;

&lt;p&gt;在使用&lt;strong&gt;解耦 RoPE 策略&lt;/strong&gt;后，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLA&lt;/code&gt; 的计算过程变成如下所示：&lt;/p&gt;

\[\left[ \mathbf{q}_{t,1}^{R}; \mathbf{q}_{t,2}^{R}; \dots; \mathbf{q}_{t,n_h}^{R} \right] = \mathbf{q}_t^{R} = \text{RoPE}(W^{QR} \mathbf{c}_t^{Q}),
\tag{14}\]

\[\mathbf{k}_t^{R} = \text{RoPE}(W^{KR} \mathbf{h}_t),
\tag{15}\]

\[\mathbf{q}_{t,i} = \left[ \mathbf{q}_{t,i}^{C}; \mathbf{q}_{t,i}^{R} \right],
\tag{16}\]

\[\mathbf{k}_{t,i} = \left[ \mathbf{k}_{t,i}^{C}; \mathbf{k}_{t,i}^{R} \right],
\tag{17}\]

\[\mathbf{o}_{t,i} = \sum_{j=1}^{t} \text{Softmax}_j \left( \frac{\mathbf{q}_{t,i}^{T} \mathbf{k}_{j,i}}{\sqrt{d_h + d_h^{R}}} \right) \mathbf{v}_{j,i}^{C},
\tag{18}\]

\[\mathbf{u}_t = W^{O} \left[ \mathbf{o}_{t,1}; \mathbf{o}_{t,2}; \dots; \mathbf{o}_{t,n_h} \right],
\tag{19}\]

&lt;p&gt;其中:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;$W^{QR} \in \mathbb{R}^{d^R_h n_h \times d’_c}$ 表示生成&lt;strong&gt;解耦查询（decoupled queries）矩阵&lt;/strong&gt;&lt;/li&gt;
  &lt;li&gt;$W^{KR} \in \mathbb{R}^{d^R_h \times d}$ 表示&lt;strong&gt;解耦键（decoupled key）的矩阵&lt;/strong&gt;。&lt;/li&gt;
  &lt;li&gt;$\text{RoPE}(\cdot)$ 表示应用 RoPE 矩阵的操作;&lt;/li&gt;
  &lt;li&gt;$\cdot ; \cdot$ 表示拼接（concatenation）操作。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;在推理过程中，解耦后的键（decoupled key）也需要缓存。因此，DeepSeek-V2 的 KV 缓存总大小为 $(d_c + d^R_h)l$ 个元素。&lt;/p&gt;

&lt;p&gt;很明显和前面公式相比，多了 $\mathbf{q}_t^{R}$ 和 $\mathbf{k}_t^{R}$ 两个变量的计算过程，它们用于单独承载 ROPE 信息，并和前面计算得到的 $\mathbf{q}_t^{C}$ 和 $\mathbf{k}_t^{C}$ 做拼接后得到新的 $q、k$，再执行 atten weight 计算（$qk^t$）。&lt;/p&gt;

&lt;p&gt;最后，总结下完成的 MLA 计算过程如下所示：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/deepseekv2/Formulas_MLA.png&quot; width=&quot;60%&quot; alt=&quot;Formulas_MLA&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLA&lt;/code&gt; 结构的可视化图如下所示：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/deepseekv2/architecture_of_MLA.png&quot; width=&quot;80%&quot; alt=&quot;architecture_of_MLA&quot; /&gt;
&lt;/div&gt;

&lt;h4 id=&quot;214-kv-cache-大小的比较&quot;&gt;2.1.4 kv cache 大小的比较&lt;/h4&gt;

&lt;p&gt;多头注意力（MHA）、分组查询注意力（GQA）、多查询注意力（MQA）和多头潜在注意力（MLA）的简化示意图对比如下图 3 所示。通过&lt;strong&gt;将键（Key）和值（Value）压缩到一个潜在向量中，MLA 在推理时大幅减少了对 KV 缓存的需求&lt;/strong&gt;。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/deepseekv2/figure3.png&quot; width=&quot;80%&quot; alt=&quot;figure3&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;下表 1 中对比了不同注意力机制下，每个 token 需要的 KV 缓存大小。MLA 仅需要&lt;strong&gt;少量的 KV 缓存&lt;/strong&gt;，其大小相当于仅有 $2.25$ 组（groups）的 GQA，但其性能却强于 MHA。&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/deepseekv2/table1.png&quot; alt=&quot;table1&quot; /&gt;&lt;/p&gt;

&lt;p&gt;表 1｜不同注意力机制下，每个 token 需要的 KV 缓存对比。其中，&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;$n_h$ 表示&lt;strong&gt;注意力头的数量&lt;/strong&gt;，&lt;/li&gt;
  &lt;li&gt;$d_h$ 表示每个注意力头的维度，&lt;/li&gt;
  &lt;li&gt;$l$ 表示模型层数，&lt;/li&gt;
  &lt;li&gt;$n_g$ 表示 GQA 的组数，&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;$d_c$ 和 $d^R_h$ 分别表示 MLA 中的 KV 压缩维度和解耦查询与键的 per-head 维度&lt;/strong&gt;。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;KV 缓存的数量以&lt;strong&gt;元素(elements)个数&lt;/strong&gt;计算，不考虑存储精度（storage precision）。对于 DeepSeek-V2，&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;$d_c$ 设定为 $4d_h$。&lt;/li&gt;
  &lt;li&gt;$d^R_h$ 设定为 $\frac{d_h}{2}$。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;因此，&lt;strong&gt;DeepSeek-V2 只需要 相当于 GQA $2.25$ 组的 KV 缓存，但相比 MHA 仍能提供更强的性能&lt;/strong&gt;。&lt;/p&gt;

&lt;h4 id=&quot;215-总结&quot;&gt;2.1.5 总结&lt;/h4&gt;

&lt;p&gt;MLA 虽然增大了计算量，但 KV Cache 的减少也降低了显存和带宽的压力，且 llm 推理的 decode 阶段是受限于带宽瓶颈和显存瓶颈，因此 MLA 的引入理论上能明显提高 Generation 的速度。&lt;/p&gt;

&lt;h3 id=&quot;22-deepseekmoe以经济成本训练强大的模型&quot;&gt;2.2 DeepSeekMoE：以经济成本训练强大的模型&lt;/h3&gt;

&lt;p&gt;对于 FFN（前馈网络），我们采用 DeepSeekMoE 架构（Dai et al., 2024）。DeepSeekMoE 主要包含两个关键思想：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;更精细地划分专家网络&lt;/strong&gt;，提升每个专家的专业性，提高知识表达的准确度。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;引入部分共享专家&lt;/strong&gt;，减少不同专家间的知识冗余，从而提升计算效率。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;相比传统的 MoE 架构（如 GShard，Lepikhin et al., 2021），&lt;strong&gt;DeepSeekMoE 在相同的专家参数量和激活参数量下，能显著提升模型性能&lt;/strong&gt;。&lt;/p&gt;

&lt;p&gt;设 $u_t$ 为第 $t$ 个 token 的 FFN 输入，其 FFN 输出 $h’_t$ 计算如下：&lt;/p&gt;

\[\mathbf{h}_t&apos; = \mathbf{u}_t + \sum_{i=1}^{N_s} \text{FFN}_i^{(s)} (\mathbf{u}_t) + \sum_{i=1}^{N_r} g_{ij,t} \text{FFN}_i^{(r)} (\mathbf{u}_t), \tag{20}\]

\[g_{ij,t} =
\begin{cases}
s_{ij,t}, &amp;amp; s_{ij,t} \in \text{Topk}(\{s_{ij,t}| 1 \leq j \leq N_r\}, K_r), \\
0, &amp;amp; \text{otherwise},
\end{cases} \tag{21}\]

\[s_{ij,t} = \text{Softmax}_i (\mathbf{u}_t^T e_i), \tag{22}\]

&lt;p&gt;上述公式中：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;$N_s$ 和 $N_r$ 分别表示&lt;strong&gt;共享专家&lt;/strong&gt;和&lt;strong&gt;路由专家&lt;/strong&gt;的数量；&lt;/li&gt;
  &lt;li&gt;$FFN(s) i(·)$ 和 $FFN(r) 𝑖(·)$ 分别表示第 $i$ 个共享专家和第 $i$ 个路由专家的计算过程；&lt;/li&gt;
  &lt;li&gt;$K_r$ 表示激活的路由专家数量；&lt;/li&gt;
  &lt;li&gt;$g_{i, t}$ 是第 $i$ 个专家的门控值，用来决定该专家是否激活；&lt;/li&gt;
  &lt;li&gt;$s_{i, t}$ 是 token 到专家的亲和度值，表示 token 和专家之间的相关性；&lt;/li&gt;
  &lt;li&gt;$e_i$ 是该层第 $i$ 个路由专家的质心，用于表示专家的聚合特征；&lt;/li&gt;
  &lt;li&gt;$\text{Topk}(·, K)$ 表示从第 $t$ 个 token 计算的所有路由专家的亲和度分数中，选择出 $K$ 个最高的值，并将这些分数组成一个集合。&lt;/li&gt;
&lt;/ul&gt;

&lt;h4 id=&quot;222-设备受限路由device-limited-routing&quot;&gt;2.2.2. 设备受限路由（Device-Limited Routing）&lt;/h4&gt;

&lt;p&gt;作者设计了一种设备受限的路由机制，用于限制 MoE 相关的通信成本。&lt;/p&gt;

&lt;p&gt;当采用&lt;strong&gt;专家并行&lt;/strong&gt;（expert parallelism）时，路由专家（routed experts）会分布在多个设备上。对于每个 token，它的 MoE &lt;strong&gt;相关通信频率&lt;/strong&gt;与其目标专家所涉及的设备数量成正比。由于 DeepSeekMoE 采用了精细的专家划分策略，激活的专家数量可能较多，如果直接应用专家并行，会导致更高的 MoE 相关通信成本。&lt;/p&gt;

&lt;p&gt;在 DeepSeek-V2 中，除了直接选择&lt;strong&gt;得分最高&lt;/strong&gt;的 $K$ 个路由专家（top-K selection） 之外，我们还确保每个 token 的目标专家最多分布在 $M$ 台设备上。具体而言，对于每个 token，我们首先选择&lt;strong&gt;拥有最高亲和度分数的&lt;/strong&gt; $M$ 台设备，然后在这 $M$ 台设备上的专家中执行 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;top-K&lt;/code&gt; 选择。实践中，我们发现当 $M \geq 3$ 时，设备受限路由的效果可以大致匹配不受限的 top-K 路由。&lt;/p&gt;

&lt;h4 id=&quot;223-负载均衡的辅助损失auxiliary-loss-for-load-balance&quot;&gt;2.2.3. 负载均衡的辅助损失（Auxiliary Loss for Load Balance）&lt;/h4&gt;

&lt;p&gt;作者在自动学习的路由策略中引入了负载均衡机制。&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;负载不均衡会导致路由塌陷（routing collapse）（Shazeer et al., 2017），即部分专家可能无法得到充分训练和利用。&lt;/li&gt;
  &lt;li&gt;在专家并行（expert parallelism）机制下，负载不均衡会降低计算效率。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;在 DeepSeek-V2 训练过程中，我们设计了三种&lt;strong&gt;辅助损失&lt;/strong&gt;（auxiliary losses），分别用于控制：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;专家级负载均衡（ $L_{\text{ExpBal}}$ ），&lt;/li&gt;
  &lt;li&gt;设备级负载均衡（ $L_{\text{DevBal}}$ ），&lt;/li&gt;
  &lt;li&gt;通信均衡（ $L_{\text{CommBal}}$ ）。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;这三种损失函数协同作用，确保 DeepSeek-V2 在计算资源受限的条件下，仍能高效训练高性能 MoE 模型。&lt;/p&gt;

&lt;h4 id=&quot;224-token-dropping-策略&quot;&gt;2.2.4. Token-Dropping 策略&lt;/h4&gt;

&lt;p&gt;虽然&lt;strong&gt;平衡损&lt;/strong&gt;失有助于实现负载平衡，但它无法完全保证负载的严格平衡。为了解决负载不平衡带来的计算浪费，作者在训练中引入了&lt;strong&gt;设备级的 Token-Dropping 策略&lt;/strong&gt;。该策略首先计算每个设备的平均计算预算，将每个设备的容量因子设为 1.0。然后，借鉴 Riquelme 等人（2021）的思路，我们会丢弃每个设备上亲和度最低的 token，直到达成计算预算。此外，我们还确保约 10% 的训练序列中的 token 不会被丢弃。这样，在推理过程中，我们可以灵活地根据效率需求选择是否丢弃 token，同时保持训练和推理的一致性。&lt;/p&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/pdf/2405.04434&quot;&gt;DeepSeek-V2 论文&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://github.com/madsys-dev/deepseekv2-profile/blob/main/workspace/blog/optimizing-mla.md&quot;&gt;DeepSeek-V2高性能推理优化笔记：MLA优化&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://mp.weixin.qq.com/s/E7NwwMYw14FRT6OKzuVXFA&quot;&gt;再读MLA，还有多少细节是你不知道的&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://bruceyuan.com/llms-zero-to-hero/the-way-of-moe-model-evolution.html#_2-%E7%89%88%E6%9C%AC2-sparsemoe-%E5%A4%A7%E6%A8%A1%E5%9E%8B%E8%AE%AD%E7%BB%83%E4%BD%BF%E7%94%A8&quot;&gt;LLMs-Zero-to-Hero&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2025-02-07</pubDate>
      <link>https://c1lovez1.github.io/2025-02-07/deepseekv2-paper.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2025-02-07/deepseekv2-paper.html</guid>
    </item>
    
		
    
    <item>
      <title>DeepSeekV3 简单概述</title>
      <description>&lt;h2 id=&quot;1-介绍&quot;&gt;1. 介绍&lt;/h2&gt;

&lt;p&gt;DeepSeek-V3 是一款强大的 Mixture-of-Experts (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MoE&lt;/code&gt;) 语言模型，&lt;strong&gt;总参数量为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;671B&lt;/code&gt;，每个 token 激活的参数为 37B&lt;/strong&gt;。为了实现高效推理和成本效益的训练，DeepSeek-V3 采用了在 DeepSeek-V2 中经过充分验证的 Multi-head Latent Attention (MLA) 和 DeepSeekMoE 架构。此外，&lt;strong&gt;DeepSeek-V3 首创了无辅助损失的负载平衡策略，并引入了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Multi-Token Prediction&lt;/code&gt; 技术来实现多 token 预测的训练目标，用以提高性能&lt;/strong&gt;。我们在 14.8 万亿多样且高质量的 token 上进行了 DeepSeek-V3 的预训练，随后通过监督微调和强化学习阶段，充分发挥其能力。全面评估结果表明，DeepSeek-V3 在性能上超越了其他开源模型，且与领先的闭源模型的性能相当。尽管表现出色，DeepSeek-V3 仅需 2.788M H800 GPU 小时进行全程训练。此外，其训练过程非常稳定。在整个训练过程中，我们没有遇到任何不可恢复的损失峰值，也没有进行任何回滚。&lt;/p&gt;

&lt;h2 id=&quot;2-模型总结&quot;&gt;2. 模型总结&lt;/h2&gt;

&lt;p&gt;&lt;strong&gt;架构：创新的负载平衡策略和训练目标&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;在 DeepSeek-V2 高效架构的基础上，我们&lt;strong&gt;首创了一种无辅助损失的负载平衡策略&lt;/strong&gt;，最大限度减少了因鼓励负载平衡而带来的性能下降。我们研究了一种&lt;strong&gt;多 token 预测（MTP）目标&lt;/strong&gt;，并证明它对模型性能有益。它还可以用于推测性解码，从而加速推理过程。&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;预训练：迈向极致的训练效率&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;我们&lt;strong&gt;设计了一个 FP8 混合精度训练框架&lt;/strong&gt;，并首次验证了在超大规模模型上进行 FP8 训练的可行性和有效性。通过&lt;strong&gt;算法、框架和硬件的联合设计&lt;/strong&gt;，我们克服了跨节点 MoE 训练中的通信瓶颈，几乎实现了计算与通信的完全重叠。这显著提高了我们的训练效率并降低了训练成本，使得我们能够在没有额外开销的情况下进一步扩大模型规模。以仅 2.664M H800 GPU 小时的经济成本，我们完成了 DeepSeek-V3 在 14.8T tokens 上的预训练，产生了目前最强的开源基础模型。预训练后的后续训练阶段仅需要 0.1M GPU 小时。&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;后训练：来自 DeepSeek-R1 的知识蒸馏&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;我们引入了一种创新的方法，将推理能力从长链推理（CoT）模型，特别是来自 DeepSeek R1 系列模型的能力，蒸馏到标准 LLM 中，尤其是 DeepSeek-V3。我们的管道巧妙地将 R1 的验证和反思模式融入 DeepSeek-V3，显著提高了其推理能力。同时，我们还保持对 DeepSeek-V3 输出风格和长度的控制。&lt;/p&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://github.com/deepseek-ai/DeepSeek-V3&quot;&gt;github-DeepSeek-V3&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2025-02-06</pubDate>
      <link>https://c1lovez1.github.io/2025-02-06/deepseekv3-paper.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2025-02-06/deepseekv3-paper.html</guid>
    </item>
    
		
    
    <item>
      <title>triton 内核编译流程解析</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#一-triton-概述&quot;&gt;一 Triton 概述&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#二-triton-编译jit入口&quot;&gt;二 Triton 编译（JIT）入口&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#21-jit-函数&quot;&gt;2.1 jit() 函数&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#22-jitfunction-类&quot;&gt;2.2 JITFunction 类&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#23-astsource-类&quot;&gt;2.3 ASTSource 类&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#三-内核编译函数-compile&quot;&gt;三 内核编译函数 compile&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#ast---tritonir&quot;&gt;AST -&amp;gt; TritonIR&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#四-内核-compile-流程解析&quot;&gt;四 内核 compile 流程解析&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#backend-系统&quot;&gt;Backend 系统&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#41-make_ttir&quot;&gt;4.1 make_ttir&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#42-make_ttgir&quot;&gt;4.2 make_ttgir&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#43-make_llir&quot;&gt;4.3 make_llir&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#44-make_ptx&quot;&gt;4.4 make_ptx&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#45-make_cubin&quot;&gt;4.5 make_cubin&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;一-triton-概述&quot;&gt;一 Triton 概述&lt;/h2&gt;

&lt;p&gt;Triton 既是语言，可用于编写高性能计算 GPU kernel，使用 Triton 编写核函数必须按照其提供的 DSL 进行编写，否则运行失败。也是编译器（根据配置自动生成 kernel），采用 just-in-time 机制进行编译。如果你编写了一个用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;triton.jit&lt;/code&gt; 修饰的核函数，那么在运行时其会被编译成底层 GPU 二进制（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PTX / SASS&lt;/code&gt;），并马上执行。&lt;/p&gt;

&lt;p&gt;Triton编译器的架构，包括&lt;strong&gt;前端、优化器以及后端等部件&lt;/strong&gt;。&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;前端&lt;/strong&gt;（Frontend）： 用于将用户使用 Python 编写的 kernel 或者 Pytorch 2.0 中通过 Inductor 生成的 TritonKernel 转换为对应的 Triton IR。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;优化器&lt;/strong&gt;（Optimizer）：通过各类 pass 将 Triton IR 逐步转换并优化为 TritonGPU IR。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;后端&lt;/strong&gt;（Backend）：将 TritonGPU IR 逐步转换为 LLVM IR，对于 NVIDIA GPU 最终会被编译为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cubin&lt;/code&gt; 文件。&lt;/li&gt;
&lt;/ol&gt;

&lt;h2 id=&quot;二-triton-编译jit入口&quot;&gt;二 Triton 编译（JIT）入口&lt;/h2&gt;

&lt;p&gt;Triton kernel 入口：python 代码中使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;triton.jit&lt;/code&gt; 装饰器修饰的函数，被视为“Triton Kernel”，函数内部可使用 Triton 的 DSL（域专用语言）来编写 GEMM、reduce 算子等。&lt;/p&gt;

&lt;p&gt;下面以官方给出的经典张量 Add 的 kernel 来简单剖析 triton 编译流程：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triton&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# *Pointer* to first input vector.
&lt;/span&gt;               &lt;span class=&quot;n&quot;&gt;y_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# *Pointer* to second input vector.
&lt;/span&gt;               &lt;span class=&quot;n&quot;&gt;output_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# *Pointer* to output vector.
&lt;/span&gt;               &lt;span class=&quot;n&quot;&gt;n_elements&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Size of the vector.
&lt;/span&gt;               &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constexpr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Number of elements each program should process.
&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;# NOTE: `constexpr` so it can be used as a shape value.
&lt;/span&gt;               &lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;program_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;axis&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# We use a 1D launch grid so axis is 0.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;block_start&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_start&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Create a mask to guard memory operations against out-of-bounds accesses.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_elements&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Load x and y from DRAM, masking out any extra elements in case the input is not a multiple of the block size.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Write x + y back to DRAM.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;store&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offsets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; 

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;empty_like&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# preallocate the output.
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DEVICE&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DEVICE&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DEVICE&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;n_elements&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;numel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# In this case, we use a 1D grid where the size is the number of blocks:
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;grid&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;meta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triton&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cdiv&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_elements&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;meta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;BLOCK_SIZE&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# NOTE:
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#  - Each torch.tensor object is implicitly converted into a pointer to its first element.
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#  - `triton.jit`&apos;ed functions can be indexed with a launch grid to obtain a callable GPU kernel.
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#  - Don&apos;t forget to pass meta-parameters as keywords arguments.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;add_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;](&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_elements&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; 
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;21-jit-函数&quot;&gt;2.1 jit() 函数&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;python/triton/runtime/jit.py&lt;/code&gt;: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;jit()&lt;/code&gt; 装饰器函数将原始函数 fn 封装为 JIT 函数: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;JITFunction&lt;/code&gt;，&lt;strong&gt;JITFunction 被调用的时候，是以 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;fn[grid](*args, **kwargs)&lt;/code&gt; 的形式被调用的&lt;/strong&gt;。jit() 函数执行流程：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;decorator 装饰器定义：检查环境变量 TRITON_INTERPRET：
    &lt;ul&gt;
      &lt;li&gt;如果设置为 “1”，返回 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;InterpretedFunction&lt;/code&gt; （用于调试）&lt;/li&gt;
      &lt;li&gt;否则返回 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;JITFunction&lt;/code&gt; （正常编译模式）&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;装饰器将原始函数封装为 JIT 函数并返回。&lt;/li&gt;
&lt;/ol&gt;

&lt;h3 id=&quot;22-jitfunction-类&quot;&gt;2.2 JITFunction 类&lt;/h3&gt;

&lt;p&gt;JITFunction 类继承自 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;KernelInterface&lt;/code&gt;，KernelInterface 类定义了内核函数的&lt;strong&gt;基本接口和启动机制&lt;/strong&gt;，其代码实现如下所示，很明显，通过 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__getitem__&lt;/code&gt; 方法允许使用方括号语法 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[grid]&lt;/code&gt; 来设置内核的执行网格，这里的网格参数定义了 &lt;strong&gt;kernel 的并行执行结构&lt;/strong&gt;。&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__getitem_&lt;/code&gt; 是 Python 的&lt;strong&gt;魔法方法&lt;/strong&gt;，通常用来定义 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;obj[key]&lt;/code&gt; 的行为。KernelInterface 类继承自 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Generic&lt;/code&gt;：泛型类型的抽象基类。Generic 作为所有泛型类的基类，提供类型参数化的基础设施。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;KernelInterface&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Generic&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__getitem__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
        A JIT function is launched with: fn[grid](*args, **kwargs).
        Hence JITFunction.__getitem__ returns a callable proxy that
        memorizes the grid.
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;args&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grid&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;warmup&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;args&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# return cast(T, functools.partial(cast(Callable, self.run), grid=grid))
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__getitem__&lt;/code&gt; 会返回一个匿名函数（lambda），其调用了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;self.run&lt;/code&gt; 函数，并记住了 grid 参数。self.run() 函数在子类 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;JITFunction&lt;/code&gt; 中实现，函数执行流程包括:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;获取/设置编译、执行环境&lt;/strong&gt;: 通过 driver.active.get_current_device() 与 driver.active.get_current_stream(device) 来获取当前设备和流（stream）。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;调用预处理钩子（pre_run_hooks）&lt;/strong&gt;: 在运行或编译前，可以执行一些用户自定义的操作（例如日志、数据检查、统计）。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;绑定参数、生成内核缓存键&lt;/strong&gt;: 与 binder 交互，将传入的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;*args, **kwargs&lt;/code&gt; 绑定成特定的参数结构并生成 “specialization” 信息。然后使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;specialization + options&lt;/code&gt; 组合成一个字符串 str 类型的&lt;strong&gt;缓存键 key&lt;/strong&gt;。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;通过 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kernel = kernel_cache.get(key, None)&lt;/code&gt; 查找或编译内核&lt;/strong&gt;: 如果缓存中已有对应内核，则复用；否则就走编译流程并把新编译的内核写回缓存。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;执行内核（可选）&lt;/strong&gt;: 如果 warmup=False，则真正将内核提交到 GPU 执行，传入相应的网格大小（grid）及参数；否则仅完成编译预热（不执行），返回编译后的 kernel。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;返回编译（或缓存）得到的内核对象&lt;/strong&gt;。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;这里重点分析下 triton 编译内核的过程代码:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# 当内核未缓存时执行编译流程
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# -------------------- 选项处理 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 从kwargs解析后端编译选项（如GPU架构参数等）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;backend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parse_options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# -------------------- 签名构建 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 从类参数中获取参数名称列表
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;sigkeys&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 从特化参数中获取参数值列表（第一个元素为参数值）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;sigvals&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;specialization&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 构建参数名-值映射的签名字典
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigkeys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigvals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)}&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# -------------------- 参数校验 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 检查是否使用了已弃用的设备相关参数
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;device_type&quot;&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;device_type选项已弃用，将使用当前目标设备&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;device&quot;&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;device选项已弃用，将使用当前设备&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;stream&quot;&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;stream选项已弃用，将使用当前流&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 验证所有传入参数是否合法
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__dict__&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sigkeys&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;raise&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;KeyError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;检测到未识别的关键字参数: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# -------------------- 常量表达式处理 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 查找标记为constexpr的常量参数路径
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;find_paths_if&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sigvals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;val&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;constexpr&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 获取这些常量参数的实际值（用于编译时优化）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;get_iterable_path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bound_args&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;values&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;path&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# -------------------- 属性处理 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 从特化参数中获取属性值（第二个元素为属性）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;attrvals&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;specialization&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 查找字符串类型的属性路径（如GPU内存配置参数）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;find_paths_if&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attrvals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;isinstance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 解析属性配置（转换为后端需要的格式）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;backend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parse_attr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_iterable_path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attrvals&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# -------------------- 编译前钩子 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 执行预编译钩子函数（如缓存检查、日志记录等）
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_call_hook&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;warmup&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;before&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 如果钩子返回True则中止编译
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# -------------------- 编译流程 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 生成抽象语法树（AST）形式的源代码
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ASTSource&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 调用后端编译器生成可执行内核
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__dict__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 将编译结果缓存
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;kernel_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# -------------------- 编译后钩子 --------------------
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 执行后编译钩子函数（如性能分析、资源注册等）
&lt;/span&gt;    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_call_hook&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;warmup&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;before&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;上述代码看着很复杂，但大部分都是预处理流程代码，真正的内核编译代码就三行：通过抽象语法树 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;AST&lt;/code&gt; &lt;strong&gt;生成中间代码&lt;/strong&gt;，调用编译器函数 self.compile，并缓存结果。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ASTSource&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# print(src.code) 可直接打印生成的 GPU 代码
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__dict__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;kernel_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;23-astsource-类&quot;&gt;2.3 ASTSource 类&lt;/h3&gt;

&lt;p&gt;triton/compiler/compiler.py: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ASTSource&lt;/code&gt; 是 Triton 编译器中的一个关键类，负责&lt;strong&gt;管理和处理源代码的抽象语法树&lt;/strong&gt;(AST)。&lt;/p&gt;

&lt;p&gt;JITFunction.run 函数调用 ASTSource 类的代码是：&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;src = self.ASTSource(self, signature, constexprs, attrs)&lt;/code&gt;。&lt;/p&gt;

&lt;p&gt;这段代码的作用是将内核的参数、常量表达式和属性配置转换为&lt;strong&gt;中间表示（IR）的抽象语法树（AST）&lt;/strong&gt;，&lt;strong&gt;这是实现跨平台代码生成和即时编译（JIT）的关键步骤&lt;/strong&gt;，其步骤包括：语义解析、类型推导、常量替换、属性注入、多后端适配（CUDA/ROCm/Metal）。这里举个例子理解其作用：&lt;/p&gt;

&lt;p&gt;假设用户编写如下 Triton 核函数：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;triton&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;jit&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;matmul_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;a_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;M&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constexpr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constexpr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;K&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constexpr&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ... 矩阵乘法逻辑
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;上述内核函数对应的输入参数结构：&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;signature（函数签名）&lt;/strong&gt;：描述参数名称和类型。&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;a_ptr&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;pointer&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;b_ptr&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;pointer&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;c_ptr&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;pointer&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;M&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;int32&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;N&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;int32&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;K&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;int32&quot;&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;ol&gt;
  &lt;li&gt;constexprs（编译时常量）: 标记需要在编译时确定的参数值。&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;M&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;N&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;K&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2048&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 矩阵维度
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;
&lt;ol&gt;
  &lt;li&gt;attrs（硬件属性）: 配置 GPU 核函数的执行参数。&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;num_warps&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;     &lt;span class=&quot;c1&quot;&gt;# 每个线程块的 warp 数量
&lt;/span&gt;    &lt;span class=&quot;s&quot;&gt;&quot;num_stages&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 流水线阶段数（用于异步拷贝优化）
&lt;/span&gt;    &lt;span class=&quot;s&quot;&gt;&quot;static_shared&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4096&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 预分配的静态共享内存大小
&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ASTSource.make_ir&lt;/code&gt; 方法会调用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ast_to_ttir&lt;/code&gt; 函数将 Python AST 转换为&lt;strong&gt;目标平台中间表示（如 MLIR 中间表示）&lt;/strong&gt;，如下所示：&lt;/p&gt;

&lt;div class=&quot;language-cpp highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;// Triton 生成的 MLIR 中间表示（简化版）&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;module&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attributes&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
  &lt;span class=&quot;s&quot;&gt;&quot;triton_gpu.num-warps&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
  &lt;span class=&quot;s&quot;&gt;&quot;triton_gpu.threads-per-warp&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
  &lt;span class=&quot;s&quot;&gt;&quot;triton_gpu.shared&quot;&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4096&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;func&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;func&lt;/span&gt; &lt;span class=&quot;err&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
      &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a_ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
      &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b_ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
      &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c_ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;
  &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 常量替换后的参数（M=1024, N=512, K=2048）&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;M&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;arith&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constant&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i32&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;arith&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constant&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i32&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;arith&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constant&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2048&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i32&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 计算逻辑（伪代码）&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;row&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_program_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i32&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;col&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_program_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i32&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;col&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;col&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;store&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c_ptr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;row&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;col&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f32&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt;
  &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;三-内核编译函数-compile&quot;&gt;三 内核编译函数 compile&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;triton/compiler/compiler.py&lt;/code&gt;: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compile()&lt;/code&gt; 函数最终返回 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CompiledKernel&lt;/code&gt; 对象，输入参数 src 为 ASTSource 对象，ASTSource 主要作用是将 JITFunction 从 ast tree 转换为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Triton IR&lt;/code&gt;。&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compile()&lt;/code&gt; 函数其简化版代码如下所示:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;target&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;driver&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;active&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_current_target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;isinstance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;GPUTarget&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;target must be of GPUTarget type&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;backend&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;make_backend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ir_source&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;isinstance&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ASTSource&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;#################省略部分代码#################
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# try:
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 1. AST -&amp;gt; TritonIR
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;module&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# except Exception as e:
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#     filter_traceback(e)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#     raise
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;use_ir_loc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;USE_IR_LOC&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 2. 遍历执行 backend 中定义的 stages 的 pass
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ext&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compile_ir&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;items&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;first_stage&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:]:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;next_module&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;compile_ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;ir_filename&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;file_name&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ext&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn_override_manager&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;full_name&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn_override_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_file&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ir_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;Overriding kernel with file &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;full_name&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;next_module&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;parse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;full_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ext&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;metadata_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ir_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn_cache_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;put&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;next_module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn_dump_manager&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;fn_dump_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;put&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;next_module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# use an env variable to parse ir from file
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;use_ir_loc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ext&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;ir_full_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn_cache_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_file&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ir_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;next_module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;create_location_snapshot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ir_full_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Creating new locations for &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ir_full_name&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;module&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_module&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# write-back metadata
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;metadata_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metadata_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn_cache_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;put&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;json&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dumps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;default&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;vars&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                                             &lt;span class=&quot;n&quot;&gt;binary&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;fn_cache_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;put_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;metadata_filename&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Compilation completed, disabling multithreading in context.
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# This is needed to safely finalize threads pool inside context: if current process forks before
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# python GC deletes context object, thread pool in child process will be invalid, which could
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# lead to child crash or hang.
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# However disabling multithreading causes the code to hang if the ASAN pass is enabled
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# this is likely due to the llvm-symbolizer forking a process
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# TODO: Reconcile the difference here between the ASAN and non-ASAN path with enabling
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# multithreading in the MLIR context
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;TRITON_ENABLE_ASAN&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;1&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;disable_multithreading&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# return handle to compiled kernel
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CompiledKernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;hash&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;complie 函数的核心流程如下，最终的返回值就是编译完成后的内核，其内部通常包含 GPU 二进制（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PTX/Cubin&lt;/code&gt;）或相应后端需要的可执行形式，也可能包括若干中间表征和完整的元数据。&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;AST -&amp;gt; TritonIR&lt;/li&gt;
  &lt;li&gt;遍历执行 backend 中定义的 stages 的 pass&lt;/li&gt;
&lt;/ol&gt;

&lt;h3 id=&quot;ast---tritonir&quot;&gt;AST -&amp;gt; TritonIR&lt;/h3&gt;

&lt;p&gt;make_ir 函数实现如下所示：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;ASTSource&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;constexprs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;############## 省略代码 ##############
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ast_to_ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                            &lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;make_ir&lt;/code&gt; 函数的第一个参数 self 实际是 JITFunction 对象，&lt;strong&gt;其作用是将 JITFunction 从 ast tree 转换为 Triton IR&lt;/strong&gt;，通过调用 ast_to_ttir 实现。&lt;/p&gt;

&lt;p&gt;ast_to_ttir 函数定义如下：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;ast_to_ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;arg_types&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;str_to_ty&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;values&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;prototype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ASTFunction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;arg_types&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constants&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attrs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;file_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;begin_line&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;get_jit_fn_file_line&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# query function representation
&lt;/span&gt;    &lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;collections&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;namedtuple&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;leaves&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;filter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constants&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;constants&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arg_names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constants&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;leaves&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;proxy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;namedtuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;SpecializationProxy&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;constants&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;signature&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constants&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;generator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CodeGenerator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prototype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gscope&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__globals__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;copy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;function_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;repr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;proxy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;jit_fn&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                              &lt;span class=&quot;n&quot;&gt;is_kernel&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;file_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;file_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;begin_line&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;begin_line&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                              &lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;codegen_fns&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;module_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;generator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;visit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;generator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;module&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# module takes ownership of the context
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;函数实现最关键的部分是初始化 CodeGenerator 对象，和使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;JITFunction.parse()&lt;/code&gt; 方法会将 kernel 代码转换为 ast tree，返回ast tree 的 root 节点。&lt;/p&gt;

&lt;h2 id=&quot;四-内核-compile-流程解析&quot;&gt;四 内核 compile 流程解析&lt;/h2&gt;

&lt;h3 id=&quot;backend-系统&quot;&gt;Backend 系统&lt;/h3&gt;

&lt;p&gt;Triton 在高层提供了一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compile&lt;/code&gt;(src, target, options) 这样的统一编译函数(在 python/triton/compiler/compiler.py 文件中)，但具体如何把 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ASTSource&lt;/code&gt; 转成可执行文件（PTX / Cubin / GCN / etc.）则由各个后端模块在 backends 目录下完成，如英伟达 GPU 的后端实现目录 :vthird_party/nvidia/backend。&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;python/triton/backends&lt;/code&gt; 目录下的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__init__.py&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compiler.py&lt;/code&gt;（BaseBackend 类）、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;driver.py&lt;/code&gt;（DriverBase 类）代码实现了跨硬件平台支持，提供了一套子类，定义了一系列抽象方法 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;abstractmethod&lt;/code&gt; 统一接口，用于将 Triton 的中间表示（IR）或 AST 编译成不同 GPU 平台所需的可执行内核。&lt;/p&gt;

&lt;p&gt;Triton 定义了一个抽象后端 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BaseBackend&lt;/code&gt; 类，定义所有硬件后端必须实现的接口：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;Backend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ABC&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;abstractmethod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;定义编译流水线阶段（如 TTIR→LLVMIR→PTX）&quot;&quot;&quot;&lt;/span&gt;
        
    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;abstractmethod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_codegen_implementation&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;返回平台特定的代码生成回调函数&quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;abstractmethod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_module_map&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;获取依赖的预编译模块（如数学函数库）&quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;################## 省略代码 ####################
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;Triton 内核的通用编译流程总结如下图所示:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/triton_compile/kernel_compile2.jpg&quot; alt=&quot;kernel_compile&quot; /&gt;&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;图片来源 &lt;a href=&quot;https://zhuanlan.zhihu.com/p/12789107689&quot;&gt;Triton概念与编程入门笔记（以Matmul为例）&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;不同的GPU 硬件 backend 定义了不同的编译 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;stages&lt;/code&gt;：&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;third_party/nvidia/backend/compiler.py&lt;/code&gt; 文件: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;class CUDABackend(BaseBackend)&lt;/code&gt;:&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_parse_arch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttir&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttgir&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_ttgir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;llir&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_llir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ptx&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_ptx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;cubin&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_cubin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;很明显英伟达 cuda 后端的编译流程分为 5 个 stages，将 kernel 函数的&lt;strong&gt;中间表示（IR）&lt;/strong&gt;转换为最终的二进制 cubin 代码。&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;third_party/amd/backend/compiler.py&lt;/code&gt; 文件: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;class HIPBackend(BaseBackend)&lt;/code&gt;:&lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttir&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttgir&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_ttgir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;llir&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_llir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;amdgcn&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_amdgcn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;hsaco&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_hsaco&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;下面以 CUDABackend 后端类为例简单描述下 kernel 编译流程和对应函数作用。&lt;/p&gt;

&lt;h3 id=&quot;41-make_ttir&quot;&gt;4.1 make_ttir&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;make_ttir&lt;/code&gt; 函数会构建并将 py 文件转换 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Triton IR&lt;/code&gt; 文件，Triton IR 是 &lt;strong&gt;Triton 编译器的高级中间表示&lt;/strong&gt;，用于表示深度学习模型的计算图，&lt;strong&gt;是硬件无关的&lt;/strong&gt;。其特点包括：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;高级抽象&lt;/strong&gt;：使开发者用接近于高级深度学习框架的方式来描述计算图。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;操作表示&lt;/strong&gt;：包含了一系列的操作，如矩阵乘法、卷机、激活函数等。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;优化&lt;/strong&gt;：编译器可以应用一些高级优化，如死代码消除（DCE）、共子表达式消除（CSE）等。&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;转换&lt;/strong&gt;：可以被转换为更接近硬件特性的低阶表示：Triton GPU IR。&lt;/li&gt;
&lt;/ol&gt;

&lt;blockquote&gt;
  &lt;p&gt;ttir 的全称是 Triton Tensor IR。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pass_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enable_debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_inliner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_rewrite_tensor_pointer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_canonicalizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_combine&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_reorder_broadcast&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_cse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_licm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_symbol_dce&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_loop_unroll&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;输入：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;mod: 一个 Module（通常是 MLIR 中定义的 Triton IR 模块）。&lt;/li&gt;
  &lt;li&gt;metadata: 编译过程中的一些上下文信息（可能包含目标架构、调试选项、或其他元数据）。&lt;/li&gt;
  &lt;li&gt;options: 进一步控制 Pass 管线行为的编译选项。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;输出：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;优化/转换过的同一 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;mod&lt;/code&gt; 对象（MLIR Module）。&lt;/li&gt;
  &lt;li&gt;返回给调用方后，可继续流向下一步编译或直接生成 GPU 二进制。&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;42-make_ttgir&quot;&gt;4.2 make_ttgir&lt;/h3&gt;

&lt;p&gt;make_ttgir 首先将 TritonIR 转换为 TritonGPUIR（&lt;strong&gt;专门针对 GPU 硬件优化的中间表示&lt;/strong&gt;），然后对 TritonGPUIR 进行优化：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;架构特定优化 根据 GPU 计算能力(capability)启用不同优化:&lt;/li&gt;
  &lt;li&gt;通用优化：内存访问优化（合并内存访问、数据预取、优化线程局部性）、计算相关优化（矩阵乘法加速、点积运算优化、指令流水线优化）等。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;函数返回依然是优化后的同一个 Module 对象。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_ttgir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 1. 准备 cluster_info 结构，用于描述集群（cluster）维度（X, Y, Z）。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ClusterInfo&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cluster_dims&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clusterDimX&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cluster_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clusterDimY&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cluster_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clusterDimZ&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cluster_dims&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 2. 如果环境变量里开启了 MLIR_ENABLE_REMARK，则设置源码诊断（Diagnostic）
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;MLIR_ENABLE_REMARK&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;1&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;srcMgr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;source_mgr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;source_mgr_diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;srcMgr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;printOpOnDiagnostic&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3. 构建一个 PassManager，用于把 TTIR 转成 TTGIR，并对其进行各种优化
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pass_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enable_debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3.1 把 TTIR 转换为 TTGPU IR
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#     - 这里指定了 “cuda:{capability}”，以及 opt.num_warps, 32, opt.num_ctas 等参数。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_convert_to_ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;cuda:&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_warps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_ctas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4. 优化 TTGIR：coalesce（合并张量操作）、f32_dot_tc（在capability&amp;gt;=8.0时启用f32 TensorCore等）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_coalesce&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_f32_dot_tc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.1 PlanCTA：将 CTA（thread block）按照 cluster_info 做布局规划
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#     注意这里由 nvidia.passes.ttnvgpuir 提供
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttnvgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_plan_cta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.2 进一步优化：移除多余 layout 转换、优化线程局部性、加速矩阵乘法等
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_remove_layout_conversions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_optimize_thread_locality&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_accelerate_matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_remove_layout_conversions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.3 优化点乘操作（点积）；capability &amp;gt;= 80（即8.0以上）时可能有专门的硬件指令
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_optimize_dot_operands&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;80&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.4 通用优化：CSE（消除公共子表达式）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_cse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.5 如果是 8.0（Ampere）或更高：做一些更高级的优化（累加器初始化、select和if合并、loop调度与pipeline等）
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_optimize_accumulator_init&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_combine_tensor_select_and_if&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_loop_scheduling&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_pipeline&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_stages&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.6 继续做预取（prefetch）、coalesce async copy、指令重排等，提升内存访问性能
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_prefetch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_optimize_dot_operands&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;80&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_coalesce_async_copy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_remove_layout_conversions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_reduce_data_duplication&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_reorder_instructions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.7 再做一波公共子表达式消除 &amp;amp; 符号死代码删除
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_cse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_symbol_dce&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 5. 如果是 9.0（Hopper）或更高架构：插入 fence &amp;amp; TMA（Tensor Memory Accelerator）等相关 pass
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttnvgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_fence_insertion&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttnvgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_tma_lowering&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 6. 最后做一次 canonicalize（规范化），并运行整个 Pass 管线
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_canonicalizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 7. 将 cluster_dims 信息记录到 metadata，便于后续阶段查询
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;cluster_dims&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clusterDimX&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clusterDimY&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                &lt;span class=&quot;n&quot;&gt;cluster_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;clusterDimZ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 8. 返回处理好的 MLIR Module
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;43-make_llir&quot;&gt;4.3 make_llir&lt;/h3&gt;

&lt;p&gt;make_llir 是 Triton 编译流水线中的 LLVM IR 生成与优化阶段，负责将 Triton GPU IR（TTGPUIR）转换为&lt;strong&gt;优化后的 LLVM IR&lt;/strong&gt;，并然后再把 MLIR 表示的 LLVM IR 进一步转化成实际的 LLVM 模块（通过 llvm.to_module），并最终返回一段 LLVM IR 字符串。它同时会根据硬件“compute capability”、编译选项（options）等信息做相应的优化与元数据记录。其核心任务包括：&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;低级优化：将平台无关的 TTGPUIR 转换为 NVIDIA GPU 专用的 LLVM IR&lt;/li&gt;
  &lt;li&gt;内存管理：分配共享内存（Shared Memory）和全局暂存内存（Global Scratch Memory）&lt;/li&gt;
  &lt;li&gt;架构适配：根据 GPU 计算能力（如 sm_80）和 PTX 版本生成目标指令&lt;/li&gt;
  &lt;li&gt;性能调优：设置寄存器限制、链接外部库、应用 O3 级优化&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;函数输出是：一段 LLVM IR 源码（字符串形式）。示例：&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; 优化后的 LLVM IR
target triple &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s2&quot;&gt;&quot;nvptx64-nvidia-cuda&quot;&lt;/span&gt;
target datalayout &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s2&quot;&gt;&quot;e-i64:64-i128:128-v16:16-v32:32-n16:32:64&quot;&lt;/span&gt;

define void @kernel&lt;span class=&quot;o&quot;&gt;(&lt;/span&gt;ptr %arg0&lt;span class=&quot;o&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c&quot;&gt;#0 {&lt;/span&gt;
  %shared &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; alloca &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;128 x float], align 4
  %val &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; load float, ptr %arg0
  store float %val, ptr %shared
  ret void
&lt;span class=&quot;o&quot;&gt;}&lt;/span&gt;

attributes &lt;span class=&quot;c&quot;&gt;#0 = { &quot;target-features&quot;=&quot;+ptx80,+sm_80&quot; &quot;maxnreg&quot;=&quot;128&quot; }&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;make_llir 函数定义：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_llir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    将 Triton GPU IR (TTGIR) 转换为 LLVM IR (MLIR 表示)，
    再进一步转换为实际的 LLVM 模块并返回其字符串形式。
    同时根据编译选项与硬件 capability 做相应处理。
    &quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 1. 从编译选项中获取 PTX 版本（如果是 NVIDIA 后端）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;get_ptx_version_from_options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 2. 根据属性 &quot;ttg.num-warp-groups-per-cta&quot; 动态调整 num_warps
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    (warp-specialization 机制)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;num_warp_groups&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_int_attr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttg.num-warp-groups-per-cta&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_warp_groups&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;num_warps&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_warp_groups&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# mod 就是 TTGIR Module (MLIR)
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 3. 构建 MLIR PassManager 来执行 TTGIR -&amp;gt; LLVM-IR (MLIR) 的转换与优化
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pass_manager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enable_debug&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3.1 若环境变量启用 MLIR 诊断，将在编译时输出额外信息
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;MLIR_ENABLE_REMARK&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;1&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;srcMgr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;source_mgr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;source_mgr_diag&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;srcMgr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;printOpOnDiagnostic&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3.2 对 TTGIR 进行一些转换与优化 Pass
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_combine_tensor_select_and_if&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;convert&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_scf_to_cf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;convert&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_index_to_llvmir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_allocate_shared_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_allocate_global_scratch_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 3.3 将 TTGIR 转换成 MLIR 中的 LLVM Dialect IR (兼顾 GPU capability 和 ptx_version)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_to_llvmir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 3.4 如果有 nvidia GPU 定制化 pass，也在此时调用
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ttnvgpuir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_nvgpu_to_llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;convert&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_arith_to_llvmir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3.5 做一些通用 IR 优化 &amp;amp; 清理
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_canonicalizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_cse&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;common&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_symbol_dce&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3.6 如果环境变量允许行号信息，则添加 debug scope
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;TRITON_DISABLE_LINE_INFO&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;passes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llvmir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_di_scope&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3.7 运行所有 Pass，得到 LLVM-IR (仍是 MLIR 表示)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;pm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4. 初始化 LLVM，准备将 MLIR -&amp;gt; LLVM 原生模块
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_targets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.1 不支持 ASAN 时的提示
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;TRITON_ENABLE_ASAN&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;1&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;raise&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;RuntimeError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&quot;Address Sanitizer Error: Address sanitizer is only supported on AMD backend&quot;&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.2 把 MLIR 中的 LLVM Dialect Module 转为一个真实的 LLVM 模块 (llvmlite)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to_module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.3 根据 compute capability 设置目标 (sm_80 / sm_90a 等) 与特性
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;proc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;sm_90a&apos;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;90&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;sm_&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;features&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;get_features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;triple&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;nvptx64-nvidia-cuda&apos;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attach_datalayout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;triple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;proc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.4 在 LLVM 模块上设定 “FTZ (flush-to-zero)” 等配置
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;nvidia&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;set_nvvm_reflect_ftz&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.5 如果用户在 options 里指定了 maxnreg，则给内核函数附加相应的 NVVM 属性
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maxnreg&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_functions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_declaration&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_external_linkage&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;set_nvvm_maxnreg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;maxnreg&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.6 若需要链接外部库，则在此操作
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;extern_libs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;paths&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;extern_libs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;link_extern_libs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4.7 优化 LLVM 模块到 O3
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimize_module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;OPTIMIZE_O3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 5. 提取更多 metadata 信息 (共享内存大小、全局临时区大小/对齐等)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;shared&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_int_attr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttg.shared&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;global_scratch_size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_int_attr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttg.global_scratch_memory_size&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;global_scratch_align&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_int_attr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ttg.global_scratch_memory_alignment&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 6. 输出最终的 LLVM IR 字符串
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 7. 释放资源，然后返回 LLVM IR 字符串
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;del&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm_mod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;del&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;context&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;44-make_ptx&quot;&gt;4.4 make_ptx&lt;/h3&gt;

&lt;p&gt;make_ptx 用于将此前编译产生的 LLVM IR (在字符串形式) 翻译成&lt;strong&gt;PTX 汇编代码&lt;/strong&gt;，并对生成的 PTX 进行一定的后处理（如插入正确的 .version、去除 debug 标记等）。它主要依赖 llvmlite 或 Triton 自带的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;llvm.translate_to_asm&lt;/code&gt; 接口，将 LLVM IR 转换为可在 NVIDIA GPU 上使用的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PTX&lt;/code&gt; 代码。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_ptx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    将 &apos;src&apos; (LLVM IR 字符串) 翻译成 PTX 汇编，并进行必要的后处理和信息提取。
    :param src:         字符串形式的 LLVM IR
    :param metadata:    字典类型的元数据，用于存放编译后的一些信息，如 kernel 名称等
    :param opt:         编译选项对象（包含 arch、fp_fusion 等）
    :param capability:  GPU 架构（如 80=sm_80, 90=sm_90a 等）
    :return:            最终生成的 PTX 字符串
    &quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 1. 从选项 &apos;opt&apos; 中获取所需 PTX 版本 (例如 63 对应 PTX 6.3)
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;get_ptx_version_from_options&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 2. 设定编译目标三元组 (triple) 和目标 GPU (proc)，以及特性 (features)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    - triple 固定为 &quot;nvptx64-nvidia-cuda&quot;
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    - 若 capability==90 则用 &quot;sm_90a&quot;，否则 &quot;sm_{capability}&quot;
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;triple&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;nvptx64-nvidia-cuda&apos;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;proc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;sm_90a&apos;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;90&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;sm_&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;features&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;get_features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 3. 将 LLVM IR 翻译成 PTX 汇编
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    - 这里传入包括 triple, proc, features, 以及一些编译选项 (如 fp_fusion)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    - 返回值 ret 是 PTX 字符串
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;llvm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;translate_to_asm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;                 &lt;span class=&quot;c1&quot;&gt;# LLVM IR
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;triple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;             &lt;span class=&quot;c1&quot;&gt;# Target triple
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;proc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;# GPU 架构
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;           &lt;span class=&quot;c1&quot;&gt;# 特性字符串
&lt;/span&gt;        &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;nvptx-short-ptr&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;&lt;span class=&quot;c1&quot;&gt;# 一些附加选项
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enable_fp_fusion&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;# 是否以 debug 方式生成
&lt;/span&gt;    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 4. 从生成的 PTX 中找到 kernel 名 (&apos;.visible .entry kernelName&apos;)
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    - 预期只有一个 kernel
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;names&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;re&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;findall&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;.visible .entry ([a-zA-Z_][a-zA-Z0-9_]*)&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;Expected exactly 1 kernel, got &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;name&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;names&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 记录到 metadata
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 5. 将 PTX 文件头中的 &apos;.version x.x&apos; 替换为实际 desired ptx_version
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    - ptx_version 字符串形如 &apos;6.3&apos;, &apos;7.5&apos; 等
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;re&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sub&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;\.version \d+\.\d+&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;.version &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptx_version&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;flags&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;re&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MULTILINE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 6. 移除会阻止 ptxas 优化的 &quot;debug&quot; 标记，以取得更高性能
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;re&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sub&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;,\s*debug|debug,\s*&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 7. 若环境变量里设置 NVPTX_ENABLE_DUMP=1，则打印生成的 PTX 以供调试
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;NVPTX_ENABLE_DUMP&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;1&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;// -----// NVPTX Dump //----- //&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 8. 返回处理好的 PTX 字符串
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ret&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;输入输出示例:&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;输入：LLVM IR 字符串&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;复制
&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; LLVM IR 输入（已优化）
target triple &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s2&quot;&gt;&quot;nvptx64-nvidia-cuda&quot;&lt;/span&gt;
define void @triton_kernel&lt;span class=&quot;o&quot;&gt;(&lt;/span&gt;ptr %arg0&lt;span class=&quot;o&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c&quot;&gt;#0 {&lt;/span&gt;
  %val &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; load float, ptr %arg0
  store float %val, ptr null  &lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; 示例简化存储操作
  ret void
&lt;span class=&quot;o&quot;&gt;}&lt;/span&gt;

attributes &lt;span class=&quot;c&quot;&gt;#0 = { &quot;target-features&quot;=&quot;+ptx80,+sm_80&quot; }&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;输出：PTX 汇编代码&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;// 生成的 PTX 代码
.version 7.8        &lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; 根据选项调整后的版本
.target sm_80       &lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; GPU 计算能力
.address_size 64    &lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; 64 位地址

.visible .entry triton_kernel&lt;span class=&quot;o&quot;&gt;(&lt;/span&gt;      &lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; 元数据中记录内核名称
  .param .u64 %arg0
&lt;span class=&quot;o&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;{&lt;/span&gt;
  ld.param.u64 %rd1, &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;%arg0]&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
  ld.global.f32 %f1, &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;%rd1]&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
  st.global.f32 &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;0], %f1&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;           &lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; 存储操作
  ret&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
&lt;span class=&quot;o&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;45-make_cubin&quot;&gt;4.5 make_cubin&lt;/h3&gt;

&lt;p&gt;make_cubin 函数负责将PTX（字符串形式）编译成Cubin（NVIDIA GPU 的二进制可执行文件）。它通过在本地创建临时文件并调用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ptxas&lt;/code&gt; 命令行工具完成最终编译输出。如果编译成功，就会返回 cubin（二进制）数据；若编译出错，则会抛出相应的异常。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;make_cubin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    将 PTX 字符串编译成 NVIDIA Cubin 二进制。
    :param src:         要编译的 PTX 源码（字符串）
    :param metadata:    存放编译中产生的一些上下文信息（本函数暂未直接修改其内容）
    :param opt:         编译选项对象 (enable_fp_fusion 等)
    :param capability:  GPU 架构 (如 80 对应 sm_80, 90 对应 sm_90a)
    :return:            编译成功时返回一个 bytes 类型的 cubin（可执行二进制）
    &quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 1. 获取 ptxas 二进制所在路径
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    如果找不到，_path_to_binary(...) 会抛出相应异常
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;ptxas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_path_to_binary&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;ptxas&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 2. 创建临时文件，用于写入 PTX 源码并接受日志输出
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    mode=&apos;w&apos; 以可写方式创建 .ptx 文件
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#    mode=&apos;r&apos; 用来读取日志 .log 文件
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tempfile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTemporaryFile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;delete&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mode&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;w&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;suffix&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;.ptx&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; \
         &lt;span class=&quot;n&quot;&gt;tempfile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NamedTemporaryFile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;delete&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mode&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;r&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;suffix&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;.log&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 2.1 写入 PTX 源码到临时文件
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;write&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;src&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flush&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 2.2 输出的目标文件名 (.o) 存在于同一路径
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;fbin&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;.o&apos;&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 3. 构建 ptxas 命令行选项
&lt;/span&gt;        &lt;span class=&quot;c1&quot;&gt;#    - line_info: 是否启用行号信息
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;line_info&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;TRITON_DISABLE_LINE_INFO&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;-lineinfo&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;#    - fmad: 是否允许 FMA 指令融合 (由 enable_fp_fusion 决定)
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;fmad&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;opt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enable_fp_fusion&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;--fmad=false&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;#    - suffix: 如果capability=90, 则在后面加 &apos;a&apos; =&amp;gt; sm_90a
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;suffix&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;a&apos;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;90&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;&apos;&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;#    - opt_level: 是否禁用 ptxas 优化
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;opt_level&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;--opt-level&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;0&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;environ&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;DISABLE_PTXAS_OPT&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;0&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;1&quot;&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;ptxas_cmd&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;ptxas&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;line_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fmad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&apos;-v&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 打印一些编译信息到 stderr
&lt;/span&gt;            &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;opt_level&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;--gpu-name=sm_&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capability&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;suffix&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 输入 PTX 文件
&lt;/span&gt;            &lt;span class=&quot;s&quot;&gt;&apos;-o&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fbin&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 输出目标文件
&lt;/span&gt;        &lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;try&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# 4. 调用 ptxas 进行编译
&lt;/span&gt;            &lt;span class=&quot;c1&quot;&gt;#    - close_fds=False 可能是为了兼容某些环境
&lt;/span&gt;            &lt;span class=&quot;c1&quot;&gt;#    - stderr=flog 将错误日志写入 flog 文件
&lt;/span&gt;            &lt;span class=&quot;n&quot;&gt;subprocess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptxas_cmd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;check&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;close_fds&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stderr&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

            &lt;span class=&quot;c1&quot;&gt;# 4.1 若编译成功，清理生成的中间日志文件
&lt;/span&gt;            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;remove&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fsrc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;remove&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;except&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;subprocess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CalledProcessError&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# 5. 如果编译出错，读取错误日志并抛出 PTXASError
&lt;/span&gt;            &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;open&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_file&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;log&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;log_file&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;read&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;remove&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;flog&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

            &lt;span class=&quot;c1&quot;&gt;# 根据 e.returncode 的不同，返回不同错误信息
&lt;/span&gt;            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;returncode&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;255&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;error&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;Internal Triton PTX codegen error&apos;&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;returncode&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;signal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SIGSEGV&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;error&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;`ptxas` raised SIGSEGV&apos;&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;error&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;`ptxas` failed with error code &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;e&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;returncode&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;

            &lt;span class=&quot;k&quot;&gt;raise&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PTXASError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;error&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
                &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;`ptxas` stderr:&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;log&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
                &lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;Repro command: &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot; &quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;join&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ptxas_cmd&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&apos;&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 6. 若编译成功，则读取生成的 .o 文件即 cubin
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;open&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fbin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&apos;rb&apos;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cubin&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;read&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# 6.1 清理 .o 文件
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;path&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;exists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fbin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;remove&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fbin&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 7. 返回最终的 cubin 二进制 (bytes)
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cubin&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;至此，Triton 编译 kernel 的完整流程梳理总结完毕。&lt;/p&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://www.zhihu.com/question/622685131/answer/3497791972&quot;&gt;如何入门 OpenAI Triton 编程?&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/2768070031&quot;&gt;OpenAI Triton分享：Triton编译流程&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/16752382946&quot;&gt;Triton学习：源码结构解析与编译流程&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2025-01-25</pubDate>
      <link>https://c1lovez1.github.io/2025-01-25/triton-kernel-compile.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2025-01-25/triton-kernel-compile.html</guid>
    </item>
    
		
    
    <item>
      <title>vllm 推理流程剖析</title>
      <description>&lt;p&gt;简单来讲，大模型推理一般会经历 4 个流程：Tokenizer -&amp;gt; Model.forward -&amp;gt; Sampler -&amp;gt; DeTokenizer，但对于集成了各种优化技术的 vllm 框架来说，大模型推理服务流程是非常复杂的&lt;/p&gt;

&lt;p&gt;先看下 vLLM 代码整体架构：&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/vllm_infer/vllm_architecture.png&quot; alt=&quot;vllm_architecture&quot; /&gt;
￼
vllm 模型执行流程调用关系可以通过以下代码打印：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vllm_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;llm_engine&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_executor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;driver_worker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_runner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;这里先直接总结 vllm 模型推理涉及到的类，及它们之间的调用关系如下，贯穿这些类的是 execute_model 函数。&lt;/p&gt;

&lt;p&gt;LLMEngine —&amp;gt; ExecutorBase —&amp;gt; WorkerBase —&amp;gt; ModelRunnerBase&lt;/p&gt;

&lt;h3 id=&quot;llmengine-类&quot;&gt;LLMEngine 类&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LLMEngine&lt;/code&gt; 类: vllm/engine/llm_engine.py&lt;/p&gt;

&lt;p&gt;（涉及代码 2000 行）： 用于管理大语言模型（LLM）的推理和生成过程。其中：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;step&lt;/code&gt; 方法：是引擎的核心方法，每调用一次执行一次解码步骤。包括调度、执行模型、处理输出和清理任务。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_initialize_kv_caches&lt;/code&gt; 方法：初始化 KV 缓存，动态调整 GPU 和 CPU 缓存块数量。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;step()&lt;/code&gt; 方法中真正调用模型推理的代码只有一行，通过调用 model_executor 类实例的 execute_model 方法执行模型推理。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;outputs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_executor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;execute_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;execute_model_req&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;execute_model_req&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;executorbase-类&quot;&gt;ExecutorBase 类&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ExecutorBase&lt;/code&gt; 类: vllm/executor/executor_base.py&lt;/p&gt;

&lt;p&gt;模型推理执行器, 针对不同的硬件平台 CPU/GPU/XPU 等，使用工厂方法注册得到不同平台的模型推理执行器类，基类不实现 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;execute_model&lt;/code&gt; 函数，每个平台的执行器类都有各自的实现。&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;execute_model&lt;/code&gt; 函数内部通过调用下述代码，实现模型推理，输入参数只有一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;execute_model_req&lt;/code&gt;。类型为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ExecuteModelRequest&lt;/code&gt;。初始化和执行模型推理代码如下所示:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;driver_worker&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_create_worker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;execute_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;driver_worker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;execute_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;execute_model_req&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h4 id=&quot;_create_worker-函数&quot;&gt;_create_worker 函数&lt;/h4&gt;

&lt;p&gt;_create_worker() 函数由 ExecutorBase 类的初始化函数调用: vllm/executor/gpu_executor.py&lt;/p&gt;

&lt;p&gt;_create_worker 通过 local_rank 和 rank 参数配置工作器（Worker），其本质上是通过调用 create_worker 函数，create_worker 函数的输入参数如下，函数输入参数通过另外一个类函数 _get_worker_module_and_class 获取。&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;worker_module_name&lt;/code&gt;: 工作器模块名字&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;worker_class_name&lt;/code&gt;： 工作起类名&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;worker_class_fn&lt;/code&gt;：工作器函数&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;create_worker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                  &lt;span class=&quot;n&quot;&gt;worker_class_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Callable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[[],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;WorkerBase&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]],&lt;/span&gt;
                  &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;wrapper&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WorkerWrapperBase&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_class_fn&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_class_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;wrapper&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_worker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wrapper&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_get_worker_module_and_class&lt;/code&gt; 函数，支持根据当前的&lt;strong&gt;调度器配置（scheduler_config）和推测性配置（speculative_config）&lt;/strong&gt;动态地选择适当的 Worker 模块和类。这些 Worker 模块负责执行实际的模型推理工作。函数实现代码如下所示:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_get_worker_module_and_class&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Callable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[[],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;WorkerBase&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]]]:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;worker_class_fn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scheduler_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_multi_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;vllm.worker.multi_step_worker&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;MultiStepWorker&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;speculative_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;vllm.spec_decode.spec_decode_worker&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;create_spec_worker&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;vllm.worker.worker&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;Worker&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;worker_class_fn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;_get_worker_module_and_class 函数逻辑说明&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;多步调度 (is_multi_step)： 如果启用了多步调度，函数返回：
    &lt;ul&gt;
      &lt;li&gt;模块路径：”vllm.worker.multi_step_worker”&lt;/li&gt;
      &lt;li&gt;类名：”MultiStepWorker”&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;推测性解码 (speculative_config)： 如果启用了推测性解码，函数返回：
    &lt;ul&gt;
      &lt;li&gt;模块路径：”vllm.spec_decode.spec_decode_worker”&lt;/li&gt;
      &lt;li&gt;类名：”create_spec_worker”&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;默认配置： 如果既未启用多步调度，也未启用推测性解码，则返回默认的 Worker：
    &lt;ul&gt;
      &lt;li&gt;模块路径：”vllm.worker.worker”&lt;/li&gt;
      &lt;li&gt;类名：”Worker”&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;函数指针：
    &lt;ul&gt;
      &lt;li&gt;默认返回 None，表示无需额外的类处理逻辑。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;总结：根据 scheduler_config 或 speculative_config 的状态，选择合适的 Worker 模块和类。&lt;/p&gt;

&lt;h3 id=&quot;workerwrapperbase-类&quot;&gt;WorkerWrapperBase 类&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;WorkerWrapperBase&lt;/code&gt; 类： vllm/worker/worker_base.py&lt;/p&gt;

&lt;p&gt;WorkerBase: 定义了所有 Worker 的抽象基类（接口），明确了每个 Worker 必须实现的功能。抽象方法包括设备初始化、缓存管理、模型执行等。&lt;/p&gt;

&lt;p&gt;WorkerWrapperBase 类的关键方法 init_worker()：完成初始化 Worker 的功能，&lt;strong&gt;输入参数是前面 create_worker 的输入参数&lt;/strong&gt;，支持自定义逻辑。init_worker 方法支持两种方式获取&lt;strong&gt;实际 Worker 类&lt;/strong&gt;：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;使用 worker_class_fn 函数返回 Worker 类。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;importlib&lt;/code&gt; 动态导入模块，并从模块中获取 Worker 类。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;其中动态导入模块的实现代码如下所示：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;importlib&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;import_module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_module_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;worker_class&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;getattr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mod&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_class_name&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;上述代码是 Python 的动态导入机制，可以在运行时根据字符串形式的模块名称加载模块，而不是在代码编译时导入。getattr() 函数的作用是从一个对象（在这里是模块 mod）中获取指定名称的属性（在这里是 self.worker_class_name）。在这里的代码中，用于动态获取模块中的类。&lt;/p&gt;

&lt;h4 id=&quot;localordistributedworkerbase-类&quot;&gt;LocalOrDistributedWorkerBase 类&lt;/h4&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LocalOrDistributedWorkerBase&lt;/code&gt;：&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vllm/worker/worker_base.py&lt;/code&gt;&lt;/p&gt;

&lt;p&gt;抽象基类，用于在本地或分布式环境中运行模型推理任务。它通过以下功能实现模型的执行和数据广播：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;元数据广播：在分布式环境中，主工作器（Driver Worker）将元数据和输入数据广播到其他工作器。辅助工作器从广播数据中提取输入。&lt;/li&gt;
  &lt;li&gt;输入准备和分布式支持：主工作器准备输入并广播。辅助工作器接收广播并提取输入。&lt;/li&gt;
  &lt;li&gt;推理执行：负责调用模型运行器（model_runner）来执行模型推理步骤。支持单步或多步推理。包含 SPMD（单程序多数据）模式执行&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;抽象方法定义&lt;/strong&gt;：定义了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;prepare_worker_input&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;execute_worker&lt;/code&gt; 等抽象方法，需子类实现本地逻辑。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;真正的模型推理执行是通过调用 self.model_runner.execute_model 代码实现。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_runner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;execute_model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;model_input&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_caches&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;worker_input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;virtual_engine&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_cache&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;intermediate_tensors&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_tensors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_steps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;self.model_runner&lt;/code&gt; 的初始化在各个不同 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Worker&lt;/code&gt; 子类中实现。如 GPU 的 Worker 类的 &lt;strong&gt;init&lt;/strong&gt; 初始化函数实现了对 self.model_runner 的赋值。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;ModelRunnerClass&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;GPUModelRunnerBase&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ModelRunner&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_runner_cls&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ModelRunnerClass&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_runner_cls&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;task&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;embedding&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ModelRunnerClass&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EmbeddingModelRunner&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_encoder_decoder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;ModelRunnerClass&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EncoderDecoderModelRunner&lt;/span&gt;
&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_runner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;GPUModelRunnerBase&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ModelRunnerClass&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;vllm_config&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vllm_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;kv_cache_dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;is_driver_worker&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_driver_worker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;speculative_args&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h4 id=&quot;worker-类&quot;&gt;Worker 类&lt;/h4&gt;

&lt;p&gt;Worker 类：vllm/worker/worker.py&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/vllm_infer/worker.png&quot; alt=&quot;worker&quot; /&gt;&lt;/p&gt;

&lt;p&gt;默认的 GPU 工作器类，继承自 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LocalOrDistributedWorkerBase&lt;/code&gt;。在 GPU 上执行模型（分区）的 worker 类。&lt;strong&gt;每个 worker 都与单个 GPU 相关联&lt;/strong&gt;。worker 负责维护 KV 缓存并在 GPU 上执行模型。在分布式推理的情况下，每个 worker 都会被分配一个模型分区。其核心函数作用总结如下：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;determine_num_available_blocks&lt;/code&gt;：分析 GPU 内存使用情况，确定可用的 KV 缓存块数量。并返回 num_gpu_blocks 和 num_cpu_blocks。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;initialize_cache&lt;/code&gt;：分配 GPU 和 CPU 上的 KV 缓存，准备模型的计算缓存。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;execute_worker&lt;/code&gt;：根据输入执行缓存操作，例如数据块的交换、复制等。&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;modelrunner-类&quot;&gt;ModelRunner 类&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ModelRunner&lt;/code&gt;：vllm/worker/model_runner.py&lt;/p&gt;

&lt;p&gt;类定义：ModelRunner 是基于 GPU 的模型运行器，支持推理和采样。类包含 2 个成员变量：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_model_input_cls&lt;/code&gt;：定义模型输入的类型，ModelInputForGPUWithSamplingMetadata。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_builder_cls&lt;/code&gt;：定义构造模型输入的工具类，ModelInputForGPUBuilder。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;核心函数：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;make_model_input_from_broadcasted_tensor_dict&lt;/code&gt;：从广播的张量字典创建 ModelInputForGPUWithSamplingMetadata 实例。通过调用ModelInputForGPUWithSamplingMetadata 类的 from_broadcasted_tensor_dict 方法解析输入张量，并与注意力后端 (attn_backend) 配置绑定。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;prepare_model_input&lt;/code&gt;：为推理准备模型输入，包括元数据和采样配置。按 prefill 和 decode 的顺序对输入数据进行批处理。使用 dataclasses.replace 方法更新模型输入。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;execute_model&lt;/code&gt;：作用执行一次模型推理。推理执行流程包括： 使用 CUDA 图 (decode_meta.use_cuda_graph) 或常规执行 —&amp;gt; 模型计算 logits —&amp;gt; 如果是最后一级流水线，使用 logits 采样下一个 token。最终返回的是采样输出 (SamplerOutput)。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;调用 model.forward 模型推理的代码如下所示，这里的 model_executable 是应用了 cuda graph 技术后捕获 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;graph_runners[virtual_engine][graph_batch_size]&lt;/code&gt;。&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;set_forward_context&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attn_metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;hidden_or_intermediate_states&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_executable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;input_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;positions&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_positions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;kv_caches&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kv_caches&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;attn_metadata&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;attn_metadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;intermediate_tensors&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;intermediate_tensors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MultiModalKwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;as_kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;multi_modal_kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                        &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
        &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seqlen_agnostic_kwargs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;self.graph_runners 的赋值是通过 GPUModelRunnerBase 类的 capture_model 函数捕获一系列不同 batch_size 的 capture_models。&lt;/p&gt;

&lt;p&gt;至此，vllm 推理的各个类的调用关系和流程的简单分析完毕了，后续会继续优化细节。&lt;/p&gt;

&lt;h3 id=&quot;参考资料&quot;&gt;参考资料&lt;/h3&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?pli=1#slide=id.g24ad94a0065_0_350&quot;&gt;The first vLLM meetup&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2024-12-02</pubDate>
      <link>https://c1lovez1.github.io/2024-12-02/vllm-infer.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2024-12-02/vllm-infer.html</guid>
    </item>
    
		
    
    <item>
      <title>LLaVA 系列模型结构详解</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#前言&quot;&gt;前言&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#llava1&quot;&gt;LLaVA1&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#vit-l14-模型结构&quot;&gt;ViT-L/14 模型结构&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#llava15&quot;&gt;LLaVA1.5&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#llava-15-hd&quot;&gt;LLaVA-1.5-HD&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#llava16llava-next&quot;&gt;LLaVA1.6（LLaVA-NeXT）&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;前言&quot;&gt;前言&lt;/h2&gt;

&lt;p&gt;视觉语言模型 VIA 或者说多模态大模型 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLLM&lt;/code&gt; 架构通常都是: LLM + 视觉编码器 + 映射层的组合。英伟达发布的视觉语言模型 VILA 架构和训练流程如下图所示：&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/llava_model/VILA_infer_train.jpg&quot; width=&quot;80%&quot; alt=&quot;VILA_infer_train&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;可以看出视觉语言模型的架构是由视觉 encoder、映射层和语言 decoder 组成。常见的视觉语言模型如下所示，模型架构都很相似。&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;VILA-1.5 (B/8B/13B/40B)&lt;/li&gt;
  &lt;li&gt;LLaVA(1.5,1.6) (7B-34B)&lt;/li&gt;
  &lt;li&gt;InternLM-XComposer2 (7B, 4khd-7B)&lt;/li&gt;
  &lt;li&gt;QWen-VL (7B)&lt;/li&gt;
  &lt;li&gt;DeepSeek-VL (7B)&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;llava1&quot;&gt;LLaVA1&lt;/h2&gt;

&lt;p&gt;Llava1 的模型结构很简洁，&lt;strong&gt;CLIP 模型的视觉编码器 + 映射层 + LLM（Vicuna、LLama）&lt;/strong&gt; ，利用 CLIP 模型的 Vison Encoder 结构对输入图片提取视觉特征，即转换为形状为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[N=1, grid_H x grid_W, hidden_dim]&lt;/code&gt; 的 feature map，然后通过一个映射层（线性层）将图像特征对齐到文本特征维度，即得到形状为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[N=1, grid_H x grid_W, embedding_dim]&lt;/code&gt; 的 image tokens embedding 向量，再然后将图片 tokens 向量和输入文本 tokens 向量 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;concat&lt;/code&gt; 后作为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LLM&lt;/code&gt; 的输入，生成回答文本。&lt;/p&gt;

&lt;p&gt;LLaVA 模型架构如下图所示吧:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/llava_model/llava_model.png&quot; alt=&quot;llava_model&quot; /&gt;&lt;/p&gt;

&lt;p&gt;具体来说，对于输入图像 $X_v$，采用预训练 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CLIP&lt;/code&gt; 模型的视觉编码器 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT-L/14(224²)&lt;/code&gt;，其生成的视觉特征为 $Z_v = g(X_v)$，在作者的实验中，只用最后一个 Transformer 层之前和之后的网格特征。并使用一个&lt;strong&gt;简单的线性层&lt;/strong&gt;将图像特征连接（映射）到词嵌入空间，通过一个可训练的投影矩阵 $W$ 将 $Z_v$ 转换为语言嵌入标记 $H_v$，$Z_v$ 向量的最后一个维度就是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LLM&lt;/code&gt; 的&lt;strong&gt;词嵌入空间维度&lt;/strong&gt; &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;embedding_dim&lt;/code&gt;。&lt;/p&gt;

\[H_v = W\cdot X_v, with Z_v = g(X_v)\]

&lt;h3 id=&quot;vit-l14-模型结构&quot;&gt;ViT-L/14 模型结构&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT-L/14&lt;/code&gt; 模型的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;L&lt;/code&gt; 表示模型的规模，为 “Large”，ViT(Vision Transformer) 模型有不同规模的模型，例如：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT-B&lt;/code&gt;（Base）：通常有 12 层 Transformer。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT-L&lt;/code&gt;（Large）：通常有 24 层 Transformer。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT-H&lt;/code&gt;（Huge）：通常有 32 层 Transformer。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT&lt;/code&gt; 会将输入图像分割成固定大小的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;patch&lt;/code&gt;（例如 14x14），&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT-L/14&lt;/code&gt; 即表示 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;patch&lt;/code&gt; 大小为 14，&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;CLIP 模型的视觉部分使用 ViT 来编码图像的特征，文本部分使用 Transformer 来编码文本的特征。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;不同版本 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;ViT&lt;/code&gt; 模型的参数总结：&lt;/p&gt;

&lt;table&gt;
  &lt;thead&gt;
    &lt;tr&gt;
      &lt;th&gt;模型版本&lt;/th&gt;
      &lt;th&gt;Transformer 层数&lt;/th&gt;
      &lt;th&gt;隐藏维度&lt;/th&gt;
      &lt;th&gt;参数量&lt;/th&gt;
      &lt;th&gt;Patch 分辨率&lt;/th&gt;
    &lt;/tr&gt;
  &lt;/thead&gt;
  &lt;tbody&gt;
    &lt;tr&gt;
      &lt;td&gt;ViT-B/16&lt;/td&gt;
      &lt;td&gt;12&lt;/td&gt;
      &lt;td&gt;768&lt;/td&gt;
      &lt;td&gt;86M&lt;/td&gt;
      &lt;td&gt;16x16&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;ViT-L/16&lt;/td&gt;
      &lt;td&gt;24&lt;/td&gt;
      &lt;td&gt;1024&lt;/td&gt;
      &lt;td&gt;307M&lt;/td&gt;
      &lt;td&gt;16x16&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;ViT-L/14&lt;/td&gt;
      &lt;td&gt;24&lt;/td&gt;
      &lt;td&gt;1024&lt;/td&gt;
      &lt;td&gt;307M&lt;/td&gt;
      &lt;td&gt;14x14&lt;/td&gt;
    &lt;/tr&gt;
    &lt;tr&gt;
      &lt;td&gt;ViT-H/14&lt;/td&gt;
      &lt;td&gt;32&lt;/td&gt;
      &lt;td&gt;1280&lt;/td&gt;
      &lt;td&gt;632M&lt;/td&gt;
      &lt;td&gt;14x14&lt;/td&gt;
    &lt;/tr&gt;
  &lt;/tbody&gt;
&lt;/table&gt;

&lt;h2 id=&quot;llava15&quot;&gt;LLaVA1.5&lt;/h2&gt;

&lt;p&gt;模型结构上和前作相比，LLaVA1.5 将之前用于维度映射的的简单一层线性层替换为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;2&lt;/code&gt; 层 线性层的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MLP&lt;/code&gt; 结构，并将 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;clip-L/14&lt;/code&gt; 的输入分辨率从 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;224*224&lt;/code&gt; 提升到 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;336*336&lt;/code&gt;，因为作者发现提高输入图像分辨率能够增强模型性能，LLM 换成了 Vicuna1.5（在 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LLama2&lt;/code&gt; 上微调的模型）&lt;/p&gt;

&lt;h3 id=&quot;llava-15-hd&quot;&gt;LLaVA-1.5-HD&lt;/h3&gt;

&lt;p&gt;目前开源的 CLIP 视觉编码器的分辨率上限为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;336*336&lt;/code&gt;，这意味着无法简单地替换视觉编码器来支持更高分辨率的图像。为了解决这个问题，论文探索了一种方法，既能让多模态语言模型（LMM）处理高分辨率图像，又能保持 LLaVA-1.5 的高效数据使用。&lt;/p&gt;

&lt;p&gt;将输入图像划分为若干小块，每块的分辨率与视觉编码器原始训练时一致，然后分别对这些块进行编码。编码完成后，我们将这些块的特征图合并为目标分辨率的大特征图，并将其输入到 LLM 中。同时，为了给 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LLM&lt;/code&gt; 提供&lt;strong&gt;全局上下文信息&lt;/strong&gt;并减少图像分割、编码和合并操作带来的不良影响，我们还将一个经过下采样（&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;resize&lt;/code&gt;）的图像特征连接到合并后的特征图中。&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/llava_model/llava1.5_model.png&quot; alt=&quot;llava1.5_model&quot; /&gt;&lt;/p&gt;

&lt;p&gt;这样的设计允许我们处理任意分辨率的图像，同时保持 LLaVA-1.5 的数据高效性。这一新模型被作者命名为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;LLaVA-1.5-HD&lt;/code&gt;。&lt;/p&gt;

&lt;h2 id=&quot;llava16llava-next&quot;&gt;LLaVA1.6（LLaVA-NeXT）&lt;/h2&gt;

&lt;p&gt;模型推理层面新的升级点在于，Vision Encoder 分辨率支持更大的分辨率，包括 672x672, 336x1344, 1344x336 几种分辨率的输入，并且支持通过图片裁切，编码，合并来实现，和前作一样的方法。毕竟，当提供高分辨率图像和保留细节的表征时，模型感知图像中复杂细节的能力会显著提高。它减少了面对低分辨率图像时的模型幻觉，即猜测想象的视觉内容。&lt;/p&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/pdf/2304.08485&quot;&gt;Visual Instruction Tuning&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://arxiv.org/pdf/2310.03744&quot;&gt;Improved Baselines with Visual Instruction Tuning&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/695100288&quot;&gt;【多模态大模型】llava系列：llava、llava1.5、llava-next&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2024-11-28</pubDate>
      <link>https://c1lovez1.github.io/2024-11-28/llava-structure.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2024-11-28/llava-structure.html</guid>
    </item>
    
		
    
    <item>
      <title>温度系数与 top-p 采样策略详解</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#一temperature-温度系数作用&quot;&gt;一、Temperature 温度系数作用&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#二解码策略介绍&quot;&gt;二、解码策略介绍&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#三top-p-采样算法&quot;&gt;三、top-p 采样算法&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#31-top-p-采样算法步骤&quot;&gt;3.1 top-p 采样算法步骤:&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#32-top-p-采样代码&quot;&gt;3.2 top-p 采样代码&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;一temperature-温度系数作用&quot;&gt;一、Temperature 温度系数作用&lt;/h2&gt;

&lt;p&gt;Temperature 采样的温度系数意义、公式和知识蒸馏很相似，结合 softmax 的公式，都是如下形式:&lt;/p&gt;

\[q_i = \frac{exp(z_i/T)}{\sum_j^K exp(z_j/T)}\]

&lt;p&gt;当 $T$ 趋于无穷大时，输出概率分布将趋于均匀分布，概率为 $1/K$, 此时信息熵是最大的。反过来，$T$ 趋于0时，正确类别的概率接近 $1$，输出结果就是确定的，信息熵为 0，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;softmax&lt;/code&gt; 的效果与 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;argmax&lt;/code&gt; 差不多.&lt;/p&gt;

&lt;p&gt;应用代码如下所示：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# logits 是 llm 推理输出, 形状为 [batch_size, seq_len, vocab_size]
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;softmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;temperature&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;代码详解：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;logits[:, -1]&lt;/code&gt; 表示选择的是最后一个位置（seq_len 维度的最后一项）对应的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;logits&lt;/code&gt;，形状变为 [batch_size, vocab_size]。因为在生成模型中的 prefill 阶段，我们只关心当前生成的最后一个 token 的分布。&lt;/li&gt;
  &lt;li&gt;temperature 作用是调整 logits 的分布，用于控制采样的随机性。总结就是，温度系数 $T$ 越大输出越平滑，结果越不确定，越小则越确定。具体来说，当 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;temperature &amp;lt; 1.0&lt;/code&gt;，分布会变得更加陡峭，更倾向于选择高概率的 token。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;temperature &amp;gt; 1.0&lt;/code&gt;，分布会变得更加平坦，增加随机性。&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;二解码策略介绍&quot;&gt;二、解码策略介绍&lt;/h2&gt;

&lt;p&gt;首先需要知道，&lt;strong&gt;LLM 的输出结果只是下一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;token&lt;/code&gt; 的概率分布 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;logits&lt;/code&gt;&lt;/strong&gt;，即对下一个单词的预测概率张量，形状为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[batch_size, seq_len, logits]&lt;/code&gt;。而如何从概率分布中选择下一个单词，就是我要介绍的解码策略，也叫采样策略&lt;/p&gt;

&lt;p&gt;解码策略里，常见的方法是是贪心策略，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top-K&lt;/code&gt; 采样和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top_p&lt;/code&gt; 采样，这几个方法的不同点在于候选集的选择策略不同。&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;贪心策略取的是概率最大的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top1&lt;/code&gt; 的样本作为候选项，也就是永远取概率最大的样本作为一个候选项，但这样只能保证是局部最优，也就是当前步是最优的，达不到全局最优。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top-K&lt;/code&gt; 采样取的是概率的前 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;TopK&lt;/code&gt; 的样本作为候选项, 也就是每一步都保留有 K 个候选项，能在一定程度上保证全局最优。但 top-k 有个问题就是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;k&lt;/code&gt; 取多少，是最优的，这个难以确定。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top-p&lt;/code&gt; 采样，针对的就是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;K&lt;/code&gt; 值难确定的问题，通过设定阈值 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;p&lt;/code&gt;, 根据候选集累积概率之和达到阈值 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;p&lt;/code&gt;，来选择候选项的个数，也叫核采样。&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;三top-p-采样算法&quot;&gt;三、top-p 采样算法&lt;/h2&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Top-p&lt;/code&gt; 采样（也称为核采样，Nucleus Sampling）是一种用于自然语言生成模型的解码策略，旨在平衡生成文本的多样性和质量。核心思想是：在每一步生成 next_token 时，都从累积概率超过阈值 p 的tokens 集合中进行随机采样。具体操作是，每个时间步，按照 token出现的概率由高到底排序，当概率之和大于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;top-p&lt;/code&gt; 的时候，就不考虑后面的低概率 tokens。&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;../images/top-p/Top-P_Visual.webp&quot; alt=&quot;TOP-P_Visual&quot; /&gt;&lt;/p&gt;

&lt;p&gt;上图很好的展示了 Top-p 采样（Nucleus Sampling） 的过程，可以分为两个步骤：&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;1，确定候选集&lt;/strong&gt;：左图显示如何根据累积概率选择候选集：&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;每个单词（或 token）都有一个概率，例如：“United”: 12%，“Netherlands”: 2.7%，按照概率降序排列，逐步累加概率，直到累积概率达到阈值（例如 15%）。&lt;/li&gt;
  &lt;li&gt;一旦达到阈值，忽略其他概率更低的词（如 “Czech” 和 “U” 被排除）。&lt;/li&gt;
  &lt;li&gt;因此，此例中，候选集包括：&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;“United” (12%)&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;“Netherlands” (2.7%)&lt;/code&gt;。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;strong&gt;2，从候选集中采样&lt;/strong&gt;: 右图显示如何从候选集中基于归一化概率进行采样：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;候选集中的概率重新归一化。例如：
    &lt;ul&gt;
      &lt;li&gt;United”: 原概率 12% 占候选集的 82%（12% / 15%）。&lt;/li&gt;
      &lt;li&gt;“Netherlands”: 原概率 2.7% 占候选集的 18%（2.7% / 15%）。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;根据归一化后的概率进行随机采样。最终生成的词可能是：“United”（较高概率）或 “Netherlands”（较低概率）。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;很明显，top-p 采样方法可以动态调整候选词的数量，避免了固定数量候选词可能带来的问题。另外，可以发现，top_p 越小，则过滤掉的小概率 token 越多，采样时的可选项目就越少，生成结果的多样性也就越小。&lt;/p&gt;

&lt;h3 id=&quot;31-top-p-采样算法步骤&quot;&gt;3.1 top-p 采样算法步骤:&lt;/h3&gt;

&lt;p&gt;Top-p 采样的详细步骤：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;概率排序：对模型在当前时间步生成的所有词汇的概率进行降序排序。&lt;/li&gt;
  &lt;li&gt;确定候选集：从排序后的词汇中，选择累积概率达到或超过设定阈值 p 的最小集合，记为 V_p。例如，若 p=0.9，则选择前几个词，使其概率之和至少为 0.9。&lt;/li&gt;
  &lt;li&gt;归一化概率：对候选集 V_p 中的词汇的概率进行重新归一化，使其和为 1。&lt;/li&gt;
  &lt;li&gt;随机采样：根据归一化后的概率分布，从候选集 V_p 中随机选择下一个生成的词。&lt;/li&gt;
  &lt;li&gt;token 索引映射：使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.gather&lt;/code&gt; 函数将采样的索引映射回原始词汇表索引。&lt;/li&gt;
&lt;/ol&gt;

&lt;h3 id=&quot;32-top-p-采样代码&quot;&gt;3.2 top-p 采样代码&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;top-p&lt;/code&gt; 采样代码详细解释：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;sample_top_p&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    执行 Top-p (Nucleus) 采样, 从概率分布中采样下一个词。
    参数：
        probs (torch.Tensor): 概率分布张量，形状为 `[batch_size, vocab_size]`。
        p (float): 累积概率阈值，取值范围在 0 到 1 之间。
    返回：
        torch.Tensor: 采样得到的词索引，形状为 `[batch_size, 1]`。

    说明：
        Top-p 采样算法: 选择概率累积和超过阈值 p 的最小集合，将这些词的概率重新归一化后进行采样。
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 对概率分布进行降序排序。probs_sort: 排序后的概率值，形状与 probs 相同。probs_idx: 排序后的索引，用于映射回原始词汇表。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;descending&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 计算排序后概率的累积和. 返回的 probs_sum 是累积概率分布。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;probs_sum&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cumsum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 保留累积概率未超过阈值 p 的词汇的概率，其余词汇的概率被置为 0.0。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs_sum&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# 创建掩码，对于每个位置，计算累积概率（不包括当前词）是否超过阈值 p。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.0&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# 将累积概率超过阈值 p 的词的概率置零。
&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 对剩余的概率重新归一化, 确保总和为 1。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;div_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keepdim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 从重新归一化的概率分布中采样下一个词. 返回的 next_token 是采样得到的词在排序后概率分布中的索引。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;next_token_sorted_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;multinomial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs_sort&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_samples&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 在 probs_idx 的最后一维（dim=-1）中，使用 next_token_sorted_idx 作为索引，提取对应的值。沿着 dim=1（列）进行索引提取
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# NOTE: torch.gather 函数按照给定的索引张量 index，从输入张量中收集 (获取) 数据，并返回一个与索引张量形状一致的张量。
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gather&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;index&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_token_sorted_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;next_token&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# 返回采样得到的下一个词的索引
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;代码运行示例：&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;1. 输入概率分布： probs &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;0.1, 0.3, 0.4, 0.15, 0.05]
2. 按降序排序： probs_sort &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;0.4, 0.3, 0.15, 0.1, 0.05]
   原始索引：   probs_idx  &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;2, 1, 3, 0, 4]
3. 计算累积概率： probs_sum &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;0.4, 0.7, 0.85, 0.95, 1.0]
4. 根据 &lt;span class=&quot;nv&quot;&gt;p&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;0.8 标记掩码： mask &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;False, False, True, True, True]
5. 将超出范围的概率置零： probs_sort &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;0.4, 0.3, 0, 0, 0]
6. 重新归一化： probs_sort &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;0.5714, 0.4286, 0, 0, 0]
7. 根据概率采样： next_token_index &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; 0
8. 从原始索引还原： next_token &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; probs_idx[0] &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; 2
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;strong&gt;重点函数解释&lt;/strong&gt;:&lt;/p&gt;

&lt;p&gt;1，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.gather&lt;/code&gt; 函数按照给定的索引张量 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;index&lt;/code&gt;，从输入张量中收集 (获取) 数据，并返回一个与索引张量形状一致的张量。&lt;/p&gt;

&lt;p&gt;示例代码：&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 创建一个 3x4 的输入张量
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_tensor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;20&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;30&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;40&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                             &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;50&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;60&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;70&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;80&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                             &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;90&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;110&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;120&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 创建一个包含索引的张量
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;index_tensor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                             &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                             &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 沿着 dim=1（列）进行索引提取
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output_tensor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gather&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;input_tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;index_tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output_tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
程序运行后输出:
tensor([[ 40,  30,  20,  10],
        [ 50,  60,  70,  80],
        [100,  90, 120, 110]])
对于 input_tensor 的第二行 [50, 60, 70, 80]，index_tensor 的第二行 [0, 1, 2, 3] 指示提取顺序为第二行的 0 列、第二行的第 1 列、第二行的第 2 列、第二行的第 3 列，结果为 [50, 60, 70, 80]。
&quot;&quot;&quot;&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;2，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.multinomial&lt;/code&gt; 用于从概率分布中抽取样本，支持带放回和不带放回两种方式。具体来说，它的功能是基于输入的概率权重进行采样。&lt;/p&gt;

&lt;p&gt;函数签名如下:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;multinomial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;input&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_samples&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;replacement&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;generator&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LongTensor&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;参数解释：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;input&lt;/code&gt;: 1D 或 2D 的张量，表示概率分布。它的值不需要是标准化的概率，但必须是非负的。如果是 2D 张量，每一行会被视为一个单独的分布。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_samples&lt;/code&gt;: 需要采样的样本数量。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;replacement&lt;/code&gt;: 是否是有放回采样。
    &lt;ul&gt;
      &lt;li&gt;如果为 True，可以多次采样同一个索引。&lt;/li&gt;
      &lt;li&gt;如果为 False，采样后不会重复选择。&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;generator&lt;/code&gt;（可选）: 控制采样随机性的生成器。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;代码示例:&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;torch&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 定义一个概率分布
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 从分布中采样 1 个样本（不放回）
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sample_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;multinomial&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;probs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_samples&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;采样的索引:&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sample_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;item&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# 输出采样的索引可能为 2，因为其对应概率最大
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/713270088&quot;&gt;如何解释 top_p 和 temperature 参数对 LLM 生成多样性的影响&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/631591713&quot;&gt;ChatGPT 温度系数t与top-p, 超参怎么设置最优&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://docs.cohere.com/docs/controlling-generation-with-top-k-top-p&quot;&gt;Top-k &amp;amp; Top-p&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2024-11-24</pubDate>
      <link>https://c1lovez1.github.io/2024-11-24/top-p-sampling.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2024-11-24/top-p-sampling.html</guid>
    </item>
    
		
    
    <item>
      <title>vllm 优化之 PagedAttention 源码解读</title>
      <description>&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;#一-pagedattention-内核&quot;&gt;一 PagedAttention 内核&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#11-主要函数&quot;&gt;1.1 主要函数&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#12-内核配置定义&quot;&gt;1.2 内核配置定义&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#13-基于-block_tables-读取-kv-cache&quot;&gt;1.3 基于 block_tables 读取 kv cache&lt;/a&gt;&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#二-paged页表原理分析&quot;&gt;二 Paged(页表)原理分析&lt;/a&gt;
    &lt;ul&gt;
      &lt;li&gt;&lt;a href=&quot;#21-block-管理相关类&quot;&gt;2.1 Block 管理相关类&lt;/a&gt;
        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#blocktable&quot;&gt;BlockTable&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#cpugpublockallocator-类&quot;&gt;CpuGpuBlockAllocator 类&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#naiveblockallocator&quot;&gt;NaiveBlockAllocator&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#blocklist-类&quot;&gt;BlockList 类&lt;/a&gt;&lt;/li&gt;
          &lt;li&gt;&lt;a href=&quot;#逻辑-block-管理类-selfattnblockspacemanager&quot;&gt;逻辑 block 管理类-SelfAttnBlockSpaceManager&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#22-slot-mapping&quot;&gt;2.2 slot mapping&lt;/a&gt;&lt;/li&gt;
      &lt;li&gt;&lt;a href=&quot;#23-物理-block-分配类-cacheengine&quot;&gt;2.3 物理 block 分配类-CacheEngine&lt;/a&gt;
        &lt;ul&gt;
          &lt;li&gt;&lt;a href=&quot;#231-num_gpu_blocks-获取-determine_num_available_blocks-函数&quot;&gt;2.3.1 num_gpu_blocks 获取-determine_num_available_blocks 函数&lt;/a&gt;&lt;/li&gt;
        &lt;/ul&gt;
      &lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;#参考资料&quot;&gt;参考资料&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;PagedAttention 算法的原理可以参考我前面写的文章&lt;a href=&quot;https://www.armcvai.cn/2024-10-26/vllm-optimize.html&quot;&gt;vllm优化技术速览&lt;/a&gt;。从源码的角度来看 PagedAttention，其实可以分为两部分:&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;PagedAttention Kernel 的实现，这里 v1 算法计算逻辑部分和标准 attention 差不多（v2 计算逻辑和 flashattentionv2 一致），但是 kv cache 的分配和管理使用了  kv cache 动态管理、存取优化技术。&lt;/li&gt;
  &lt;li&gt;PagedAttention 的对 kv cache 的内存分配管理技术（KV cache manager），之前的 kv cache 在 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt; 这个维度都是固定为最大输入尺寸的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;max_seq_len&lt;/code&gt;, 但实际单个请求可能消耗这么多内存，且在请求推理结束后也没有释放对应内存的功能，这必然会造成大量的内存浪费和碎片化。由此 PagedAttention 算法基于操作系统的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;page table&lt;/code&gt; 思想构建了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block table&lt;/code&gt; 来动态分配管理 kv cache 内存，下图很好的展示了相比早前的其他推理框架，LLM 在显存利用率上极高。
    &lt;blockquote&gt;
      &lt;p&gt;这种 kv cache 内存管理分配的算法（思想）是可以应用到其他 llm 推理服务框架中。&lt;/p&gt;
    &lt;/blockquote&gt;
  &lt;/li&gt;
&lt;/ul&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/llm_memory_waste.png&quot; width=&quot;60%&quot; alt=&quot;llm_memory_waste&quot; /&gt;
&lt;/div&gt;

&lt;h2 id=&quot;一-pagedattention-内核&quot;&gt;一 PagedAttention 内核&lt;/h2&gt;

&lt;h3 id=&quot;11-主要函数&quot;&gt;1.1 主要函数&lt;/h3&gt;

&lt;p&gt;看一个文件代码之前，先快速过一下这个文件有哪些主要（模板）类或者函数，vllm 中 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pagedattention&lt;/code&gt; 内核的实现 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;csrc/attention/attention_kernels.cu&lt;/code&gt; 文件中，其主要有以下模板函数。&lt;/p&gt;

&lt;div class=&quot;language-cpp highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;template&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;HEAD_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;vllm&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Fp8KVCacheDataType&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;IS_BLOCK_SPARSE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PARTITION_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;&amp;gt;&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// Zero means no partitioning.&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;__device__&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;// v2 内核的算法实现逻辑对应的就是 Flashattentionv1&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// Grid: (num_heads, num_seqs, 1).&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;template&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;HEAD_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;vllm&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Fp8KVCacheDataType&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;IS_BLOCK_SPARSE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;&amp;gt;&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;__global__&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v1_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;// v2 内核的算法实现逻辑对应的就是 Flashattentionv2, 因此并行度多了一个 kv cache seq！并行数量为 max_num_partitions.&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// Grid: (num_heads, num_seqs, max_num_partitions).&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;template&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;HEAD_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;vllm&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Fp8KVCacheDataType&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;IS_BLOCK_SPARSE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PARTITION_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;&amp;gt;&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;__global__&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v2_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;// paged_attention_v1 内核的包装函数&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;template&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CACHE_T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;vllm&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Fp8KVCacheDataType&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;IS_BLOCK_SPARSE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;&amp;gt;&lt;/span&gt;
&lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v1_launcher&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;// paged_attention_v2 内核的包装函数&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;template&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CACHE_T&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;n&quot;&gt;vllm&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Fp8KVCacheDataType&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;IS_BLOCK_SPARSE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
          &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PARTITION_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;512&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;&amp;gt;&lt;/span&gt;
&lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v2_launcher&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;// paged_attention_v1 对外提供的接口函数，也是生成 python 调用接口的函数形式，部分参数我省略了&lt;/span&gt;
&lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;// [num_seqs, num_heads, head_size]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;query&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_seqs, num_heads, head_size]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_blocks, num_heads, head_size/x, block_size, x]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;value_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;       &lt;span class=&quot;c1&quot;&gt;// [num_blocks, num_heads, head_size, block_size]&lt;/span&gt;
    &lt;span class=&quot;kt&quot;&gt;int64_t&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_kv_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_heads]&lt;/span&gt;
    &lt;span class=&quot;kt&quot;&gt;double&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scale&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_tables&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_seqs, max_num_blocks_per_seq]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;// [num_seqs]&lt;/span&gt;
    &lt;span class=&quot;kt&quot;&gt;int64_t&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int64_t&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// paged_attention_v2 对外提供的接口函数，参数和 v1 类似&lt;/span&gt;
&lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;PagedAttention 本质上是集合了 kv cache 动态管理技术的优化版 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;flashattention&lt;/code&gt;。PagedAttention 的内核实现有两个版本 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;paged_attention_v1_kernel&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;paged_attention_v2_kernel&lt;/code&gt;，v1 改编自FasterTransformers 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MHA&lt;/code&gt; 实现，适合长度小于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;8192&lt;/code&gt; 或者 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_seqs * num_heads &amp;gt; 512&lt;/code&gt; 的情况，v2 是参考 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FlashDecoding&lt;/code&gt; 方式进行实现，对 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sequence&lt;/code&gt; 维度进行切分以增加并行粒度。&lt;/p&gt;

&lt;p&gt;这两个版本的内核都是基于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;paged_attention_kernel&lt;/code&gt; 内核通过输入不同参数来实现，v2 内核版本多了一个 kv cache &lt;strong&gt;seq 维度分区数量&lt;/strong&gt;的参数，并行度层面多了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kv cache&lt;/code&gt; &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt; 分区的并行度！&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;flashattention 两种算法实现集成在一个内核里，这还是很考验作者工程功底的！&lt;/p&gt;
&lt;/blockquote&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/paged_attention_v1_v2_kernel.png&quot; width=&quot;60%&quot; alt=&quot;paged_attention_v1_v2_kernel&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;PagedAttention 内核的实现函数和常规 Attention 的实现相比最明显的就是多了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 相关参数，以及 k_cache 的尺寸变成了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[num_blocks, num_kv_heads, head_size/x, block_size, x]&lt;/code&gt;，很明显了多了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_blocks&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_size&lt;/code&gt; 维度（v_cache 变量也是），用于表示一个 seq 用多少个 blocks 存储，以及每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 存储多少个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tokens&lt;/code&gt;。&lt;/p&gt;

&lt;p&gt;PagedAttention kernel 模板函数签名如下所示:&lt;/p&gt;

&lt;div class=&quot;language-cpp highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;// Grid: (num_heads, num_seqs, 1).&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// 这里为了方便阅读我删除了块稀疏的参数&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;template&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;HEAD_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;&amp;gt;&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;__global__&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;paged_attention_v1_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;out&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;           &lt;span class=&quot;c1&quot;&gt;// [num_seqs, num_heads, head_size]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;       &lt;span class=&quot;c1&quot;&gt;// [num_seqs, num_heads, head_size]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_blocks, num_kv_heads, head_size/x, block_size, x], 最后一个x 是 vectorize，一个thread fetch一个vector&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_blocks, num_kv_heads, head_size, block_size]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_kv_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;               &lt;span class=&quot;c1&quot;&gt;// [num_heads]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scale&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_tables&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_seqs, max_num_blocks_per_seq]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;      &lt;span class=&quot;c1&quot;&gt;// [num_seqs]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_num_blocks_per_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;float&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__restrict__&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;alibi_slopes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// [num_heads]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_block_stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_head_stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_scale&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v_scale&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; 
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;12-内核配置定义&quot;&gt;1.2 内核配置定义&lt;/h3&gt;

&lt;p&gt;先阅读 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;paged_attention_v1_kernel()&lt;/code&gt; 内核的调用（包装）函数 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;paged_attention_v1_launcher()&lt;/code&gt; 的 内容来看 kernel 的配置如何。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/grid_block_definition.png&quot; width=&quot;55%&quot; alt=&quot;paged_attention_v1_kernel 配置&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;可以看出 kernel 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;grid&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 配置如下所示，即分别定义了二维 grid 和一维 block 配置，其中每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BLOCKS_X&lt;/code&gt; 处理一个 head，每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BLOCKS_Y&lt;/code&gt; 处理一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt;，每个 thread 处理最后一个维度  &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;hidden_size&lt;/code&gt;  的计算。&lt;/p&gt;

&lt;div class=&quot;language-cpp highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;dim3&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;grid&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_seqs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;// dim3 grid(BLOCKS_X, BLOCKS_Y)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;dim3&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;知道了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kernel&lt;/code&gt; 的配置，我们再回过头去看 kernel 源码-&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;paged_attention_kernel()&lt;/code&gt; 模板函数，按照 kernel 编写惯例，开头的代码依然是&lt;strong&gt;先计算全局线程 id 和偏移&lt;/strong&gt;，只保留 v1 内核相关且注释后的代码如下所示：&lt;/p&gt;

&lt;div class=&quot;language-cpp highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;// 用于对整数除法结果进行向上取整。&lt;/span&gt;
&lt;span class=&quot;cp&quot;&gt;#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
&lt;/span&gt;
&lt;span class=&quot;kt&quot;&gt;void&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;paged_attention_kernel&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
    &lt;span class=&quot;cm&quot;&gt;/****** seq 维度的线程索引和 blocks 数量计算 ********/&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 获取当前请求序列 seq 的索引，基于网格的 y 维度&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blockIdx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 获取当前请求的实际长度，即 tokens 数量&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_lens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;];&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 计算序列被分割成多少个块，每块大小为 BLOCK_SIZE&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_seq_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DIVIDE_ROUND_UP&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 未启用分区处理（USE_PARTITIONING 为 false），起始块索引为 0。&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_seq_blocks&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;// 计算当前 block 块的起始和结束索引&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;start_block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;end_block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 计算当前 block 处理的令牌（token）范围&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;start_token_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;start_block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;end_token_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;end_block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 计算 warp 数量&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_WARPS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WARP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 获取当前线程在块内的索引&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;threadIdx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 计算当前线程所在的 warp 索引&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;warp_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WARP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 计算当前线程在 warp 内的 lane 索引&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lane&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WARP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

    &lt;span class=&quot;cm&quot;&gt;/****** num_heads 维度的 kv head 索引计算 ********/&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 兼容了 GQA 技术的 kv head 地址计算&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 获取当前查询头的索引，基于网格的 x 维度&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;head_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blockIdx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gridim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 计算每个 Key/Value 头对应的查询头数, 看到没，GQA 是可以集成到 attention 内核里面的！&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_queries_per_kv&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_kv_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 计算当前 Key/Value 头的索引&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_head_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;head_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_queries_per_kv&lt;/span&gt;

    &lt;span class=&quot;cm&quot;&gt;/****** thread group 向量化load&amp;amp;store 相关代码, 不太理解 *****/&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 定义线程组大小和数量&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;MAX&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;WARP_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_THREAD_GROUPS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;// 确保 THREAD_GROUP_SIZE 能整除 NUM_THREADS&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;assert&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NUM_THREADS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 定义向量类型，用于存储部分 Key 或 Query&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 向量大小配置为线程组中的线程数 * 向量大小保证16字节对齐&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;MAX&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;sizeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)),&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;using&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;K_vec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Vec&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;using&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Q_vec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Vec&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;using&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Quant_vec&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;typename&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Vec&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 每个线程处理的元素数&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_ELEMS_PER_THREAD&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;HEAD_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;// 每个线程处理的向量数&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_VECS_PER_THREAD&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_ELEMS_PER_THREAD&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 计算线程组索引和偏移&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_group_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_group_offset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 查询地址计算&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q_stride&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;head_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;HEAD_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;通过注释我们可以发现最前面代码的核心就是计算 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_heads&lt;/code&gt; 维度的索引以及线程组索引和偏移。&lt;/p&gt;

&lt;h3 id=&quot;13-基于-block_tables-读取-kv-cache&quot;&gt;1.3 基于 block_tables 读取 kv cache&lt;/h3&gt;

&lt;p&gt;这部分代码是真正属于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pagedattention&lt;/code&gt; 原创性的设计，即如何基于 block_tables 去 token 的 offset。&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;首先就是：获取当前序列 (seq_idx) 的块表。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt; 是一个二维数组，每个序列有最多 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;max_num_blocks_per_seq&lt;/code&gt; 个块。通过 seq_idx * max_num_blocks_per_seq 计算&lt;strong&gt;当前序列的块表起始地址&lt;/strong&gt;。&lt;/li&gt;
  &lt;li&gt;遍历 Key 块。&lt;/li&gt;
  &lt;li&gt;获取物理块号: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_table[block_idx]&lt;/code&gt;。&lt;/li&gt;
  &lt;li&gt;循环加载 Key 向量、计算点积和更新 qk_max。
    &lt;ul&gt;
      &lt;li&gt;最外层循环: 每个 warp 负责计算一个 block key，而每个 block key shape 为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[block_size, num_head, head_size]&lt;/code&gt;&lt;/li&gt;
      &lt;li&gt;第二层循环: 每个thread_group取一个key，即num_head个元素，计算QK dot&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;div class=&quot;language-cpp highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;// x == THREAD_GROUP_SIZE * VEC_SIZE&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// Each thread group fetches x elements from the key at a time.&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;sizeof&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// 获取当前序列的块表，确定要迭代的 Key 块。&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// block_tables 是函数参数，形状为 [num_seqs, max_num_blocks_per_seq] 的二维数组，每个序列有最多 max_num_blocks_per_seq 个块。&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;// 通过 seq_idx * max_num_blocks_per_seq 计算当前序列的块表起始地址。&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_tables&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_num_blocks_per_seq&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;// 每个 warp 负责 blocksize * headsize 个元素&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;start_block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;warp_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;end_block_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_WARPS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;// 获取当前块的物理块号，并将其转换为 int64_t 以避免溢出。&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int64_t&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;physical_block_number&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;static_cast&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;kt&quot;&gt;int64_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]);&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;// 加载一个 Key 向量到寄存器。&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 每个线程组中的每个线程处理 Key 的不同部分。&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;// 例如，如果线程组大小为 4，则组中的第一个线程处理 0, 4, 8,... 向量，第二个线程处理 1, 5, 9,... 向量，依此类推。&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_TOKENS_PER_THREAD_GROUP&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;++&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;// 在当前 physical block 中找到当前 thread group 负责的局部 token id&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;physical_block_offset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;thread_group_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WARP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;// 计算 token 在当前 seq 的所有 block 中的全局索引。&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;physical_block_offset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;// 声明一个用于存储多个 Key 向量的数组。&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;K_vec&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_vecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NUM_VECS_PER_THREAD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;];&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;// 遍历每个向量，加载 Key 向量到 k_vecs 数组。&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;// 根据上述 shape 算出当前 seq 的具体 k cache 的 block size 这一维度的 offset&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NUM_VECS_PER_THREAD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;++&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;// k_cache.shape=[num_blocks, num_kv_heads, head_size/x, block_size, x]&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_t&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;k_cache&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;physical_block_number&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_block_stride&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;kv_head_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kv_head_stride&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;physical_block_offset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

            &lt;span class=&quot;c1&quot;&gt;// 因为是向量化 LOAD，还需要计算出 vec 的全局id，和 vec 内元素的局部 offset&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;vec_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;thread_group_offset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;j&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vec_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;vec_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;VEC_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

            &lt;span class=&quot;c1&quot;&gt;// 根据 Key 缓存的数据类型，加载并转换 Key 向量。&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;constexpr&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Fp8KVCacheDataType&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;kAuto&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
                &lt;span class=&quot;c1&quot;&gt;// 直接加载 Key 向量。&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;k_vecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;reinterpret_cast&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;K_vec&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&amp;gt;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;k_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
                &lt;span class=&quot;c1&quot;&gt;// 从量化向量转换为 Key 向量。&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;Quant_vec&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_vec_quant&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;reinterpret_cast&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Quant_vec&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&amp;gt;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;k_ptr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BLOCK_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;offset2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;k_vecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;j&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fp8&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scaled_convert&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;K_vec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Quant_vec&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;KV_DTYPE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;k_vec_quant&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_scale&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;// 计算查询与 Key 的点积。&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;// 这包括线程组内的归约操作。&lt;/span&gt;
        &lt;span class=&quot;kt&quot;&gt;float&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;qk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scale&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Qk_dot&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scalar_t&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;THREAD_GROUP_SIZE&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;&amp;gt;::&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;q_vecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;thread_group_offset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k_vecs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
        
        &lt;span class=&quot;c1&quot;&gt;// 如果提供了 ALiBi 斜率，则添加偏置。&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;qk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;alibi_slope&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;!=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;?&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;alibi_slope&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;// 如果当前线程组偏移量为 0，则进行以下操作：&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;thread_group_offset&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;// 计算当前令牌是否超出序列长度，用于掩码。&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;const&lt;/span&gt; &lt;span class=&quot;kt&quot;&gt;bool&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
            
            &lt;span class=&quot;c1&quot;&gt;// 如果掩码为真，则将 logits 设为 0；否则，设为计算的 qk。&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;start_token_idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;?&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;f&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;qk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;;&lt;/span&gt;
            
            &lt;span class=&quot;c1&quot;&gt;// 更新查询与 Key 的最大值，用于后续的 softmax 计算。&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;qk_max&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mask&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;?&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;qk_max&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;fmaxf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;qk_max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;qk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;);&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;后续的代码就是去更新 softmax 和姨同样的操作去加载 v, 然后再做 gemv（softmax(qk^t) * v），最终得到 attention 输出，这里不再分析具体算法逻辑。&lt;/p&gt;

&lt;h2 id=&quot;二-paged页表原理分析&quot;&gt;二 Paged(页表)原理分析&lt;/h2&gt;

&lt;p&gt;这里的算法分析重点在于分析如何创建 block table、实现逻辑 table 和物理 table 的映射，以及如何针对每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt; 动态分配相应数量的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 用于存储 kv cache。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/Manage_KV_Cache.png&quot; width=&quot;75%&quot; alt=&quot;Manage_KV_Cache&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;这块代码的实现是在 vllm 的请求调度层模块里，&lt;strong&gt;请求调度模块&lt;/strong&gt;的作用是将服务接收到的请求进行状态管理，包括入队出队操作，并且将请求解析成推理引擎的输入，由 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Worker&lt;/code&gt; 模块完成模型推理。&lt;/p&gt;

&lt;p&gt;请求调度的核心源码在：&lt;a href=&quot;https://github.com/vllm-project/vllm/blob/v0.3.0/vllm/core/scheduler.py#L73&quot;&gt;vllm/core/scheduler.py&lt;/a&gt; 文件中。核心实现在 Scheduler 类中，这个类中的 _schedule 函数调用内部的相关完成具体的请求调度。另外，调度器中的下发 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch&lt;/code&gt; 请求到 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Worker&lt;/code&gt; 模块中的相关函数，如 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_schedule_prefills&lt;/code&gt; 函数会先调用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_manager.can_allocate&lt;/code&gt; 函数判断是否有足够内存分配。而在初始化方法 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__init__&lt;/code&gt; 函数有 kv cache 显存块状态管理器的初始化。具体代码如下所示：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;scheduler_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;SchedulerConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CacheConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;lora_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;LoRAConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pipeline_parallel_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;output_proc_callback&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Callable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scheduler_config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;scheduler_config&lt;/span&gt;
    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# Note for LoRA scheduling: the current policy is extremely
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# simple and NOT fair. It can lead to starvation of some
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# LoRAs. This should be improved in the future.
&lt;/span&gt;    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lora_config&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lora_config&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;selfattn&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;scheduler_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;task&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;embedding&quot;&lt;/span&gt;
            &lt;span class=&quot;ow&quot;&gt;or&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_attention_free&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;version&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;placeholder&quot;&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;BlockSpaceManagerImpl&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockSpaceManager&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_block_space_manager_class&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;version&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Create the block space manager.
&lt;/span&gt;    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_manager&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockSpaceManagerImpl&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;sliding_window&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sliding_window&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;enable_caching&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;enable_prefix_caching&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;初始化函数中定义的块管理器 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_manager&lt;/code&gt; 就是我们关心的，它是 KVCache 显存块状态管理器。&lt;strong&gt;用于分配、释放 KVCache 显存块以及状态更新&lt;/strong&gt;，分配显存块时会返回显存块 id，用于 PagedAttention 计算时获取 KVCache 块显存地址。&lt;/p&gt;

&lt;p&gt;值得注意的是，BlockManager（和调度器）实际上只负责管理页表（即管理逻辑块和每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt; 到物理块的映射关系），实际的物理块中的数据不由它管理。这个实际上和 os 中的页表也差不多，BlockManager中的一个物理块就相当于页表中的一个PTE，而不是真实存放数据的物理块，实际进行内存分配的是 CacheEngine。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/BlockManager_CacheManager.png&quot; width=&quot;70%&quot; alt=&quot;BlockManager_CacheManager&quot; /&gt;
&lt;/div&gt;

&lt;h3 id=&quot;21-block-管理相关类&quot;&gt;2.1 Block 管理相关类&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BlockManager&lt;/code&gt; 相关类的包装关系对应文件: block_manager.py -&amp;gt; block_table.py -&amp;gt; naive_block.py&lt;/p&gt;

&lt;h4 id=&quot;blocktable&quot;&gt;BlockTable&lt;/h4&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_table.py&lt;/code&gt; 文件的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BlockTable&lt;/code&gt; 类&lt;strong&gt;将 tokens 序列映射到块列表 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 中&lt;/strong&gt;，其中每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 代表序列一部分的连续内存分配。这些块由 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;DeviceAwareBlockAllocator&lt;/code&gt; 管理，它负责分配和释放这些逻辑块。&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;其中 SelfAttnBlockSpaceManager 类继承自 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BlockSpaceManager&lt;/code&gt;，父类只负责定义接口，子类才负责具体的实现。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;BlockTable 类最主要的函数是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;allocate&lt;/code&gt; 用于将 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tokens&lt;/code&gt; 序列映射相应物理内存块列表 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 中，具体物理内存块的分配是通过设备内存分配器的分配函数 ` self._allocator.allocate_immutable_blocks` 实现的。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;
&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;BlockTable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;管理特定序列的内存块的类。
    
    BlockTable 将一系列令牌映射到一组块中，每个块代表序列的一部分连续内存分配。这些块由 DeviceAwareBlockAllocator 管理，负责块的分配和释放。
    
    参数：
        block_size (int): 每个内存块可以存储的最大令牌数量。
        block_allocator (DeviceAwareBlockAllocator): 用于管理物理块内存的分配器。
        _blocks (Optional[List[Block]], optional): 可选的现有块列表，用于初始化 BlockTable。如果未提供，则创建一个空的 BlockTable。
        max_block_sliding_window (Optional[int], optional): 每个序列需要保留的最大块数。如果为 None，则保留所有块（例如，当不使用滑动窗口时）。至少应满足模型的滑动窗口大小。
    
    属性：
        _block_size (int): 每个内存块可以存储的最大令牌数量。
        _allocator (DeviceAwareBlockAllocator): 用于管理物理块内存的分配器。
        _blocks (Optional[List[Block]]): 由此 BlockTable 管理的逻辑块列表。
        _num_full_slots (int): 当前存储在块中的令牌数量。
    &quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeviceAwareBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;max_block_sliding_window&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 设置每个块的大小
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_allocator&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 设置内存分配器
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockList&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockList&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 初始化块列表
&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_max_block_sliding_window&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_block_sliding_window&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 设置滑动窗口的最大块数
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_num_full_slots&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_get_num_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 获取当前存储的令牌数量
&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;allocate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;GPU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;为给定的令牌序列分配内存块。
        
        此方法分配所需数量的块以存储给定的令牌序列。
        
        参数：
            token_ids (List[int]): 要存储的令牌 ID 序列。
            device (Device, optional): 要分配块的设备。默认为 GPU。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_is_allocated&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 确保尚未分配块
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 确保有令牌需要分配
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocate_blocks_for_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                                     &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                                     &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 分配块
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 更新块表
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_num_full_slots&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 更新存储的令牌数量
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_allocate_blocks_for_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                        &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                        &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;为给定的令牌 ID 分配内存块。
        
        参数：
            prev_block (Optional[Block]): 前一个块。如果没有，则为 None。
            token_ids (List[int]): 要存储的令牌 ID 列表。
            device (Device): 要分配块的设备。
        
        返回：
            List[Block]: 分配的块列表。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;

        &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;tail_token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cur_token_ids&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;chunk_list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cur_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cur_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;tail_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cur_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;extend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocate_immutable_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 分配不可变块
&lt;/span&gt;            &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tail_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tail_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 仅有一个尾块
&lt;/span&gt;            &lt;span class=&quot;n&quot;&gt;cur_token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tail_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

            &lt;span class=&quot;n&quot;&gt;block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocate_mutable_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 分配可变块
&lt;/span&gt;            &lt;span class=&quot;n&quot;&gt;block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cur_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 追加令牌 ID
&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;重置块表为新提供的块（包括其对应的块 ID）。
        
        参数：
            blocks (List[Block]): 新分配的块列表。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_allocate_blocks_for_token_ids&lt;/code&gt; 函数会据块大小 (block_size) 将令牌序列分割成多个块，并使用 DeviceAwareBlockAllocator 来分配这些块。具体函数流程总结如下:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;创建空块列表: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks: List[Block] = []&lt;/code&gt;，也是函数返回的结果&lt;/li&gt;
  &lt;li&gt;分配逻辑块列表，通过分割 token ids 实现，每个子逻辑 block 包含的内容实际是 token ids。&lt;/li&gt;
  &lt;li&gt;分配实际物理块列表: 调用 _ allocator.allocate_immutable_blocks 函数返回实际块列表，并拓展到逻辑块列表 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 中。&lt;/li&gt;
  &lt;li&gt;返回分配的块列表 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt;。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;下述是一个测试示例，展示如何使用 _allocate_blocks_for_token_ids 方法：&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;test_allocate_blocks_for_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 初始化分配器和 BlockTable
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;allocator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeviceAwareBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockTable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_allocator&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 定义测试令牌 ID
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 分配块
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocate_blocks_for_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;GPU&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 更新块表
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_num_full_slots&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 打印分配结果
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;分配的块数量:&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;块 &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;: Token IDs = &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;, Mutable = &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mutable&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 预期：
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# - 前两个块完全填满（4个令牌）
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# - 最后一个块不完全填满（1个令牌，5）
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# 所以应有 3 个块
&lt;/span&gt;    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;应分配 3 个块&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;第一个块令牌 ID 错误&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;第二个块令牌 ID 错误&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;第三个块令牌 ID 错误&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mutable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;第一个块应为不可变块&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mutable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;第二个块应为不可变块&quot;&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocated_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mutable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;第三个块应为可变块&quot;&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;测试通过！&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__name__&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;__main__&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;test_allocate_blocks_for_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;运行上述测试代码，结果如下所示:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;分配的块数量: 3
块 0: Token IDs &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;1, 2, 3, 4], Mutable &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; False
块 1: Token IDs &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;5, 6, 7, 8], Mutable &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; False
块 2: Token IDs &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;[&lt;/span&gt;9], Mutable &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; True
测试通过！
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h4 id=&quot;cpugpublockallocator-类&quot;&gt;CpuGpuBlockAllocator 类&lt;/h4&gt;

&lt;p&gt;前面的分析内容我们知道，请求到逻辑表的分配是通过 CpuGpuBlockAllocator 类实现的。CpuGpuBlockAllocator 类是一个内存块分配器，能够在 CPU 和 GPU 内存中分配和管理内存块。它实现了 DeviceAwareBlockAllocator 基类的接口，提供了在多个设备（如 CPU 和 GPU）之间分配、释放、分叉（forking）和交换（swapping）内存块的功能。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;CpuGpuBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;DeviceAwareBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;一个能够在 CPU 和 GPU 内存中分配块的块分配器。
    
    该类实现了 `DeviceAwareBlockAllocator` 接口，提供了在 CPU 和 GPU 设备上分配和管理内存块的功能。
    
    `CpuGpuBlockAllocator` 维护了独立的 CPU 和 GPU 内存块池，并允许在这些内存池之间进行分配、释放、分叉和交换操作。
    &quot;&quot;&quot;&lt;/span&gt;
    
    &lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;create&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;allocator_type&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;DeviceAwareBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;创建一个具有指定配置的 CpuGpuBlockAllocator 实例。
        
        这个静态方法根据提供的参数创建并返回一个 CpuGpuBlockAllocator 实例。它初始化了 CPU 和 GPU 块分配器，指定块的数量、块大小和分配器类型。
        
        参数：
            allocator_type (str): 用于 CPU 和 GPU 块的块分配器类型。目前支持的值为 &quot;naive&quot; 和 &quot;prefix_caching&quot;。
            num_gpu_blocks (int): 要为 GPU 内存分配的块数量。
            num_cpu_blocks (int): 要为 CPU 内存分配的块数量。
            block_size (int): 每个块的大小，以令牌数量表示。
        
        返回：
            DeviceAwareBlockAllocator: 一个具有指定配置的 CpuGpuBlockAllocator 实例。
        
        注意：
            - 块 ID 是连续分配的，GPU 块 ID 在前，CPU 块 ID 在后。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# 对于 HPU，块 ID 0 仅用于填充
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;reserved_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;current_platform&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_hpu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reserved_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reserved_blocks&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;gpu_block_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;cpu_block_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:]&lt;/span&gt;
    
        &lt;span class=&quot;c1&quot;&gt;# 根据 allocator_type 创建不同类型的块分配器
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocator_type&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;naive&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;gpu_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockAllocator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NaiveBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;create_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NaiveBlock&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 创建不可变块的函数
&lt;/span&gt;                &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gpu_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
            &lt;span class=&quot;n&quot;&gt;cpu_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockAllocator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;NaiveBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;create_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;NaiveBlock&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 创建不可变块的函数
&lt;/span&gt;                &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cpu_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;elif&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocator_type&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;prefix_caching&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;gpu_allocator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PrefixCachingBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gpu_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
            &lt;span class=&quot;n&quot;&gt;cpu_allocator&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;PrefixCachingBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cpu_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;raise&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;ValueError&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;未知的分配器类型 &lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocator_type&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CpuGpuBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cpu_block_allocator&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cpu_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;gpu_block_allocator&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gpu_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cpu_block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                 &lt;span class=&quot;n&quot;&gt;gpu_block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;初始化 CpuGpuBlockAllocator 实例。
        
        参数：
            cpu_block_allocator (BlockAllocator): 用于管理 CPU 内存块的分配器。
            gpu_block_allocator (BlockAllocator): 用于管理 GPU 内存块的分配器。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# 确保 CPU 和 GPU 分配器的块 ID 没有交集
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;cpu_block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;all_block_ids&lt;/span&gt;
            &lt;span class=&quot;o&quot;&gt;&amp;amp;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gpu_block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;all_block_ids&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;CPU 和 GPU 块分配器的块 ID 不能有交集&quot;&lt;/span&gt;
    
        &lt;span class=&quot;c1&quot;&gt;# 将 CPU 和 GPU 分配器存储在字典中
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocators&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;CPU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cpu_block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;GPU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gpu_block_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;}&lt;/span&gt;
    
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_swap_mapping&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{}&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 记录交换操作的块 ID 映射
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_null_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 用于存储空块
&lt;/span&gt;    
        &lt;span class=&quot;c1&quot;&gt;# 记录每个块 ID 对应的分配器
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_ids_to_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;{}&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocator&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocators&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;items&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;all_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
                &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_ids_to_allocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;allocator&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;allocate_mutable_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;在指定设备上分配一个新的可变块。
        
        参数：
            prev_block (Optional[Block]): 序列中的前一个块。用于前缀哈希。
            device (Device): 要分配新块的设备。
        
        返回：
            Block: 新分配的可变块。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocators&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocate_mutable_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;allocate_immutable_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                  &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]],&lt;/span&gt;
                                  &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;在指定设备上分配一组包含提供的块令牌 ID 的不可变块。
        
        参数：
            prev_block (Optional[Block]): 序列中的前一个块。用于前缀哈希。
            block_token_ids (List[int]): 要存储在新块中的块令牌 ID 列表。
            device (Device): 要分配新块的设备。
        
        返回：
            List[Block]: 新分配的包含提供的块令牌 ID 的不可变块列表。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocators&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;allocate_immutable_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h4 id=&quot;naiveblockallocator&quot;&gt;NaiveBlockAllocator&lt;/h4&gt;

&lt;p&gt;上述类其实也还是一层包装，在看针对不同场景的类，以简单的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NaiveBlockAllocator&lt;/code&gt; 类为例分析。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;NaiveBlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BlockAllocator&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;一个简单的块分配器，不支持前缀缓存。
    
    该类实现了 `BlockAllocator` 接口，提供了基本的内存块分配和释放功能。
    
    参数：
        create_block (Block.Factory): 用于创建新块的工厂函数。当 NaiveBlockAllocator 被前缀缓存分配器组合使用时，必须能够创建前缀缓存块（但不应了解其余细节）。
        num_blocks (int): 要管理的块的总数量。
        block_size (int): 每个块的大小，以令牌数量表示。
        block_ids (Optional[Iterable[int]], optional): 可选的块 ID 可迭代对象。如果未提供，块 ID 将从 0 到 `num_blocks - 1` 顺序分配。
    &quot;&quot;&quot;&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;create_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Iterable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BlockPool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 如果未提供块 ID，则顺序分配
&lt;/span&gt;        
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_free_block_indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;deque&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;deque&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 初始化自由块队列
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_all_block_indices&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;frozenset&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 所有块 ID 的集合
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_all_block_indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;块 ID 数量应与 num_blocks 相等&quot;&lt;/span&gt;
    
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_refcounter&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;RefCounter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;all_block_indices&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_free_block_indices&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;
    
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_cow_tracker&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CopyOnWriteTracker&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;refcounter&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_refcounter&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_pool&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;extra_factor&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# 预分配 &quot;num_blocks * extra_factor&quot; 个块对象。
&lt;/span&gt;            &lt;span class=&quot;c1&quot;&gt;# &quot;* extra_factor&quot; 是为了允许分配比物理块更多的块对象
&lt;/span&gt;            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BlockPool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;create_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                         &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;extra_factor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# 在这种情况下，块池由调用者提供，意味着可能需要在分配器之间共享块池
&lt;/span&gt;            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_pool&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_pool&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;allocate_immutable_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]],&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;分配一组新的不可变块，包含提供的块令牌 ID，并链接到前一个块。
        
        参数：
            prev_block (Optional[Block]): 序列中的前一个块。如果为 None，则分配的块为序列中的第一个块。
            block_token_ids (List[List[int]]): 要存储在新块中的块令牌 ID 列表。
            device (Optional[str], optional): 分配块的设备。对于 NaiveBlockAllocator，通常为 None。
        
        返回：
            List[Block]: 新分配的不可变块列表。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;NaiveBlockAllocator 不支持设备参数&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 需要分配的块数量
&lt;/span&gt;    
        &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocate_block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 分配块 ID
&lt;/span&gt;    
        &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;c1&quot;&gt;# 初始化块，设置前一个块、令牌 ID、块大小和物理块 ID
&lt;/span&gt;            &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_token_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                &lt;span class=&quot;n&quot;&gt;physical_block_id&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;allocate_mutable_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;分配一个新的可变块，并链接到前一个块。
        
        参数：
            prev_block (Optional[Block]): 序列中的前一个块。如果为 None，则分配的块为序列中的第一个块。
            device (Optional[str], optional): 分配块的设备。对于 NaiveBlockAllocator，通常为 None。
        
        返回：
            Block: 新分配的可变块。
        &quot;&quot;&quot;&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;device&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;NaiveBlockAllocator 不支持设备参数&quot;&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_allocate_block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 分配一个块 ID
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;init_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;prev_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                            &lt;span class=&quot;n&quot;&gt;token_ids&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[],&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 初始化为空令牌 ID
&lt;/span&gt;                                            &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                            &lt;span class=&quot;n&quot;&gt;physical_block_id&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mutable&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# 设置为可变块
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;NaiveBlockAllocator 类的块函数也是调用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;self._block_pool.init_block&lt;/code&gt; 接口，再跳转到 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vllm/core/block/common.py&lt;/code&gt; 文件 BlockPool 类的 init_block 函数，其是通过 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;create_block: Block.Factory&lt;/code&gt; 类的相关接口来实现的。&lt;/p&gt;

&lt;p&gt;也就是说 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;NaiveBlockAllocator&lt;/code&gt;、&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CpuGpuBlockAllocator&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PrefixCachingBlockAllocator&lt;/code&gt; 三个 block 分配器类是为了适应不同设备和场景（简单场景和 PrefixCaching）而设计出来的，但真正的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 定义类是通过块工厂创建函数 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Block.Factory&lt;/code&gt; 创建得到。&lt;/p&gt;

&lt;h4 id=&quot;blocklist-类&quot;&gt;BlockList 类&lt;/h4&gt;

&lt;p&gt;BlockList 类通过维护块及其对应的 ID 列表，优化对物理块 ID 的访问。提供方法来更新列表、向块添加令牌 ID，以及检索块或其 ID，避免在每次迭代块管理器时重新构建块 ID 列表。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;BlockList&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;This class is an optimization to allow fast-access to physical 
    block ids. It maintains a block id list that is updated with the 
    block list and this avoids the need to reconstruct the block id 
    list on every iteration of the block manager
    &quot;&quot;&quot;&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;List&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[]&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;update&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_add_block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BlockId&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_update_block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                         &lt;span class=&quot;n&quot;&gt;new_block_id&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BlockId&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;assert&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_block_id&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;is&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_block_ids&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;block_index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_block_id&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;#####省略代码######
&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h4 id=&quot;逻辑-block-管理类-selfattnblockspacemanager&quot;&gt;逻辑 block 管理类-SelfAttnBlockSpaceManager&lt;/h4&gt;

&lt;p&gt;SelfAttnBlockSpaceManager 类用于管理注意力机制中 KV（Key-Value）缓存块，主要负责&lt;strong&gt;逻辑内存块的分配、交换&lt;/strong&gt;、以及其他高级功能如前缀缓存（即请求共享前缀&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;systemp prompt&lt;/code&gt;）、分叉/写时复制（Forking/Copy-on-Write）和滑动窗口内存分配。&lt;/p&gt;

&lt;p&gt;和前面几个是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 模块内部类不同，它是对外部模块提供的类，但 BlockManager（和调度器）实际上只负责管理页表（即管理每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;seq&lt;/code&gt; 到物理块的映射关系），实际的物理块中的数据不由它管理。&lt;/p&gt;

&lt;p&gt;先看构造函数 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__init__()&lt;/code&gt;，函数中维护了一个逻辑 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt;，它是一个字典，形式如 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables: Dict[SeqId, BlockTable] = {}&lt;/code&gt;，这个字典维护着整个 vllm 系统中每个 Sequence 实例到它的 block_table 之间的映射关系，&lt;strong&gt;方便快速查找到当前文本序列(Sequence) 对应的 PhysicalTokenBlock&lt;/strong&gt;。构造函数的输入参数比较多，这里重点看三个参数的意义：&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_size&lt;/code&gt;: 每个内存块的大小，表示可以存储多少个令牌的 KV 数据。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_gpu_blocks&lt;/code&gt;: 分配在 GPU 上的内存块数量。&lt;/li&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_cpu_blocks&lt;/code&gt;: 分配在 CPU 上的内存块数量。&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;__init__()&lt;/code&gt; 的部分代码如下所示:&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/block_tables.png&quot; width=&quot;60%&quot; alt=&quot;block_tables&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;SelfAttnBlockSpaceManager&lt;/code&gt; 类中内存块分配相关有 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;allocate&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_allocate_sequence&lt;/code&gt; 函数，分别用于为给定的序列组分配所需的内存块和为单个序列分配块表。&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;块的分配策略由 CpuGpuBlockAllocator 来实现&lt;/strong&gt;，分配策略有 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;naive&lt;/code&gt; 和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;prefix_caching&lt;/code&gt; 两种。CpuGpuBlockAllocator 是一个管理 CPU 和 GPU 内存块的分配器类，继承自 DeviceAwareBlockAllocator，提供了在 CPU 和 GPU 设备上分配和管理内存块的功能。&lt;/p&gt;

&lt;p&gt;块的分配策略类涉及到 4 个类，它们的类图关系如下所示：&lt;/p&gt;

&lt;pre&gt;&lt;code class=&quot;language-mermaid&quot;&gt;classDiagram
    DeviceAwareBlockAllocator &amp;lt;-- CpuGpuBlockAllocator
    BlockAllocator &amp;lt;-- NaiveBlockAllocator
    class DeviceAwareBlockAllocator {
        +allocate_mutable_block()
        +allocate_immutable_block()
        +free()
    }
    class BlockAllocator {
        +allocate_mutable_block()
        +allocate_immutable_block()
        +free()
    }
    class CpuGpuBlockAllocator {
        -_allocators: Dict
        +create()
        +swap()
    }
    class NaiveBlockAllocator {
        -_free_blocks: List
        +allocate_block()
        +free()
    }
&lt;/code&gt;&lt;/pre&gt;

&lt;p&gt;&lt;strong&gt;类的关系说明&lt;/strong&gt;：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;DeviceAwareBlockAllocator
    &lt;ul&gt;
      &lt;li&gt;顶层接口&lt;/li&gt;
      &lt;li&gt;定义设备相关的内存块分配方法&lt;/li&gt;
      &lt;li&gt;CpuGpuBlockAllocator 实现此接口&lt;/li&gt;
      &lt;li&gt;BlockAllocator&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;基础抽象类
    &lt;ul&gt;
      &lt;li&gt;定义基本的内存块分配操作&lt;/li&gt;
      &lt;li&gt;NaiveBlockAllocator 继承此类&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;CpuGpuBlockAllocator
    &lt;ul&gt;
      &lt;li&gt;实现 DeviceAwareBlockAllocator 接口&lt;/li&gt;
      &lt;li&gt;&lt;strong&gt;内部使用 BlockAllocator 的具体实现(如 NaiveBlockAllocator)&lt;/strong&gt;&lt;/li&gt;
      &lt;li&gt;管理 CPU 和 GPU 两种设备上的内存块&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
  &lt;li&gt;NaiveBlockAllocator
    &lt;ul&gt;
      &lt;li&gt;BlockAllocator 的简单实现&lt;/li&gt;
      &lt;li&gt;被 CpuGpuBlockAllocator 使用来管理具体设备上的内存块&lt;/li&gt;
    &lt;/ul&gt;
  &lt;/li&gt;
&lt;/ol&gt;

&lt;h3 id=&quot;22-slot-mapping&quot;&gt;2.2 slot mapping&lt;/h3&gt;

&lt;p&gt;上一节讲的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt; 是逻辑层面的，而传给实际计算 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kernel&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt; 是形状为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[batch_size, max_blocks_per_seq]&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.Tensor&lt;/code&gt; 表示每个序列的块地址列表，第一维表示序列 ID，第二维是物理块列表。&lt;/p&gt;
&lt;ul&gt;
  &lt;li&gt;例如，[0, 1, 2] 表示 tokens 存储在 kv cache 的第 0、1 和 2 个块中。&lt;/li&gt;
  &lt;li&gt;每个块最多可容纳 block_size 个 tokens。&lt;/li&gt;
  &lt;li&gt;如果启用了 cuda-graph 捕获，则第二维将填充至 max_blocks_per_seq 的大小。&lt;/li&gt;
&lt;/ul&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;block_tables&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Optional&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;另外 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vllm/attention/backends/utils.py&lt;/code&gt; 文件中提供了一些函数用于计算“槽映射”（slot mapping），并将序列中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;token&lt;/code&gt; 索引映射到内存块中的槽索引。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/compute_slot_mapping_python.png&quot; width=&quot;60%&quot; alt=&quot;_allocate_kv_cache&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;主函数 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compute_slot_mapping&lt;/code&gt;，根据是否进行性能分析、是否需要填充以及使用哪种实现方式（Python 或 NumPy），计算序列的槽映射。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/compute_slot_mapping.png&quot; width=&quot;60%&quot; alt=&quot;compute_slot_mapping&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;在模型 forward 过程中调用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;flash_attention&lt;/code&gt; 做注意力分值计算时会按照 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;slot_mapping&lt;/code&gt; 指引位置将本层的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kv cache&lt;/code&gt; 存储到 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vllm&lt;/code&gt; 初始化过程中分配的全零张量中，这在 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cuda&lt;/code&gt; 函数中实现。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/reshape_and_cache_flash.png&quot; width=&quot;60%&quot; alt=&quot;reshape_and_cache_flash&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;FlashAttentionMetadata&lt;/code&gt; 数据类的定义如下：&lt;/p&gt;
&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dataclass&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;FlashAttentionMetadata&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# NOTE(sang): Definition of context_len, query_len, and seq_len.
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# |---------- N-1 iteration --------|
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# |---------------- N iteration ---------------------|
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# |- tokenA -|......................|-- newTokens ---|
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# |---------- context_len ----------|
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# |-------------------- seq_len ---------------------|
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;#                                   |-- query_len ---|
&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_actual_tokens&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;  &lt;span class=&quot;c1&quot;&gt;# Number of tokens excluding padding.
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;max_query_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;query_start_loc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;max_seq_len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;seq_start_loc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;block_table&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;slot_mapping&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Tensor&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;h3 id=&quot;23-物理-block-分配类-cacheengine&quot;&gt;2.3 物理 block 分配类-CacheEngine&lt;/h3&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;CacheEngine&lt;/code&gt; 给GPU分配空间的方式，本质上通过 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pytorch&lt;/code&gt; 的接口在 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;gpu&lt;/code&gt; 上分配 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_blocks&lt;/code&gt; 大小的零 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tensor&lt;/code&gt; 来作为物理块的空间的，而不是直接使用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cudaMalloc&lt;/code&gt; 进行操作的。&lt;/p&gt;

&lt;p&gt;和 &lt;a href=&quot;https://github.com/ModelTC/lightllm&quot;&gt;lightllm&lt;/a&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tokenattention&lt;/code&gt; 直接提前分配一个可用最大形状的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Tensor&lt;/code&gt;，且后续 kv cache 的获取和释放都从这里操作不同。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;PagedAttention&lt;/code&gt; 为 transformer 模型的每个 layer 都分配一个&lt;strong&gt;可用最大尺寸&lt;/strong&gt;的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tensor&lt;/code&gt;，并组合成列表的形式。也就是如果模型为 layer 为 16，对应的 kv cache 就是一个拥有 16 个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tensor&lt;/code&gt; 的列表。&lt;/p&gt;

&lt;p&gt;代码通过循环遍历 self.num_attention_layers，&lt;strong&gt;为每个层分配独立的 KV 缓存张量&lt;/strong&gt;，确保每层的 kv 张量能够被单独存储和访问，避免不同层之间的干扰。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_gpu_blocks&lt;/code&gt; 会通过 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model_executor.determine_num_available_blocks&lt;/code&gt; 函数获取当前模型在指定设备上的每个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;layer&lt;/code&gt; 的最大可用物理 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 数目。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/allocate_kv_cache.png&quot; width=&quot;65%&quot; alt=&quot;_allocate_kv_cache&quot; /&gt;
&lt;/div&gt;

&lt;p&gt;绝大部分后端的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kv_cache_shape&lt;/code&gt; 形状都是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;[2, num_blocks, block_size, num_kv_heads, head_size]&lt;/code&gt;。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;c1&quot;&gt;# vllm/attention/ops/paged_attn.py
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;PagedAttention&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_kv_cache_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;num_kv_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;head_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;...]:&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_kv_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;head_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;其中 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_size&lt;/code&gt; 由 llm 服务启动参数设定，而 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_gpu_blocks&lt;/code&gt; 既可以在服务启动的时候通过预热自动跑出来，也可通过服务启动参数 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_gpu_blocks_override&lt;/code&gt; 由用户自行设定并覆盖 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_gpu_blocks&lt;/code&gt;。&lt;/p&gt;

&lt;div align=&quot;center&quot;&gt;
&lt;img src=&quot;../images/vllm_pagedattention/initialize_kv_caches.png&quot; width=&quot;65%&quot; alt=&quot;llm_memory_waste&quot; /&gt;
&lt;/div&gt;

&lt;h4 id=&quot;231-num_gpu_blocks-获取-determine_num_available_blocks-函数&quot;&gt;2.3.1 num_gpu_blocks 获取-determine_num_available_blocks 函数&lt;/h4&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;determine_num_available_blocks&lt;/code&gt; 函数的具体实现是在 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;worker&lt;/code&gt; 目录下的各个设备的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;work.py&lt;/code&gt; 实现，先以简单 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cpu_work.py&lt;/code&gt; 的实现为例分析，cpu 中的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;num_gpu_blocks&lt;/code&gt;（实际是 cpu 的可用内存块数量）计算是通过理论计算得到的，通过 cpu/gpu 设备可用的内存空间除以相关 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kv_cache_block_size&lt;/code&gt; 得到可用 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 数量。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;n&quot;&gt;cache_block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_cache_block_size_bytes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# self.cache_config.cpu_kvcache_space_bytes 可通过环境变量 `VLLM_CPU_KVCACHE_SPACE` 定义
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_cpu_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cpu_kvcache_space_bytes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;上述代码中的 get_cache_block_size_bytes 函数实际上是先计算对应模型 kv 每个 token 占用的空间，又因为 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_size&lt;/code&gt; 表示一个 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block&lt;/code&gt; 对应 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_size&lt;/code&gt; 个 tokens，自然需要再乘以 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_size&lt;/code&gt;。&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;get_cache_block_size&lt;/code&gt; 具体实现代码如下所示：&lt;/p&gt;
&lt;blockquote&gt;
  &lt;p&gt;值得一提的是，指定设备的 kv cache 可用空间，以及可分配的 tokens 数量，也可以参考我的 &lt;a href=&quot;https://github.com/harleyszhang/llm_counts&quot;&gt;llm_counts&lt;/a&gt; 工具理论计算得到，代码更优雅，使用更简单。&lt;/p&gt;
&lt;/blockquote&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;staticmethod&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_cache_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;cache_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;str&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ModelConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;parallel_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ParallelConfig&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;head_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_head_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_num_kv_heads&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parallel_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_num_layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parallel_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;key_cache_block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_heads&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;head_size&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;value_cache_block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;key_cache_block&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# 每层 layer 的 block size
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;total&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;key_cache_block&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;value_cache_block&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;&quot;auto&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;STR_DTYPE_TO_TORCH_DTYPE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cache_dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;dtype_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;element_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;total&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;gpu_work.py&lt;/code&gt; 中的实现的计算逻辑和 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cpu_work.py&lt;/code&gt; 一样，不同的是，这里借助了 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torch.cuda.mem_get_info()&lt;/code&gt; 函数直接获取 gpu 总内存和在加载完模型之后的剩余显存。另外，用户在启动 llm 服务时，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;--block-size&lt;/code&gt; 参数可设置的取值范围是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;{8,16,32,64,128}&lt;/code&gt;，默认值是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;16&lt;/code&gt;。&lt;/p&gt;

&lt;div class=&quot;language-python highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;o&quot;&gt;@&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inference_mode&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;determine_num_available_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gpu_memory_utilization&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Tuple&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]:&lt;/span&gt;
    &lt;span class=&quot;s&quot;&gt;&quot;&quot;&quot;
    评估模型的峰值内存使用情况，以确定在不发生内存溢出的情况下可以分配的 KV（键值）缓存块的数量。

    该方法首先清理 CUDA 缓存，然后使用虚拟输入执行一次前向传播，以评估模型的内存使用情况。
    接着，计算在剩余可用内存下，最多可以分配的 GPU 和 CPU 缓存块数量。

    提示：
        可以通过调整 `gpu_memory_utilization` 参数来限制 GPU 内存的使用。
    &quot;&quot;&quot;&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 清理 CUDA 缓存，以确保获取准确的内存使用信息
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;empty_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 使用虚拟输入执行一次前向传播，以评估模型的内存使用情况
&lt;/span&gt;    &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_runner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;profile_run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 同步 CUDA 操作，确保内存信息准确
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;synchronize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# 获取当前 GPU 的空闲内存和总内存（单位：字节）
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;free_memory_pre_profile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;total_gpu_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mem_get_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 计算模型加载后的峰值内存使用量
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# Get the peak memory allocation recorded by torch
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;peak_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;memory_stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;allocated_bytes.all.peak&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# 清理未使用的缓存，计算非Torch分配的内存
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;empty_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;torch_allocated_bytes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;memory_stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;allocated_bytes.all.current&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;total_allocated_bytes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mem_get_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mem_get_info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;non_torch_allocations&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;total_allocated_bytes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch_allocated_bytes&lt;/span&gt;
    
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;non_torch_allocations&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;peak_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;non_torch_allocations&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;available_kv_cache_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;total_gpu_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gpu_memory_utilization&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;peak_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    
    &lt;span class=&quot;c1&quot;&gt;# 计算每个缓存块的大小
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;cache_block_size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_get_cache_block_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model_config&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 计算在剩余可用内存下，最多可以分配的 GPU 缓存块数量
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;int&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
        &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;total_gpu_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;gpu_memory_utilization&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;
         &lt;span class=&quot;n&quot;&gt;peak_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;//&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cache_block_size&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 确保缓存块数量不为负数
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;logger&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;info&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&quot;Memory profiling results: total_gpu_memory=%.2fGiB &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&quot; initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&quot; memory_usage_post_profile=%.2fGib &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&quot; non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB &lt;/span&gt;&lt;span class=&quot;se&quot;&gt;\n&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;&quot;&lt;/span&gt;
            &lt;span class=&quot;s&quot;&gt;&quot; gpu_memory_utilization=%.2f&quot;&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;total_gpu_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;total_gpu_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;free_memory_pre_profile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;peak_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;non_torch_allocations&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;total_allocated_bytes&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;non_torch_allocations&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;available_kv_cache_memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;gpu_memory_utilization&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# 进行垃圾回收，释放未使用的内存
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;gc&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;collect&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 再次清理 CUDA 缓存
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cuda&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;empty_cache&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# 返回可分配的 GPU 和 CPU 缓存块数量（此处 CPU 块数量为 0）
&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_gpu_blocks&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;总结：至此 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;vllm&lt;/code&gt; 的 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pagedattention&lt;/code&gt; 内核设计和动态分配、管理 kv cache 内存的模块分析完毕，难点主要有三个：&lt;/p&gt;
&lt;ol&gt;
  &lt;li&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cpu/gpu&lt;/code&gt; 设备在指定模型上的可分配的内存 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt; 的计算;&lt;/li&gt;
  &lt;li&gt;(不同设备的) &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt; 的创建和管理，以及基于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Scheduler&lt;/code&gt; 的调度策略去分配、释放和回收 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;blocks&lt;/code&gt;。&lt;/li&gt;
  &lt;li&gt;最后就是 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;pagedattention&lt;/code&gt; 内核代码中相关线程索引和偏移的计算怎么改成基于 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt; 的形式，多了三个参数：&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;k_cache, v_cache, block_table&lt;/code&gt;，其中 &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;block_tables&lt;/code&gt; 尺寸为 [num_seqs, max_num_blocks_per_seq]，&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;v_cache&lt;/code&gt; 尺寸为 [num_blocks, num_heads, head_size, block_size]。&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;以上，这都需要反复阅读理解代码才能得到清晰的理解。&lt;/p&gt;

&lt;h2 id=&quot;参考资料&quot;&gt;参考资料&lt;/h2&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;a href=&quot;https://docs.google.com/presentation/d/1QL-XPFXiFpDBh86DbEegFXBXFXjix4v032GhShbKf3s/edit?pli=1#slide=id.p&quot;&gt;vllm First Meetup&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/658233994&quot;&gt;CUDA PagedAttention kernel源码解析–大模型推理服务框架vLLM要点简析（下）&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/668736097&quot;&gt;PageAttention代码走读&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/661360117&quot;&gt;vLLM &amp;amp; PagedAttention 论文深度解读（二）—— vLLM 服务架构及源码实现&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;a href=&quot;https://zhuanlan.zhihu.com/p/5085306075&quot;&gt;vLLM源码阅读&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;
</description>
      <pubDate>2024-11-17</pubDate>
      <link>https://c1lovez1.github.io/2024-11-17/vllm-pagedattention.html</link>
      <guid isPermaLink="true">https://c1lovez1.github.io/2024-11-17/vllm-pagedattention.html</guid>
    </item>
    
		
  </channel>
</rss>
