当前位置: 首页 > news >正文

PyTorch数据加载与预处理

数据加载与预处理详解

1. 数据集类(Dataset和DataLoader)

1.1 Dataset基类

PyTorch中的Dataset是一个抽象类,所有自定义的数据集都应该继承这个类,并实现以下两个方法:

  • __len__(): 返回数据集的大小
  • __getitem__(): 根据索引返回一个样本

概念解析

  • Dataset类提供了统一的数据访问接口
  • 通过继承Dataset,我们可以轻松地将数据集成到PyTorch的生态系统中
  • 自定义数据集可以处理各种格式的数据(图像、文本、音频等)

1.2 自定义Dataset子类实例

import torch
from torch.utils.data import Dataset
from PIL import Image
import osclass CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):"""初始化数据集参数:img_dir (str): 图像目录路径transform (callable, optional): 可选的数据变换函数"""self.img_dir = img_dirself.transform = transform# 假设目录中每个文件都是一个图像样本self.img_names = os.listdir(img_dir)def __len__(self):"""返回数据集中的样本数"""return len(self.img_names)def __getitem__(self, idx):"""根据索引获取样本"""img_path = os.path.join(self.img_dir, self.img_names[idx])image = Image.open(img_path).convert('RGB')  # 转换为RGB格式if self.transform:image = self.transform(image)# 假设文件名格式为"label_image.jpg"label = int(self.img_names[idx].split('_')[0])return image, label

注意事项

  1. __getitem__中执行耗时的操作(如文件读取)可能会影响性能
  2. 确保__len__返回的是整数
  3. 如果数据需要预处理,最好在__init__中完成,而不是在__getitem__
  4. 多进程数据加载时,确保数据集是可序列化的

1.3 DataLoader

DataLoader是PyTorch提供的一个迭代器,它封装了Dataset,提供以下功能:

  • 批量处理数据(batching)
  • 数据打乱(shuffling)
  • 多进程数据加载

主要参数

  • dataset: 要加载的数据集(Dataset对象)
  • batch_size: 每个batch的大小(默认: 1)
  • shuffle: 是否在每个epoch打乱数据(默认: False)
  • num_workers: 用于数据加载的子进程数(默认: 0,表示在主进程中加载)
  • pin_memory: 是否将数据复制到CUDA固定内存中(加速GPU传输)
  • drop_last: 是否丢弃最后一个不完整的batch(默认: False)

实例

from torch.utils.data import DataLoader# 创建数据集实例
dataset = CustomImageDataset(img_dir='data/images', transform=None)# 创建DataLoader
dataloader = DataLoader(dataset,batch_size=32,shuffle=True,num_workers=4,pin_memory=True
)# 使用示例
for batch_images, batch_labels in dataloader:# batch_images是形状为[32, C, H, W]的张量# batch_labels是形状为[32]的张量pass

注意事项

  1. num_workers设置过高可能导致内存问题
  2. 在Windows系统上,多进程加载可能需要将代码放在if __name__ == '__main__':块中
  3. pin_memory在GPU训练时可以显著提高数据传输速度
  4. 当数据集大小不能被batch_size整除时,最后一个batch会比较小

2. 使用torchvision.datasets加载标准数据集

PyTorch的torchvision.datasets模块提供了许多常用的数据集:

  • MNIST
  • CIFAR-10/CIFAR-100
  • ImageNet
  • Fashion-MNIST
  • COCO
  • 等等

实例

import torchvision.datasets as datasets
import torchvision.transforms as transforms# 定义数据变换
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))
])# 下载并加载MNIST数据集
train_dataset = datasets.MNIST(root='./data',          # 数据存储路径train=True,             # 训练集download=True,          # 如果数据不存在则下载transform=transform     # 应用的数据变换
)test_dataset = datasets.MNIST(root='./data',train=False,            # 测试集download=True,transform=transform
)# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

注意事项

  1. 首次使用时会下载数据集,可能需要较长时间
  2. root参数指定数据存储位置,确保有足够的磁盘空间
  3. 对于大型数据集(如ImageNet),下载和解压可能需要特殊处理
  4. 不同数据集可能有不同的默认变换,需要查看文档

