Python torchvision.datasets 下常用数据集配置和使用方法
torchvision.datasets
提供了许多常用的数据集类,以下是一些常用方法(datasets中有大量数据集处理方法,这里仅展示了部分数据集处理方法
)及其实现类的介绍、用法和输入参数解释:
CIFAR
-
CIFAR10 :包含 10 个类别的彩色图像数据集,共有 60000 张 32x32 的图像,其中训练集 50000 张,测试集 10000 张。
-
参数:
-
root
:数据集根目录,用于存储或检索数据集。 -
train
:布尔值,若为 True,则创建训练集;否则创建测试集。 -
transform
:可调用函数,对 PIL 图像进行变换并返回变换后的版本,如transforms.RandomCrop
。 -
target_transform
:可调用函数,对目标进行变换。 -
download
:布尔值,若为 True,则从网上下载数据集并放入根目录,若数据集已存在则不再下载。
-
-
-
CIFAR100 :与 CIFAR10 类似,但包含 100 个类别,每个类别有 600 张图像,共 60000 张图像,训练集和测试集分别为 50000 和 10000 张。
使用方法演示:
import torchvision
import torchvision.transforms as transforms# 定义数据预处理方式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2
)# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(trainloader)
images, labels = next(dataiter)
print(' '.join(f'{trainset.classes[labels[j]]}' for j in range(4)))
MNIST
-
MNIST :手写数字数据集,包含 60000 张训练图像和 10000 张测试图像,每个图像是 28x28 的灰度图像。
-
参数:
-
root
:数据集根目录。 -
train
:布尔值,指示是否为训练集。 -
transform
:对图像进行变换的函数。 -
target_transform
:对目标进行变换的函数。 -
download
:布尔值,指示是否下载数据集。
-
-
-
Fashion-MNIST :与 MNIST 类似,但包含时尚产品的灰度图像,共 10 个类别,每个类别 7000 张图像,训练集和测试集分别为 60000 和 10000 张。
-
KMNIST :与 MNIST 类似,但基于日文字符数据集,用于替代 MNIST 数据集进行实验。
使用方法演示:
import torchvision
import torchvision.transforms as transforms
# 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor()])trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2
)testset = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(trainloader)
images, labels = next(dataiter)
print(' '.join(f'{labels[j].item()}' for j in range(4)))
ImageFolder
-
ImageFolder不是具体某个数据集。它是一个通用的图像数据加载工具类,是PyTorch中用于加载图像分类数据集的一个实用类,特别适用于将文件夹结构映射到类别标签上。它能够自动遍历指定目录下的所有子文件夹,并将每个子文件夹视为一个不同的类别或标签,从而将图像数据加载并组织成PyTorch的Dataset对象
-
参数:
-
root
:根目录路径,图像按类别文件夹存放。 -
transform
:对 PIL 图像进行变换的函数。 -
target_transform
:对目标进行变换的函数。 -
loader
:给定路径加载图像的函数。 -
is_valid_file
:检查图像文件是否有效的函数。
-
-
使用方法演示:
import torchvision
import torchvision.transforms as transforms
# 假设数据存储在 ./data/custom_data/ 下,按文件夹分类
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])dataset = torchvision.datasets.ImageFolder(root='./data/custom_data/',transform=transform
)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=4,shuffle=True,num_workers=2
)# 输出数据集中的前 4 张图像及其标签
dataiter = iter(dataloader)
images, labels = next(dataiter)
print(' '.join(f'{dataset.classes[labels[j]]}' for j in range(4)))
COCO
-
CocoCaptions :
torchvision.datasets
中的 COCO 数据集,全称为 Microsoft Common Objects in Context(MS COCO),是计算机视觉领域一个大规模、丰富的数据集,广泛应用于目标检测、分割和字幕生成等任务。包含超过 33 万张图片,其中 20 万张有标注。每张图片有 5 个字幕描述场景,标注了 80 个对象类别和 91 种材料类别,还包含 150 万个目标实例,25 万个带关键点标注的行人。适用于目标检测、图像分割、关键点检测、图像字幕生成等多种任务。-
参数:
-
root
:存储图像的根目录。 -
annFile
:标注文件的路径。 -
transform
:对 PIL 图像进行变换的函数。 -
target_transform
:对目标进行变换的函数。 -
transforms
:同时对输入样本及其目标进行变换的函数。
-
-
-
CocoDetection :COCO 数据集中的目标检测任务数据集,参数与 CocoCaptions 类似。
import torchvision
import torchvision.transforms as transforms# 定义数据预处理方式
transform = transforms.Compose([transforms.ToTensor(),
])# 加载 COCO 数据集中的目标检测任务数据集
train_dataset = torchvision.datasets.CocoDetection(root='./data/coco/images/train2017',annFile='./data/coco/annotations/instances_train2017.json',transform=transform
)# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)# 输出数据集中的前 4 张图像及其标注信息
dataiter = iter(train_loader)
images, targets = next(dataiter)
for i in range(4):print(f"图像 {i+1} 的标注信息:{targets[i]}")
SVHN
-
SVHN :SVHN(Street View House Numbers)数据集是一个用于数字识别、目标检测等任务的计算机视觉数据集,由斯坦福大学于 2011 年发布,数据均采自 Google 街景图像中的门牌号码。主要用于训练数字识别模型,例如在停车场管理系统、道路交通管理系统等场景中,自动识别车辆号码中的数字。训练集包含 73257 张图像,用于训练模型。测试集包含 26032 张图像,用于评估模型性能。所有图像的大小为 32×32 像素,通道数为 3,是彩色图像,这增加了数据的复杂性,使其更接近现实世界的场景。对应 10 个类别,即数字 0 至 9,其中数字 “0” 的标签为 10。每个图像中可能包含多个数字,因此是一个多标签分类问题。
- 参数:
-
download
:布尔值,指示是否下载数据集。 -
target_transform
:对目标进行变换的函数。 -
transform
:对图像进行变换的函数。 -
split
:指定数据集的分割方式,可选值为 'train'、'test' 或 'extra'。 -
root
:数据集根目录。
-
import torchvision
import torchvision.transforms as transforms# 定义数据预处理方式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载 SVHN 训练集和测试集
train_dataset = torchvision.datasets.SVHN(root='./data',split='train',download=True,transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)test_dataset = torchvision.datasets.SVHN(root='./data',split='test',download=True,transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(' '.join(f'{labels[j].item()}' for j in range(4)))
STL10
STL10 数据集是一个用于开发无监督特征学习、深度学习和自监督学习算法的图像识别数据集。所有图像均为 96×96 像素的彩色图像,比 CIFAR-10 的图像分辨率更高。共有 113000 张图像,包含 10 个类别,如飞机、鸟、车等。训练集有 5000 张标注图像,每个类别 500 张;测试集有 8000 张标注图像,每个类别 800 张;还有 100000 张无标签图像用于无监督学习。
-
参数:
-
root
:数据集根目录。 -
split
:指定数据集的分割方式,可选值为 'train'、'test'、'unlabeled' 等。 -
folds
:指定使用哪个折叠的数据,可选值为 0-9 或 None。 -
transform
:对图像进行变换的函数。 -
target_transform
:对目标进行变换的函数。 -
download
:布尔值,指示是否下载数据集。
-
# 加载 STL10 训练集和测试集
train_dataset = torchvision.datasets.STL10(root='./data',split='train',download=True,transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)test_dataset = torchvision.datasets.STL10(root='./data',split='test',download=True,transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(' '.join(f'{labels[j].item()}' for j in range(4)))
VOC
PASCAL VOC(Visual Object Classes)数据集是计算机视觉领域的重要资源,最初由英国牛津大学的计算机视觉小组创建,并在 PASCAL VOC 挑战赛中使用。VOC 2007 包含 9963 张标注过的图片,标注出 24640 个物体;VOC 2012 包含 11540 张图片,共 27450 个物体,类别数均为 20,包括人、动物(猫、狗、鸟等)、交通工具(汽车、自行车、飞机等)和室内物品(椅子、桌子等)等。
- 参数:
-
transforms
:同时对输入样本及其目标进行变换的函数。 -
target_transform
:对目标进行变换的函数。 -
transform
:对 PIL 图像进行变换的函数。 -
download
:布尔值,指示是否下载数据集。 -
image_set
:选择使用的图像集,可选值为 'train'、'trainval' 或 'val',若 year 为 2007,还可以是 'test'。 -
year
:数据集年份,支持 2007 到 2012 年。 -
root
:VOC 数据集的根目录。
-
-
VOCDetection :Pascal VOC 数据集中的目标检测任务数据集,参数与 VOCSegmentation 类似。
# 加载 VOC 数据集中的目标检测任务数据集
train_dataset = torchvision.datasets.VOCDetection(root='./data',year='2007',image_set='train',download=True
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)test_dataset = torchvision.datasets.VOCDetection(root='./data',year='2007',image_set='val',download=True
)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标注信息
dataiter = iter(train_loader)
for i in range(4):image, target = next(dataiter)print(f"图像 {i+1} 的标注信息:{target}")