当前位置: 首页 > news >正文

tensor 的计算操作

1、创建tensor

  常见创建 tensor 的方法

函数

作用

torch.Tensor(*size)

通过指定尺寸,生成一个 值全为 0 的 tensor

torch.tensor(*list)

直接通过指定数据,生成tensor,支持 List、Numpy数组

torch.eye(row, column)

按照指定的行列数,生成二维单位tensor

torch.rand(*size)

从 (0, 1) 之间,进行均匀分布采样,生成指定size 的tensor

torch.randn(*size)

从标准正态分布中采样,生成指定size 的tensor

torch.ones(*size)

按照指定 size,生成 值全为1 的 tensor

torch.zeros(*size)

按照指定 size,生成 值全为0 的 tensor

torch.ones_like(t)

返回尺寸与 t 相同的,值全为1 的 tensor

torch.zeros_like(t)

返回尺寸与 t 相同的,值全为0 的 tensor

torch.arange(start, end, step)

在区间[start, end) 上,每间隔 step 生成一个序列张量

torch.linspace(start, end, steps)

从 start 到 end,均匀切分成 steps 份

torch.from_Numpy(ndarray)

根据 ndarray 生成 tensor

  你可以将如下生成的 tensor,逐个打印,进行观察

import numpy as np
import torcha = torch.Tensor(2, 3)
b = torch.tensor([[1, 2, 3], [4, 5, 6]])c = torch.eye(3, 3)d = torch.rand(2, 3)
e = torch.randn(2, 3)f = torch.ones(3, 2)
g = torch.zeros(2, 3)
h = torch.ones_like(b)
i = torch.zeros_like(b)j = torch.linspace(1, 10, 4)
k = torch.arange(1, 5, 2)l = np.arange(1, 5)
m = torch.from_numpy(l)

2、查看 tensor 的形状

函数

作用

tensor.numel()

统计 tensor 元素的个数

tensor.size()

获取 tensor的尺寸,tensor.size() 是一个方法

tensor.shape

获取 tensor的尺寸,tensor.shape 是一个属性

import torcha = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(a.numel())
print(a.size())
print(a.shape)

 


3、修改 tensor 形状

函数

作用

tensor.view(*size)

修改 tensor 的尺寸,返回的对象 与 源tensor 共享内存(修改一个,另一个也会被修改)

tensor.resize(*size)

修改 tensor 的尺寸,类似于view,但在size 超出时会重新分配内存空间

tensor.reshape(*size)

修改 tensor 的尺寸,返回的对象是一个新生成的tensor,且不要求源tensor 是连续的

tensor.flatten()

将张量扁平化为一维张量

torch.unsqueeze(dim)

在指定维度增加一个 “1”

torch.squeeze(dim)

在指定维度删除一个 “1”

tensor.transpose(dim0, dim1)

交换 dim0 和 dim1 指定的两个维度,返回一个新的张量而不修改原始张量

tensor.permute(dims)

按照 dims 指定的顺序重新排列张量的维度,返回一个新的张量而不修改原始张量

import torcha = torch.tensor([[1, 2, 3], [4, 5, 6]])print(a.view(3, 2))
print(a.resize(3, 2))
print(a.reshape(3, 2))
print(a.flatten())

import torcha = torch.tensor([[[1, 2, 3], [4, 5, 6]]])
print(a.shape)b = torch.unsqueeze(a, 0)
print(b.shape)c = torch.squeeze(b, 0)
print(c.shape)

 

import torcha = torch.tensor([[1, 2, 3], [4, 5, 6]])
b = a.transpose(dim0=0, dim1=1)
print(a.shape)
print(b.shape)c = torch.tensor([[[1, 2, 2], [3, 4, 3]]])
d = c.permute(1, 2, 0)
print(c.shape)
print(d.shape)


4、按条件筛选

函数

作用

torch.index_select(input, dim, index)

从输入张量中按索引选择某些元素

torch.nonzero(input)

获取张量中非零元素的索引

