基于论文 Attention Is All You Need (Vaswani et al., 2017)
1. 概述 Transformer 是一种基于自注意力机制(Self-Attention) 的深度学习模型,完全抛弃了传统的 RNN/CNN 结构,仅依靠注意力机制来捕捉序列中的全局依赖关系。
核心优势:
并行计算 :不像 RNN 需要逐步处理,Transformer 可以并行处理整个序列
长距离依赖 :自注意力机制直接计算序列中任意两个位置的关系,不受距离限制
可扩展性 :易于堆叠多层,模型容量大
架构总览:
1 2 3 输入序列 → [ Embedding + Positional Encoding ] → N × Encoder Block → Encoder 输出 ↓ 目标序列 → [ Embedding + Positional Encoding ] → N × Decoder Block → 输出概率
2. 自注意力机制 (Self-Attention) 2.1 核心思想 对于序列中的每个位置,通过查询(Query) 去与所有位置的键(Key) 做匹配,得到注意力权重,再用权重对值(Value) 加权求和,得到该位置的输出。
直观类比: 在数据库中查询信息 —— Query 是搜索关键词,Key 是每条记录的索引,Value 是记录内容。Query 与 Key 越匹配,对应的 Value 权重越大。
2.1.1 Q、K、V三者关系的整体理解 自注意力的核心公式:$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V$
Q、K、V来自同一个输入 $x$,分别通过 $W_Q$、$W_K$、$W_V$ 投影得到。它们分工协作,各司其职:
角色
功能
优化目标
直觉类比
Q
决定”我想要什么信息”
找到最相关的K
招聘者发布需求
K
决定”我有什么信息标签”
被正确的Q匹配到
求职者的简历标签
V
决定”被选中后提供什么内容”
提供最有用的信息
求职者的实际能力
整体数据流:
1 2 3 输入 x ──→ W_Q → Q "我需要什么?" ──→ 与K匹配 ──→ 得到注意力权重α ──→ 对V加权求和 ──→ 输出 输入 x ──→ W_K → K "我是谁?" ──→ 被Q匹配 ↑ ↑ 输入 x ──→ W_V → V "我有什么?" ──→──────────────────────────────────────被α选中──→ 贡献内容
Q和K的交互 决定”路由”:哪个位置的信息流向哪个位置(注意力权重 $\alpha$)
V 决定”内容”:被选中后实际提供的语义信息
整个过程可以理解为:Q×K 建立信息通道,V在通道中传输内容
为什么三者不会坍缩为相同的向量?
三者的梯度路径完全不同:
1 2 3 Loss → output → V (直接路径:V直接影响输出内容) Loss → output → α → Q (间接路径:Q只影响路由,再通过路由影响输出) Loss → output → α → K (间接路径:K只影响路由,再通过路由影响输出)
V的梯度:$\frac{\partial L}{\partial V_j} = \sum_i \alpha_{ij} \cdot \frac{\partial L}{\partial \text{output}_i}$ → V优化”被选中后提供什么”
K的梯度:$\frac{\partial L}{\partial K_j} = \sum_i \frac{\partial L}{\partial \alpha_{ij}} \cdot \frac{\partial \alpha_{ij}}{\partial K_j}$ → K优化”应该被谁关注”
Q的梯度:$\frac{\partial L}{\partial Q_i} = \sum_j \frac{\partial L}{\partial \alpha_{ij}} \cdot \frac{\partial \alpha_{ij}}{\partial Q_i}$ → Q优化”应该关注谁”
三条梯度路径指向三个不同的优化目标,合并参数会同时满足三个矛盾需求,梯度会自动避免坍缩。
Q和K的分化:从随机到协作 初始时 :W_Q和W_K都是Xavier随机初始化的,从第一天起Q和K就不同,但这种差异是”随机噪声”,没有语义意义。
训练中逐步分化 的关键是——Q和K承担不同的角色,梯度方向不同,导致参数朝不同方向演化 :
**Q是”需求方”**:决定”我想要什么信息”——优化目标是让Q能找到最相关的K
**K是”供给方”**:决定”我有什么信息标签”——优化目标是让K能被正确的Q匹配到
角色不同 → 梯度不同 → 参数分化。
逐步分化的具体过程(以翻译”我爱你”→”I love you”为例):
Step 0(随机初始化): Q和K是随机投影,注意力权重接近均匀分布(每个位置获得≈1/3的注意力),模型没有学到任何有用的模式。
Step 100(初步分化): 模型发现”翻译’I’时应关注源序列的’我’”。为了实现这一点:
W_Q调整 :让”I”的Q向量更接近”我”的K向量 → Q学习”我是代词,我需要找名词”
W_K调整 :让”我”的K向量更容易被”I”的Q匹配到 → K学习”我是中文代词,我的标签应该容易被英文代词找到”
两者的梯度方向天然相反 :
W_Q的梯度推动Q去”寻找”(Q更泛化,能匹配多个相关K)
W_K的梯度推动K去”被找到”(K更独特,只被相关Q匹配)
Step 1000(角色定型): 经过大量训练样本后:
Q形成了查询模式 :代词的Q总是寻找名词、动词的Q总是寻找宾语…
K形成了索引模式 :名词的K总是被代词Q匹配、宾语的K总是被动词Q匹配…
对注意力权重 $\alpha_{ij} = \text{softmax}(Q_i K_j^T / \sqrt{d_k})$,注意 $\frac{\partial \alpha_{ij}}{\partial W_Q}$ 和 $\frac{\partial \alpha_{ij}}{\partial W_K}$ 的结构不同 :
对Q的梯度:$\frac{\partial \alpha_{ij}}{\partial Q_i} = \alpha_{ij}(e_j - \sum_k \alpha_{ik} e_k)$ → 推动Q远离”平均K”,朝”目标K”移动
对K的梯度:$\frac{\partial \alpha_{ij}}{\partial K_j} = \alpha_{ij}(Q_i - \sum_k \alpha_{ik} Q_i \cdot \text{缩放因子})$ → 推动K远离”平均匹配”,朝”专属匹配”移动
直觉类比: Q像”招聘者”,学习发布什么样的需求描述能找到最合适的人;K像”简历标签”,学习标注什么样的技能关键词能被最合适的招聘者搜到。两者的优化方向不同,经过多轮互动(训练迭代),双方各自演化出不同的”语言体系”。
K和V的分工:路由与内容 核心区别:K控制”能否被选中”,V控制”被选中后提供什么”。
K :决定注意力权重 $\alpha_{ij}$——“哪些Q会关注我”
V :决定被选中后的输出内容——“关注我后得到什么信息”
V的梯度直接 影响output——V是最终输出的原材料;K的梯度间接 影响output——K只改变注意力权重(路由),再通过权重影响output。V优化的是”产品质量”,K优化的是”产品曝光”。
逐步分化过程(继续以翻译为例):
Step 0: W_K和W_V随机初始化,K起不到索引作用(注意力接近均匀),V提供的内容也是随机噪声。
Step 100: 模型发现错误——翻译”我”时输出不对。两条优化路径同时启动:
路径A(K的优化):”我”的K没有被”I”的Q足够关注 → 增大K(“我”)与Q(“I”)的点积 → K学习”我是代词类,代词Q应该找我”
路径B(V的优化):即使K被正确关注了,V(“我”)提供的信息不对 → 修正V(“我”)的向量 → V学习”我被选中后应该提供’第一人称’的语义信息”
两者必须协同收敛 :如果K不被关注 → V再好也没用(信息无法传递);如果K被关注了但V提供垃圾 → 输出仍然错误。
Step 1000: K和V形成稳定的分工:
1 2 3 4 5 6 7 token "我" : K ("我" ) ≈ [代词标志, 主语标志, ...] ← 紧凑的"标签向量" ,编码"我是什么类型" V ("我" ) ≈ [第一人称, 人称代词, ...] ← 丰富的"内容向量" ,编码"我有什么含义" token "爱" : K ("爱" ) ≈ [动词标志, 情感标志, ...] ← 标签:我是动词类,宾语Q应该找我 V ("爱" ) ≈ [喜爱, 情感动作, ...] ← 内容:被选中后提供"喜爱" 的语义
K追求紧凑区分性 (名词K vs 动词K差异大,方便路由);V追求丰富语义性 (同一个词的V包含多方面含义)。这两个需求矛盾:路由要维度少且差异大,内容要维度多且细粒度。分开后各得其所。
直觉类比: K像”餐厅招牌”(”正宗川菜”),吸引特定顾客走进来;V像”菜品”(水煮鱼、麻婆豆腐),顾客进来后实际吃到的东西。招牌优化目标:吸引爱吃川菜的顾客(路由准确性);菜品优化目标:让顾客吃完满意(内容质量)。
Q和V的关系:需求的”口味偏好” Q和V没有直接交互(Q不与V做点积),但通过注意力权重间接关联:
$$\text{output}_i = \sum_j \alpha(Q_i, K_j) \cdot V_j$$
Q决定从哪些位置取V(路由),但不直接约束取到什么V。然而,训练过程中Q和V会形成隐式配合:
Q学会只关注”能提供有用V的位置”——如果某个位置的V总是提供噪声,Q会学会降低对它的注意力
V学会提供”被Q选中后有价值的内容”——如果一个位置的V从未被任何Q选中,它的梯度几乎为零(参数更新停滞)
类比:招聘者(Q)会学会只面试有实力的候选人(V),而不会浪费时间面试能力差的候选人。
Q、K、V三者的动态博弈 把三者放在一起看,整个训练过程是一个三方协同博弈 :
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 训练初期: Q : 随机搜索,不知道找谁 → 注意力均匀分布 K : 随机标签,不知道被谁找 → 无法有效路由 V : 随机内容,不知道该提供什么 → 输出质量差 训练中期: Q →K : Q 学会向"有用的K" 靠近 → 注意力开始集中 K →Q : K 学会向"需要我的Q" 靠近 → 路由逐渐准确 K →V : K 锁定目标客户 → V 知道该服务谁 V →Q : V 提供有价值的内容 → Q 学会只关注有用位置 训练后期: Q 、K 、V 协同收敛: Q 精确知道找谁( K ) → K 精确知道被谁找( Q ) → V 精确知道该提供什么 形成稳定的信息流通路:Q ×K 建路,V 在路上传货
一句话总结:Q是”我要什么”,K是”我是谁”,V是”我有什么”。Q和K握手建通道,V在通道里传内容。三者梯度路径不同,自然分化,无需外力干预。
2.2 数学公式 缩放点积注意力(Scaled Dot-Product Attention):
$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right) V $$
$Q \in \mathbb{R}^{n \times d_k}$,$K \in \mathbb{R}^{m \times d_k}$,$V \in \mathbb{R}^{m \times d_v}$
$d_k$ 为每个注意力头中 Key/Query 的维度,除以 $\sqrt{d_k}$ 是为了防止点积值过大导致 softmax 梯度消失
$n$ 为查询序列长度,$m$ 为键/值序列长度
关键维度参数的含义:
参数
含义
原论文值
选择逻辑
$d_{\text{model}}$
所有层的向量维度
512
模型容量:越大表达力越强,但参数量和计算量也越大
$n_{\text{heads}}$
注意力头数
8
多视角:越多视角越丰富,但每个头表达力越弱
$d_k$
每个头的Key/Query维度
64
由 $d_{\text{model}}/n_{\text{heads}}$ 决定,不是独立选择
$d_{\text{model}}$(模型维度): Transformer中所有层的输入/输出的向量维度。它是整个模型的”通用货币”——Embedding输出512维,Encoder每层输入输出512维,Decoder每层输入输出512维。所有子层的输入输出都是同一个维度,只有FFN中间层临时扩展到2048维。为什么所有层都用同一个维度?因为残差连接要求维度一致 :x + SubLayer(x) 要求 x 和 SubLayer(x) shape完全相同才能相加。如果不同层维度不同,残差连接就无法实现。
$n_{\text{heads}}$(注意力头数): 多头注意力中并行计算的注意力子空间数量。每个头独立在自己的 $d_k$ 维子空间里计算注意力,8个头相当于模型同时用8个”视角”观察序列。
$d_k$ 是怎么决定的?
$d_k$ 不是独立选择的,而是由 $d_{\text{model}}$ 和 $n_{\text{heads}}$ 的硬约束决定:
$$n_{\text{heads}} \times d_k = d_{\text{model}} \implies d_k = \frac{d_{\text{model}}}{n_{\text{heads}}}$$
必须保证 $d_{\text{model}}$ 能被 $n_{\text{heads}}$ 整除,否则 $d_k$ 不是整数,无法均分。512维向量要均分成8个头,每个头拿64维。如果不能整除(如512分成7个头),就无法均分,有的头拿73维有的拿72维,实现困难且不对称。
整个多头注意力的流程决定了这个约束:
1 2 3 4 5 6 输入 x: 维 → W_Q投影: → (一次性投影所有头) → split成n_heads个头: → n_heads × ← 每个头拿d_k维 → 每个头独立计算注意力: 维的Q与维的K → 输出维 → concat所有头: n_heads × → → W_O投影: → ← 必须回到d_model!
最后一行要求:$n_{\text{heads}} \times d_v = d_{\text{model}}$。原论文中 $d_v = d_k$,所以 $n_{\text{heads}} \times d_k = d_{\text{model}}$。
$d_{\text{model}}$、$n_{\text{heads}}$、$d_k$ 三者互相约束:
$d_{\text{model}}$ 太小 → 模型表达力不足
$n_{\text{heads}}$ 太多 → $d_k$ 太小 → 每个头表达力不足
$n_{\text{heads}}$ 太少 → $d_k$ 太大 → 退化为单头,失去多视角优势
方案
头数
$d_k$
效果
单头注意力
1
512
一个大空间,所有模式混在一起
8头注意力
8
64
8个小空间,不同头学不同模式
16头注意力
16
32
太小,每个头表达力不足
原论文实验了 $n_{\text{heads}}=8, d_k=64$ 和 $n_{\text{heads}}=1, d_k=512$,发现8头效果更好,确定了这个组合。
计算步骤拆解:
$QK^T$:计算 Query 与所有 Key 的相似度(点积),得到 $n \times m$ 的分数矩阵
$\div \sqrt{d_k}$:缩放,防止分数过大
$\text{softmax}$:归一化为概率分布(每行和为1)
$\times V$:用注意力权重对 Value 加权求和
数值示例(d_k=2, 序列长度=3):
假设输入序列有3个token,经过投影后得到:
1 2 3 Q = [[1, 0], [0, 1], [1, 1]] # 3 个查询向量 K = [[1, 0], [0, 1], [1, 1]] # 3 个键向量 V = [[10, 0], [0, 10], [5, 5]] # 3 个值向量
步骤1:$QK^T$ 计算点积:
1 2 3 scores = Q @ K^T = [[1*1 +0*0 , 1*0 +0*1 , 1*1 +0*1 ], = [[1, 0, 1], [0*1 +1*0 , 0*0 +1*1 , 0*1 +1*1 ], [0, 1, 1], [1*1 +1*0 , 1*0 +1*1 , 1*1 +1*1 ]] [1, 1, 2]]
步骤2:除以 $\sqrt{d_k} = \sqrt{2} \approx 1.414$:
1 2 3 scaled = [[0.707, 0, 0.707], [0, 0.707, 0.707], [0.707, 0.707, 1.414]]
步骤3:softmax归一化(每行):
1 2 3 weights ≈ [[0.34 , 0.17 , 0.34 ], # 第1个token: 主要关注自己和第3个 [0.17, 0.34, 0.34 ], # 第2个token: 主要关注自己和第3个 [0.21, 0.21, 0.42 ]] # 第3个token: 最关注自己(点积最大)
步骤4:对V加权求和:
1 2 3 output [0 ] = 0 .34 *[10 ,0 ] + 0 .17 *[0 ,10 ] + 0 .34 *[5 ,5 ] ≈ [5.1, 2.5] output [1 ] = 0 .17 *[10 ,0 ] + 0 .34 *[0 ,10 ] + 0 .34 *[5 ,5 ] ≈ [2.5, 5.1] output [2 ] = 0 .21 *[10 ,0 ] + 0 .21 *[0 ,10 ] + 0 .42 *[5 ,5 ] ≈ [4.2, 4.2]
为什么除以 $\sqrt{d_k}$?——梯度消失的数学推导:
假设 $q$ 和 $k$ 的每个元素独立、均值为0、方差为1,则点积 $q \cdot k = \sum_{i=1}^{d_k} q_i k_i$ 的均值为0、方差为 $d_k$(因为独立变量乘积的方差之和)。
当 $d_k$ 很大时(如64),点积的方差约为64,值集中在 $\pm 8$ 附近。softmax 函数 $\sigma(x_i) = \frac{e^{x_i}}{\sum_j e^{x_j}}$ 在输入值很大时进入饱和区:
$$\frac{\partial \sigma(x_i)}{\partial x_i} = \sigma(x_i)(1 - \sigma(x_i))$$
当某个 $x_i \gg x_j$ 时,$\sigma(x_i) \approx 1$,梯度 $\approx 1 \times (1-1) = 0$ —— 梯度消失 !
除以 $\sqrt{d_k}$ 后,点积方差从 $d_k$ 降为1,值集中在 $\pm 1$ 附近,softmax梯度保持正常范围(最大值约0.25),训练稳定。
直觉: 假设4个考生考同一份试卷,原始分数如下:
考生
原始分数(未缩放)
softmax后的”录取概率”
缩放后分数(÷√d_k)
softmax后的”录取概率”
A
800
99.6%
100
33.3%
B
400
0.2%
50
16.7%
C
300
0.1%
37.5
11.7%
D
200
0.1%
25
38.3%
未缩放时:A碾压其他人(99.6%),B/C/D之间的差异完全被淹没(0.2% vs 0.1%几乎无区别)。softmax看到的是”A遥遥领先,其他人都差不多”。
缩放后:概率分布更均匀(33%→16%→11%→38%),每个考生之间的差异都能被感知。softmax能区分”B比C稍好”、”D也不错”等细微差异。
对训练的影响:未缩放时,模型无法通过调整注意力来强调B而非C(因为两者概率几乎相同,调整权重的梯度≈0);缩放后,模型能精确控制注意力分配(梯度≈0.17~0.24),训练有效。
2.3 PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 import torchimport torch.nn as nnimport torch.nn.functional as Fimport mathclass ScaledDotProductAttention (nn.Module): """缩放点积注意力机制 公式: Attention(Q, K, V) = softmax(Q @ K^T / sqrt(d_k)) @ V 核心作用: 让每个查询位置(Query)自主决定"应该关注哪些键位置(Key)", 然后从这些键位置对应的值(Value)中提取信息, 加权求和作为输出. 三种角色类比: - Query: "我想要什么信息?" (主动方, 发起查询) - Key: "我有什么信息?" (被动方, 提供索引标签) - Value: "我的具体内容" (被选中后, 提供实际数据) 匹配过程: Q和K的点积衡量"需求与标签的匹配程度", 匹配度越高, 该位置的V在输出中权重越大. Args: dropout: 注意力权重的dropout概率, 用于正则化防止过拟合 训练时随机置零部分注意力权重, 强制模型不过度依赖少数位置 推理时dropout自动关闭, 所有权重完整保留 """ def __init__ (self, dropout=0.1 ): super ().__init__() self.dropout = nn.Dropout(dropout) def forward (self, Q, K, V, mask=None ): """前向传播 数据流: Q,K,V → 点积 → 缩放 → mask → softmax → dropout → 加权求和 → 输出 Args: Q: 查询矩阵, shape=[batch, n_heads, seq_len_q, d_k] 每一行代表一个查询位置想寻找的信息模式 K: 键矩阵, shape=[batch, n_heads, seq_len_k, d_k] 每一行代表一个键位置提供的信息标签 seq_len_k可以与seq_len_q不同(如Cross-Attention时) V: 值矩阵, shape=[batch, n_heads, seq_len_k, d_v] 每一行代表一个键位置携带的实际内容 d_v可以与d_k不同(但原论文中d_v=d_k) mask: 掩码矩阵, shape=[batch, 1, seq_len_q, seq_len_k] 或可广播的形状 0表示屏蔽该位置(设为-1e9, softmax后≈0) 1表示保留该位置(正常参与注意力计算) 两种用途: 1. Decoder的Masked Self-Attention: 防止看到未来位置(下三角) 2. Encoder/Decoder的padding mask: 屏蔽<PAD>位置 Returns: output: 注意力输出, shape=[batch, n_heads, seq_len_q, d_v] 每个查询位置融合了所有(未被mask屏蔽的)值向量信息的加权平均 attn_weights: 注意力权重, shape=[batch, n_heads, seq_len_q, seq_len_k] 可用于可视化模型"在看哪里", 也可用于后续分析 """ d_k = Q.size(-1 ) scores = torch.matmul(Q, K.transpose(-2 , -1 )) scores = scores / math.sqrt(d_k) if mask is not None : scores = scores.masked_fill(mask == 0 , -1e9 ) attn_weights = F.softmax(scores, dim=-1 ) attn_weights = self.dropout(attn_weights) output = torch.matmul(attn_weights, V) return output, attn_weights
3. 多头注意力 (Multi-Head Attention) 3.1 核心思想 将 Q、K、V 分别投影到多个不同的子空间(头),每个头独立计算注意力,最后将所有头的输出拼接并投影回原始维度。
好处: 不同头可以关注序列中不同类型的关系模式(如语法关系、语义关系、位置关系等),让模型同时从多个角度理解序列。
为什么不直接用一个大注意力头?
一个大头($d_k = d_{\text{model}} = 512$)虽然能捕捉所有信息,但容易将不同模式的信息混杂在一起。8个小头($d_k = 64$)各自学习不同的注意力模式,最后通过 $W^O$ 组合,相当于给了模型8个”视角”去观察序列,每个视角独立地决定”看哪里、看什么”。
各头的典型关注模式(实际训练中观察到的现象):
头1-2:关注相邻位置(局部依赖,类似n-gram)
头3-4:关注句法关系(主语-谓语、修饰语-被修饰词)
头5-6:关注语义关系(同义词、反义词、上下位词)
头7-8:关注特定结构(标点符号、句首句尾)
3.2 数学公式 $$ \text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, …, \text{head}_h) W^O $$
$$ \text{head}_i = \text{Attention}(QW_i^Q, KW_i^K, VW_i^V) $$
其中 $h$ 为头数,$d_k = d_v = d_{\text{model}} / h$。
3.3 PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 class MultiHeadAttention (nn.Module): """多头注意力机制 将Q/K/V投影到h个不同的子空间, 每个子空间独立计算注意力, 最后拼接所有头的输出并做线性投影. 整体流程: 1. 线性投影: 将输入的d_model维向量分别投影为Q/K/V (各d_model维) 2. 分头: 把d_model拆成 n_heads × d_k, 每个头处理d_k维的子空间 3. 并行计算: 每个头独立做缩放点积注意力, 得到 d_v维输出 4. 拼接: 把所有头的输出拼接成 n_heads × d_v = d_model维 5. 输出投影: 通过W_O将拼接结果映射回d_model维 参数量分析: - W_Q, W_K, W_V: 各 d_model × d_model = 3 × 512² = 786,432 - W_O: d_model × d_model = 512² = 262,144 - 总计: 1,048,576 (与单头注意力 d_model×d_model×4 相同!) → 多头注意力的参数量并不比单头多, 只是把参数分配到了不同的子空间 Args: d_model: 模型的特征维度 (如512) n_heads: 注意力头数 (如8), 要求 d_model 能被 n_heads 整除 dropout: dropout概率 """ def __init__ (self, d_model, n_heads, dropout=0.1 ): super ().__init__() assert d_model % n_heads == 0 self.d_k = d_model // n_heads self.n_heads = n_heads self.W_Q = nn.Linear(d_model, d_model) self.W_K = nn.Linear(d_model, d_model) self.W_V = nn.Linear(d_model, d_model) self.W_O = nn.Linear(d_model, d_model) self.attention = ScaledDotProductAttention(dropout) self.dropout = nn.Dropout(dropout) def forward (self, Q, K, V, mask=None ): """前向传播 Args: Q: 查询输入, shape=[batch, seq_len_q, d_model] Self-Attention时: Q=K=V=上一层的输出x Cross-Attention时: Q=Decoder的输出, K=V=Encoder的输出 K: 键输入, shape=[batch, seq_len_k, d_model] Self-Attention时: 与Q相同 Cross-Attention时: Encoder的输出 V: 值输入, shape=[batch, seq_len_k, d_model] Self-Attention时: 与K相同 Cross-Attention时: 与K相同 mask: 掩码, shape=[batch, 1, seq_len_q, seq_len_k] 通过unsqueeze扩展到多头维度后应用于所有头 Returns: output: 多头注意力输出, shape=[batch, seq_len_q, d_model] 融合了所有头的注意力信息 attn_weights: 注意力权重, shape=[batch, n_heads, seq_len_q, seq_len_k] 注意: 这里返回的是经过分头reshape后的中间状态 包含所有头的权重, 可用于可视化每个头的注意力模式 """ batch_size = Q.size(0 ) Q = self.W_Q(Q).view(batch_size, -1 , self.n_heads, self.d_k).transpose(1 , 2 ) K = self.W_K(K).view(batch_size, -1 , self.n_heads, self.d_k).transpose(1 , 2 ) V = self.W_V(V).view(batch_size, -1 , self.n_heads, self.d_k).transpose(1 , 2 ) if mask is not None : mask = mask.unsqueeze(1 ) attn_output, attn_weights = self.attention(Q, K, V, mask) attn_output = attn_output.transpose(1 , 2 ).contiguous().view( batch_size, -1 , self.n_heads * self.d_k ) output = self.W_O(attn_output) return output, attn_weights
4. 位置编码 (Positional Encoding) 4.1 为什么需要位置编码? 自注意力机制本身是位置无关 的(置换等变),同一组 Q/K/V 无论如何排列,注意力输出只是换了个排列顺序,每个token得到的表示向量不变。因此需要注入位置信息。
详细证明:为什么”我爱你”和”你爱我”的自注意力输出一样?
“我爱你”的embedding矩阵(无位置编码):
1 X₁ = [e_我, e_爱, e_你] ← 位置0 =我, 位置1 =爱, 位置2 =你
“你爱我”的embedding矩阵:
1 X₂ = [e_你, e_爱, e_我] ← 位置0 =你, 位置1 =爱, 位置2 =我
自注意力的计算:output_i = Σ_j softmax(Q_i · K_j / √d_k) · V_j
对”我爱你”中”我”(位置0)的输出:
1 2 output_我(位置0 ) = Σ_j attn(Q(e_ 我) , K_j, V_j) 其中 j遍历 {e_我, e_爱, e_你} = α₀·V(e_ 我) + α₁·V(e_ 爱) + α₂·V(e_ 你)
对”你爱我”中”我”(位置2)的输出:
1 2 output_我(位置2 ) = Σ_j attn(Q(e_ 我) , K_j, V_j) 其中 j遍历 {e_你, e_爱, e_我} = α₂'·V(e_ 你) + α₁'·V(e_ 爱) + α₀'·V(e_ 我)
关键: softmax是对所有K做归一化,求和遍历的是同一个集合 {e_我, e_爱, e_你}(只是顺序不同)。集合的元素相同 → Q·K的点积值只是换了位置 → softmax后的权重也只是换了位置 → 加权求和的结果完全相同 。
即:
1 2 3 output_ 我(在"我爱你" 的位置0 ) = output_我(在"你爱我" 的位置2 ) ← 同一个向量!output_ 爱(在"我爱你" 的位置1 ) = output_爱(在"你爱我" 的位置1 ) ← 同一个向量!output_ 你(在"我爱你" 的位置2 ) = output_你(在"你爱我" 的位置0 ) ← 同一个向量!
三个token的输出向量完全一样 ,只是排列顺序不同:
1 2 "我爱你" : [output_我, output_爱, output_你]"你爱我" : [output_你, output_爱, output_我] ← 只是重排了一下
作为”含义”来看 :这两组输出的多集(multiset)完全相同 {output_我, output_爱, output_你},模型无法区分哪个是”我爱你”哪个是”你爱我”——它把两个句子当成了一回事。
这就是”置换等变性”(permutation equivariance):Permute(Input) → Permute(Output),每个token的表示只取决于”我是谁”和”序列中有哪些token”,不取决于”我在哪个位置”。所以模型失去了位置感知能力,把序列当成了无序的词袋(bag of words) 。
加位置编码后:e_我 → e_我 + PE(0),e_你 → e_你 + PE(2),同一token在不同位置获得不同的向量,打破等变性,模型就能区分顺序了。
4.2 正弦余弦位置编码 $$ PE(pos, 2i) = \sin\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
$$ PE(pos, 2i+1) = \cos\left(\frac{pos}{10000^{2i/d_{\text{model}}}}\right) $$
特点:
对于任意固定偏移 $k$,$PE(pos+k)$ 可以表示为 $PE(pos)$ 的线性函数,使得模型可以学习相对位置关系
偶数维度用 sin,奇数维度用 cos
低维度频率高(捕捉局部位置),高维度频率低(捕捉全局位置),类似二进制编码
为什么 sin/cos 编码能表达线性相对位置关系?
利用三角恒等式,可以证明位置 $pos+k$ 的编码是位置 $pos$ 编码的线性变换:
$$\sin(pos+k) = \sin(pos)\cos(k) + \cos(pos)\sin(k)$$
$$\cos(pos+k) = \cos(pos)\cos(k) - \sin(pos)\sin(k)$$
因此,$PE(pos+k)$ 可以通过一个仅依赖于 $k$ 的线性变换 $M_k$ 从 $PE(pos)$ 得到:
$$PE(pos+k) = M_k \cdot PE(pos)$$
其中 $M_k$ 是一个由 $\cos(k\cdot\omega_i)$ 和 $\sin(k\cdot\omega_i)$ 构成的旋转矩阵(块对角矩阵,每块2×2)。这意味着模型只需学习一组线性变换权重,就能推断任意相对位置的关系 —— 这比直接学习每个绝对位置的表示更加高效和泛化。
频率分布与”二进制编码”类比:
1 2 3 维度0 (ω=1 /10000 ^0 = 1 ): sin(pos)周期=2 π≈6.28 — 频率最高, 区分近邻位置 维度2 (ω=1 /10000^{2/ 512 }≈0.012 ): 周期≈523 — 中等频率 维度510 (ω=1 /10000^{510/ 512 }≈0.0001 ): 周期≈62832 — 频率最低, 区分远距离位置
类比二进制计数器:
1 2 3 4 5 位置 0 : 000 → PE的低频维度变化慢(几乎不变), 高频维度变化快(每个位置不同) 位置 1 : 001 → 最右位(bit0=高频)每次都翻转, 最左位(bit2=低频)很少翻转 位置 2 : 010 → sin/cos编码的频率递减模式与此类似 位置 3 : 011 位置 4 : 100
每个维度相当于一个”时钟指针”,低维度(短周期)的指针转得快,区分细微位置差异;高维度(长周期)的指针转得慢,区分宏观位置差异。多个”时钟”组合起来,就能唯一地编码任意位置。
4.3 PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 class PositionalEncoding (nn.Module): """正弦-余弦位置编码 为每个位置生成固定的位置向量, 加到输入embedding上, 让模型感知序列中每个token的位置信息. 核心思想: 用不同频率的sin/cos函数为每个位置生成独特的编码向量 - 低维度用高频sin/cos(周期短): 区分相邻位置的细微差异 - 高维度用低频sin/cos(周期长): 区分远距离位置的宏观差异 - 组合起来形成每个位置唯一的"指纹" 为什么用固定编码而不是可学习的位置向量? 1. 固定编码无需额外参数, 不增加训练负担 2. sin/cos的线性关系特性(见上文推导)让模型更容易学习相对位置 3. 固定编码可以处理任意长度的序列(只要不超过max_len) 可学习编码只能处理训练时见过的长度范围 4. 实际效果: 研究表明两者性能相近, 但固定编码更简洁 Args: d_model: 模型特征维度 (如512) max_len: 支持的最大序列长度 (如5000) 预计算max_len个位置的编码, 超过此长度会报错 dropout: dropout概率 """ def __init__ (self, d_model, max_len=5000 , dropout=0.1 ): super ().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0 , max_len, dtype=torch.float ).unsqueeze(1 ) div_term = torch.exp( torch.arange(0 , d_model, 2 ).float () * (-math.log(10000.0 ) / d_model) ) pe[:, 0 ::2 ] = torch.sin(position * div_term) pe[:, 1 ::2 ] = torch.cos(position * div_term) pe = pe.unsqueeze(0 ) self.register_buffer('pe' , pe) def forward (self, x ): """前向传播: 将位置编码加到输入上 Args: x: 输入embedding, shape=[batch, seq_len, d_model] 经过nn.Embedding后的连续向量表示 Returns: 加上位置编码后的输出, shape=[batch, seq_len, d_model] x + pe 后每个token向量同时包含了语义信息(embedding)和位置信息(PE) 数学含义: output[b, t, :] = embedding[b, t, :] + PE[t, :] 即: 同一位置的不同batch共享相同的位置编码 (位置编码是绝对位置, 不依赖batch内容) """ x = x + self.pe[:, :x.size(1 ), :] return self.dropout(x)
5. Encoder Block 5.1 结构 每个 Encoder Block 由两个子层组成:
1 输入 → Multi-Head Self-Attention → Add & Norm → Feed-Forward Network → Add & Norm → 输出
Add & Norm :残差连接 + Layer Normalization,即 $LayerNorm(x + SubLayer(x))$
残差连接:缓解深层网络梯度消失,让信息可以跨层传递
LayerNorm:对每个样本的特征维度做归一化,稳定训练
Feed-Forward Network :两层线性变换 + ReLU,即 $FFN(x) = \max(0, xW_1 + b_1)W_2 + b_2$
逐位置独立应用(不同位置共享参数,但各自独立计算)
中间维度 $d_{ff}$ 通常为 $4 \times d_{\text{model}}$,提供非线性变换能力
为什么要残差连接(Add)?
深层网络的核心问题:每经过一个子层,信息可能被”扭曲”或”稀释”。残差连接让原始输入 $x$ 直接跳过子层,与子层输出 $SubLayer(x)$ 相加:
$$output = x + SubLayer(x)$$
好处:
梯度直通 :反向传播时,梯度可以沿残差路径直接回传($∂output/∂x = 1 + ∂SubLayer(x)/∂x$),至少有1的梯度保底,不会梯度消失
信息保真 :即使子层学到了”无用变换”(SubLayer(x)≈0),输出≈x,信息不会丢失 —— 子层只需学习”增量修正”
训练稳定 :让深层网络(6层+)的训练变得可行,否则梯度在6次非线性变换后几乎消失
为什么用LayerNorm而不是BatchNorm?
特性
BatchNorm
LayerNorm
归一化维度
沿batch维度
沿feature维度
对batch大小敏感
是(小batch不稳定)
否(每个样本独立)
适用序列长度
固定长度
变长序列
隐藏状态依赖
依赖同batch其他样本
仅依赖当前样本
Transformer选择LayerNorm的原因:
序列长度可变 → BatchNorm沿batch归一化时, 不同位置的统计量混合, 不合理
推理时batch_size可能=1 → BatchNorm需要用训练时的running_mean/var, 不精确
自注意力输出混合了所有位置的信息 → 每个位置的特征分布不同, 应独立归一化
5.2 PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 class LayerNorm (nn.Module): """Layer Normalization 对输入的最后一个维度做归一化: output = gamma * (x - mean) / (std + eps) + beta 与BatchNorm不同, LayerNorm对每个样本独立归一化, 不依赖batch中其他样本, 适合变长序列. 归一化的作用: 1. 稳定训练: 防止深层网络中特征值的分布随训练逐渐偏移(内部协变量偏移) 2. 加速收敛: 让梯度保持在合理范围, 不因为特征值过大或过小而训练不稳定 3. 不依赖batch: 每个样本独立归一化, 推理时(batch_size=1)行为与训练一致 gamma和beta(可学习参数)的作用: - 归一化后所有特征均值0方差1, 丢失了原始分布信息 - gamma/beta通过"缩放+偏移"恢复表达能力: 模型可以学出任意均值和方差 - 初始化gamma=1, beta=0 → 初始时归一化后分布不变, 随训练逐步调整 Args: d_model: 归一化的特征维度 eps: 防止除零的小常数, 通常1e-6 当std≈0时(如所有特征值相同), (x-mean)/(std+eps)≈0而不是inf """ def __init__ (self, d_model, eps=1e-6 ): super ().__init__() self.gamma = nn.Parameter(torch.ones(d_model)) self.beta = nn.Parameter(torch.zeros(d_model)) self.eps = eps def forward (self, x ): """前向传播 Args: x: 输入, shape=[batch, seq_len, d_model] 可能是Self-Attention或FFN的输出(经过残差连接后) Returns: 归一化后的输出, shape不变 每个位置的特征维度d_model被独立归一化 计算步骤: 1. mean: 每个位置(每个样本的每个token)的d_model维均值 2. std: 每个位置(每个样本的每个token)的d_model维标准差 3. 归一化: (x-mean)/(std+eps) → 均值0, 标准差≈1 4. 仿射变换: gamma * 归一化结果 + beta → 可学习的均值和方差 """ mean = x.mean(-1 , keepdim=True ) std = x.std(-1 , keepdim=True ) return self.gamma * (x - mean) / (std + self.eps) + self.betaclass PositionwiseFeedForward (nn.Module): """位置前馈网络 (Position-wise Feed-Forward Network) 对序列中每个位置独立应用相同的两层全连接网络: FFN(x) = W2 * ReLU(W1 * x + b1) + b2 "位置独立"(Position-wise)的含义: - 不同token位置使用相同的W1,W2参数(共享权重) - 但每个位置独立计算, 不交换信息(不像Self-Attention那样跨位置) - 可以理解为: 用同一个1×1卷积核扫描序列的每个位置 为什么需要FFN? - Self-Attention: 收集全局信息(跨位置交互), 但本质是加权平均 → 线性操作 - FFN: 提供非线性变换能力, 对每个位置的信息做"深加工" - 两者互补: Attention负责"看哪里", FFN负责"看懂什么" 为什么中间维度d_ff=4*d_model? - 高维中间层(2048 vs 512)提供更大的"记忆容量" - 类似KV存储: 先把信息展开到高维空间(存储), 再压缩回低维(检索) - 实验表明: d_ff太小(如2*d_model)性能下降, d_ff太大(如8*d_model)收益递减 - 4倍是经验最佳值, 也是原论文的设置 Args: d_model: 输入/输出维度 (如512) d_ff: 中间隐藏层维度 (如2048), 通常为4倍d_model dropout: dropout概率, 在ReLU激活后应用 """ def __init__ (self, d_model, d_ff, dropout=0.1 ): super ().__init__() self.fc1 = nn.Linear(d_model, d_ff) self.fc2 = nn.Linear(d_ff, d_model) self.dropout = nn.Dropout(dropout) def forward (self, x ): """前向传播 Args: x: 输入, shape=[batch, seq_len, d_model] 每个位置是d_model维向量, 位置间独立处理 Returns: 输出, shape=[batch, seq_len, d_model] 每个位置经过非线性变换后的d_model维向量 内部形状变化: [batch, seq_len, d_model] → fc1 → [batch, seq_len, d_ff] → ReLU → [batch, seq_len, d_ff] (约一半变0) → dropout → [batch, seq_len, d_ff] (再随机置零) → fc2 → [batch, seq_len, d_model] """ return self.fc2(self.dropout(F.relu(self.fc1(x))))class EncoderLayer (nn.Module): """Encoder的一层 (一个Encoder Block) 结构: Self-Attention → Add&Norm → FFN → Add&Norm 数据流详解: x → MultiHeadAttention(Q=K=V=x, mask) → attn_output x + dropout(attn_output) → norm1 → x' [残差连接+归一化] x' → FFN(x') → ff_output x' + dropout(ff_output) → norm2 → output [残差连接+归一化] 注意: 这里用的是Post-LN(先子层再归一化): output = LN(x + SubLayer(x)) 另一种变体是Pre-LN(先归一化再子层): output = x + SubLayer(LN(x)) Pre-LN训练更稳定(梯度更平滑), 但Post-LN是原论文的做法 Args: d_model: 模型特征维度 n_heads: 注意力头数 d_ff: FFN中间层维度 dropout: dropout概率 """ def __init__ (self, d_model, n_heads, d_ff, dropout=0.1 ): super ().__init__() self.self_attn = MultiHeadAttention(d_model, n_heads, dropout) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) def forward (self, x, mask=None ): """前向传播 Args: x: 输入, shape=[batch, seq_len, d_model] 来自上一层EncoderLayer的输出, 或最底层来自Embedding+PE mask: 自注意力掩码, 用于屏蔽padding位置 shape=[batch, 1, 1, seq_len] 或可广播的形状 padding位置为0, 有效位置为1 Returns: 输出, shape=[batch, seq_len, d_model] 每个位置融合了全局上下文信息(经过Attention)和非线性变换(经过FFN) """ attn_output, _ = self.self_attn(x, x, x, mask) x = self.norm1(x + self.dropout1(attn_output)) ff_output = self.ffn(x) x = self.norm2(x + self.dropout2(ff_output)) return x
6. Decoder Block 6.1 结构 每个 Decoder Block 由三个子层组成:
1 2 3 输入 → Masked Multi-Head Self-Attention → Add & Norm → Multi-Head Cross-Attention (用Encoder输出) → Add & Norm → Feed-Forward Network → Add & Norm → 输出
关键区别:
Masked Self-Attention :用掩码防止看到未来位置(训练时确保自回归性质,即预测第t个位置只能看到1~t-1的位置)
Cross-Attention :Query 来自 Decoder,Key 和 Value 来自 Encoder 输出 —— 这是 Decoder 获取源序列信息的方式
三子层的详细作用:
Masked Self-Attention :让Decoder的每个位置了解自己已经生成了什么内容
类比:写作时回顾前面已经写的句子,确保续写内容与前面连贯
为什么必须Masked?如果不屏蔽未来位置,模型就能”偷看”后面的答案,训练时不需要费力预测 —— 这就破坏了自回归训练的意义
Cross-Attention :让Decoder的每个位置获取源序列中最相关的信息
类比:翻译时回头看原文,找到当前正在翻译的词对应的源语言词
Q来自Decoder(”我需要什么信息?”),K/V来自Encoder(”源序列有什么信息?”)
这是Seq2Seq的核心:将源语言信息传递给目标语言的生成过程
FFN :对融合了源信息和目标历史信息的每个位置做非线性变换
类比:消化吸收收集到的所有信息,做出最终的表达决策
6.2 掩码机制 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 def generate_mask (seq_len ): """生成下三角掩码矩阵, 用于Decoder的Masked Self-Attention 掩码为下三角矩阵, 确保位置i只能attend到位置0~i(包括自身), 不能attend到未来位置i+1~seq_len-1. 为什么需要这个掩码? - Decoder是自回归模型: 预测第t个token时, 只能使用第0~t-1个token的信息 - 如果不加掩码, 训练时第t个位置可以直接"看到"第t+1位置的ground truth - 这会让模型偷懒: 不需要预测, 直接复制未来位置的信息即可 - 掩码强制模型真正学习预测能力: 只基于历史信息推断未来 掩码在训练和推理中都起作用: - 训练: 输入是完整的目标序列(teacher forcing), 但掩码让每个位置只能看历史 - 推理: 逐步生成, 自然只能看已生成的部分(掩码是冗余的安全保障) Args: seq_len: 序列长度 Returns: mask: 掩码矩阵, shape=[1, 1, seq_len, seq_len] 1表示可attend, 0表示屏蔽 """ mask = torch.tril(torch.ones(seq_len, seq_len)) mask = mask.unsqueeze(0 ).unsqueeze(0 ) return mask
掩码示例(seq_len=4):
位置 1 只能看到位置 1,位置 2 可以看到位置 1 和 2,以此类推。
6.3 PyTorch 实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 class DecoderLayer (nn.Module): """Decoder的一层 (一个Decoder Block) 结构: 1. Masked Self-Attention → Add&Norm (Decoder自身的历史信息) 2. Cross-Attention → Add&Norm (从Encoder获取源序列信息) 3. FFN → Add&Norm (非线性变换) 与EncoderLayer的关键区别: - 多了一个Cross-Attention子层(子层2), 这是Encoder-Decoder架构的核心 - 子层1是Masked Self-Attention而非普通Self-Attention - 每个DecoderLayer都接收Encoder的最终输出(enc_output)作为参数 Args: d_model: 模型特征维度 n_heads: 注意力头数 d_ff: FFN中间层维度 dropout: dropout概率 """ def __init__ (self, d_model, n_heads, d_ff, dropout=0.1 ): super ().__init__() self.masked_self_attn = MultiHeadAttention(d_model, n_heads, dropout) self.cross_attn = MultiHeadAttention(d_model, n_heads, dropout) self.ffn = PositionwiseFeedForward(d_model, d_ff, dropout) self.norm1 = LayerNorm(d_model) self.norm2 = LayerNorm(d_model) self.norm3 = LayerNorm(d_model) self.dropout1 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout) def forward (self, x, enc_output, src_mask=None , tgt_mask=None ): """前向传播 Args: x: Decoder当前层的输入, shape=[batch, tgt_len, d_model] 来自上一层DecoderLayer的输出, 或最底层来自Embedding+PE enc_output: Encoder的最终输出, shape=[batch, src_len, d_model] 所有DecoderLayer共享同一个enc_output(Encoder只跑一次) 注意: enc_output不随Decoder层数变化, 每层看到的源信息相同 src_mask: 源序列掩码, shape=[batch, 1, 1, src_len] 或可广播 用于Cross-Attention屏蔽Encoder中的padding位置 防止Decoder关注源序列中无意义的PAD token tgt_mask: 目标序列掩码, shape=[batch, 1, tgt_len, tgt_len] 下三角矩阵, 用于Masked Self-Attention屏蔽未来位置 Returns: 输出, shape=[batch, tgt_len, d_model] 每个位置融合了: 历史目标信息(Self-Attn) + 源序列信息(Cross-Attn) + 非线性变换(FFN) """ attn_output, _ = self.masked_self_attn(x, x, x, tgt_mask) x = self.norm1(x + self.dropout1(attn_output)) cross_output, _ = self.cross_attn(x, enc_output, enc_output, src_mask) x = self.norm2(x + self.dropout2(cross_output)) ff_output = self.ffn(x) x = self.norm3(x + self.dropout3(ff_output)) return x
7.1 PyTorch 完整实现 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 class Encoder (nn.Module): """Transformer Encoder: N个EncoderLayer堆叠 + 最终LayerNorm 整体流程: Embedding+PE → EncoderLayer_1 → EncoderLayer_2 → ... → EncoderLayer_N → Final LN 每个EncoderLayer的输出shape不变([batch, src_len, d_model]), 但内容逐层丰富: 底层捕捉局部关系, 高层捕捉更抽象的语义关系. 为什么需要堆叠N层? - 1层: 只能捕捉一次全局交互, 关系建模能力有限 - 6层: 每层在前一层的基础上进一步细化, 低层→高层形成层次化表示 - 类比: 从像素→边缘→纹理→部件→对象, 逐层抽象 为什么最后有一个额外的LayerNorm? - Post-LN架构中, 每个子层后都有LN, 但残差连接可能让输出偏离归一化范围 - 最终LN确保Encoder输出的数值稳定性, 方便Decoder接收 - 这不是原论文的做法(原论文没有最终LN), 但是常见的稳定训练变体 Args: n_layers: Encoder层数 (原论文为6) d_model: 模型特征维度 n_heads: 注意力头数 d_ff: FFN中间层维度 dropout: dropout概率 """ def __init__ (self, n_layers, d_model, n_heads, d_ff, dropout=0.1 ): super ().__init__() self.layers = nn.ModuleList([ EncoderLayer(d_model, n_heads, d_ff, dropout) for _ in range (n_layers) ]) self.norm = LayerNorm(d_model) def forward (self, x, mask=None ): """前向传播 Args: x: Embedding+位置编码后的输入, shape=[batch, src_len, d_model] mask: 源序列掩码, 屏蔽padding位置 传递给每个EncoderLayer的Self-Attention Returns: Encoder最终输出, shape=[batch, src_len, d_model] 每个位置的向量包含整个源序列的全局上下文信息(经过N层逐步融合) """ for layer in self.layers: x = layer(x, mask) return self.norm(x)class Decoder (nn.Module): """Transformer Decoder: N个DecoderLayer堆叠 + 最终LayerNorm 整体流程: Embedding+PE → DecoderLayer_1(enc_out) → ... → DecoderLayer_N(enc_out) → Final LN 关键特点: - 每个DecoderLayer都接收同一个enc_output(Encoder只跑一次) - 但不同层的Cross-Attention可以关注不同的源位置模式 - 低层可能关注词级对应, 高层可能关注短语/句子级对应 Args: n_layers: Decoder层数 (原论文为6) d_model: 模型特征维度 n_heads: 注意力头数 d_ff: FFN中间层维度 dropout: dropout概率 """ def __init__ (self, n_layers, d_model, n_heads, d_ff, dropout=0.1 ): super ().__init__() self.layers = nn.ModuleList([ DecoderLayer(d_model, n_heads, d_ff, dropout) for _ in range (n_layers) ]) self.norm = LayerNorm(d_model) def forward (self, x, enc_output, src_mask=None , tgt_mask=None ): """前向传播 Args: x: 目标序列的Embedding+位置编码, shape=[batch, tgt_len, d_model] enc_output: Encoder的输出, shape=[batch, src_len, d_model] 注意: 所有DecoderLayer共享同一个enc_output Encoder只执行一次, 其输出被每层DecoderLayer的Cross-Attention复用 这意味着不同Decoder层看到相同的源信息, 但各自的Cross-Attention 可以学到不同的关注模式(不同层关注源序列的不同位置) src_mask: 源序列掩码 (用于Cross-Attention屏蔽Encoder中的padding) tgt_mask: 目标序列掩码 (用于Masked Self-Attention屏蔽未来位置) Returns: Decoder最终输出, shape=[batch, tgt_len, d_model] 每个位置的向量融合了: 历史目标信息 + 源序列信息 + N层逐步精炼 """ for layer in self.layers: x = layer(x, enc_output, src_mask, tgt_mask) return self.norm(x)class Transformer (nn.Module): """完整的Transformer模型 (Encoder-Decoder架构) 数据流: 1. 源序列 → Embedding → ×√d_model → +位置编码 → Encoder → enc_output 2. 目标序列 → Embedding → ×√d_model → +位置编码 → Decoder(enc_output) → dec_output 3. dec_output → 线性投影 → vocab_size维logits → softmax → 概率分布 关键设计决策: 1. Embedding缩放(×√d_model): 使embedding量级与PE匹配(详见forward中注释) 2. 权重共享: 论文中提议Embedding和proj共享权重, 此实现未采用(简化) 3. Xavier初始化: 深层网络需要合理的初始化, 防止梯度消失/爆炸 参数量估算(d_model=512, n_heads=8, n_layers=6, d_ff=2048): - Embedding: 2 × vocab_size × 512 - 每个EncoderLayer: 4×512²(注意力) + 2×512×2048(FFN) + 4×512(LN) ≈ 2.5M - 每个DecoderLayer: 4×512²×2(两种注意力) + 2×512×2048(FFN) + 6×512(LN) ≈ 3.1M - Proj: 512 × vocab_size - 总计(不含Embedding/Proj): 6×2.5M + 6×3.1M ≈ 33.6M Args: src_vocab_size: 源语言词表大小 tgt_vocab_size: 目标语言词表大小 d_model: 模型特征维度 (默认512, 原论文值) n_heads: 注意力头数 (默认8, 原论文值) n_layers: Encoder/Decoder层数 (默认6, 原论文值) d_ff: FFN中间层维度 (默认2048, 原论文值) dropout: dropout概率 (默认0.1) max_len: 最大序列长度 (默认5000) """ def __init__ (self, src_vocab_size, tgt_vocab_size, d_model=512 , n_heads=8 , n_layers=6 , d_ff=2048 , dropout=0.1 , max_len=5000 ): super ().__init__() self.d_model = d_model self.src_embed = nn.Embedding(src_vocab_size, d_model) self.tgt_embed = nn.Embedding(tgt_vocab_size, d_model) self.src_pe = PositionalEncoding(d_model, max_len, dropout) self.tgt_pe = PositionalEncoding(d_model, max_len, dropout) self.encoder = Encoder(n_layers, d_model, n_heads, d_ff, dropout) self.decoder = Decoder(n_layers, d_model, n_heads, d_ff, dropout) self.proj = nn.Linear(d_model, tgt_vocab_size) self._init_weights() def _init_weights (self ): """Xavier均匀初始化 对所有维度>1的参数(即权重矩阵, 不含偏置和LayerNorm参数) 使用Xavier均匀分布初始化, 使前向传播和反向传播的方差保持一致. Xavier初始化的数学原理: - 目标: 让每层的输出方差 ≈ 输入方差, 避免逐层放大或缩小 - 线性层 y = Wx, 输入方差Var[x], 输出方差Var[y] = n_in × Var[W] × Var[x] - 要求 Var[y] = Var[x], 则 Var[W] = 1/n_in - 同时考虑反向传播: Var[∂L/∂x] = n_out × Var[W] × Var[∂L/∂y] - 折中: Var[W] = 2/(n_in + n_out) - Xavier均匀分布: W ~ Uniform(-a, a), a = √(6/(n_in+n_out)) 为什么只初始化dim>1的参数? - dim=1的是偏置向量(b)和LayerNorm参数(gamma,beta) - 偏置: 初始化为0是标准做法, Xavier初始化偏置没有意义(只有1维) - gamma/beta: 已经在LayerNorm.__init__中初始化为1和0, 不应覆盖 """ for p in self.parameters(): if p.dim() > 1 : nn.init.xavier_uniform_(p) def forward (self, src, tgt, src_mask=None , tgt_mask=None ): """前向传播 完整数据流: src → src_embed × √d_model → src_pe → Encoder → enc_output tgt → tgt_embed × √d_model → tgt_pe → Decoder(enc_output) → dec_output → proj → logits Args: src: 源序列token ID, shape=[batch, src_len] 例如翻译任务中: 源语言句子"我 爱 北京"的token ID序列 tgt: 目标序列token ID (shifted right), shape=[batch, tgt_len] 训练时传入完整目标序列(去掉最后一个token) 例如: [BOS, I, love, Beijing] (不含EOS) "shifted right": 相比原始序列右移一位, 让Decoder从BOS开始预测 src_mask: 源序列掩码, shape=[batch, 1, 1, src_len] 屏蔽源序列中的padding位置, 防止Encoder/Decoder关注PAD tgt_mask: 目标序列掩码, shape=[batch, 1, tgt_len, tgt_len] 下三角矩阵, 防止Decoder看到未来位置 Returns: output: logits, shape=[batch, tgt_len, tgt_vocab_size] 每个位置对词表中每个词的原始分数 取softmax得到概率分布, 取argmax得到预测token ID """ src_embedded = self.src_embed(src) * math.sqrt(self.d_model) tgt_embedded = self.tgt_embed(tgt) * math.sqrt(self.d_model) src_embedded = self.src_pe(src_embedded) tgt_embedded = self.tgt_pe(tgt_embedded) enc_output = self.encoder(src_embedded, src_mask) dec_output = self.decoder(tgt_embedded, enc_output, src_mask, tgt_mask) output = self.proj(dec_output) return output
7.2 论文中的默认参数
参数
值
说明
$d_{\text{model}}$
512
模型特征维度
$n_{\text{heads}}$
8
注意力头数
$n_{\text{layers}}$
6 (Encoder) / 6 (Decoder)
堆叠层数
$d_{\text{ff}}$
2048
FFN中间层维度 (4×d_model)
$d_k = d_v$
64 (= 512 / 8)
每个头的Key/Value维度
Dropout
0.1
正则化概率
8. 训练示例 8.1 简单的 Seq2Seq 训练流程 Teacher Forcing 策略详解:
训练时使用 Teacher Forcing:每步的Decoder输入不是模型自己上一步的预测结果,而是真实的ground truth token。这是序列模型训练的标准做法。
为什么用 Teacher Forcing 而不用模型自己的预测?
训练稳定 :如果模型某一步预测错误,后续步骤的输入就会偏离 → 错误累积 → 梯度爆炸
收敛更快 :每步的输入都是”正确答案”,模型只需学习单步预测,不需要处理错误输入
**暴露偏差(Exposure Bias)**:训练时只见过正确输入,推理时却要面对自己的错误预测 → 训练和推理分布不一致。解决方案:Scheduled Sampling(逐步从模型预测中采样替代ground truth)
Shifted Right 的含义:
1 2 3 4 5 6 7 8 9 目标序列: [BOS, I, love, Beijing, EOS] ← 完整序列(含BOS和EOS) tgt_input: [BOS, I, love, Beijing] ← 去掉最后一个(EOS), 作为Decoder输入tgt_output: [I, love, Beijing, EOS] ← 去掉第一个(BOS), 作为训练标签 对应关系: Decoder看到 BOS → 应该预测 I Decoder看到 [BOS,I] → 应该预测 love Decoder看到 [BOS,I,love] → 应该预测 Beijing Decoder看到 [BOS,I,love,Beijing] → 应该预测 EOS
“Shifted Right”就是将目标序列右移一位(去掉EOS),让Decoder从BOS开始逐步预测。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 def train_step (model, src, tgt, criterion, optimizer, device ): """单步训练函数 训练一个batch的完整流程: 1. 数据准备: shifted right + 生成mask 2. 前向传播: src→Encoder, tgt_input→Decoder→logits 3. 计算损失: logits vs tgt_output (CrossEntropy) 4. 反向传播: loss.backward() 5. 参数更新: optimizer.step() Args: model: Transformer模型 src: 源序列batch, shape=[batch, src_len], token ID 例如: [2, 15, 47, 8, 0, 0] ← "我 爱 北京 PAD PAD" tgt: 目标序列batch, shape=[batch, tgt_len], token ID 包含<BOS>开头和<EOS>结尾, 如: [BOS, w1, w2, ..., wn, EOS] 注意: tgt包含完整序列(BOS到EOS), train_step内部会做shifted right处理 criterion: 损失函数 (CrossEntropyLoss) optimizer: 优化器 device: 计算设备 (cpu/cuda) Returns: loss.item(): 当前步的损失值(浮点数) """ model.train() src = src.to(device) tgt = tgt.to(device) tgt_input = tgt[:, :-1 ] tgt_output = tgt[:, 1 :] seq_len = tgt_input.size(1 ) tgt_mask = generate_mask(seq_len).to(device) output = model(src, tgt_input, src_mask=None , tgt_mask=tgt_mask) loss = criterion(output.view(-1 , output.size(-1 )), tgt_output.view(-1 )) optimizer.zero_grad() loss.backward() optimizer.step() return loss.item() model = Transformer(src_vocab_size=1000 , tgt_vocab_size=1000 ).to(device) criterion = nn.CrossEntropyLoss(ignore_index=0 ) optimizer = torch.optim.Adam(model.parameters(), lr=1e-4 , betas=(0.9 , 0.98 ), eps=1e-9 )for epoch in range (num_epochs): for batch_src, batch_tgt in train_dataloader: loss = train_step(model, batch_src, batch_tgt, criterion, optimizer, device) print (f"Epoch {epoch} : Loss = {loss:.4 f} " )
8.2 学习率调度:Warmup + Decay 原论文使用了一种特殊的学习率调度策略:先线性增长(warmup),再按步数平方根衰减。
$$lr = d_{\text{model}}^{-0.5} \cdot \min(step^{-0.5}, step \cdot warmup_steps^{-1.5})$$
为什么需要Warmup?
训练初期不稳定 :模型参数随机初始化,梯度方向混乱,大学习率会导致训练发散
Warmup阶段 :学习率从0线性增长到峰值,给模型”适应期”,逐步建立稳定的梯度方向
Decay阶段 :学习率逐步衰减,精细调整参数,防止后期震荡
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 class WarmupDecayScheduler : """学习率调度器: Warmup(线性增长) + Decay(平方根衰减) 公式: lr = d_model^{-0.5} * min(step^{-0.5}, step * warmup_steps^{-1.5}) 效果: - step < warmup_steps: lr = d_model^{-0.5} * step * warmup_steps^{-1.5} (线性增长) - step >= warmup_steps: lr = d_model^{-0.5} * step^{-0.5} (平方根衰减) 例如 d_model=512, warmup_steps=4000: - step=0: lr=0 (开始时学习率为0) - step=2000: lr≈5e-4 (warmup中途, 线性增长) - step=4000: lr≈1e-3 (峰值, warmup结束时) - step=8000: lr≈7e-4 (衰减阶段) - step=80000: lr≈2.5e-4 (继续衰减) Args: d_model: 模型维度, 用于学习率缩放 warmup_steps: warmup步数, 学习率线性增长的阶段长度 """ def __init__ (self, d_model, warmup_steps=4000 ): self.d_model = d_model self.warmup_steps = warmup_steps def get_lr (self, step ): """根据当前步数计算学习率 Args: step: 当前训练步数(全局, 不是epoch内) Returns: 当前学习率 """ factor = self.d_model ** (-0.5 ) step_factor = min (step ** (-0.5 ), step * (self.warmup_steps ** (-1.5 ))) return factor * step_factor scheduler = WarmupDecayScheduler(d_model=512 , warmup_steps=4000 )for step in range (total_steps): lr = scheduler.get_lr(step) for param_group in optimizer.param_groups: param_group['lr' ] = lr
8.3 推理(贪心解码) 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 def greedy_decode (model, src, max_len, start_symbol, eos_symbol, device ): """贪心解码: 逐步生成目标序列, 每步选概率最大的token 与训练不同, 推理时没有目标序列作为输入, 需要模型自己逐步生成: 每生成一个token, 就追加到已生成序列中, 作为下一步的输入. 贪心解码的优缺点: - 优点: 简单、快速、确定性强(同一输入永远得到同一输出) - 缺点: 每步只选局部最优(概率最大的), 不保证全局最优 - 例如: 第1步选了概率0.3的A(而非0.25的B), 但B可能引出更好的后续序列 - 改进: Beam Search(见8.5节)同时维护多个候选序列, 更可能找到全局最优 Args: model: 训练好的Transformer模型 src: 源序列, shape=[1, src_len] (单条样本) 注意: 贪心解码通常逐条处理, batch_size=1 如果要批量推理, 需要更复杂的实现(不同样本长度不同) max_len: 最大生成长度, 防止无限循环 如果模型始终不生成EOS, 最多生成max_len个token后强制停止 start_symbol: 起始符<BOS>的token ID eos_symbol: 结束符<EOS>的token ID, 遇到则停止生成 device: 计算设备 Returns: ys: 生成的目标序列, shape=[1, gen_len], 包含BOS和EOS """ model.eval () src = src.to(device) with torch.no_grad(): src_embedded = model.src_embed(src) * math.sqrt(model.d_model) src_embedded = model.src_pe(src_embedded) enc_output = model.encoder(src_embedded) ys = torch.ones(1 , 1 ).fill_(start_symbol).long().to(device) for i in range (max_len - 1 ): tgt_mask = generate_mask(ys.size(1 )).to(device) tgt_embedded = model.tgt_embed(ys) * math.sqrt(model.d_model) tgt_embedded = model.tgt_pe(tgt_embedded) out = model.decoder(tgt_embedded, enc_output, tgt_mask=tgt_mask) prob = model.proj(out[:, -1 ]) _, next_word = torch.max (prob, dim=1 ) next_word = next_word.item() ys = torch.cat([ ys, torch.ones(1 , 1 ).fill_(next_word).long().to(device) ], dim=1 ) if next_word == eos_symbol: break return ys
8.3.1 自回归生成:为什么下一个token基于已生成的所有token? 自回归(Autoregressive) 是Decoder推理的核心特性:每一步只生成一个token,且这个token的概率依赖前面所有已生成的token。
$$P(\text{整个序列}) = P(t_1) \times P(t_2|t_1) \times P(t_3|t_1,t_2) \times \cdots \times P(t_n|t_1,\ldots,t_{n-1})$$
即:序列的联合概率 = 各步条件概率的乘积,每步的条件概率只依赖历史,不依赖未来。
为什么必须自回归?
Decoder的Masked Self-Attention用下三角掩码确保:位置 $t$ 只能看到位置 $0 \sim t-1$ 的信息,看不到 $t+1$ 及之后的”未来”token。这是训练时就建立的约束:
1 2 3 4 训练时的目标序列: [BOS, I, love, Beijing, EOS] ↓ ↓ ↓ ↓ ↓ Decoder每步看到: BOS BOS,I BOS,I,love ... 只看历史,不看未来 预测目标: I love Beijing EOS 基于历史预测下一步
这个训练约束延续到推理:模型只学会了”基于历史预测下一个token”的能力,推理时自然也必须逐步生成——每步把新token追加到历史中,作为下一步的输入。
非自回归的替代方案(简述):
非自回归模型(NAT)试图一步直接输出所有token,但面临两个难题:
输出长度未知 :推理前不知道要生成多少个token(需要额外的长度预测器)
多模态问题 :不同位置的token可能互相矛盾(第2步选了”love”,第3步却选了”hate”,因为两步独立决策) Transformer原论文选择了自回归方案,保证生成质量。
8.3.2 KV Cache:避免重复计算 朴素推理的浪费:
看贪心解码代码,每一步都对整个已生成序列 重新跑一遍Decoder:
第3步时,”BOS”和”I”的注意力计算在第1步、第2步已经做过了一遍,第3步又重新计算——这是巨大的浪费。
KV Cache的原理:
注意力计算 $\text{output}_i = \sum_j \alpha(Q_i, K_j) \cdot V_j$ 中,新token的Q只需要与所有历史位置的K和V 做交互。历史位置的K/V不随新token的加入而改变(因为Masked Self-Attention下,位置j只看位置0~j-1,新token在j之后,不影响j的计算结果)。
因此,可以缓存每一步计算的K和V ,后续步骤只需:
计算新token的Q、K、V
把新token的K、V追加到缓存中
新token的Q与缓存中所有K做注意力计算,得到新token的输出
不用KV Cache vs 用KV Cache的计算量对比:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 生成n个token: 不用KV Cache: 第1 步: 1 个token的注意力 → O(1² × d ) 第2 步: 2 个token的注意力 → O(2² × d ) 第3 步: 3 个token的注意力 → O(3² × d ) ... 第n步: n个token的注意力 → O(n ² × d ) 总计: O(n ³ × d ) ← 每步重算所有历史, 计算量爆炸式增长 用KV Cache: 第1 步: 计算1 个token的Q/K/V → O(d ²) , 1 个Q与1 个K → O(1 × d ) 第2 步: 计算1 个新token的Q/K/V → O(d ²) , 1 个Q与2 个K → O(2 × d ) 第3 步: 计算1 个新token的Q/K/V → O(d ²) , 1 个Q与3 个K → O(3 × d ) ... 第n步: 计算1 个新token的Q/K/V → O(d ²) , 1 个Q与n个K → O(n × d ) 总计: O(n ² × d ) ← 只计算新token与历史的交互, 线性增长!
KV Cache对注意力矩阵的影响:
1 2 3 4 5 6 7 8 9 10 11 12 不用KV Cache (每步重新计算完整注意力矩阵): 第3 步的注意力矩阵: [[α(BOS→BOS), α(BOS→I), α(BOS→love)] ← BOS行的第2、3列重新计算了 [α(I→BOS), α(I→I), α(I→love)] ← I行的第3 列重新计算了 [α(love→BOS), α(love→I), α(love→love)]] ← 新行, 必须计算 但前两行的前两列在第1 、2 步已经算过了, 重复计算! 用KV Cache (只计算新行): 第3 步只计算: [α(love→BOS), α(love→I), α(love→love)] 因为Masked Self-Attention保证: 位置0 (BOS)和位置1(I)的输出不受位置2(love)的影响 所以不需要重算前两行
为什么只缓存K和V,不缓存Q?
因为Q只在”当前步骤”被使用,历史Q永远不会再被需要。每一步只需要计算当前新token的输出 。要算位置t的输出,只需要:
$Q_t$:当前新token的查询向量 —— “我想找谁”
所有历史 $K_{0 \sim t-1}$ 和 $V_{0 \sim t-1}$:历史token提供的”标签”和”内容”
历史位置的Q($Q_0, Q_1, …, Q_{t-1}$)完全不需要 ,因为:
1 2 3 第1步: 算output_0 — 需要Q_0与K_0,V_0交互 → 输出output_0,Q_0任务完成,不再需要 第2步: 算output_1 — 需要Q_1与K_0,K_1,V_0,V_1交互 → Q_0完全没用(output_0已经算过了) 第3步: 算output_2 — 需要Q_2与K_0,K_1,K_2,V_0,V_1,V_2交互 → Q_0,Q_1都没用
每一步只算一个位置 的输出,只需要一个Q 。而K和V是”被查阅的资源”——未来的每一步都需要它们,所以必须缓存。
类比:
Q像”顾客”:每个顾客只来一次,点完菜就走,不需要记住这个顾客
K像”菜单”:每个餐厅的菜单要长期保留,因为未来所有新顾客都要看
V像”菜品”:菜品要长期供应,因为未来所有新顾客都要吃
所以只缓存”菜单”(K)和”菜品”(V),不缓存”顾客”(Q)。
KV Cache的内存代价:
KV Cache虽然节省了计算,但需要存储所有历史token的K和V向量:
1 2 3 4 5 6 7 8 9 10 每个token的KV缓存大小 = 2 × n_layers × n_heads × d_k × sizeof(float16) = 2 × 6 × 8 × 64 × 2 bytes (float16) = 12,288 bytes ≈ 12KB per token 生成1000个token: 12KB × 1000 = 12MB (单条样本) 批量推理batch_size=32: 12MB × 32 = 384MB 大模型(Llama-2 70B, d_model=8192, n_layers=80, n_heads=64): 每token KV缓存 ≈ 2 × 80 × 64 × 128 × 2 = 2.5MB 生成2048个token: 2.5MB × 2048 ≈ 5GB (单条样本!)
这就是为什么大模型推理时,内存瓶颈往往不是模型权重,而是KV Cache 。优化KV Cache内存是当前大模型推理加速的核心研究方向(如MQA/GQA多头共享、PagedAttention等)。
PyTorch实现KV Cache的简化示例:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 def greedy_decode_with_kv_cache (model, src, max_len, start_symbol, eos_symbol, device ): """带KV Cache的贪心解码 与朴素实现的区别: - 朴素: 每步把整个ys序列送入Decoder, 重新计算所有token的Q/K/V和注意力 - KV Cache: 只把新token送入Decoder, K/V从缓存中读取 """ model.eval () src = src.to(device) with torch.no_grad(): src_embedded = model.src_embed(src) * math.sqrt(model.d_model) src_embedded = model.src_pe(src_embedded) enc_output = model.encoder(src_embedded) self_attn_k_cache = [] self_attn_v_cache = [] cross_attn_k_cache = [] cross_attn_v_cache = [] for layer in model.decoder.layers: cross_k = layer.cross_attn.W_K(enc_output) cross_v = layer.cross_attn.W_V(enc_output) cross_k = cross_k.view(1 , enc_output.size(1 ), model.decoder.layers[0 ].cross_attn.n_heads, model.decoder.layers[0 ].cross_attn.d_k).transpose(1 , 2 ) cross_v = cross_v.view(1 , enc_output.size(1 ), model.decoder.layers[0 ].cross_attn.n_heads, model.decoder.layers[0 ].cross_attn.d_k).transpose(1 , 2 ) cross_attn_k_cache.append(cross_k) cross_attn_v_cache.append(cross_v) ys = torch.ones(1 , 1 ).fill_(start_symbol).long().to(device) for step in range (max_len - 1 ): new_token = ys[:, -1 :] new_embedded = model.tgt_embed(new_token) * math.sqrt(model.d_model) new_embedded = model.tgt_pe(new_embedded) x = new_embedded for layer_idx, layer in enumerate (model.decoder.layers): new_q = layer.masked_self_attn.W_Q(x).view(1 , 1 , layer.masked_self_attn.n_heads, layer.masked_self_attn.d_k).transpose(1 , 2 ) new_k = layer.masked_self_attn.W_K(x).view(1 , 1 , layer.masked_self_attn.n_heads, layer.masked_self_attn.d_k).transpose(1 , 2 ) new_v = layer.masked_self_attn.W_V(x).view(1 , 1 , layer.masked_self_attn.n_heads, layer.masked_self_attn.d_k).transpose(1 , 2 ) if step == 0 : self_attn_k_cache.append(new_k) self_attn_v_cache.append(new_v) else : self_attn_k_cache[layer_idx] = torch.cat([self_attn_k_cache[layer_idx], new_k], dim=2 ) self_attn_v_cache[layer_idx] = torch.cat([self_attn_v_cache[layer_idx], new_v], dim=2 ) scores = torch.matmul(new_q, self_attn_k_cache[layer_idx].transpose(-2 , -1 )) scores = scores / math.sqrt(layer.masked_self_attn.d_k) attn_weights = F.softmax(scores, dim=-1 ) attn_weights = layer.dropout1(attn_weights) attn_output = torch.matmul(attn_weights, self_attn_v_cache[layer_idx]) attn_output = attn_output.transpose(1 , 2 ).contiguous().view(1 , 1 , -1 ) attn_output = layer.masked_self_attn.W_O(attn_output) x = layer.norm1(x + attn_output) cross_q = layer.cross_attn.W_Q(x).view(1 , 1 , layer.cross_attn.n_heads, layer.cross_attn.d_k).transpose(1 , 2 ) scores = torch.matmul(cross_q, cross_attn_k_cache[layer_idx].transpose(-2 , -1 )) scores = scores / math.sqrt(layer.cross_attn.d_k) attn_weights = F.softmax(scores, dim=-1 ) cross_output = torch.matmul(attn_weights, cross_attn_v_cache[layer_idx]) cross_output = cross_output.transpose(1 , 2 ).contiguous().view(1 , 1 , -1 ) cross_output = layer.cross_attn.W_O(cross_output) x = layer.norm2(x + cross_output) ff_output = layer.ffn(x) x = layer.norm3(x + ff_output) prob = model.proj(x) _, next_word = torch.max (prob[0 , 0 , :], dim=0 ) next_word = next_word.item() ys = torch.cat([ys, torch.ones(1 , 1 ).fill_(next_word).long().to(device)], dim=1 ) if next_word == eos_symbol: break return ys
8.5 Beam Search 解码 Beam Search 是贪心解码的改进版:同时维护 k 个最优候选序列,每步从所有候选的扩展中选概率最大的 k 个继续。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 def beam_search_decode (model, src, max_len, start_symbol, eos_symbol, beam_size, device ): """Beam Search解码: 维护beam_size个最优候选序列 与贪心解码的区别: - 贪心: 每步只保留1个最优token → 只有一条候选路径 - Beam: 每步保留beam_size个最优候选 → 同时探索多条路径 - 效果: Beam Search通常比贪心解码产出更高质量的翻译 工作原理(beam_size=3为例): 第1步: 从BOS出发, 找概率最大的3个token → 3条候选序列 第2步: 每条候选扩展1个token → 3×vocab_size个扩展 → 选概率最大的3个 第3步: 同上, 逐步扩展 最终: 选3条候选中概率最大的那条作为输出 概率计算: - 每条候选序列的概率 = 所有token概率的对数求和 - log P(BOS, w1, w2) = log P(w1|BOS) + log P(w2|BOS,w1) - 用对数避免概率乘积的数值下溢(多个小概率相乘→接近0→浮点精度丢失) Args: model: 训练好的Transformer模型 src: 源序列, shape=[1, src_len] max_len: 最大生成长度 start_symbol: <BOS>的token ID eos_symbol: <EOS>的token ID beam_size: beam宽度, 同时维护的候选序列数(常用3-5) device: 计算设备 Returns: best_sequence: 概率最大的生成序列, shape=[1, gen_len] """ model.eval () src = src.to(device) with torch.no_grad(): src_embedded = model.src_embed(src) * math.sqrt(model.d_model) src_embedded = model.src_pe(src_embedded) enc_output = model.encoder(src_embedded) beams = [(0 , torch.ones(1 , 1 ).fill_(start_symbol).long().to(device))] completed = [] for step in range (max_len - 1 ): all_candidates = [] for log_prob, seq in beams: tgt_mask = generate_mask(seq.size(1 )).to(device) tgt_embedded = model.tgt_embed(seq) * math.sqrt(model.d_model) tgt_embedded = model.tgt_pe(tgt_embedded) out = model.decoder(tgt_embedded, enc_output, tgt_mask=tgt_mask) logits = model.proj(out[:, -1 ]) log_probs_next = F.log_softmax(logits, dim=-1 ) topk_log_probs, topk_indices = log_probs_next.topk(beam_size, dim=-1 ) for k in range (beam_size): new_log_prob = log_prob + topk_log_probs[0 , k].item() new_token = topk_indices[0 , k].item() new_seq = torch.cat([ seq, torch.ones(1 , 1 ).fill_(new_token).long().to(device) ], dim=1 ) if new_token == eos_symbol: length_penalty = ((5 + new_seq.size(1 )) / 5 ) ** 0.6 completed.append((new_log_prob / length_penalty, new_seq)) else : all_candidates.append((new_log_prob, new_seq)) if not all_candidates: break all_candidates.sort(key=lambda x: x[0 ], reverse=True ) beams = all_candidates[:beam_size] if not completed: completed = beams completed.sort(key=lambda x: x[0 ], reverse=True ) best_score, best_sequence = completed[0 ] return best_sequence
9. 关键知识点总结
问题
解决方案
原理
RNN 无法并行
自注意力机制并行计算所有位置
所有位置的Q/K/V同时计算, matmul一次完成
RNN 长距离依赖衰减
自注意力直接建模任意距离的关系
任意两位置之间只需1步attention, 不像RNN需n步
无位置信息
正弦余弦位置编码注入位置
PE(pos+k)可由PE(pos)线性变换得到
深层网络梯度问题
残差连接 + LayerNorm
梯度至少有1保底(残差), LN稳定分布
过拟合
Dropout + 多头注意力分散计算
随机置零防止过度依赖, 多头分散注意力
9.2 Attention 计算过程图解 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 输入: X [batch, seq_len, d_model] │ ├──→ W_Q → Q [batch, seq_len, d_model] → split → [batch, n_heads, seq_len, d_k] ├──→ W_K → K [batch, seq_len, d_model] → split → [batch, n_heads, seq_len, d_k] └──→ W_V → V [batch, seq_len, d_model] → split → [batch, n_heads, seq_len, d_k] │ ↓ 计算注意力 (每个头独立) Q @ K^T = scores [batch, n_heads, seq_len, seq_len] ← 注意力分数矩阵 scores / √d_k = scaled_scores ← 缩放防止梯度消失 + mask → masked_scores ← 屏蔽不可见位置 softmax → attn_weights [batch, n_heads, seq_len, seq_len] ← 概率分布(每行和=1 ) attn_weights @ V = head_output [batch, n_heads, seq_len, d_k] ← 加权求和 │ ↓ 合并多头 concat → [batch, seq_len, n_heads × d_k] = [batch, seq_len, d_model] W_O → output [batch, seq_len, d_model] ← 最终线性投影
模型
年份
架构
特点
典型应用
Transformer
2017
Encoder-Decoder
原始Seq2Seq
机器翻译
BERT
2018
Encoder-only
双向理解, MLM预训练
文本分类、NER
GPT-1/2/3
2018-2020
Decoder-only
单向生成, LM预训练
文本生成
T5
2019
Encoder-Decoder
文本到文本统一框架
多任务NLP
ViT
2020
Encoder-only
图像切patch当token
图像分类
GPT-4
2023
Decoder-only
多模态大模型
通用AI助手
Encoder-only vs Decoder-only vs Encoder-Decoder:
架构
代表模型
适用任务
核心特点
Encoder-only
BERT
理解类(分类、抽取)
双向注意力, 看到完整输入
Decoder-only
GPT
生成类(对话、续写)
单向注意力, 自回归生成
Encoder-Decoder
T5
翻译、摘要
先理解(Encoder)再生成(Decoder)
9.4 计算复杂度分析
操作
复杂度
说明
何时成为瓶颈
Self-Attention
$O(n^2 \cdot d)$
n为序列长度, d为维度
n很大时(长文本)
FFN
$O(n \cdot d^2)$
逐位置独立, 与序列长度线性关系
d很大时(大模型)
RNN单步
$O(n \cdot d^2)$
序列长时优于Attention, 但无法并行
总是(无法并行)
Conv(局部k)
$O(n \cdot k \cdot d^2)$
k为卷积核宽度, 有限范围交互
需要多层堆叠捕捉远距离
自注意力的 $O(n^2)$ 瓶颈:
当序列长度n很大时(如n=8192), 注意力分数矩阵的size = n² = 67M, 这是Transformer处理长文本的主要瓶颈。解决方案:
方法
复杂度
思路
Sparse Attention
$O(n \cdot \sqrt{n})$
只计算部分位置的注意力(局部+全局)
Linformer
$O(n)$
用低秩投影压缩K/V的序列维度
Flash Attention
$O(n^2)$ 但更快
利用GPU内存层次优化, 减少HBM读写
KV Cache
推理优化
缓存已计算的K/V, 每步只算新token
局限
描述
改进方向
长序列瓶颈
$O(n^2)$注意力计算
Sparse/Linear Attention
位置编码局限
固定PE不适应动态长度
RoPE(旋转位置编码)、ALiBi
单向信息流
Decoder只能看历史
BERT的双向理解
无局部性
注意力无空间/位置约束
加局部窗口约束
计算量大
大模型需要大量GPU
量化、蒸馏、剪枝
参考资料