当前位置: 首页 > news >正文

深度学习3.5 图像分类数据集

%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l

代码执行流程图

下载FashionMNIST数据集
定义标签转换函数
构建数据加载器
可视化第一批次图像
配置批量加载参数
测试数据加载速度
动态调整图像尺寸
验证调整后的数据形状

3.5.1 读取数据集

trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)

下载并加载FashionMNIST数据集
‌关键参数‌:
transform=trans:将图像转换为张量(形状 [1, 28, 28],值域 [0,1])。
download=True:若本地无数据则自动下载。
数据集结构‌:
训练集:60,000 张 28x28 灰度图像。
测试集:10,000 张 28x28 灰度图像。

def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]

标签映射
将数字标签(0-9)转换为可读的文本标签(如 0 → ‘t-shirt’)。

def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes

输入 imgs 可以是张量或PIL图像。
squeeze():移除单通道维度(1x28x28 → 28x28),否则 imshow 可能报错。
cmap=‘gray’:确保灰度图正确显示(默认可能为彩色)。

X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));

‌输出‌:显示 2行x9列 的图像网格,标题为对应的文本标签。
X.reshape(18, 28, 28):调整形状以匹配 imshow 的输入要求(原始形状为 18x1x28x28)。

在这里插入图片描述

3.5.2 读取小批量

batch_size = 256def get_dataloader_workers():return 4  # 根据CPU核心数调整(通常设为4-8)train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())

shuffle=True:打乱训练数据顺序,避免模型记忆批次。
num_workers=4:启用4个进程并行加载数据,加速数据读取。

timer = d2l.Timer()
for X, y in train_iter:continue
print(f'加载时间:{timer.stop():.2f} sec')

‘2.30 sec’

3.5.3 整合所有组件

def load_data_fashion_mnist(batch_size, resize=None):trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) # Resize必须在ToTensor前trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))

‌功能扩展‌:支持调整图像尺寸(如 resize=64 将图像缩放为 64x64)。
‌预处理顺序‌:
Resize(若指定)
ToTensor(转为张量并归一化)

train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(f'X形状: {X.shape}, 数据类型: {X.dtype}')  # 输出如 torch.Size([32,1,64,64])print(f'y形状: {y.shape}, 数据类型: {y.dtype}')  # 输出如 torch.int64break

X形状: torch.Size([32, 1, 64, 64]), 数据类型: torch.float32
y形状: torch.Size([32]), 数据类型: torch.int64

X.shape = [batch_size, channels, height, width]
y 为标签张量,形状 [batch_size]

相关文章:

  • 每日算法-250421
  • Java 并发包核心机制深度解析:锁的公平性、异步调度、AQS 原理全解
  • 【MySQL】:数据库事务管理
  • JavaEE--2.多线程
  • 把dll模块注入到游戏进程的方法_基于文件修改的注入方式
  • MCP:AI时代的“万能插座”,开启大模型无限可能
  • SvelteKit 最新中文文档教程(22)—— 最佳实践之无障碍与 SEO
  • 进程与线程:02 多进程图像
  • 在统信UOS 1060上实现自动关机
  • 高防IP能抵御哪些类型的网络攻击?
  • Buildroot、BusyBox与Yocto:嵌入式系统构建工具对比与实战指南
  • 辛格迪客户案例 | 苏州富士莱医药GMP培训管理(TMS)项目
  • 深度学习3.3 线性回归的简洁实现
  • XXL-JOB 深入理解教程
  • 【MySQL】表的约束(主键、唯一键、外键等约束类型详解)、表的设计
  • javaSE.二叉查找树和平衡二叉树
  • EMQX安装使用和客户端认证
  • PCIE Spec ---Base Address Registers
  • 13 数据存储单位与 C 语言整数类型:从位到艾字节、常见整数类型及其范围、字面量后缀、精确宽度类型详解
  • 【嵌入式系统设计师(软考中级)】第二章:嵌入式系统硬件基础知识(上)
  • 为青少年写新中国成立的故事,刘统遗著《火种》出版
  • 王毅同印尼外长苏吉约诺会谈
  • 罗马教皇方济各去世
  • 阿塞拜疆总统阿利耶夫将访华
  • 十大券商看后市|A股下行波动风险有限,震荡中有望逐步抬升
  • 涉嫌在饭局后性侵一女子,湖南机场董事长邱继兴被警方刑拘