torch.masked_select(input, masked)

根据掩码(mask)从输入张量中选择元素

torch.gather(input, dim, index)

在指定维度上根据索引从输入张量中选择元素

torch.scatter(input, dim, index, src)

在指定维度(dim)上根据指定的索引(index),将源张量(src)的值散射到目标张量(input)中

 1)index_select

torch.index_select(input, dim, index) : 从输入张量中按索引选择某些元素,并返回一个新的张量

import torchtensor = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])# 创建一个要选择的索引张量
indices = torch.tensor([0, 2])# 在第0维上使用index_select选择索引为0和2的行
selected_rows = torch.index_select(tensor, 0, indices)print("Selected rows:\n", selected_rows)

2)nonzero

torch.nonzero(input) : 获取张量中非零元素的索引

import torchtensor = torch.tensor([[0, 1, 0],[2, 0, 3],[0, 4, 0]])# 获取非零元素的索引
nonzero_indices = torch.nonzero(tensor)
print("Non-zero indices:\n", nonzero_indices)

 

3)masked_select

torch.masked_select(input, masked) : 根据掩码(mask)从输入张量中选择元素

import torch# 创建一个示例张量
input_tensor = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])# 创建一个掩码张量
mask_tensor = torch.tensor([[0, 1, 0],[1, 0, 1],[0, 1, 0]], dtype=torch.bool)# 使用掩码选择张量中的元素
selected_tensor = torch.masked_select(input_tensor, mask_tensor)
print("Selected tensor:\n", selected_tensor)

4)gather

torch.gather(input, dim, index) : 在指定维度上根据索引从输入张量中选择元素。若 dim=n :

  • 要求除了第n个维度,input_tensor 和 index_tensor 其他维度数量一致

  • 在第n个维度上,通过 index_tensor 指定的索引,从 input_tensor 中取出对应位置的元素。

import torch# 创建一个示例张量
input_tensor = torch.tensor([[1, 2, 3],[4, 5, 6],[7, 8, 9]])# 创建一个示例索引张量
index_tensor = torch.tensor([[0, 2],[1, 0],[2, 1]])# 在第1维度上使用索引张量收集值
gathered_tensor = torch.gather(input_tensor, 1, index_tensor)print("Gathered tensor:\n", gathered_tensor)

 

5)scatter

torch.scatter(input, dim, index, src) :在指定维度(dim)上根据指定的索引(index),将源张量(src)的值放到目标张量(input)中

  • src 的形状 和 index的形状 必须保持一致,否则会报错

import torch# 创建一个示例目标张量
input_tensor = torch.zeros((4, 4), dtype=torch.int32)# 创建一个示例索引张量
index_tensor = torch.tensor([[0, 1],[2, 3]], dtype=torch.long)# 创建一个示例源张量
src_tensor = torch.tensor([[1, 2],[3, 4]], dtype=torch.int32)# 在第1维度上使用索引张量散射值
scattered_tensor = torch.scatter(input_tensor, 1, index_tensor, src_tensor)
print("Scattered tensor:\n", scattered_tensor)

 

No. 1

  • src_tensor 第0行第0列的值为“1”

  • index_tensor 第0行第0列的值为“0”, dim=1 表示 “0” 是列索引

  • 将 “1” 放到 input_tensor 中,对应行第“0”列的位置上

No. 2

  • src_tensor 第0行第1列的值为“2”,

  • index_tensor 第0行第1列的值为“1”,dim=1 表示 “1” 是列索引

  • 将 “2” 放到 input_tensor 中,对应行第“1”列的位置上

No. 3

  • src_tensor 第1行第0列的值为“3”,

  • index_tensor 第1行第0列的值为“2”,dim=1 表示 “2” 是列索引

  • 将 “3” 放到 input_tensor 中,对应行第“2”列的位置上

No. 4

  • src_tensor 第1行第1列的值为“4”,

  • index_tensor 第1行第1列的值为“3”,dim=1 表示 “3” 是列索引

  • 将 “4” 放到 input_tensor 中,对应行第“3”列的位置上


