当前位置: 首页 > news >正文

动手学深度学习:手语视频在NiN模型中的测试

前言

NiN模型是在LeNet的基础上修改,提出了1x1卷积层和全局平均池化层的概念,减少了全连接所带来的参数量很多的问题。本篇在之前代码的基础上添加了模型保存,loss和acc记录以及记录模型时间等功能,所以模型后面的代码会重新记录一下。

模型

NiN模型主要的特色有1x1卷积和全局平均池化,以下是我个人的一些看法。

1x1卷积

由于再模型结尾将不再使用全连接层,如果还是原有的3x3等卷积的话就会丢失通道之间的信息,而1x1卷积在不改变图片大小的前提下,对通道进行卷积,可以解决这一问题。

全局平均池化

这个层主要是对每一个通道的图像进行池化变成1x1大小,也是取代全连接缩小像素得到要输出的类别大小的功能,如果说全连接是横向排布,不断减少到需要的数量的话(第一张图),那么全局平均池化就是竖向连接,一次性缩小到需要的形状(第二张图)。
在这里插入图片描述

在这里插入图片描述

代码

import torch.nn as nn
import os
import time
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
def nin_block(in_channels,out_channels,kernel_size,strides,padding):return nn.Sequential(nn.Conv2d(in_channels,out_channels,kernel_size,strides,padding),nn.ReLU(),nn.Conv2d(out_channels,out_channels,1),nn.ReLU(),nn.Conv2d(out_channels,out_channels,1),nn.ReLU())
net=nn.Sequential(
nin_block(frames_len,96,11,4,0),
nn.MaxPool2d(3,2),
nin_block(96,256,5,1,2),
nn.MaxPool2d(3,2),
nin_block(256,384,3,1,1),
nn.MaxPool2d(3,2),
nn.Dropout(0.5),
nin_block(384,len(labels),3,1,1),
nn.AdaptiveAvgPool2d((1,1)),
nn.Flatten()).to(device)
def init_weight(m):if type(m)==nn.Linear or type(m)==nn.Conv2d:nn.init.xavier_uniform_(m.weight)
net.apply(init_weight)

在这里插入图片描述
学习率单独设置了一个变量

loss_fn=nn.CrossEntropyLoss()
lr=0.001
optimer=torch.optim.SGD(net.parameters(),lr=lr)#0.01会导致loss为nan

定义保存路径和模型名次,将主要需要调节的参数作为整个保存的文件夹名更易于区分。

# 初始化最小测试损失
best_test_loss = float('inf')
model_name="NiN"
epochs_num=300
save_path="./save/"+model_name+"_input_channels"+str(frames_len)+"_output_channels"+str(len(labels))+"_lr"+str(lr)+"_epochs"+str(epochs_num)+"/"
if not os.path.exists(save_path):os.makedirs(save_path)print(f"文件夹 '{save_path}' 已创建。")
else:print(f"文件夹 '{save_path}' 已存在。")
best_model_path = save_path+'model.pt'
best_onnx_path=save_path+"model.onnx"

添加计时功能,便于查看模型训练时间

train_len=len(train_iter.dataset)
all_acc=[]
all_loss=[]
test_all_acc=[]
test_all_loss=[]
start_time = time.time()
shape=None
for epoch in range(epochs_num):acc=0loss=0for x,y in train_iter:x=x.to(device)y=y.to(device)hat_y=net(x)l=loss_fn(hat_y,y)loss+=loptimer.zero_grad()l.backward()optimer.step()acc+=(hat_y.argmax(1)==y).sum()all_acc.append((acc/train_len).cpu().numpy())all_loss.append(loss.detach().cpu().numpy())
#     print(all_loss)test_acc=0test_loss=0test_len=len(test_iter.dataset)with torch.no_grad():for x,y in test_iter:x=x.to(device)y=y.to(device)shape=x.shapehat_y=net(x)test_loss+=loss_fn(hat_y,y)test_acc+=(hat_y.argmax(1)==y).sum()test_all_acc.append((test_acc/test_len).cpu().numpy())test_all_loss.append(test_loss.detach().cpu().numpy())print(f'{epoch}的test的acc{test_acc/test_len}')# 保存测试损失最小的模型if test_loss < best_test_loss:best_test_loss = test_losstorch.save(net, best_model_path)
#         dummy_input = torch.randn(shape).to(device)
#         torch.onnx.export(net, dummy_input, best_onnx_path, opset_version=11)print(f'Saved better model with Test Loss: {best_test_loss:.4f}')
end_time = time.time()
elapsed_time = end_time - start_time  # 计算耗时
print(f"程序运行了 {elapsed_time:.4f} 秒")  # 保留4位小数

