当前位置: 首页 > news >正文

深度学习3.7 softmax回归的简洁实现

import torch
from torch import nn
from d2l import torch as d2lbatch_size = 256
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.7.1 初始化模型参数

net = nn.Sequential(nn.Flatten(), nn.Linear(784, 10))def init_weights(m):if type(m) == nn.Linear:nn.init.normal_(m.weight, std=0.01)net.apply(init_weights);

3.7.2 重新审视Softmax的实现

loss = nn.CrossEntropyLoss(reduction='none')

3.7.3 优化算法

# 在这里,我们(使用学习率为0.1的小批量随机梯度下降作为优化算法)
trainer = torch.optim.SGD(net.parameters(), lr=0.1)

3.7.4 训练

num_epochs = 10
d2l.train_ch3(net, train_iter, test_iter, loss, num_epochs, trainer)

在这里插入图片描述

3.7.5 预测

batch_size = 256 #迭代器批量
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)def predict_ch3(net, test_iter, n=6):  """Predict labels (defined in Chapter 3)."""for X, y in test_iter:  # 获取第一批测试数据breaktrues = d2l.get_fashion_mnist_labels(y)  # 真实标签转文本preds = d2l.get_fashion_mnist_labels(d2l.argmax(net(X), axis=1))  # 预测标签转文本titles = [true +'\n' + pred for true, pred in zip(trues, preds)]  # 组合标签d2l.show_images(d2l.reshape(X[0:n], (n, 28, 28)), 1, n, titles=titles[0:n])  # 可视化predict_ch3(net, test_iter)

在这里插入图片描述

相关文章:

  • 基于大模型的食管平滑肌瘤全周期预测与诊疗方案研究
  • Kaamel白皮书:Model Context Protocol (MCP) 隐私安全最佳实践
  • 沁恒CHV203中断嵌套导致修改线程栈-韦东山
  • 什么是IT人力外包?IT人力外包服务流程分为哪些步骤?
  • 序论文42 | patch+MLP用于长序列预测
  • Python基础语法:标识符,运算符,数据输入input(),数据输出print(),转义字符,续行符
  • CompletableFuture到底怎么用?
  • 飞算 JavaAI 的 “需求变更” 解决方案:让开发更灵活!
  • 如何解决PyQt从主窗口打开新窗口时出现闪退的问题
  • ai人才需要掌握什么
  • linux 桌面环境
  • JCE cannot authenticate the provider BC
  • 三国杀专业分析面板,立志成为桌游界的stockfish
  • Git多人协作与企业级开发模型
  • AXOP34032: 40V/40µA 轨到轨输入输出双通道运算放大器
  • 如何在windows10上英伟达gtx1060上部署通义千问-7B-Chat
  • 嵌入式:Linux系统应用程序(APP)启动流程概述
  • rk3588 驱动开发(三)第五章 新字符设备驱动实验
  • 算法设计与分析(基础)
  • 抽象类相关
  • 经济日报刊文:积极应对稳住外贸基本盘
  • 上海市闵行区原二级巡视员琚汉铮接受纪律审查和监察调查
  • 2025年一季度上海市生产总值
  • 陈曦任中华人民共和国二级大法官
  • 具身智能资本盛宴:3个月37笔融资,北上深争锋BAT下场,人形机器人最火
  • 上海与丰田汽车签署战略合作协议,雷克萨斯纯电动汽车项目落子金山