Transformer 注意力机制入门

用直觉解释、数值例子和最小代码示例,真正理解 Q/K/V、缩放点积注意力与 Multi-Head 的工作方式。

难度

入门

阅读时长

约 90 分钟

更新日期

2026/03/23

主题

Transformer / Attention / 基础原理

先修知识

线性代数基础向量点积矩阵乘法

学习目标

读完这篇教程,你应该能回答四个问题:

  1. 为什么 Transformer 不再依赖 RNN,也能处理序列。
  2. QKV 到底分别在表示什么。
  3. 为什么注意力公式里要除以 sqrt(d_k)
  4. 为什么一个头不够,Multi-Head 到底带来了什么。

如果你是第一次接触这个主题,建议一边看文中的数值例子,一边打开站内的 Attention Sandbox 做参数实验,会更容易形成直觉。

为什么会需要注意力机制

在传统 RNN 里,信息需要沿着时间步一层层传递。句子越长,早期信息越难稳定地传到后面。Transformer 的想法更直接:

不要让信息只能沿着时间顺序慢慢走,而是让当前 token 直接去“看”整个上下文里谁和自己最相关。

这就是注意力机制的核心直觉。可以把它想成一个会议场景:

  • Query 表示“我现在想找什么信息”。
  • Key 表示“每个 token 手里举着什么标签,告诉别人自己适合被谁关注”。
  • Value 表示“如果你关注我,你真正会取走的信息内容”。

于是,注意力就是“拿着我的查询需求,去上下文里给每个人打分,再按分数把信息汇总回来”。

从一句话理解 Q、K、V

如果只看公式,很多人会把 Q/K/V 背成三个抽象字母。更有效的方法是把它们理解成三种不同视角下的表示:

  • Q:当前 token 正在提什么问题。
  • K:每个 token 能不能回答这个问题。
  • V:一旦被关注,这个 token 应该贡献什么信息。

同一个输入向量为什么要投影出三份表示?因为“被谁匹配上”和“输出什么内容”不是一回事。比如代词“它”在查找先行词时,关注规则和最终要取回的信息就不完全相同。

单头注意力的最小数值例子

先看一个简化到二维向量的例子。假设当前 query 是:

q = [1, 0]

上下文里有三个 key:

  • k1 = [1, 0]
  • k2 = [0, 1]
  • k3 = [1, 1]

对应的 value 分别是:

  • v1 = [10, 0]
  • v2 = [0, 10]
  • v3 = [6, 6]

第一步:算相似度分数

用点积计算 q 和每个 k 的相关性:

  • q · k1 = 1
  • q · k2 = 0
  • q · k3 = 1

所以原始分数是 [1, 0, 1]

第二步:做缩放

这里 d_k = 2,所以除以 sqrt(2),得到近似分数:

[0.71, 0, 0.71]

第三步:过 softmax

softmax 后可以近似看成:

[0.40, 0.20, 0.40]

意思是当前 token 大约把 40% 的注意力放给第 1 个位置,20% 给第 2 个位置,40% 给第 3 个位置。

第四步:加权汇总 value

输出向量为:

0.40 * v1 + 0.20 * v2 + 0.40 * v3

结果约等于:

[6.4, 4.4]

这就是一次注意力输出。你可以看到,模型不是简单复制某个位置的 value,而是把多个位置的信息按权重混合成一个新的表示。

为什么要除以 sqrt(d_k)

这是初学者最容易忽略、但非常重要的细节。

当向量维度变大时,点积的数值范围通常也会增大。如果不做缩放,softmax 输入就可能非常极端,导致:

  • 某个位置权重几乎变成 1,其他位置接近 0。
  • 梯度变得很小,训练不稳定。

除以 sqrt(d_k) 的作用,就是把分数拉回更适合 softmax 的区间。你可以把它理解成一种“温度校准”,让不同维度规模下的注意力分布都保持相对稳定。

把单头扩展成矩阵形式

真实模型不会一个 token 一个 token 地手算,而是把整个序列打包成矩阵:

  • Q = XW_Q
  • K = XW_K
  • V = XW_V

然后统一做:

softmax(QK^T / sqrt(d_k))V

这里 QK^T 会得到一个“每个位置对所有位置”的打分矩阵。矩阵中的每一行,表示一个 query 对整段上下文的注意力分布。

