PyTorch数据加载与预处理
数据加载与预处理详解
1. 数据集类(Dataset和DataLoader)
1.1 Dataset基类
PyTorch中的Dataset
是一个抽象类,所有自定义的数据集都应该继承这个类,并实现以下两个方法:
__len__()
: 返回数据集的大小__getitem__()
: 根据索引返回一个样本
概念解析:
Dataset
类提供了统一的数据访问接口- 通过继承
Dataset
,我们可以轻松地将数据集成到PyTorch的生态系统中 - 自定义数据集可以处理各种格式的数据(图像、文本、音频等)
1.2 自定义Dataset子类实例
import torch
from torch.utils.data import Dataset
from PIL import Image
import osclass CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):"""初始化数据集参数:img_dir (str): 图像目录路径transform (callable, optional): 可选的数据变换函数"""self.img_dir = img_dirself.transform = transform# 假设目录中每个文件都是一个图像样本self.img_names = os.listdir(img_dir)def __len__(self):"""返回数据集中的样本数"""return len(self.img_names)def __getitem__(self, idx):"""根据索引获取样本"""img_path = os.path.join(self.img_dir, self.img_names[idx])image = Image.open(img_path).convert('RGB') # 转换为RGB格式if self.transform:image = self.transform(image)# 假设文件名格式为"label_image.jpg"label = int(self.img_names[idx].split('_')[0])return image, label
注意事项:
- 在
__getitem__
中执行耗时的操作(如文件读取)可能会影响性能 - 确保
__len__
返回的是整数 - 如果数据需要预处理,最好在
__init__
中完成,而不是在__getitem__
中 - 多进程数据加载时,确保数据集是可序列化的
1.3 DataLoader
DataLoader
是PyTorch提供的一个迭代器,它封装了Dataset
,提供以下功能:
- 批量处理数据(batching)
- 数据打乱(shuffling)
- 多进程数据加载
主要参数:
dataset
: 要加载的数据集(Dataset对象)batch_size
: 每个batch的大小(默认: 1)shuffle
: 是否在每个epoch打乱数据(默认: False)num_workers
: 用于数据加载的子进程数(默认: 0,表示在主进程中加载)pin_memory
: 是否将数据复制到CUDA固定内存中(加速GPU传输)drop_last
: 是否丢弃最后一个不完整的batch(默认: False)
实例:
from torch.utils.data import DataLoader# 创建数据集实例
dataset = CustomImageDataset(img_dir='data/images', transform=None)# 创建DataLoader
dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=4,pin_memory=True
)# 使用示例
for batch_images, batch_labels in dataloader:# batch_images是形状为[32, C, H, W]的张量# batch_labels是形状为[32]的张量pass
注意事项:
num_workers
设置过高可能导致内存问题- 在Windows系统上,多进程加载可能需要将代码放在
if __name__ == '__main__':
块中 pin_memory
在GPU训练时可以显著提高数据传输速度- 当数据集大小不能被
batch_size
整除时,最后一个batch会比较小
2. 使用torchvision.datasets加载标准数据集
PyTorch的torchvision.datasets
模块提供了许多常用的数据集:
- MNIST
- CIFAR-10/CIFAR-100
- ImageNet
- Fashion-MNIST
- COCO
- 等等
实例:
import torchvision.datasets as datasets
import torchvision.transforms as transforms# 定义数据变换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', # 数据存储路径train=True, # 训练集download=True, # 如果数据不存在则下载transform=transform # 应用的数据变换
)test_dataset = datasets.MNIST(root='./data',train=False, # 测试集download=True,transform=transform
)# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
注意事项:
- 首次使用时会下载数据集,可能需要较长时间
root
参数指定数据存储位置,确保有足够的磁盘空间- 对于大型数据集(如ImageNet),下载和解压可能需要特殊处理
- 不同数据集可能有不同的默认变换,需要查看文档
3. 数据变换(Transforms)
3.1 torchvision.transforms常用变换
PyTorch提供了多种数据预处理变换:
-
基础变换:
ToTensor()
: 将PIL图像或NumPy数组转换为PyTorch张量Normalize(mean, std)
: 用给定的均值和标准差对张量进行归一化Resize(size)
: 调整图像大小CenterCrop(size)
: 中心裁剪Pad(padding)
: 填充
-
几何变换:
RandomHorizontalFlip(p=0.5)
: 随机水平翻转RandomVerticalFlip(p=0.5)
: 随机垂直翻转RandomRotation(degrees)
: 随机旋转RandomAffine(degrees, translate, scale, shear)
: 随机仿射变换
-
颜色变换:
ColorJitter(brightness, contrast, saturation, hue)
: 随机改变亮度、对比度、饱和度和色调Grayscale(num_output_channels=1)
: 转换为灰度图RandomGrayscale(p=0.1)
: 以概率p转换为灰度图
-
组合变换:
Compose([transforms])
: 将多个变换组合在一起
3.2 变换实例
from torchvision import transforms# 定义训练集变换
train_transform = transforms.Compose([transforms.RandomResizedCrop(224), # 随机裁剪并调整大小transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # 随机颜色变换transforms.ToTensor(), # 转换为张量transforms.Normalize( # 归一化mean=[0.485, 0.456, 0.406], # ImageNet均值std=[0.229, 0.224, 0.225] # ImageNet标准差)
])# 定义测试集变换 (通常不包含数据增强)
test_transform = transforms.Compose([transforms.Resize(256), # 调整大小transforms.CenterCrop(224), # 中心裁剪transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
注意事项:
- 变换的顺序很重要,例如
ToTensor()
通常应该在Normalize()
之前 - 训练集和测试集的变换可能不同(训练集通常包含数据增强)
- 归一化的参数(mean和std)应该与预训练模型或数据集统计信息匹配
- 某些变换只适用于PIL图像,某些只适用于张量
3.3 自定义变换
我们可以通过实现__call__
方法创建自定义变换:
import random
import torchvision.transforms.functional as Fclass RandomRotationWithReflect:"""随机旋转并保持图像大小不变(通过反射填充)"""def __init__(self, degrees):self.degrees = degreesdef __call__(self, img):angle = random.uniform(-self.degrees, self.degrees)return F.rotate(img, angle, fill=(0, 0, 0)) # 用黑色填充# 使用自定义变换
custom_transform = transforms.Compose([transforms.Resize(256),RandomRotationWithReflect(30), # 自定义变换transforms.ToTensor()
])
注意事项:
- 自定义变换应该处理PIL图像或张量,取决于它在变换链中的位置
- 确保变换是可重现的(如果设置了随机种子)
- 考虑变换对性能的影响,避免在
__call__
中执行耗时操作
4. 综合案例
下面是一个完整的图像分类任务数据加载和预处理示例:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os
from PIL import Image# 1. 自定义数据集类
class CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):self.img_dir = img_dirself.transform = transformself.classes = sorted(os.listdir(img_dir)) # 假设每个子目录是一个类别self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}self.samples = []# 收集所有图像路径和标签for cls in self.classes:cls_dir = os.path.join(img_dir, cls)for img_name in os.listdir(cls_dir):self.samples.append((os.path.join(cls_dir, img_name),self.class_to_idx[cls]))def __len__(self):return len(self.samples)def __getitem__(self, idx):img_path, label = self.samples[idx]img = Image.open(img_path).convert('RGB')if self.transform:img = self.transform(img)return img, label# 2. 定义数据变换
# 训练集变换 (包含数据增强)
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(30),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 验证集/测试集变换 (不包含数据增强)
val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 3. 创建数据集
train_dataset = CustomImageDataset(img_dir='data/train',transform=train_transform
)val_dataset = CustomImageDataset(img_dir='data/val',transform=val_transform
)# 4. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,num_workers=4,pin_memory=True,drop_last=True
)val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False,num_workers=4,pin_memory=True
)# 5. 使用示例
def train_model():for epoch in range(10):# 训练阶段model.train()for images, labels in train_loader:images = images.to(device)labels = labels.to(device)# 训练代码...# 验证阶段model.eval()with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)# 验证代码...if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')train_model()
5. 关键知识点总结
-
Dataset类:
- 必须实现
__len__
和__getitem__
方法 - 负责数据的读取和单个样本的预处理
- 可以是磁盘数据的接口,也可以是内存数据的包装器
- 必须实现
-
DataLoader:
- 负责批量生成数据
- 提供数据打乱和多进程加载功能
- 主要参数:
batch_size
,shuffle
,num_workers
,pin_memory
-
标准数据集:
torchvision.datasets
提供了常用数据集- 通常自动下载和管理数据
- 可以指定变换和数据分割(训练/测试)
-
数据变换:
transforms.Compose
用于组合多个变换- 训练集通常包含数据增强变换
- 测试集通常只包含必要的预处理
- 变换顺序很重要(如
ToTensor
通常在Normalize
之前)
-
最佳实践:
- 训练集和验证集使用不同的变换
- 使用
num_workers
加速数据加载 - GPU训练时启用
pin_memory
- 确保数据变换是可重现的(设置随机种子)
通过合理的数据加载和预处理流程,可以显著提高模型训练效率和最终性能。数据预处理应该与模型的需求和数据的特性相匹配,同时考虑计算效率和内存使用。