5、运算操作

函数

作用

torch.abs()

取绝对值

torch.add()

相加

torch.addcdiv(input, tensor1, tensor2, value)

result =input + value * tensor1 / tensor2

torch.addcmul(input, tensor1, tensor2, value)

result =input + value * tensor1 * tensor2

torch.ceil()

向上取整

torch.floor()

向下取整

torch.clamp(input, min, max)

对张量的元素进行截断操作,将超出指定范围的元素限制在指定范围内

torch.exp() / torch.log() / torch.pow

指数 / 对数 / 幂

torch.mul()

逐元素乘法,效果和 * 一样

torch.neg()

取反

torch.sqrt()

开根号

torch.sign()

取符号

 1)abs

import torchx = torch.arange(-5, 5)
y = torch.abs(x)
print(y)

2)add

import torchx = torch.arange(2, 5)
y = torch.arange(4, 7)
z = torch.add(x, y)
print(z)

 

3)addcdiv

  torch.addcdiv(input, tensor1, tensor2, value)

result =input + value * tensor1 / tensor2

import torcht = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)a = t + 0.1 *(t1 / t2)
print(a)b = torch.addcdiv(t, t1, t2, value=0.1)
print(b)

 

4)addcmul

torch.addcmul(input, tensor1, tensor2, value)

result =input + value * tensor1 * tensor2

import torcht = torch.randn(1, 3)
t1 = torch.randn(3, 1)
t2 = torch.randn(1, 3)a = t + 0.1 * t1 * t2
print(a)b = torch.addcmul(t, t1, t2, value=0.1)
print(b)

 

5)ceil、floor

  • torch.ceil(input) :向上取整

  • torch.floor(input) :向下取整

import torchtorch.manual_seed(8)
x = torch.randn(3) * 10
y = torch.ceil(x)
z = torch.floor(x)
print(x)  
print(y) 
print(z)

6)clamp

  将张量元素大小限制在指定区间范围内

import torchx = torch.arange(1, 8)
y = torch.clamp(x, 2, 5)
print(y)

 

7)exp、log、pow

import torchtorch.manual_seed(8)
x = torch.arange(3)
print(x)
print(torch.exp(x))
print(torch.log(x))  # 以e为底
print(torch.pow(x, 3)) 

8)mul

  逐元素乘法,效果和 * 一样

import torcha = torch.tensor([[2, 2, 2],[2, 3, 4]])b = torch.tensor([[3, 3, 3],[4, 5, 6]])print(torch.mul(a,b))
print(a*b)

 

9)neg、sqrt、sign

  • torch.neg() 取反

  • torch.sqrt() 开根号

  • torch.sign() 取符号

import torcha = torch.tensor([[2, -2, 2], [2, -3, 4]])print(torch.neg(a)) print(torch.sqrt(a))print(torch.sign(a))


6、统计操作

函数

作用

torch.sum() / torch.prod()

求和 / 求积

torch.cumsum() / torch.cumprod()

在指定维度上进行累加 / 在指定维度上进行累乘

torch.mean()、torch.median()

均值 / 中位数

torch.std() / torch.var()

标准差 / 方差

torch.norm(t, p)

t 的 p阶范数

torch.dist(a, b, p)

a,b 之间的 p阶范数

1)sum、prod、cumsum、cumprod

import torcha = torch.linspace(0, 10, 6).view(2, 3)
print(a)b = a.sum(dim=0)
c = torch.cumsum(a, dim=0)
print('\n维度0上求和 与 累加')
print(b)
print(c)d = a.prod(dim=1)
e = torch.cumprod(a, dim=1)
print('\n维度0上求积 与 累积')
print(d)
print(e)

 

2)mean、median

import torcha = torch.tensor([[2., 2., 5.],[3., 3., 8.],[4., 4., 4.]])print('求均值')
print(torch.mean(a))
print(torch.mean(a, 0))
print(torch.mean(a, 1))print('\n求中位数')
print(torch.median(a))
print(torch.median(a, 0))
print(torch.median(a, 1))

