WITRAN_2DPSGMU_Encoder 类
以下是对 WITRAN_2DPSGMU_Encoder
类的主要逻辑和流程的图示化表示,帮助你更直观地理解代码的执行过程。
1. 整体流程图
输入数据 (input) [4, 32, 24, 11]
↓
【数据预处理】
1. 调整维度顺序 (permute)
- 如果 flag=0: [4, 32, 24, 11]
- 如果 flag=1: [32, 4, 24, 11]
↓
2. 计算扩展后的序列长度 Water2sea_slice_len = 27
↓
3. 初始化隐藏状态
- hidden_slice_row [128, 32]
- hidden_slice_col [128, 32]
↓
4. 滑动窗口填充
- input_transfer [4, 32, 27, 11]
↓
【多层网络计算】
5. 遍历每一层 (num_layers)
↓
【时间步计算】
6. 遍历每个时间步 (Water2sea_slice_len = 27)
↓
6.1 拼接输入
- hidden_slice_row [128, 32]
- hidden_slice_col [128, 32]
- 当前时间步输入 a[:, slice, :] [128, 11]
- 拼接后输入 [128, 75]
↓
6.2 线性变换 (self.linear)
- 输入 [128, 75]
- 权重 W [192, 75]
- 偏置 B [192]
- 输出 gate [128, 192]
↓
6.3 分割 gate
- sigmod_gate [128, 128] → 更新门、输出门
- tanh_gate [128, 64] → 输入门
↓
6.4 更新隐藏状态
- 更新行隐藏状态 hidden_slice_row [128, 32]
- 更新列隐藏状态 hidden_slice_col [128, 32]
↓
6.5 拼接输出
- 拼接 hidden_slice_row 和 hidden_slice_col
- 输出 output_slice [128, 64]
↓
6.6 保存输出
- 保存 output_slice 到 output_all_slice_list
- 保存行隐藏状态到 hidden_row_all_list
- 保存列隐藏状态到 hidden_col_all_list
↓
【残差连接】
7. 如果启用残差连接,将当前层的输出与第一层的输出相加
↓
【最终输出】
8. 堆叠所有时间步的输出
- output_all_slice [128, 27, 64]
- hidden_row_all [32, num_layers, 4, 32]
- hidden_col_all [32, num_layers, 24, 32]
↓
返回结果
2. 图示化表示
(1) 数据预处理
输入数据 (input) [4, 32, 24, 11]
↓ permute
调整维度顺序 (flag=0 或 flag=1)
↓
扩展序列长度 Water2sea_slice_len = 27
↓
滑动窗口填充
input_transfer [4, 32, 27, 11]
(2) 多层网络计算
遍历每一层 (num_layers)
↓
遍历每个时间步 (Water2sea_slice_len = 27)
↓
拼接输入
hidden_slice_row [128, 32]
hidden_slice_col [128, 32]
当前时间步输入 a[:, slice, :] [128, 11]
拼接后输入 [128, 75]
↓
线性变换 (self.linear)
输入 [128, 75]
权重 W [192, 75]
偏置 B [192]
输出 gate [128, 192]
↓
分割 gate
sigmod_gate [128, 128] → 更新门、输出门
tanh_gate [128, 64] → 输入门
↓
更新隐藏状态
hidden_slice_row [128, 32]
hidden_slice_col [128, 32]
↓
拼接输出
output_slice [128, 64]
↓
保存输出
output_all_slice_list
hidden_row_all_list
hidden_col_all_list
(3) 残差连接
如果启用残差连接
当前层输出 + 第一层输出
(4) 最终输出
堆叠所有时间步的输出
output_all_slice [128, 27, 64]
hidden_row_all [32, num_layers, 4, 32]
hidden_col_all [32, num_layers, 24, 32]
返回结果
3. 关键点总结
-
输入数据:
- 输入是二维时间序列数据,按天和小时组织。
- 通过滑动窗口扩展序列长度,捕获时间维度上的依赖关系。
-
门控机制:
- 类似于 LSTM 的门控机制,包含更新门、输出门和输入门。
- 分别更新行隐藏状态和列隐藏状态。
-
多层结构:
- 支持多层网络,每层输出隐藏状态。
- 可选残差连接,增强模型的表达能力。
-
最终输出:
- 返回每个时间步的输出,以及行隐藏状态和列隐藏状态。
通过以上图示化表示,可以更直观地理解 WITRAN_2DPSGMU_Encoder
的逻辑和执行流程。