当前位置: 首页 > news >正文

深度学习3.3 线性回归的简洁实现

步骤操作作用
前向计算net(X)计算预测值 y_hat = Xw + b
损失计算loss(y_hat, y)量化预测误差,驱动参数更新
反向传播l.backward()计算参数梯度
参数更新trainer.step()根据梯度调整参数,逼近最优解
梯度清零trainer.zero_grad()防止梯度累积(必须放在 backward() 之后,step() 之前)
训练监控loss(net(features), labels)评估模型整体性能,避免过拟合或欠拟合

3.3.1 生成数据集

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2ltrue_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

3.3.2 读取数据集

def load_array(data_arrays, batch_size, is_train=True):dataset = data.TensorDataset(*data_arrays)return data.DataLoader(dataset, batch_size, shuffle=is_train)batch_size = 10
data_iter = load_array((features, labels), batch_size)
next(iter(data_iter))

数据加载器 (DataLoader)
‌数据集封装‌:TensorDataset 将特征和标签包装为 PyTorch 数据集。‌
批量加载‌:DataLoader 按 batch_size=10 加载数据,训练时打乱数据 (shuffle=True)。

3.3.3 定义模型

from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

3.3.4 初始化模型参数

net[0].weight.data.normal_(0, 0.01) # 权重初始化
net[0].bias.data.fill_(0) # 偏置初始化

3.3.5 定义损失函数

loss = nn.MSELoss() # 均方误差损失

3.3.6 定义优化算法

trainer = torch.optim.SGD(net.parameters(), lr=0.03)  # 随机梯度下降

3.3.7 训练

num_epochs = 3
for epoch in range(num_epochs):for X, y in data_iter:l = loss(net(X), y)     # 前向计算损失trainer.zero_grad()      # 清零梯度l.backward()            # 反向传播trainer.step()          # 参数更新# 计算并输出整个训练集的损失l = loss(net(features), labels)print(f'epoch{epoch + 1}, loss{l:f}')

epoch1, loss0.000205
epoch2, loss0.000094
epoch3, loss0.000094

# 输出参数估计误差
w = net[0].weight.data
print(f'w的估计误差:{true_w - w.reshape(true_w.shape)}')
b = net[0].bias.data
print(f'b的估计误差:{true_b - b}')

w的估计误差:tensor([5.9402e-04, 4.6015e-05])
b的估计误差:tensor([0.0001])

相关文章:

  • XXL-JOB 深入理解教程
  • 【MySQL】表的约束(主键、唯一键、外键等约束类型详解)、表的设计
  • javaSE.二叉查找树和平衡二叉树
  • EMQX安装使用和客户端认证
  • PCIE Spec ---Base Address Registers
  • 13 数据存储单位与 C 语言整数类型:从位到艾字节、常见整数类型及其范围、字面量后缀、精确宽度类型详解
  • 【嵌入式系统设计师(软考中级)】第二章:嵌入式系统硬件基础知识(上)
  • 玩转Docker | 使用Docker部署nullboard任务管理工具
  • 基于Python的图片/签名转CAD小工具开发方案
  • 数字IC后端PR阶段Innovus,ICC,ICC2修复short万能脚本分享
  • Sunscreen的TFHE 与Parasol编译器新愿景
  • 前端配置代理解决发送cookie问题
  • 算法 | 鲸鱼优化算法(WOA)与强化学习的结合研究
  • Google独立站和阿里国际站不是一回事
  • 【踩坑tip】解决两个一样的USB设备插入后第二个识别失败的问题
  • Ubuntu20.04安装Pangolin遇到的几种报错的解决方案
  • 记录seatunnel排查重复数据的案例分析
  • 第33周JavaSpringCloud微服务 实现电商项目
  • uni-app 开发企业级小程序课程
  • AI音乐解决方案:1分钟可切换suno、udio、luno、kuka等多种模型,suno风控秒切换 | AI Music API
  • 智飞生物一季度营收下滑79%,连续三个季度亏损,称业绩波动与行业整体趋势一致
  • 旁白丨无罪后领到国家赔偿,一位退休教师卸下了“包袱”
  • 新片|真人版《星际宝贝史迪奇》5月23日与北美同步上映
  • 世界读书日丨“好书最美”,国家图书馆举办读书日特别活动
  • 新闻1+1丨居民水电气计量收费乱象,如何治?
  • 最高法:抢票软件为用户提供不正当优势,构成不正当竞争