LSTM
LSTM 构造
1. LSTM 的基本结构
- 每个时间步 t:
- 输入向量:
- 隐藏状态:
- 细胞状态(Cell State):
- 输入向量:
- 与普通 RNN 相比,LSTM 额外维护了细胞状态
,可长期存储信息。
2. 三个门控制机制
LSTM 通过门控机制控制信息的流动,包括:
门名称 | 作用 | 激活函数 | 数学表达式 |
---|---|---|---|
遗忘门 |
控制保留多少旧记忆 |
sigmoid | |
输入门 |
控制写入多少新信息 | sigmoid | |
候选值 |
新信息候选值 | tanh | |
输出门 |
控制是否输出当前记忆 | sigmoid |
并行化计算
3. 细胞状态更新
- 过去的信息
乘以遗忘门 ,决定哪些部分保留。 - 新的信息
乘以输入门 ,决定哪些部分加入细胞状态。
4. 计算隐藏状态
- 细胞状态
经过 变换,使得数值限制在 [-1,1] 之间。 - 乘以输出门
决定是否输出该信息。
LSTM 反向传播的优势
用Element-wise Multiplication 的乘法
矩阵乘法 |
元素乘法 |
|
---|---|---|
计算复杂度 | 需要 O(n^2) 的计算量 | 只需要 O(n) |
梯度传播 | 可能导致梯度爆炸或消失(取决于 W 的最大奇异值) | 梯度稳定(仅在 0-1 范围内缩放) |
数学性质 | 可能会产生长距离的依赖不稳定 | 保持梯度方向一致,更稳定 |
数学推导
2. LSTM 如何计算 W 的梯度?
梯度在反向传播时如何到达 W:
- 损失 L 首先反向传播到隐藏状态
- 再通过
传递到细胞状态 - 通过细胞状态
传递到各个门(输入门 、遗忘门 、输出门 、候选状态 ) - 梯度最终回传到 W,因为门的计算公式中包含 W
梯度传播的公式:
由于每个时间步都要计算这个梯度,并进行 累加,所以时间步数越多,梯度的累积计算就越多。
LSTM 如何用元素乘法解决梯度问题
在 LSTM 中,细胞状态的更新方式是:
其中:
(遗忘门)控制旧信息的保留,其值在 0 到 1 之间。 (输入门)决定是否加入新信息。
在反向传播时:
为什么 LSTM 更稳定?
(1) LSTM 允许梯度缓慢衰减
- 在 RNN 里,梯度指数级衰减(每一步都乘
和 )。 - 在 LSTM 里,梯度只是缓慢乘以
。
(2) 遗忘门
- LSTM 不是固定的
,而是可以学习最优的 。 - 如果模型检测到需要长期记忆,
会自动接近 1,使得梯度可以长时间传播。 - 在重要的信息上,LSTM 可以动态调整
,确保梯度不会过早消失。
(3) LSTM 允许梯度选择性地消失
- 有些信息应该被遗忘,而有些信息应该长期存储。
- LSTM 通过
控制哪些信息应该被保留,哪些可以丢弃。 - 这比 RNN 全局性梯度消失/爆炸的情况更合理。
(4) 细胞状态不经过 tanh
- 在 RNN 里,梯度在每个时间步都要乘
,导致指数级衰减。 - 在 LSTM 里,细胞状态梯度传播时不经过 tanh,减少了梯度衰减。
LSTM 和ResNet 的相似性
- LSTM 通过细胞状态
形成“梯度高速公路”,类似于 ResNet 的残差连接(skip connection)。 - 高速公路网络(Highway Networks)是介于 LSTM 和 ResNet 之间的一种模型,它结合了门控机制和残差思想。
- 深度 CNN 和深度 RNN 在训练时面临相似的问题,许多架构设计的灵感可以相互借鉴。
1. LSTM 和 ResNet 的梯度高速公路
在深度网络(无论是 CNN 还是 RNN)中,一个主要挑战是 梯度消失。两种方法都使用了跳跃连接(Skip Connection) 来缓解这个问题:
(1) ResNet 的残差连接
ResNet 解决梯度消失的关键思想是 引入身份连接(Identity Connection):
y = F(x) + x
其中:
- F(x) 是普通的 CNN 层(例如卷积 + ReLU)。
- 残差连接 + x 确保梯度可以直接从深层传播回前面,不会过度衰减。
- 梯度不会完全依赖于 F(x) 的计算,因此即使 F(x) **计算出一个接近 0 的梯度,**x 仍然能保证信息流动。
梯度计算(反向传播):
这意味着,即使
(2) LSTM 的细胞状态
LSTM 也有一个类似的 “身份连接”,即:
在反向传播时:
- 类似 ResNet 的跳跃连接,LSTM 允许梯度绕过
等非线性变换,直接传播到前面的时间步。 - 如果
,梯度可以很好地保持,防止梯度消失。 - 这种梯度“高速公路”让 LSTM 在长时间依赖任务上表现比 RNN 好很多。
总结:
- ResNet 通过残差连接确保 CNN 层的梯度可以稳定传播。
- LSTM 通过细胞状态 c_t 提供了一条直接的梯度通道,确保梯度不会完全消失。
- 两者都用到了“梯度高速公路”这一核心思想。
2. 高速公路网络(Highway Networks)
在 LSTM 和 ResNet 之间,还有一个早期的概念叫 高速公路网络(Highway Networks),它借鉴了 LSTM 的门控机制,并将其应用到 CNN 里。
(1) 高速公路网络的公式
高速公路网络的核心公式:
其中:
- F(x) 是普通的神经网络层(例如卷积层)。
- T(x) 是一个门控机制(通常是 sigmoid 函数),决定当前层的输出应该来自于 F(x) 还是直接跳过并使用 x。
如果 T(x) = 1,则:
y = F(x)
如果 T(x) = 0,则:
y = x
这意味着:
- 如果网络需要保留原始信息,就可以让 T(x) 变小,相当于一个跳跃连接。
- 如果网络需要计算更复杂的特征,就让 T(x) 变大,使 F(x) 生效。
梯度计算:
• 这意味着,即使 F(x) 计算出的梯度很小,梯度仍然可以通过 x 直接回传,减少梯度消失问题。
总结:高速公路网络
• 结合了 LSTM 的门控机制(T(x))和 ResNet 的跳跃连接(+ x)。
• 允许网络动态决定是否应该让信息直接传播,或者应该进行更复杂的计算。
• 高速公路网络的思想后来被用于 ResNet,成为深度 CNN 的关键技术。
架构 | 关键思想 | 梯度如何传播? |
---|---|---|
RNN | 隐藏状态 |
梯度容易消失(因为每步乘 |
LSTM | 细胞状态 |
梯度不会指数级消失(因为 |
ResNet | 残差连接 y = F(x) + x | 梯度可以直接传播回前层,不会衰减 |
Highway Networks | 门控连接 |
结合 LSTM 门控机制 + ResNet 的跳跃连接 |