当前位置: 首页 > news >正文

Python torchvision.datasets 下常用数据集配置和使用方法

torchvision.datasets 提供了许多常用的数据集类,以下是一些常用方法(datasets中有大量数据集处理方法,这里仅展示了部分数据集处理方法)及其实现类的介绍、用法和输入参数解释:

CIFAR

  • CIFAR10 :包含 10 个类别的彩色图像数据集,共有 60000 张 32x32 的图像,其中训练集 50000 张,测试集 10000 张。

    • 参数:

      • root :数据集根目录,用于存储或检索数据集。

      • train :布尔值,若为 True,则创建训练集;否则创建测试集。

      • transform :可调用函数,对 PIL 图像进行变换并返回变换后的版本,如 transforms.RandomCrop

      • target_transform :可调用函数,对目标进行变换。

      • download :布尔值,若为 True,则从网上下载数据集并放入根目录,若数据集已存在则不再下载。

  • CIFAR100 :与 CIFAR10 类似,但包含 100 个类别,每个类别有 600 张图像,共 60000 张图像,训练集和测试集分别为 50000 和 10000 张。

使用方法演示:

import torchvision
import torchvision.transforms as transforms# 定义数据预处理方式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2
)# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(trainloader)
images, labels = next(dataiter)
print(' '.join(f'{trainset.classes[labels[j]]}' for j in range(4)))

MNIST

  • MNIST :手写数字数据集,包含 60000 张训练图像和 10000 张测试图像,每个图像是 28x28 的灰度图像。

    • 参数:

      • root :数据集根目录。

      • train :布尔值,指示是否为训练集。

      • transform :对图像进行变换的函数。

      • target_transform :对目标进行变换的函数。

      • download :布尔值,指示是否下载数据集。

  • Fashion-MNIST :与 MNIST 类似,但包含时尚产品的灰度图像,共 10 个类别,每个类别 7000 张图像,训练集和测试集分别为 60000 和 10000 张。

  • KMNIST :与 MNIST 类似,但基于日文字符数据集,用于替代 MNIST 数据集进行实验。

使用方法演示:

import torchvision
import torchvision.transforms as transforms
# 加载 MNIST 数据集
transform = transforms.Compose([transforms.ToTensor()])trainset = torchvision.datasets.MNIST(root='./data',train=True,download=True,transform=transform
)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2
)testset = torchvision.datasets.MNIST(root='./data',train=False,download=True,transform=transform
)
testloader = torch.utils.data.DataLoader(testset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(trainloader)
images, labels = next(dataiter)
print(' '.join(f'{labels[j].item()}' for j in range(4)))

ImageFolder

  • ImageFolder不是具体某个数据集。它是一个通用的图像数据加载工具类,是PyTorch中用于加载图像分类数据集的一个实用类,特别适用于将文件夹结构映射到类别标签上‌。它能够自动遍历指定目录下的所有子文件夹,并将每个子文件夹视为一个不同的类别或标签,从而将图像数据加载并组织成PyTorch的Dataset对象‌

    • 参数:

      • root :根目录路径,图像按类别文件夹存放。

      • transform :对 PIL 图像进行变换的函数。

      • target_transform :对目标进行变换的函数。

      • loader :给定路径加载图像的函数。

      • is_valid_file :检查图像文件是否有效的函数。

使用方法演示:

import torchvision
import torchvision.transforms as transforms
# 假设数据存储在 ./data/custom_data/ 下,按文件夹分类
transform = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])dataset = torchvision.datasets.ImageFolder(root='./data/custom_data/',transform=transform
)
dataloader = torch.utils.data.DataLoader(dataset,batch_size=4,shuffle=True,num_workers=2
)# 输出数据集中的前 4 张图像及其标签
dataiter = iter(dataloader)
images, labels = next(dataiter)
print(' '.join(f'{dataset.classes[labels[j]]}' for j in range(4)))

COCO

  • CocoCaptionstorchvision.datasets 中的 COCO 数据集,全称为 Microsoft Common Objects in Context(MS COCO),是计算机视觉领域一个大规模、丰富的数据集,广泛应用于目标检测、分割和字幕生成等任务。包含超过 33 万张图片,其中 20 万张有标注。每张图片有 5 个字幕描述场景,标注了 80 个对象类别和 91 种材料类别,还包含 150 万个目标实例,25 万个带关键点标注的行人。适用于目标检测、图像分割、关键点检测、图像字幕生成等多种任务。

    • 参数:

      • root :存储图像的根目录。

      • annFile :标注文件的路径。

      • transform :对 PIL 图像进行变换的函数。

      • target_transform :对目标进行变换的函数。

      • transforms :同时对输入样本及其目标进行变换的函数。

  • CocoDetection :COCO 数据集中的目标检测任务数据集,参数与 CocoCaptions 类似。

import torchvision
import torchvision.transforms as transforms# 定义数据预处理方式
transform = transforms.Compose([transforms.ToTensor(),
])# 加载 COCO 数据集中的目标检测任务数据集
train_dataset = torchvision.datasets.CocoDetection(root='./data/coco/images/train2017',annFile='./data/coco/annotations/instances_train2017.json',transform=transform
)# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)# 输出数据集中的前 4 张图像及其标注信息
dataiter = iter(train_loader)
images, targets = next(dataiter)
for i in range(4):print(f"图像 {i+1} 的标注信息:{targets[i]}")

 SVHN

  • SVHN :SVHN(Street View House Numbers)数据集是一个用于数字识别、目标检测等任务的计算机视觉数据集,由斯坦福大学于 2011 年发布,数据均采自 Google 街景图像中的门牌号码。主要用于训练数字识别模型,例如在停车场管理系统、道路交通管理系统等场景中,自动识别车辆号码中的数字。训练集包含 73257 张图像,用于训练模型。测试集包含 26032 张图像,用于评估模型性能。所有图像的大小为 32×32 像素,通道数为 3,是彩色图像,这增加了数据的复杂性,使其更接近现实世界的场景。对应 10 个类别,即数字 0 至 9,其中数字 “0” 的标签为 10。每个图像中可能包含多个数字,因此是一个多标签分类问题。

  • 参数:
    • download :布尔值,指示是否下载数据集。

    • target_transform :对目标进行变换的函数。

    • transform :对图像进行变换的函数。

    • split :指定数据集的分割方式,可选值为 'train'、'test' 或 'extra'。

    • root :数据集根目录。