3)std、var

import torch# 创建示例张量
a = torch.tensor([[1.0, 2.0],[3.0, 4.0]])# 计算张量的标准差
std_value = torch.std(a)
print("Standard deviation:", std_value)# 计算张量的方差
var_value = torch.var(a)
print("Variance:", var_value)

 

4)norm、dist

  • torch.norm(t, p) :t 的 p阶范数

  • torch.dist(a, b, p) :a,b 之间的 p阶范数

import torch# 创建示例张量
tensor1 = torch.tensor([1.0, 2.0, 3.0])
tensor2 = torch.tensor([4.0, 5.0, 6.0])# 使用 torch.norm() 计算张量的范数
norm_value = torch.norm(tensor1)
print("tensor1 的L2范数:", norm_value)# 使用 torch.dist() 计算两个张量的范数
dist_value = torch.dist(tensor1, tensor2)
print("tensor1 和 tensor2 之间的L2范数:", dist_value)


7、比较操作

函数

作用

torch.eq

比较 tensor 是否相等 (支持 broadcast)

torch.equal

比较 tensor 是否有相同的 shape 与 值

torch.gt / torch.lt

大于 / 小于 gt : great than ; lt : less than

torch.ge / torch.le

大于等于 / 小于等于 ge : greater than or equal to ; le : less than or equal to

torch.max / torch.min(t,axis)

最大值 / 最小值

torch.topk(t, k, axis)

在指定维度上(axis)取最高的 k个值

 1)eq、 equal

import torch# 示例张量
t = torch.tensor([[1, 2, 3],[4, 5, 6]])# 使用eq()比较张量中的元素是否等于2
result_eq = torch.eq(t, 2)
print("Result of eq():\n", result_eq)# 使用equal()比较两个张量是否相等
t1 = torch.tensor([[1, 2, 3],[4, 5, 6]])
t2 = torch.tensor([[1, 2, 3],[4, 5, 6]])
result_equal = torch.equal(t1, t2)
print("\nResult of equal():", result_equal)

2)gt、 lt、ge、 le

import torch# 示例张量
t = torch.tensor([[1, 2, 3],[4, 5, 6]])# 使用gt()比较张量中的元素是否大于3
result_gt = torch.gt(t, 3)
print("Result of gt():\n", result_gt)# 使用lt()比较张量中的元素是否小于3
result_lt = torch.lt(t, 3)
print("\nResult of lt():\n", result_lt)# 使用gt()比较张量中的元素是否大于等于3
result_ge = torch.ge(t, 3)
print("\nResult of ge():\n", result_ge)# 使用le()比较张量中的元素是否小于等于3
result_le = torch.le(t, 3)
print("\nResult of le():\n", result_le)

 

3)max、min

import torch# 示例张量
t = torch.tensor([[1, 2, 3],[4, 5, 6]])# 计算张量的最大值和最小值
max_value = torch.max(t)
min_value = torch.min(t)
print("Max value:", max_value)
print("Min value:", min_value)# 沿着指定维度计算张量的最大值和最小值
max_value_axis_0 = torch.max(t, axis=0)
min_value_axis_1 = torch.min(t, axis=1)
print("\n沿0维上的最大值:", max_value_axis_0.values)
print("沿0维上的最大值索引:", max_value_axis_0.indices)
print("沿1维上的最小值:", min_value_axis_1.values)
print("沿1维上的最小值索引:", min_value_axis_1.indices)

4)topk

import torch# 示例张量
t = torch.tensor([[1, 2, 3],[4, 5, 6]])# 沿着指定维度获取张量中最大的两个值及其索引
topk_values, topk_indices = torch.topk(t, k=2, dim=1)
print("沿维度1上的最大的2个值:\n", topk_values)
print("\n沿维度1上的最大的2个值的索引:\n", topk_indices)

 


8、矩阵操作

函数

作用

torch.dot(t1, t2)