这就是 Transformer 能高效并行的关键:所有 token 之间的关系,可以在一次大矩阵运算中算出来,而不是像 RNN 那样逐步递推。

Multi-Head 到底带来什么

如果只有一个注意力头,模型所有关系都要在同一个表示空间里解决。Multi-Head 的想法是:

  1. 先把输入投影到多个子空间。
  2. 每个子空间各自做一次注意力。
  3. 再把多个头的结果拼接起来,统一映射回输出空间。

这样做的好处不是“头越多越神奇”,而是让模型有机会同时学习不同类型的关联,例如:

  • 一个头偏向局部语法关系。
  • 一个头偏向长距离指代。
  • 一个头偏向特殊分隔符或结构 token。

从解释性研究的角度看,不是每个头都一定具有清晰语义,但多头确实给了模型更多并行建模视角。后续很多模型分析工作,都会去研究“哪些头在做什么”。

自注意力和交叉注意力的区别

Transformer 中常见两种注意力:

  • 自注意力:Q/K/V 都来自同一段序列。比如 encoder 里,输入句子内部彼此关注。
  • 交叉注意力:Q 来自 decoder 当前状态,K/V 来自 encoder 输出。也就是“生成端去读取输入端的编码结果”。

理解这个区别很重要,因为它能帮你把 Transformer 看成一个统一积木:不同模块只是 Q/K/V 的来源不同,底层计算形式并没有变。

一个最小 PyTorch 实现

下面这段代码省略了 batch 内复杂细节,但保留了注意力的核心过程:

import math
import torch

def scaled_dot_product_attention(q, k, v, mask=None):
    scores = q @ k.transpose(-2, -1) / math.sqrt(q.size(-1))

    if mask is not None:
        scores = scores.masked_fill(mask == 0, float("-inf"))

    weights = torch.softmax(scores, dim=-1)
    output = weights @ v
    return output, weights

batch = 2
tokens = 4
d_model = 8

x = torch.randn(batch, tokens, d_model)
Wq = torch.randn(d_model, d_model)
Wk = torch.randn(d_model, d_model)
Wv = torch.randn(d_model, d_model)

q = x @ Wq
k = x @ Wk
v = x @ Wv

output, attn = scaled_dot_product_attention(q, k, v)
print(output.shape)  # [2, 4, 8]
print(attn.shape)    # [2, 4, 4]

如果你已经能读懂这段代码,说明你对单头注意力已经建立了基本理解。接下来要补的是“多个头如何拆分维度并拼接”,以及“为什么 decoder 里需要 causal mask”。

学习时最容易混淆的几点

1. 权重高不等于 value 大

注意力权重来自 QK 的匹配,最终输出内容来自 V。所以“谁被关注”与“输出什么信息”是两层逻辑。

2. 注意力不是检索数据库

它更像一种可微分的软检索。模型不会精确地找到唯一答案,而是把多个相关位置混合成新的表示。

3. 多头不是多次重复计算

每个头都在不同线性投影后的子空间中工作,因此看到的是不同的特征切面,而不是简单复制。

4. 注意力强不等于因果解释

可视化出来的注意力热图很有帮助,但不能把它直接当成模型“真正思考过程”的完整解释。它更适合帮助我们建立直觉,而不是替代理论分析。

一个建议的学习顺序

如果你要从这里继续往下学,可以按下面顺序推进:

  1. 先确保自己能手算一个单头注意力例子。
  2. 再理解 Multi-Head 的维度拆分与拼接。
  3. 接着学习位置编码,理解模型如何感知顺序。
  4. 最后回到完整 Transformer Block,看残差、LayerNorm 和 FFN 如何配合。

对应到站内内容,可以继续看:

练习题

  1. 如果把上文数值例子里的 q 改成 [0, 1],注意力权重会发生什么变化?
  2. 为什么当 d_k 很大时,不做缩放更容易让 softmax 饱和?
  3. 如果两个头学到完全一样的模式,多头机制还带来额外收益吗?
  4. decoder 中为什么必须使用 mask,而 encoder 中通常不需要?

延伸阅读

配套模拟器

先看原理,再到模拟器里调参验证,学习效果更稳定。

相关阅读

从相近主题继续深入,建立连续学习链路。