在这里插入图片描述
针对loss添加了test的记录并且将图片保存起来便于以后查看

import matplotlib.pyplot as plt
plt.plot(range(1,epochs_num+1),all_loss,'.-',label='train_loss')
plt.text(epochs_num, all_loss[-1], f'{all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_loss,'.-',label='test_loss')
plt.text(epochs_num, test_all_loss[-1], f'{test_all_loss[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.legend()
plt.xlabel("epoch")
plt.ylabel("loss")
plt.savefig(save_path+"train_loss.png")

在这里插入图片描述

acc同理处理

plt.plot(range(1,epochs_num+1),all_acc,'-',label='train_acc')
plt.text(epochs_num, all_acc[-1], f'{all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.plot(range(1,epochs_num+1),test_all_acc,'-.',label='test_acc')
plt.text(epochs_num, test_all_acc[-1], f'{test_all_acc[-1]:.4f}', fontsize=12, verticalalignment='bottom')
plt.legend()
plt.xlabel("epoch")
plt.ylabel("acc")
plt.savefig(save_path+"acc.png")

在这里插入图片描述

结论

NiN整体效果上比VGG还是要差一点的,收敛速度也很慢。但是运行时间比VGG快了快一倍,VGG花费了下图时间。
在这里插入图片描述

相关文章:

  • 万物互联时代,AWS IoT Core如何构建企业级物联网中枢平台?
  • MCP系列之实践篇:搭建你的第一个MCP应用
  • DemoGen:用于数据高效视觉运动策略学习的合成演示生成
  • Python 文本和字节序列(支持字符串和字节序列的双模式API)
  • Webview+Python:用HTML打造跨平台桌面应用的创新方案
  • DHTMLX宣布推出支持 Redux、TypeScript 和 MUI 的 React Gantt甘特图控件
  • xml+html 概述
  • 【前端HTML生成条形码——MQ】
  • 极狐GitLab 项目导入导出设置介绍?
  • #Linux动态大小裁剪以及包大小变大排查思路
  • ApiHug 前端解决方案 - M1 内侧
  • Clickhouse 配置参考
  • 类型补充,scan 和数据库管理命令
  • 一本通 2063:【例1.4】牛吃牧草 1005:地球人口承载力估计
  • 下载electron 22.3.27 源码错误集锦
  • 记录一次问题排查,前台传的日期参数到后台取到的时候少了一天。
  • 考研系列-计算机网络-第二章、物理层
  • IntelliJ IDEA clean git password
  • 广搜bfs-P1443 马的遍历
  • 8.Rust+Axum 数据库集成实战:从 ORM 选型到用户管理系统开发
  • 涉李小龙形象商标被判定无效,真功夫:暂无更换计划
  • 居民被脱落的外墙瓦砖砸中致十级伤残,小区物业赔付16万元
  • 俄总统助理:普京与美特使讨论了恢复俄乌直接谈判的可能性
  • 2025全国知识产权宣传周:用AI生成的图片要小心什么?
  • 钱学森数据服务中心在沪上线,十万个数字资源向公众开放
  • 美称中美贸易谈判仍在进行中,外交部:美方不要混淆视听