view、reshape、resize 的区别
1、tensor.view()
1)当 tensor 连续时
当 tensor 连续时,tensor.view()
不改变存储区的真实数据,只改变元数据(Metadata) 中的信息
注:
.storage() :用于获取张量在存储区的内容
.data_ptr() :用于获取 张量在数据存储区的起始位置,即张量中第一个元素的存储位置
import torcha = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.view(2, 3)print(a)
print(b)
print(b.is_contiguous())# 查看结果,会发现二者输出一致,表示存储区的数据并没有发生改变
print(a.storage())
print(b.storage())# 查看结果,会发现二者输出一致,表示 a 和 b 共享存储区
print(a.storage().data_ptr())
print(b.storage().data_ptr())# 查看结果,会发现二者在元数据 (metadata) 中的 stride 信息发生了改变
print(a.stride())
print(b.stride())
2)当 tensor 不连续时
不连续的 tensor 是不能使用 torch.view() 方法的,否则会报错
import torcha = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()print(a)
print(b)
print(b.is_contiguous())c = b.view(6, 1)
print(c)
如果一定要用 torch.view() 方法,就必须先使用 .contiguous() 方法,让 tensor 先变得连续(重新开辟一块内存空间,生成一个新的、连续的张量对象),再使用 .view()
方法
import torcha = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()b = b.contiguous()
print(b.is_contiguous())c = b.view(6, 1)
print(c)
2、tensor.reshape()
1)当 tensor 连续时
当 tensor 连续时, tensor.reshape() 与 tensor.view() 效果一样,会和原来 tensor 共用存储区
import torcha = torch.tensor([1, 2, 3, 4, 5, 6])
b = a.reshape(2, 3)print(a)
print(b)
print(b.is_contiguous())# 查看结果,会发现二者输出一致,表示存储区的数据并没有发生改变
print(a.storage())
print(b.storage())# 查看结果,会发现二者输出一致,表示 a 和 b 共享存储区
print(a.storage().data_ptr())
print(b.storage().data_ptr())# 查看结果,会发现二者在元数据 (metadata) 中的 stride 信息发生了改变
print(a.stride())
print(b.stride())
2)当 tensor 不连续时
当 tensor 不连续时, reshape() = contiguous() + view(),即 :会先通过 .contiguous()
方法,在新的存储区创建一个连续的新的 tensor,再进行 view()
,它与原来 tensor 不共用存储区
import torch a = torch.tensor([1, 2, 3, 4, 5, 6]).view(2, 3)
b = a.t()print(a)
print(b)
print(b.is_contiguous())c = b.reshape(6, 1)
print(c)# 查看结果,会发现二者输出不一致,表示 a 和 b 不共享存储区
print(b.storage().data_ptr())
print(c.storage().data_ptr())
3、tensor.resize_()
注意 :是.resize_()
,不是 .resize()
前面说到的 .view()
和 .reshape()
都必须要用到全部的原始数据,比如:原始数据只有12个,无论你怎么变形都必须要用到 12个数字,不能多,不能少。因此,你就不能把有12个数字的 tensor 强行给 reshap 成 2*5 的维度的 tensor。
但是 .resize_()
可以,无论存储区原始数据有多少个元素,它都能将数据变成你想要的维度。
-
如果数字不够,它会用0进行填充,凑满你要的尺寸
-
如果数字多了,就只取你需要的部分
1)当原始数据 元素多余的时候
由以下代码,我们可以观察到,a 的原始数据始终是 1~7,但是在 a 中,它只取了前6个
import torcha = torch.tensor([1, 2, 3, 4, 5, 6, 7])
print(a.storage().data_ptr())a = a.resize_(2, 3)
print(a)
print(a.storage())
print(a.storage().data_ptr())
2)当原始数据 元素不够的时候
如果原始数据不够,它会开辟一个新的存储区,并用0进行填充,凑满你要的尺寸
import torcha = torch.tensor([1, 2, 3, 4, 5])
print(a.storage().data_ptr())a = a.resize_(2, 3)
print(a)
print(a.storage())
print(a.storage().data_ptr())