3. 数据变换(Transforms)

3.1 torchvision.transforms常用变换

PyTorch提供了多种数据预处理变换:

  1. 基础变换:

    • ToTensor(): 将PIL图像或NumPy数组转换为PyTorch张量
    • Normalize(mean, std): 用给定的均值和标准差对张量进行归一化
    • Resize(size): 调整图像大小
    • CenterCrop(size): 中心裁剪
    • Pad(padding): 填充
  2. 几何变换:

    • RandomHorizontalFlip(p=0.5): 随机水平翻转
    • RandomVerticalFlip(p=0.5): 随机垂直翻转
    • RandomRotation(degrees): 随机旋转
    • RandomAffine(degrees, translate, scale, shear): 随机仿射变换
  3. 颜色变换:

    • ColorJitter(brightness, contrast, saturation, hue): 随机改变亮度、对比度、饱和度和色调
    • Grayscale(num_output_channels=1): 转换为灰度图
    • RandomGrayscale(p=0.1): 以概率p转换为灰度图
  4. 组合变换:

    • Compose([transforms]): 将多个变换组合在一起

3.2 变换实例

from torchvision import transforms# 定义训练集变换
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),          # 随机裁剪并调整大小transforms.RandomHorizontalFlip(),          # 随机水平翻转transforms.ColorJitter(0.4, 0.4, 0.4, 0.1), # 随机颜色变换transforms.ToTensor(),                      # 转换为张量transforms.Normalize(                       # 归一化mean=[0.485, 0.456, 0.406],            # ImageNet均值std=[0.229, 0.224, 0.225]              # ImageNet标准差)
])# 定义测试集变换 (通常不包含数据增强)
test_transform = transforms.Compose([transforms.Resize(256),                     # 调整大小transforms.CenterCrop(224),                 # 中心裁剪transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])

注意事项

  1. 变换的顺序很重要,例如ToTensor()通常应该在Normalize()之前
  2. 训练集和测试集的变换可能不同(训练集通常包含数据增强)
  3. 归一化的参数(mean和std)应该与预训练模型或数据集统计信息匹配
  4. 某些变换只适用于PIL图像,某些只适用于张量

3.3 自定义变换

我们可以通过实现__call__方法创建自定义变换:

import random
import torchvision.transforms.functional as Fclass RandomRotationWithReflect:"""随机旋转并保持图像大小不变(通过反射填充)"""def __init__(self, degrees):self.degrees = degreesdef __call__(self, img):angle = random.uniform(-self.degrees, self.degrees)return F.rotate(img, angle, fill=(0, 0, 0))  # 用黑色填充# 使用自定义变换
custom_transform = transforms.Compose([transforms.Resize(256),RandomRotationWithReflect(30),  # 自定义变换transforms.ToTensor()
])

注意事项

  1. 自定义变换应该处理PIL图像或张量,取决于它在变换链中的位置
  2. 确保变换是可重现的(如果设置了随机种子)
  3. 考虑变换对性能的影响,避免在__call__中执行耗时操作

4. 综合案例

下面是一个完整的图像分类任务数据加载和预处理示例:

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, datasets
import os
from PIL import Image# 1. 自定义数据集类
class CustomImageDataset(Dataset):def __init__(self, img_dir, transform=None):self.img_dir = img_dirself.transform = transformself.classes = sorted(os.listdir(img_dir))  # 假设每个子目录是一个类别self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}self.samples = []# 收集所有图像路径和标签for cls in self.classes:cls_dir = os.path.join(img_dir, cls)for img_name in os.listdir(cls_dir):self.samples.append((os.path.join(cls_dir, img_name),self.class_to_idx[cls]))def __len__(self):return len(self.samples)def __getitem__(self, idx):img_path, label = self.samples[idx]img = Image.open(img_path).convert('RGB')if self.transform:img = self.transform(img)return img, label# 2. 定义数据变换
# 训练集变换 (包含数据增强)
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.RandomVerticalFlip(),transforms.RandomRotation(30),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 验证集/测试集变换 (不包含数据增强)
val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])# 3. 创建数据集
train_dataset = CustomImageDataset(img_dir='data/train',transform=train_transform
)val_dataset = CustomImageDataset(img_dir='data/val',transform=val_transform
)# 4. 创建数据加载器
train_loader = DataLoader(train_dataset,batch_size=64,shuffle=True,num_workers=4,pin_memory=True,drop_last=True
)val_loader = DataLoader(val_dataset,batch_size=64,shuffle=False,num_workers=4,pin_memory=True
)# 5. 使用示例
def train_model():for epoch in range(10):# 训练阶段model.train()for images, labels in train_loader:images = images.to(device)labels = labels.to(device)# 训练代码...# 验证阶段model.eval()with torch.no_grad():for images, labels in val_loader:images = images.to(device)labels = labels.to(device)# 验证代码...if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')train_model()

