当前位置: 首页 > news >正文

【PyTorch】colab上跑VGG(深度学习)数据集是 CIFAR10

跑得结果是测试准确率10%,欠拟合。

import torch
import torchvision.datasets
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision import datasets, transformstransform = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor()])
device =  torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")train_data = datasets.CIFAR10(root='cifar', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='cifar', train=False, download=True, transform=transform)train_data_size = len(train_data)
test_data_size = len (test_data)print("Training data size: {}".format(train_data_size))
print("Testing data size: {}".format(test_data_size))train_dataloader = DataLoader(train_data,batch_size=64)
test_dataloader = DataLoader(test_data, batch_size=64)class VGG(nn.Module):def __init__(self):super(VGG, self).__init__()self.model = nn.Sequential(nn.Conv2d(3, 64, 3, 1, 1), #224 * 224*64nn.ReLU(),nn.Conv2d(64, 64, 3, 1, 1), #224 * 224*64nn.ReLU(),nn.MaxPool2d(2,2), #112 * 112*64nn.Conv2d(64,128, 3, 1, 1),#112 * 112*128nn.ReLU(),nn.Conv2d(128,128, 3, 1, 1), #112 * 112*128nn.ReLU(),nn.MaxPool2d(2,2), #56 * 56*128nn.Conv2d(128,256, 3, 1, 1), #56 * 56*256nn.ReLU(),nn.Conv2d(256,256, 3, 1, 1), #56 * 56*256nn.ReLU(),nn.MaxPool2d(2,2), #28 * 28*256nn.Conv2d(256,512, 3, 1, 1), #28 * 28*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #28 * 28*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #28 * 28*512nn.ReLU(),nn.MaxPool2d(2,2), #14 * 14*512nn.Conv2d(512,512, 3, 1, 1), #14 * 14*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #14 * 14*512nn.ReLU(),nn.Conv2d(512,512, 3, 1, 1), #14 * 14*512nn.ReLU(),nn.MaxPool2d(2,2), #7 * 7*512nn.Flatten(), #7*7*512 -> 25088nn.Linear(25088, 4096), #25088 -> 4096nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 4096), #4096 -> 4096nn.ReLU(),nn.Dropout(0.5),nn.Linear(4096, 1000), #4096 -> 1000)def forward(self, x):x = self.model(x)return xvgg = VGG()
vgg = vgg.to(device)loss_fn = nn.CrossEntropyLoss()
optim = torch.optim.SGD(vgg.parameters(), lr = 0.01)total_train_step = 0
total_test_step = 0epoch = 10writer = SummaryWriter("../logs")
for i in range(epoch):print("---------------------------第{}轮训练开始-------------------------------------".format(i+1))vgg.train()for data in train_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = vgg (imgs)loss = loss_fn(outputs, targets)optim.zero_grad()loss.backward()optim.step()total_train_step += 1if total_train_step % 100 == 0:print("训练次数 {},损失值:{}".format(total_train_step,loss))writer.add_scalar("train_loss", loss.item(),total_train_step)#Testingtotal_test_loss = 0total_accuracy = 0vgg.eval()with torch.no_grad():for data in test_dataloader:imgs, targets = dataimgs, targets = imgs.to(device), targets.to(device)outputs = vgg(imgs)loss = loss_fn(outputs, targets)total_test_loss += lossaccuracy = (outputs.argmax(1) == targets).sum()total_accuracy += accuracyprint("整体测试集的Loss:{}".format(total_test_loss))print("整体测试集上的正确率:{}".format(total_accuracy/test_data_size))writer.add_scalar("test_loss", total_test_loss, total_test_step)total_test_step += 1torch.save(vgg, "Elite_{}.pth".format(i))
writer.close()

相关文章:

  • Python 一等函数( 把函数视作对象)
  • AtCoder ABC402 A~D 题解
  • 五分钟学会如何基本使用JJWT!!!
  • Linux系统编程 day6 进程间通信mmap
  • 借助LlamaIndex实现简单Agent
  • Day2—3:前端项目uniapp壁纸实战
  • 深入理解 MCP 协议:开启 AI 交互新时代
  • 【人工智能】再谈探索AI幻觉及其解决方案(进一步整理)
  • 信创开发:开启信息自主创新、国产替代新时代
  • [Java微服务组件]注册中心P3-Nacos中的设计模式1-观察者模式
  • mysql控制单表数据存储及单实例表创建
  • 生物化学笔记:医学免疫学原理23 免疫检查点分子与肿瘤免疫治疗(PD-1抑制剂黑色素瘤)
  • 【进程信号】五、信号集操作接口详解
  • SICAR标准功能块 FB1514 “Robot_request_FB”
  • 增量式PID基础解析与代码实例:温控系统
  • 有效的完全平方数--LeetCode
  • HFSS3(limy)——建模学习记录
  • 工业级MIFI解决方案:打造低时延、高可靠性的Wi-Fi网络快速部署体系!
  • 【专刷】滑动窗口(一)
  • 字符串系列一>二进制求和
  • 经济日报金观平:拥抱中国就是拥抱确定性
  • “30小时不够”,泽连斯基建议延长停火至30天
  • 财政部:一季度证券交易印花税411亿元,同比增长60.6%
  • 江西一季度GDP为7927.1亿元,同比增长5.7%
  • 西北政法大学推无手机课堂,有学生称要求全交,学校:并非强制
  • 接下来上海很热闹,天后天团轮番来开演唱会