深度学习3.5 图像分类数据集
%matplotlib inline
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
代码执行流程图
3.5.1 读取数据集
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)
下载并加载FashionMNIST数据集
关键参数:
transform=trans:将图像转换为张量(形状 [1, 28, 28],值域 [0,1])。
download=True:若本地无数据则自动下载。
数据集结构:
训练集:60,000 张 28x28 灰度图像。
测试集:10,000 张 28x28 灰度图像。
def get_fashion_mnist_labels(labels):text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']return [text_labels[int(i)] for i in labels]
标签映射
将数字标签(0-9)转换为可读的文本标签(如 0 → ‘t-shirt’)。
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5):figsize = (num_cols * scale, num_rows * scale)_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)axes = axes.flatten()for i, (ax, img) in enumerate(zip(axes, imgs)):if torch.is_tensor(img):ax.imshow(img.numpy())else:ax.imshow(img)ax.axes.get_xaxis().set_visible(False)ax.axes.get_yaxis().set_visible(False)if titles:ax.set_title(titles[i])return axes
输入 imgs 可以是张量或PIL图像。
squeeze():移除单通道维度(1x28x28 → 28x28),否则 imshow 可能报错。
cmap=‘gray’:确保灰度图正确显示(默认可能为彩色)。
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
输出:显示 2行x9列 的图像网格,标题为对应的文本标签。
X.reshape(18, 28, 28):调整形状以匹配 imshow 的输入要求(原始形状为 18x1x28x28)。
3.5.2 读取小批量
batch_size = 256def get_dataloader_workers():return 4 # 根据CPU核心数调整(通常设为4-8)train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, num_workers=get_dataloader_workers())
shuffle=True:打乱训练数据顺序,避免模型记忆批次。
num_workers=4:启用4个进程并行加载数据,加速数据读取。
timer = d2l.Timer()
for X, y in train_iter:continue
print(f'加载时间:{timer.stop():.2f} sec')
‘2.30 sec’
3.5.3 整合所有组件
def load_data_fashion_mnist(batch_size, resize=None):trans = [transforms.ToTensor()]if resize:trans.insert(0, transforms.Resize(resize)) # Resize必须在ToTensor前trans = transforms.Compose(trans)mnist_train = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True)mnist_test = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True)return (data.DataLoader(mnist_train, batch_size, shuffle=True,num_workers=get_dataloader_workers()),data.DataLoader(mnist_test, batch_size, shuffle=False,num_workers=get_dataloader_workers()))
功能扩展:支持调整图像尺寸(如 resize=64 将图像缩放为 64x64)。
预处理顺序:
Resize(若指定)
ToTensor(转为张量并归一化)
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:print(f'X形状: {X.shape}, 数据类型: {X.dtype}') # 输出如 torch.Size([32,1,64,64])print(f'y形状: {y.shape}, 数据类型: {y.dtype}') # 输出如 torch.int64break
X形状: torch.Size([32, 1, 64, 64]), 数据类型: torch.float32
y形状: torch.Size([32]), 数据类型: torch.int64
X.shape = [batch_size, channels, height, width]
y 为标签张量,形状 [batch_size]