当前位置: 首页 > news >正文

Pytorch图像数据转为Tensor张量

PyTorch的所有模型(nn.Module)都只接受Tensor格式的输入,所以我们在使用图像数据集时,必须将图像转换为Tensor格式。PyTorch提供了torchvision.transforms模块来处理图像数据集。torchvision.transforms模块提供了一些常用的图像预处理方法,如Resize、CenterCrop、RandomCrop、RandomHorizontalFlip等。
torchvision.transforms模块还提供了ToTensor()方法,可以将PIL格式的图像转换为Tensor格式。ToTensor()方法会将图像的像素值从[0, 255]范围缩放到[0, 1]范围,并将图像的通道顺序从(H, W, C)转换为(C, H, W)。
本文以Pytorch中TorchVision的FashionMNIST数据集为例,展示如何将图像数据转换为Tensor格式。

import torch
from torchvision.datasets import FashionMNIST
from torchvision.transforms import ToTensor #将图像数据转换为张量 #加载数据集
train_data = FashionMNIST(root='./fashion_data', train=True, download=True) 
test_data = FashionMNIST(root='./fashion_data', train=False, download=True)

未转换为Tensor格式的FashionMNIST数据集

train_data  #Dataset对象(输入数据的集合) 60000个样本
Dataset FashionMNISTNumber of datapoints: 60000Root location: ./fashion_dataSplit: Train
import matplotlib.pyplot as plt train_data = FashionMNIST(root='./fashion_data', train=True, download=True) img,clzz = train_data[5]  #返回一个元组,第一个元素是图片,第二个元素是标签
plt.imshow(img, cmap='gray') #显示图片,cmap='gray'表示以灰度图显示;img是一个tensor,是一个PIL.Image对象(python原始数据类型)
plt.title(clzz)
plt.show()

在这里插入图片描述

train_data[1]  #返回的是一个元组,第一个元素是图片,第二个元素是标签
#train_data[1][0].shape  #图像数据(一个颜色通道:28*28,图像高度,图像宽度)
#train_data[1][0].reshape(-1).shape  #将图像数据展平为一维张量
(<PIL.Image.Image image mode=L size=28x28>, 0)

转换为Tensor格式的FashionMNIST数据集

#加载数据集
train_data = FashionMNIST(root='./fashion_data', train=True, download=True, transform=ToTensor()) 
test_data = FashionMNIST(root='./fashion_data', train=False, download=True, transform=ToTensor())
train_data  #Dataset对象(输入数据的集合) 60000个样本
Dataset FashionMNISTNumber of datapoints: 60000Root location: ./fashion_dataSplit: TrainStandardTransform
Transform: ToTensor()

可以对比下图像数据集转换前后的效果,未转换为Tensor格式的FashionMNIST数据集是PIL格式的图像,而转换为Tensor格式的FashionMNIST数据集是Tensor格式的图像。

