Python torchvision.transforms 下常用图像处理方法
torchvision.transforms
是 PyTorch 用于处理图像数据的一个模块,提供了丰富的图像变换功能。
1. transforms.Compose 的使用方法
transforms.Compose
用于将多个 transforms
操作组合起来,形成一个变换序列,然后按顺序对图像进行处理。其输入参数是一个包含多个 transforms
操作的列表或元组,这些操作会按照在列表或元组中的顺序依次执行。
2. transforms 下常用的图像处理方法及输入参数
格式转换
ToTensor
功能:将 PIL
图像或 NumPy
数组转换为 torch.FloatTensor
,并将像素值归一化到 [0.0,1.0]。 输入参数:
-
输入图像:
PIL
图像或NumPy
数组。 输出:归一化后的torch.FloatTensor
,其维度为 [C,H,W],其中 C 是通道数,H 是高度,W 是宽度。
ToPILImage
功能:将 torch.Tensor
或 NumPy
数组转换为 PIL
图像。 输入参数:
-
tensor
:输入的张量,要求是torch.Tensor
或NumPy
数组,其形状应为 [C,H,W] 或 [H,W]。 -
mode
:可选参数,用于指定输出图像的模式(如 "RGB"、"L" 等)。 输出:转换后的PIL
图像。
Normalize
功能:对张量进行标准化处理,公式为 (input−mean)/std。 输入参数:
-
mean
:均值序列,长度应与输入张量的通道数一致,可以是浮点数列表或元组。 -
std
:标准差序列,长度应与输入张量的通道数一致,可以是浮点数列表或元组。 -
inplace
:可选参数,默认为False
。若为True
,则在原张量上进行操作。 输出:标准化后的张量。
ConvertImageDtype
功能:将图像张量的数据类型转换为目标数据类型。 输入参数:
-
dtype
:目标数据类型,如torch.float32
、torch.uint8
等。 输出:转换数据类型后的张量。
常用方法的演示:
import torch
from torchvision.transforms import ToTensor, ToPILImage, Normalize, ConvertImageDtype
from PIL import Image
import numpy as np# 创建示例图像
img_pil = Image.new('RGB', (100, 100)) # 创建一个空白的 RGB 图像# 1. 使用 ToTensor
tensor_converter = ToTensor()
img_tensor = tensor_converter(img_pil)
print(f"ToTensor 输出张量的形状: {img_tensor.shape}, 数据类型: {img_tensor.dtype}")# 2. 使用 Normalize
normalizer = Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
normalized_tensor = normalizer(img_tensor)
print(f"Normalize 输出张量的形状: {normalized_tensor.shape}, 数据类型: {normalized_tensor.dtype}")# 3. 使用 ConvertImageDtype
dtype_converter = ConvertImageDtype(dtype=torch.float16)
converted_tensor = dtype_converter(normalized_tensor)
print(f"ConvertImageDtype 输出张量的形状: {converted_tensor.shape}, 数据类型: {converted_tensor.dtype}")# 4. 使用 ToPILImage
pil_converter = ToPILImage()
img_back_to_pil = pil_converter(converted_tensor.to(torch.float32)) # 转换回 float32 以避免类型问题
print(f"ToPILImage 输出图像的模式: {img_back_to_pil.mode}")
假设 img_pil
是一个 100x100 的 RGB 图像,示例代码的输出可能如下:
ToTensor 输出张量的形状: torch.Size([3, 100, 100]), 数据类型: torch.float32
Normalize 输出张量的形状: torch.Size([3, 100, 100]), 数据类型: torch.float32
ConvertImageDtype 输出张量的形状: torch.Size([3, 100, 100]), 数据类型: torch.float16
ToPILImage 输出图像的模式: RGB
基本变换
transforms.Resize
-
功能 :调整图像大小。若
size
为整数,则图像的较短边会被缩放到该大小,而较长边会按比例缩放以保持宽高比;若size
为元组,则图像会被直接调整为指定的大小。 -
输入参数 :
-
size
:指定调整后的大小,可以是int
或tuple
类型。例如,size=256
表示将图像的较短边缩放到 256,而size=(256, 128)
则表示将图像调整为 256×128 的大小。 -
interpolation
:指定插值方法,默认为PIL.Image.BICUBIC
,即双三次插值。其他常见的插值方法还包括:-
PIL.Image.NEAREST
:最近邻插值。 -
PIL.Image.LANCZOS
:兰索斯插值。 -
PIL.Image.BILINEAR
:双线性插值。
-
-
max_size
:可选参数,用于指定调整后图像的最大边长。如果调整后的图像的长边超过该值,则会根据max_size
进行进一步的调整。 -
antialias
:可选参数,是否使用抗锯齿算法。默认为None
,在某些版本中可能被默认启用或禁用。
-
from torchvision import transforms
from PIL import Image# 创建一个 Resize 变换
resize_transform = transforms.Resize(size=(256, 256), interpolation=Image.BILINEAR)# 打开图像
image = Image.open("image.jpg")# 应用变换
resized_image = resize_transform(image)# 保存变换后的图像
resized_image.save("resized_image.jpg")
transforms.CenterCrop
-
功能 :从图像的中心裁剪出指定大小的区域。
-
输入参数 :
-
size
:指定裁剪区域的大小,可以是int
或tuple
类型。若为int
,则裁剪出的区域是正方形;若为tuple
,则分别指定高度和宽度。
-
from torchvision import transforms
from PIL import Image# 创建一个 CenterCrop 变换
center_crop_transform = transforms.CenterCrop(size=(200, 200))# 打开图像
image = Image.open("image.jpg")# 应用变换
cropped_image = center_crop_transform(image)# 保存变换后的图像
cropped_image.save("cropped_image.jpg")
transforms.RandomCrop
-
功能 :从图像中随机裁剪出一个指定大小的区域。
-
输入参数 :
-
size
:指定裁剪区域的大小,可以是int
或tuple
类型。 -
padding
:可选参数,用于指定在裁剪前对图像进行填充的大小。可以是int
或tuple
类型。例如,padding=2
表示在图像的上下左右各填充 2 个像素;padding=(2, 4)
则表示在上下填充 2 个像素,在左右填充 4 个像素。 -
pad_if_needed
:可选参数,默认为False
。若为True
,当图像的大小小于裁剪区域大小时,会先对图像进行填充,使图像大小满足裁剪要求。 -
fill
:可选参数,指定填充的像素值,默认为 0。 -
padding_mode
:可选参数,默认为'constant'
,表示填充的模式。常见的模式包括:-
'constant'
:常数填充。 -
'edge'
:边缘像素填充。 -
'reflect'
:反射填充。 -
'symmetric'
:对称填充。
-
-
from torchvision import transforms
from PIL import Image# 创建一个 RandomCrop 变换
random_crop_transform = transforms.RandomCrop(size=(200, 200), padding=20, pad_if_needed=True, fill=0, padding_mode='constant')# 打开图像
image = Image.open("image.jpg")# 应用变换
cropped_image = random_crop_transform(image)# 保存变换后的图像
cropped_image.save("random_cropped_image.jpg")
transforms.RandomResizedCrop
- 功能 :先随机裁剪出一个区域,然后将该区域调整为指定的大小。这种变换常用于数据增强,可以增加模型对不同尺度和宽高比的图像的适应能力。
-
输入参数 :
-
size
:指定调整后的图像大小,可以是int
或tuple
类型。 -
scale
:可选参数,默认为(0.08, 1.0)
,指定裁剪区域相对于原图面积的比例范围。例如,(0.08, 1.0)
表示裁剪区域的面积在原图面积的 8% 到 100% 之间。 -
ratio
:可选参数,默认为(0.75, 1.3333333333333333)
,指定裁剪区域的宽高比范围。例如,(0.75, 1.3333333333333333)
表示宽高比在 3:4 到 4:3 之间。 -
interpolation
:可选参数,默认为PIL.Image.BICUBIC
,指定插值方法。
-
-
from torchvision import transforms
from PIL import Image# 创建一个 RandomResizedCrop 变换
random_resized_crop_transform = transforms.RandomResizedCrop(size=(224, 224), scale=(0.08, 1.0), ratio=(0.75, 1.3333333333333333), interpolation=Image.BICUBIC)# 打开图像
image = Image.open("image.jpg")# 应用变换
cropped_and_resized_image = random_resized_crop_transform(image)# 保存变换后的图像
cropped_and_resized_image.save("random_resized_cropped_image.jpg")
数据增强
RandomHorizontalFlip
-
使用方法:以一定的概率对图像进行水平翻转。
-
输入参数:
-
p
:指定图像被水平翻转的概率,范围为 [0,1]。默认值为 0.5。
-
RandomVerticalFlip
-
使用方法:以一定的概率对图像进行垂直翻转。
-
输入参数:
-
p
:指定图像被垂直翻转的概率,范围为 [0,1]。默认值为 0.5。
-
RandomRotation
-
使用方法:将图像随机旋转一定角度。
-
输入参数:
-
degrees
:旋转角度范围,可以是数字或数字序列(如元组、列表等)。若为数字,旋转角度在 (-degrees, degrees) 之间随机;若为序列,旋转角度在指定范围内随机,例如 (min_angle, max_angle)。 -
resample
:可选参数,指定重采样方法,默认为False
。 -
expand
:可选参数,若为True
,则会扩增图像以保持完整,默认为False
。 -
center
:可选参数,指定旋转中心,默认为None
。 -
fill
:可选参数,指定填充值,默认为 0。 -
fillcolor
:可选参数,指定填充颜色,默认为None
。
-
RandomResizedCrop
-
使用方法:先随机裁剪出一个区域,然后将该区域调整为指定的大小。
-
输入参数:
-
size
:指定调整后的图像大小,可以是整数或元组。 -
scale
:可选参数,默认为(0.08, 1.0)
,指定裁剪区域相对于原图面积的比例范围。 -
ratio
:可选参数,默认为(0.75, 1.3333333333333333)
,指定裁剪区域的宽高比范围。 -
interpolation
:可选参数,默认为PIL.Image.BICUBIC
,指定插值方法。
-
RandomAffine
-
使用方法:对图像进行随机仿射变换,包括平移、旋转、缩放等。
-
输入参数:
-
degrees
:旋转角度范围,可以是数字或数字序列。若为数字,旋转角度在 (-degrees, degrees) 之间随机;若为序列,旋转角度在指定范围内随机。 -
translate
:可选参数,指定平移比例范围,默认为None
。可以是元组,其中第一个值为水平方向平移比例,第二个值为垂直方向平移比例,范围为 [0,1]。 -
scale
:可选参数,指定缩放比例范围,默认为None
。可以是元组,其中第一个值为缩放比例下限,第二个值为缩放比例上限。 -
shear
:可选参数,指定剪切角度范围,默认为None
。可以是数字或数字序列,若为数字,剪切角度在 (-shear, shear) 之间随机;若为序列,剪切角度在指定范围内随机。 -
interpolation
:可选参数,默认为3
,指定插值方法。 -
fill
:可选参数,默认为 0,指定填充值。 -
center
:可选参数,默认为None
,指定仿射变换的中心。
-
RandomErasing
-
使用方法:随机擦除图像的一部分,用于增强模型的鲁棒性。
-
输入参数:
-
p
:可选参数,默认为 0.5,指定擦除操作执行的概率。 -
scale
:可选参数,默认为(0.02, 0.33)
,指定擦除区域相对于原图面积的比例范围。 -
ratio
:可选参数,默认为(0.3, 3.3)
,指定擦除区域的宽高比范围。 -
value
:可选参数,默认为 0,指定擦除区域的填充值。 -
inplace
:可选参数,默认为False
,若为True
,则在原图像上进行擦除操作。
-
import torchvision.transforms as transforms
from PIL import Image# 定义图像变换
transform = transforms.Compose([transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),transforms.RandomRotation(degrees=30),transforms.RandomResizedCrop(size=(224, 224), scale=(0.5, 1.0), ratio=(0.75, 1.3333333333333333)),transforms.RandomAffine(degrees=30, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10),transforms.ToTensor(),transforms.RandomErasing(p=0.5, scale=(0.02, 0.2), ratio=(0.3, 3.3), value=0, inplace=False)
])# 打开图像
image = Image.open("image.jpg")# 应用变换
transformed_image = transform(image)# 显示或保存图像
# transformed_image.show()
# transformed_image.save("transformed_image.jpg")