当前位置: 首页 > news >正文

深度学习3.5图像分类数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

代码执行流程图

下载FashionMNIST数据集
定义标签转换函数
构建数据加载器
可视化第一批次图像
配置批量加载参数
测试数据加载速度
动态调整图像尺寸
验证调整后的数据形状

3.5.1 读取数据集

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

下载并加载FashionMNIST数据集
‌关键参数‌:
transform=trans:将图像转换为张量(形状 [1, 28, 28],值域 [0,1])。
download=True:若本地无数据则自动下载。
数据集结构‌:
训练集:60,000 张 28x28 灰度图像。
测试集:10,000 张 28x28 灰度图像。

def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

标签映射
将数字标签(0-9)转换为可读的文本标签(如 0 → ‘t-shirt’)。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

输入 imgs 可以是张量或PIL图像。
squeeze():移除单通道维度(1x28x28 → 28x28),否则 imshow 可能报错。
cmap=‘gray’:确保灰度图正确显示(默认可能为彩色)。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

‌输出‌:显示 2行x9列 的图像网格,标题为对应的文本标签。
X.reshape(18, 28, 28):调整形状以匹配 imshow 的输入要求(原始形状为 18x1x28x28)。

在这里插入图片描述

3.5.2 读取小批量

batch_size = 256def get_dataloader_workers():return 4  # 根据CPU核心数调整(通常设为4-8)train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())

shuffle=True:打乱训练数据顺序,避免模型记忆批次。
num_workers=4:启用4个进程并行加载数据,加速数据读取。

timer = d2l.Timer()
for X, y in train_iter:continue
print(f'加载时间:{timer.stop():.2f} sec')

‘2.30 sec’

3.5.3 整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) # Resize必须在ToTensor前trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

‌功能扩展‌:支持调整图像尺寸(如 resize=64 将图像缩放为 64x64)。
‌预处理顺序‌:
Resize(若指定)
ToTensor(转为张量并归一化)

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(f'X形状: {X.shape}, 数据类型: {X.dtype}')  # 输出如 torch.Size([32,1,64,64])print(f'y形状: {y.shape}, 数据类型: {y.dtype}')  # 输出如 torch.int64break

X形状: torch.Size([32, 1, 64, 64]), 数据类型: torch.float32
y形状: torch.Size([32]), 数据类型: torch.int64

X.shape = [batch_size, channels, height, width]
y 为标签张量,形状 [batch_size]

相关文章:

  • SQL 使用 UPDATE FROM 语法进行更新
  • vue2练习项目 家乡特色网站—前端静态网站模板
  • (7)NodeJS的使用与NPM包管理器
  • OpenCV基础函数学习4
  • 快手砍掉本地生活的门槛
  • 【ZYNQ MP开发】Linux下使用bootgen命令生成BOOT.bin报错架构不对问题探究
  • 科学养生指南:解锁健康生活新方式
  • 3200温控板电路解析
  • XMC4800 芯片深度解读:架构、特性、应用与开发指南
  • WebRTC通信技术EasyRTC音视频实时通话安全巡检搭建低延迟、高可靠的智能巡检新体系
  • 视频生成上下文并行方案
  • SQL 中 ROLLUP 的使用方法
  • 大模型面经 | 春招、秋招算法面试常考八股文附答案(三)
  • vue3 主题模式 结合 element-plus的主题
  • 《数据结构之美--双向链表》
  • Spring_MVC 高级特性详解与实战应用
  • Debian GNU/Linux的新手入门介绍
  • 【Spring】深入解析 Spring AOP 核心概念:切点、连接点、通知、切面、通知类型和使用 @PointCut 定义切点的方法
  • 安装Github软件详细流程,win10系统从配置git到安装软件详解,以及github软件整合包制作方法(
  • BUUCTF PWN刷题笔记(1-9)
  • 水利部启动干旱防御Ⅳ级响应,指导广西陕西抗旱保供保灌
  • 钱理群|直面衰老与死亡
  • 延安市委副书记马月逢已任榆林市委副书记、市政府党组书记
  • 深化应用型人才培养,这所高校聘任行业企业专家深度参与专业设置
  • 第六次国民体质监测展开,高抬腿俯卧撑等新增运动指标受关注
  • 俄罗斯与乌克兰互换246名在押人员