train_data[1]  #输出的是一个元组,第一个元素是图片,第二个元素是标签
(tensor([[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.0000, 0.0000, 0.1608, 0.7373, 0.4039, 0.2118, 0.1882, 0.1686,0.3412, 0.6588, 0.5216, 0.0627, 0.0000, 0.0000, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000, 0.0000, 0.1922,0.5333, 0.8588, 0.8471, 0.8941, 0.9255, 1.0000, 1.0000, 1.0000,1.0000, 0.8510, 0.8431, 0.9961, 0.9059, 0.6275, 0.1765, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0549, 0.6902, 0.8706,0.8784, 0.8314, 0.7961, 0.7765, 0.7686, 0.7843, 0.8431, 0.8000,0.7922, 0.7882, 0.7882, 0.7882, 0.8196, 0.8549, 0.8784, 0.6431,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.7373, 0.8588, 0.7843,0.7765, 0.7922, 0.7765, 0.7804, 0.7804, 0.7882, 0.7686, 0.7765,0.7765, 0.7843, 0.7843, 0.7843, 0.7843, 0.7882, 0.7843, 0.8824,0.1608, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.2000, 0.8588, 0.7804, 0.7961,0.7961, 0.8314, 0.9333, 0.9725, 0.9804, 0.9608, 0.9765, 0.9647,0.9686, 0.9882, 0.9725, 0.9216, 0.8118, 0.7961, 0.7961, 0.8706,0.5490, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.4549, 0.8863, 0.8078, 0.8000,0.8118, 0.8000, 0.3961, 0.2941, 0.1843, 0.2863, 0.1882, 0.1961,0.1765, 0.2000, 0.2471, 0.4431, 0.8706, 0.7922, 0.8078, 0.8627,0.8784, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.7843, 0.8706, 0.8196, 0.7961,0.8431, 0.7843, 0.0000, 0.2745, 0.3843, 0.0000, 0.4039, 0.2314,0.2667, 0.2784, 0.1922, 0.0000, 0.8588, 0.8078, 0.8392, 0.8235,0.9804, 0.1490, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.9686, 0.8549, 0.8314, 0.8235,0.8431, 0.8392, 0.0000, 0.9961, 0.9529, 0.5451, 1.0000, 0.6824,0.9843, 1.0000, 0.8039, 0.0000, 0.8431, 0.8510, 0.8392, 0.8157,0.8627, 0.3725, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.1765, 0.8863, 0.8392, 0.8392, 0.8431,0.8784, 0.8039, 0.0000, 0.1647, 0.1373, 0.2353, 0.0627, 0.0667,0.0471, 0.0510, 0.2745, 0.0000, 0.7412, 0.8471, 0.8314, 0.8078,0.8314, 0.6118, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.6431, 0.9216, 0.8392, 0.8275, 0.8627,0.8471, 0.7882, 0.2039, 0.2784, 0.3490, 0.3686, 0.3255, 0.3059,0.2745, 0.2980, 0.3608, 0.3412, 0.8078, 0.8118, 0.8706, 0.8353,0.8588, 0.8157, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.4157, 0.7333, 0.8745, 0.9294, 0.9725,0.8275, 0.7765, 0.9882, 0.9804, 0.9725, 0.9608, 0.9725, 0.9882,0.9922, 0.9804, 0.9882, 0.9373, 0.7882, 0.8314, 0.8824, 0.8431,0.7569, 0.4431, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0667, 0.2118, 0.6235,0.8706, 0.7569, 0.8157, 0.7529, 0.7725, 0.7843, 0.7843, 0.7843,0.7843, 0.7882, 0.7961, 0.7647, 0.8235, 0.6471, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.1843,0.8824, 0.7529, 0.8392, 0.7961, 0.8078, 0.8000, 0.8000, 0.8039,0.8078, 0.8000, 0.8314, 0.7725, 0.8549, 0.4196, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0235, 0.0000, 0.1804,0.8314, 0.7647, 0.8314, 0.7922, 0.8078, 0.8039, 0.8000, 0.8039,0.8078, 0.8000, 0.8314, 0.7843, 0.8549, 0.3569, 0.0000, 0.0118,0.0039, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0431,0.7725, 0.7804, 0.8039, 0.7922, 0.8039, 0.8078, 0.8000, 0.8039,0.8118, 0.8000, 0.8039, 0.8039, 0.8549, 0.3020, 0.0000, 0.0196,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0118, 0.0000, 0.0078,0.7490, 0.7765, 0.7882, 0.8039, 0.8078, 0.8039, 0.8039, 0.8078,0.8196, 0.8078, 0.7804, 0.8196, 0.8588, 0.2902, 0.0000, 0.0196,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000,0.7373, 0.7725, 0.7843, 0.8118, 0.8118, 0.8000, 0.8118, 0.8118,0.8235, 0.8157, 0.7765, 0.8118, 0.8667, 0.2824, 0.0000, 0.0157,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0078, 0.0000, 0.0000,0.8431, 0.7765, 0.7961, 0.8078, 0.8157, 0.8039, 0.8118, 0.8118,0.8235, 0.8157, 0.7843, 0.7922, 0.8706, 0.2941, 0.0000, 0.0157,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.8314, 0.7765, 0.8196, 0.8078, 0.8196, 0.8078, 0.8157, 0.8118,0.8275, 0.8078, 0.8039, 0.7765, 0.8667, 0.3137, 0.0000, 0.0118,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.8000, 0.7882, 0.8039, 0.8157, 0.8118, 0.8039, 0.8275, 0.8039,0.8235, 0.8235, 0.8196, 0.7647, 0.8667, 0.3765, 0.0000, 0.0118,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.7922, 0.7882, 0.8039, 0.8196, 0.8118, 0.8039, 0.8353, 0.8078,0.8235, 0.8196, 0.8235, 0.7608, 0.8510, 0.4118, 0.0000, 0.0078,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.8000, 0.8000, 0.8039, 0.8157, 0.8118, 0.8039, 0.8431, 0.8118,0.8235, 0.8157, 0.8275, 0.7569, 0.8353, 0.4510, 0.0000, 0.0078,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.8000, 0.8118, 0.8118, 0.8157, 0.8078, 0.8078, 0.8431, 0.8235,0.8235, 0.8118, 0.8314, 0.7647, 0.8235, 0.4627, 0.0000, 0.0078,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.7765, 0.8157, 0.8157, 0.8157, 0.8000, 0.8118, 0.8314, 0.8314,0.8235, 0.8118, 0.8275, 0.7686, 0.8118, 0.4745, 0.0000, 0.0039,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.7765, 0.8235, 0.8118, 0.8157, 0.8078, 0.8196, 0.8353, 0.8314,0.8275, 0.8118, 0.8235, 0.7725, 0.8118, 0.4863, 0.0000, 0.0039,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.6745, 0.8235, 0.7961, 0.7882, 0.7804, 0.8000, 0.8118, 0.8039,0.8000, 0.7882, 0.8039, 0.7725, 0.8078, 0.4980, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,0.7373, 0.8667, 0.8392, 0.9176, 0.9255, 0.9333, 0.9569, 0.9569,0.9569, 0.9412, 0.9529, 0.8392, 0.8784, 0.6353, 0.0000, 0.0078,0.0000, 0.0000, 0.0000, 0.0000],[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0039, 0.0000, 0.0000,0.5451, 0.5725, 0.5098, 0.5294, 0.5294, 0.5373, 0.4902, 0.4863,0.4902, 0.4745, 0.4667, 0.4471, 0.5098, 0.2980, 0.0000, 0.0000,0.0000, 0.0000, 0.0000, 0.0000]]]),0)

