tensor.repeat和tensor.repeat_interleave
tensor.repeat
在指定维度上整体复制张量内容:
x = torch.arange(6).reshape(2,3)
print(x)
print(x.repeat(2,1))
上述代码的执行结果为:
tensor([[0, 1, 2],[3, 4, 5]])
tensor([[0, 1, 2],[3, 4, 5],[0, 1, 2],[3, 4, 5]])
可以看到,x.repeat(2, 1)
表示沿着第一个维度(行)复制 2 次,而第二个维度(列)不变,相当于将原始矩阵整体复制一份后在行方向上拼接。
tensor.repeat_interleave
这里只讲解dim
不为None
时的用法:
x = torch.arange(6).reshape(2,3)
print(x)
print(x.repeat_interleave(repeats=2,dim=0))
结果为:
tensor([[0, 1, 2],[3, 4, 5]])
tensor([[0, 1, 2],[0, 1, 2],[3, 4, 5],[3, 4, 5]])
可以看出,repeat_interleave(repeats=2, dim=0)
会将原张量在第 0 维上逐行重复每一行 2 次。
总结
repeat
和 repeat_interleave
都可以用于在张量的某个维度上进行扩展,但适用的场景略有不同。选择哪个函数取决于你想要的复制粒度:块级 还是 行/元素级。