当前位置: 首页 > news >正文

深度学习3.6 softmax回归的从零开始实现

本章节引入3.5的数据集

import torch
from IPython import display
from d2l import torch as d2lbatch_size = 256 #迭代器批量
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size)

3.6.1 初始化模型参数

num_inputs = 784 # 权重矩阵长度
num_outputs = 10 # 类别数量
W = torch.normal(0, 0.01, size=(num_inputs, num_outputs), requires_grad=True) # 权重矩阵
b = torch.zeros(num_outputs, requires_grad=True) # 偏置

图像尺寸28*28像素
‌权重W‌:从均值为0、标准差0.01的正态分布采样,形状 [784, 10]。
‌偏置b‌:初始化为全0,形状 [10]。
‌梯度追踪‌:requires_grad=True 启用自动微分。

3.6.2 定义softmax操作

def softmax(X):X_exp = torch.exp(X) # 处理计算自然指数函数e的幂(GPU计算效率高)partition = X_exp.sum(1, keepdim=True) # 0:列,1:行,计算为x行1列张量return X_exp / partition # 归一化-概率[[1/3,2/3],[3/7,4/7]]X = torch.normal(0, 1, (2, 5)) # torch.normal 用于生成服从‌正态分布(高斯分布)‌的随机数张量,支持多种参数形式(均值,标准差,(形状))
X_prob = softmax(X) # 概率
X_prob, X_prob.sum(1) # 概率和=1

在这里插入图片描述

3.6.3 定义模型

def net(X):a1 = X.reshape((-1, W.shape[0])) # 保持[*,len(W)]a2 = torch.matmul(a1, W) # torch.matmul矩阵乘法return softmax(a2 + b) # 返回对应概率

展平输入:X.reshape((-1, 784))(将 [batch_size,1,28,28] 转为 [batch_size,784])。
线性变换:XW+b(输出 [batch_size,10])。
Softmax归一化:得到每个类别的概率分布。

3.6.4 定义损失函数

y = torch.tensor([0, 2])
y_hat = torch.tensor([[0.1, 0.3, 0.6], [0.3, 0.2, 0.5]])
y_hat[[0, 1], y]

tensor([0.1000, 0.5000])
高级索引 : 索引列表会按‌位置配对‌,从y_hat中提取特定位置的元素
‌第一个元素‌:y_hat[0行, y[0]=0列] → 0.1
‌第二个元素‌:y_hat[1行, y[1]=2列] → 0.5

def cross_entropy(y_hat, y):return - torch.log(y_hat[range(len(y_hat)), y])cross_entropy(y_hat, y)

tensor([2.3026, 0.6931])

3.6.5 分类精度

def accuracy(y_hat, y):if len(y_hat.shape) > 1 and y_hat.shape[1] > 1:y_hat = y_hat.argmax(axis=1)cmp = y_hat.type(y.dtype) == yreturn float(cmp.type(y.dtype).sum())

相关文章:

  • 模拟实现strncat、qsort、atoi
  • 低光环境下双目云台摄像头监控性能解析
  • Element UI、Element Plus 里的表单验证的required必填的属性不能动态响应?
  • 题解:[ABC385F] Visible Buildings
  • GNOME桌面隐藏回收站和分区
  • 赛灵思 XC7K325T-2FFG900I FPGA Xilinx Kintex‑7
  • 基于SpringBoot的中华诗词文化分享平台-项目分享
  • 【FPGA开发】Vivado开发中的LUTRAM占用LUT资源吗
  • FPGA设计 时空变换
  • 前端学习笔记
  • 系统架构师2025年论文《论软件三层结构的设计》
  • Ubuntu24.04安装ROS2问题
  • 服务器上安装maven
  • 题解:P11185 奖牌排序
  • linux下内存地址数学运算
  • HTTP状态码有哪些常见的类型?
  • 搭建 Spark - Local 模式:开启数据处理之旅
  • 推荐一个简单又好用的在线视频编辑工具,在线免费使用,便捷高效!
  • ​​批发商商城小程序制作哪家强?开启高效批发新模式!
  • Python爬虫从入门到实战详细版教程Char01:爬虫基础与核心技术
  • 网培机构围猎中老年人:低价引流卖高价课、“名师”无资质,舆论呼吁加强监管
  • 伊朗外长访华将会见哪些人?讨论哪些议题?外交部回应
  • 动力电池、风光电设备退役潮来袭,国家队即将推出“再生计划”
  • 全球在役最大火电厂被通报
  • 广西出现今年首场超警洪水
  • 重大虚开发票偷税骗补案被查处:价税2.26亿,涉700余名主播