此时,图像数据集的每个图像都是一个Tensor对象,Tensor对象的形状为(C, H, W),其中C表示图像的通道数,H表示图像的高度,W表示图像的宽度。对于灰度图像,C=1;对于RGB图像,C=3。Tensor对象的dtype为torch.float32。
Tensor对象的像素值范围为[0, 1],而PIL格式的图像的像素值范围为[0, 255]。Tensor对象的像素值是浮点数,而PIL格式的图像的像素值是整数。

train_data[1][0].shape  #图像数据(一个颜色通道:28*28,图像高度,图像宽度)
torch.Size([1, 28, 28])

上面输出的是FashionMNIST数据集中第0个图像的Tensor对象。Tensor对象的形状为(1, 28, 28),表示图像的通道数为1,高度为28,宽度为28。Tensor对象的dtype为torch.float32,表示数据类型为32位浮点数。Tensor对象的像素值范围为[0, 1],表示像素值是浮点数。

train_data[1][0].reshape(-1).shape  #将图像数据展平为一维张量
torch.Size([784])

相关文章:

  • 大厂面试:MySQL篇
  • create_function()漏洞利用
  • centos stream 10 修改 metric
  • LSTM-GAN生成数据技术
  • 4. 继承基类实现浏览器_Chrome
  • 6.1.多级缓存架构
  • 【Axure高保真原型】动态折线图
  • MongoDB Ubuntu 安装
  • 智能文档解析系统架构师角色定义
  • 智驭未来:NVIDIA自动驾驶安全白皮书与实验室创新实践深度解析
  • Axure按钮设计分享:打造高效交互体验的六大按钮类型
  • Anomize: Better Open Vocabulary Video Anomaly Detection
  • 3.第三章:数据治理的战略价值
  • 初识Redis · 持久化
  • 配置 Nginx 的 HTTPS
  • 分布式理论和事务
  • Python常用的第三方模块之【jieba库】支持三种分词模式:精确模式、全模式和搜索引擎模式(提高召回率)
  • 从Nacos derby RCE学习derby数据库的利用
  • 【Linux】冯诺依曼体系结构及操作系统架构图的具体剖析
  • Redisson Watchdog实现原理与源码解析:分布式锁的自动续期机制
  • “家门口的图书馆”有多好?上海静安区居民给出答案
  • 时隔七年,上合组织国家电影节再度在中国举办
  • 谁为金子疯狂:有人贷款十万博两千,有人不敢再贸然囤货
  • 88岁罗马教皇方济各突然去世,遗嘱内容对外公布
  • 澳大利亚大选提前投票开始
  • 【社论】地铁读书人也是一道城市风景