import torchvision
import torchvision.transforms as transforms# 定义数据预处理方式
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])# 加载 SVHN 训练集和测试集
train_dataset = torchvision.datasets.SVHN(root='./data',split='train',download=True,transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)test_dataset = torchvision.datasets.SVHN(root='./data',split='test',download=True,transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(' '.join(f'{labels[j].item()}' for j in range(4)))

 STL10

STL10 数据集是一个用于开发无监督特征学习、深度学习和自监督学习算法的图像识别数据集。所有图像均为 96×96 像素的彩色图像,比 CIFAR-10 的图像分辨率更高。共有 113000 张图像,包含 10 个类别,如飞机、鸟、车等。训练集有 5000 张标注图像,每个类别 500 张;测试集有 8000 张标注图像,每个类别 800 张;还有 100000 张无标签图像用于无监督学习。

  • 参数:

    • root :数据集根目录。

    • split :指定数据集的分割方式,可选值为 'train'、'test'、'unlabeled' 等。

    • folds :指定使用哪个折叠的数据,可选值为 0-9 或 None。

    • transform :对图像进行变换的函数。

    • target_transform :对目标进行变换的函数。

    • download :布尔值,指示是否下载数据集。

# 加载 STL10 训练集和测试集
train_dataset = torchvision.datasets.STL10(root='./data',split='train',download=True,transform=transform
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)test_dataset = torchvision.datasets.STL10(root='./data',split='test',download=True,transform=transform
)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标签
dataiter = iter(train_loader)
images, labels = next(dataiter)
print(' '.join(f'{labels[j].item()}' for j in range(4)))

VOC

PASCAL VOC(Visual Object Classes)数据集是计算机视觉领域的重要资源,最初由英国牛津大学的计算机视觉小组创建,并在 PASCAL VOC 挑战赛中使用。VOC 2007 包含 9963 张标注过的图片,标注出 24640 个物体;VOC 2012 包含 11540 张图片,共 27450 个物体,类别数均为 20,包括人、动物(猫、狗、鸟等)、交通工具(汽车、自行车、飞机等)和室内物品(椅子、桌子等)等。

  • 参数:
    • transforms :同时对输入样本及其目标进行变换的函数。

    • target_transform :对目标进行变换的函数。

    • transform :对 PIL 图像进行变换的函数。

    • download :布尔值,指示是否下载数据集。

    • image_set :选择使用的图像集,可选值为 'train'、'trainval' 或 'val',若 year 为 2007,还可以是 'test'。

    • year :数据集年份,支持 2007 到 2012 年。

    • root :VOC 数据集的根目录。

  • VOCDetection :Pascal VOC 数据集中的目标检测任务数据集,参数与 VOCSegmentation 类似。

# 加载 VOC 数据集中的目标检测任务数据集
train_dataset = torchvision.datasets.VOCDetection(root='./data',year='2007',image_set='train',download=True
)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size=4,shuffle=True,num_workers=2
)test_dataset = torchvision.datasets.VOCDetection(root='./data',year='2007',image_set='val',download=True
)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size=4,shuffle=False,num_workers=2
)# 输出训练集中的前 4 张图像及其标注信息
dataiter = iter(train_loader)
for i in range(4):image, target = next(dataiter)print(f"图像 {i+1} 的标注信息:{target}")

相关文章:

  • 如何根据需求选择合适的氢气监测分析仪?
  • C++ Lambda 表达式
  • 24FIC 决赛 计算机部分
  • SAP SuccessFactors Recruiting and Onboarding The Comprehensive Guide
  • [250423] Caddy 2.10 正式发布:引入 ECH、后量子加密等重要更新
  • 基于javaweb的SpringBoot校园服务平台系统设计与实现(源码+文档+部署讲解)
  • 差分探头关键性能参数解析
  • 【Python语言基础】24、并发编程
  • 单片机 + 图像处理芯片 + TFT彩屏 触摸滑动条控件
  • github 简单访问方法(无魔法)
  • YOLOv8 涨点新方案:SlideLoss FocalLoss 优化,小目标检测效果炸裂!
  • LeetCode算法题(Go语言实现)_60
  • 【python】一文掌握 markitdown 库的操作(用于将文件和办公文档转换为Markdown的Python工具)
  • 第1讲:Transformers 的崛起:从RNN到Self-Attention
  • 【AI提示词】艺人顾问
  • 实验三 进程间通信实验
  • Flink介绍——实时计算核心论文之Flink论文
  • 入门-C编程基础部分:19、输入 输出
  • nuxt3持久化存储全局变量
  • 深入浅出:Pinctrl与GPIO子系统详解
  • 印媒称印巴在克什米尔控制线沿线发生小规模交火,巴方暂未回应
  • 政治局会议:要提高中低收入群体收入,设立服务消费与养老再贷款
  • 女子隐私被“上墙”莫名遭网暴,网警揪出始作俑者
  • 长三角议事厅|国际产业转移对中国产业链韧性的影响与对策
  • “70后”女博士张姿卸任国家国防科技工业局副局长
  • 百年前的亚裔艺术家与巴黎