5. 关键知识点总结

  1. Dataset类:

    • 必须实现__len____getitem__方法
    • 负责数据的读取和单个样本的预处理
    • 可以是磁盘数据的接口,也可以是内存数据的包装器
  2. DataLoader:

    • 负责批量生成数据
    • 提供数据打乱和多进程加载功能
    • 主要参数:batch_size, shuffle, num_workers, pin_memory
  3. 标准数据集:

    • torchvision.datasets提供了常用数据集
    • 通常自动下载和管理数据
    • 可以指定变换和数据分割(训练/测试)
  4. 数据变换:

    • transforms.Compose用于组合多个变换
    • 训练集通常包含数据增强变换
    • 测试集通常只包含必要的预处理
    • 变换顺序很重要(如ToTensor通常在Normalize之前)
  5. 最佳实践:

    • 训练集和验证集使用不同的变换
    • 使用num_workers加速数据加载
    • GPU训练时启用pin_memory
    • 确保数据变换是可重现的(设置随机种子)

通过合理的数据加载和预处理流程,可以显著提高模型训练效率和最终性能。数据预处理应该与模型的需求和数据的特性相匹配,同时考虑计算效率和内存使用。

相关文章:

  • Redis的两种持久化方式:RDB和AOF
  • OSPF的不规则区域和特殊区域
  • WPF实现多语言切换
  • Java 实用工具类:深入讲解 CollectionUtils
  • CCF CSP 第30次(2023.05)(4_电力网络_C++)
  • C++:string 1
  • 游戏状态管理:用Pygame实现场景切换与暂停功能
  • Java 日志:掌握本地与网络日志技术
  • 6.1腾讯技术岗2025面试趋势前瞻:大模型、云原生与安全隐私新动向
  • HTML与安全性:XSS、防御与最佳实践
  • 华为OD机试真题——二维伞的雨滴效应(2025A卷:200分)Java/python/JavaScript/C/C++/GO最佳实现
  • 在WSL2+Ubuntu22.04中通过conda pack导出一个conda环境包,然后尝试导入该环境包
  • 【Linux网络】打造初级网络计算器 - 从协议设计到服务实现
  • 1.4 大模型应用产品与技术架构
  • 静态多态和动态多态的区别
  • 【Tauri】桌面程序exe开发 - Tauri+Vue开发Windows应用 - 比Electron更轻量!8MB!
  • 【高频考点精讲】实现垂直居中的多种CSS方法比较与最佳实践
  • BS架构与CS架构的对比分析:了解两种架构的不同特点与应用
  • 计算机网络 | 应用层(4)--DNS:因特网的目录服务
  • (done) 吴恩达版提示词工程 5. 推理 (情绪分类,控制输出格式,输出 JSON,集成多个任务,文本主题推断和索引,主题内容提醒)
  • 张家界乒乓球公开赛设干部职级门槛引关注,回应:仅限嘉宾组
  • 《我的后半生》:人生下半场,也有活力重启的可能
  • 青海一只人工繁育秃鹫雏鸟破壳后脱险成活,有望填补国内空白
  • 韩国检方重启调查金建希操纵股价案
  • 李良生已任应急管理部党委委员、政治部主任
  • 马上评丨从东方红一号到神二十,中国航天步履不停