计算 1维张量的内积或点积

torch.mul(t1, t2)

逐元素相乘

torch.mm(t1, t2) / torch.mv(t, v)

计算矩阵乘法 / 计算矩阵t与向量v 的乘法

bmm

含 batch 的 3D 矩阵乘法

svd

计算 t 的 SVD分解

 1)dot

  Torch的 dot 只能对两个 一维张量 进行点积运算,否则会报错;Numpy中的dot无此限制。

import torcha = torch.tensor([2, 3])
b = torch.tensor([3, 4])print(torch.dot(a, b))

import torcha = torch.tensor([[2, 3],[3, 4]])b = torch.tensor([[3, 4],[1, 2]])print(torch.dot(a, b))

2)mul

  • a 和 b 必须尺寸相同。

  • torch.mul(a, b) 和 a * b 效果一样

  • torch.mul(a, b) 是逐元素相乘,torch.mm(a, b) 是矩阵相乘

import torcha = torch.tensor([[2, 3],[3, 4]])b = torch.tensor([[3, 4],[1, 2]])print(torch.mul(a, b))
print(a * b)
print(torch.mm(a, b))

 

3)mm、mv

  • torch.mm(t1, t2) 是矩阵相乘, torch.mv(t, v) 矩阵与向量乘法

  • torch.mv(t, v) , 矩阵t为第一个参数,向量v为第二个参数,位置不能换,否则会报错

import torcha = torch.tensor([[1, 2, 3],[2, 3, 4]])b = torch.tensor([[1, 2],[1, 2],[3, 4]])c = torch.tensor([1, 2, 3])print(torch.mm(a, b))
print(torch.mv(a, c))

 

4)bmm 

import torchbatch1 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], dtype=torch.float32)  # Shape: (2, 2, 2)
batch2 = torch.tensor([[[2, 0], [0, 2]], [[1, 0], [0, 1]]], dtype=torch.float32)  # Shape: (2, 2, 2)result = torch.bmm(batch1, batch2)
print("Result:\n", result)

5)svd

import torcha = torch.randn(2, 3)
print(torch.svd(a))

 

相关文章:

  • AUTOSAR图解==>AUTOSAR_RS_InteractionWithBehavioralModels
  • Kafka 配置参数性能调优建议
  • 第十四届蓝桥杯Scratch03月stema选拔赛——九九乘法表
  • vite项目tailwindcss4的使用
  • WebGIS开发之地形土方开挖回填分析
  • Vue3 + Element-Plus + 阿里云文件上传
  • SpringBoot 接口国际化i18n 多语言返回 中英文切换 全球化 语言切换
  • Mioty|采用报文分割(Telegram Splitting)以提高抗干扰能力的无线通信技术
  • 北极花携语音唤醒、专家鉴定等新功能 亮相第十七届中国林业青年学术年会
  • 继承(c++版 非常详细版)
  • C++ CRC16校验方法详解
  • QT中的多线程
  • Leetcode算法题:字符串转换整数(atoi)
  • ship_plant船舶模型
  • 小草GrassRouter多卡聚合路由器聚合卫星、MESH网络应用解决方案
  • 低功耗蓝牙BLE之高斯频移键控(GFSK)
  • 【Vue.js】组件数据通信:基于Props 实现父组件→子组件传递数据(最基础案例)
  • 前端连接websocket服务报错 Unexpected response code: 301
  • 31、简要描述Promise.all的用途
  • 生成对抗网络(Generative Adversarial Nets,GAN)
  • 在岸、离岸人民币对美元汇率双双升破7.26关口
  • 俄宣布停火三天,外交部:希望各方继续通过对话谈判解决危机
  • 铁路五一假期运输今日启动,预计发送旅客1.44亿人次
  • 太好玩了!坐进大卫·霍克尼的敞篷车穿越他画笔下的四季
  • 王星昊再胜连笑,夺得中国围棋天元赛冠军
  • 美情报机构攻击中国大型商用密码产品提供商,调查报告公布