深度学习数据预处理:Dataset类的全面解析与实战指南
前言
在深度学习项目中,数据预处理是模型训练前至关重要的一环。一个高效、灵活的数据预处理流程不仅能提升模型性能,还能大大加快开发效率。本文将深入探讨PyTorch中的Dataset类,介绍数据预处理的常见技巧,并通过实战示例展示如何构建自己的数据预处理流程。
一、Dataset作用
在深度学习项目中,原始数据通常需要经过一系列处理才能输入模型。Dataset类的主要作用包括:
1. 数据统一接口:为不同类型的数据提供统一的访问接口
2. 内存高效利用:实现按需加载,避免一次性加载所有数据
3. 数据增强:方便集成各种数据增强技术
4. 代码可维护性:使数据处理逻辑模块化,便于维护和复用
二、Dataset基础
PyTorch提供了两个核心类来处理数据:
- torch.utils.data.Dataset:抽象类,所有自定义数据集应继承此类
- torch.utils.data.DataLoader:数据加载器,负责批量生成数据
基本Dataset实现:
from torch.utils.data import Datasetclass CustomDataset(Dataset):def __init__(self, data, labels, transform=None):self.data = dataself.labels = labelsself.transform = transformdef __len__(self):return len(self.data)def __getitem__(self, idx):sample = self.data[idx]label = self.labels[idx]if self.transform:sample = self.transform(sample)return sample, label
三、常见数据预处理技术
1. 图像数据预处理
from torchvision import transforms# 常见的图像预处理流程
image_transform = transforms.Compose([transforms.Resize(256), # 调整大小transforms.CenterCrop(224), # 中心裁剪transforms.ToTensor(), # 转为Tensortransforms.Normalize( # 标准化mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
2. 文本数据预处理
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator# 分词器
tokenizer = get_tokenizer('basic_english')# 构建词汇表
def yield_tokens(data_iter):for text, _ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>", "<pad>"])
vocab.set_default_index(vocab["<unk>"])# 文本转tensor
def text_pipeline(text):return torch.tensor([vocab[token] for token in tokenizer(text)], dtype=torch.long)
3. 数值数据预处理
from sklearn.preprocessing import StandardScaler# 标准化数值特征
scaler = StandardScaler()
train_data = scaler.fit_transform(train_data)
test_data = scaler.transform(test_data) # 使用相同的scaler
四、高级Dataset技巧
1. 懒加载大数据集
对于大型数据集(如图像数据集),我们通常不希望一次性加载所有数据:
class LazyImageDataset(Dataset):def __init__(self, file_paths, labels, transform=None):self.file_paths = file_pathsself.labels = labelsself.transform = transformdef __getitem__(self, idx):img_path = self.file_paths[idx]image = Image.open(img_path).convert('RGB') # 按需加载if self.transform:image = self.transform(image)return image, self.labels[idx]
2. 多模态数据集处理
处理同时包含图像和文本的数据:
class MultiModalDataset(Dataset):def __init__(self, image_paths, texts, labels, image_transform, text_transform):self.image_paths = image_pathsself.texts = textsself.labels = labelsself.image_transform = image_transformself.text_transform = text_transformdef __getitem__(self, idx):image = Image.open(self.image_paths[idx])text = self.texts[idx]label = self.labels[idx]if self.image_transform:image = self.image_transform(image)if self.text_transform:text = self.text_transform(text)return {"image": image, "text": text}, label
3. 数据增强技巧
# 训练和验证时使用不同的预处理
train_transform = transforms.Compose([transforms.RandomResizedCrop(224),transforms.RandomHorizontalFlip(),transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])val_transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
五、实战:构建图像分类Dataset
让我们实现一个完整的图像分类数据集:
import osimport numpy as np
from PIL import Imagedef train_test_file(root,dir):file_txt=open(dir+'.txt','w')path=os.path.join(root,dir)for roots,directories,files in os.walk(path):if len(directories) !=0:dirs=directorieselse:now_dir=roots.split('\\')for file in files:path_1=os.path.join(roots,file)print(path_1)file_txt.write(path_1+' '+str(dirs.index(now_dir[-1]))+'\n')file_txt.close()
root=r'.\食物分类\food_dataset'
train_dir='train'
test_dir='test'
train_test_file(root,train_dir)
train_test_file(root,test_dir)import torch
from torch import nn #导入神经网络模块,
from torch.utils.data import DataLoader #数据包管理工具,打包数据,
from torchvision import transforms
from torch.utils.data import Datasetdata_transforms={
'train':
transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45), # 随机旋转,-45到45度之间随机选transforms.CenterCrop(256), # 从中心开始裁剪[256,256]transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转 选择一个概率概率transforms.RandomVerticalFlip(p=0.5), # 随机垂直翻转transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1), # 概率转换成灰度率,3通道就是R=G=Btransforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'valid':
transforms.Compose([transforms.Resize([256, 256]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])}food_type={0:"八宝粥",1:"巴旦木",2:"白萝卜",3:"板栗",4:"菠萝",5:"草莓",6:"蛋",7:"蛋挞",8:"骨肉相连",9:"瓜子",10:"哈密瓜",11:"汉堡",12:"胡萝卜",13:"火龙果",14:"鸡翅",15:"青菜",16:"生肉",17:"圣女果",18:"薯条",19:"炸鸡"}class food_dataset(Dataset):def __init__(self,file_path,transform=None):self.file_path=file_pathself.imgs=[]self.labels=[]self.transform=transformwith open(self.file_path) as f:samples=[x.strip().split(' ') for x in f.readlines()]for img_path,label in samples:self.imgs.append(img_path)self.labels.append(label)def __len__(self):return len(self.imgs)def __getitem__(self, idx):image=Image.open(self.imgs[idx])if self.transform:image=self.transform(image)label = self.labels[idx]label = torch.from_numpy(np.array(label,dtype=np.int64))return image,labeltraining_data=food_dataset(file_path='train.txt', transform=data_transforms['train'])
test_data=food_dataset(file_path='test.txt', transform=data_transforms['valid'])train_dataloader=DataLoader(training_data,batch_size=64)
test_dataloader=DataLoader(test_data,batch_size=64)'''断当前设备是否支持GPU,其中mps是苹果m系列芯片的GPU。'''
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Using {device} device") #字符串的格式化'''定义神经网络 类的继承'''
class CNN(nn.Module): # 通过调用类的形式来使用神经网络,神经网络的模型nn.moudledef __init__(self):super().__init__() # 继承父类的初始化self.conv1=nn.Sequential(nn.Conv2d(in_channels=3,out_channels=16,kernel_size=5,stride=1,padding=2,),nn.ReLU(), #(16,28,28)nn.MaxPool2d(kernel_size=2) #(16,14,14))self.conv2=nn.Sequential(nn.Conv2d(16,32,5,1,2), #32,14,14nn.ReLU(),)self.conv3=nn.Sequential(nn.Conv2d(32,64,5,1,2), #128,7,7nn.ReLU())self.out=nn.Linear(64*128*128,20)def forward(self, x): # 前向传播,指明数据的流向,使神经网络连接起来,函数名称不能修改x=self.conv1(x)x=self.conv2(x)x=self.conv3(x)x=x.view(x.size(0),-1)out=self.out(x)return outmodel = CNN().to(device)
print(model)def train(dataloader,model,loss_fn,optimizer):model.train() #告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w。在训练过程中,w会被修改的
#pytorch提供2种方式来切换训练和测试的模式,分别是:model.train()和 model.eval()。
#一般用法是:在训练开始之前写上model.trian(),在测试时写上 model.eval()batch_size_num=1for X,y in dataloader: #其中batch为每一个数据的编号X,y=X.to(device),y.to(device) #把训练数据集和标签传入cpu或GPUpred=model.forward(X) #.forward可以被省略,父类中已经对次功能进行了设置。自动初始化loss=loss_fn(pred,y) #通过交叉熵损失函数计算损失值loss# Backpropagation 进来一个batch的数据,计算一次梯度,更新一次网络optimizer.zero_grad() #梯度值清零loss.backward() #反向传播计算得到每个参数的梯度值woptimizer.step() #根据梯度更新网络w参数loss_value=loss.item() #从tensor数据中提取数据出来,tensor获取损失值if batch_size_num %1 ==0:print(f'loss:{loss:>7f} [number:{batch_size_num}]')batch_size_num+=1def test(dataloader, model, loss_fn):size = len(dataloader.dataset)num_batches = len(dataloader)model.eval()test_loss, correct = 0, 0with torch.no_grad():for X, y in dataloader:X, y = X.to(device), y.to(device)pred = model.forward(X)test_loss += loss_fn(pred, y).item()correct += (pred.argmax(1) == y).type(torch.float).sum().item()a = (pred.argmax(1) == y)b = (pred.argmax(1) == y).type(torch.float)test_loss /= num_batchescorrect /= size# print(food_type)# print(pred.argmax(1).tolist())# print(y.tolist())result=zip(pred.argmax(1).tolist(),y.tolist())for i in result:print(f"当前测试的结果为:{food_type[i[0]]},当前真实的结果为:{food_type[i[1]]}")print(f"Test result:\n Accurracy:{(100 * correct)}%,AVG loss:{test_loss}")test_loss /=num_batchescorrect /=sizeprint(f'Test result: \n Accuracy: {(100*correct)}%, Avg loss: {test_loss}')loss_fn=nn.CrossEntropyLoss() #创建交叉熵损失函数对象,因为手写字识别中一共有10个数字,输出会有10个结果
optimizer=torch.optim.Adam(model.parameters(),lr=0.01) #创建一个优化器,SGD为随机梯度下降算法
# #params:要训练的参数,一般我们传入的都是model.parameters()#
# lr:learning_rate学习率,也就是步长#loss表示模型训练后的输出结果与,样本标签的差距。如果差距越小,就表示模型训练越好,越逼近干真实的模型。# train(train_dataloader,model,loss_fn,optimizer)
# test(test_dataloader,model,loss_fn)epoch=10
for i in range(epoch):print(i + 1)train(train_dataloader, model, loss_fn, optimizer)test(test_dataloader, model, loss_fn)
总结
数据预处理是深度学习项目成功的关键因素之一。通过合理设计Dataset类,我们可以:
1. 实现高效的数据加载和预处理
2. 方便地应用各种数据增强技术
3. 保持代码的整洁和可维护性
4. 轻松处理不同类型的数据(图像、文本、音频等)