pytorch中的各种乘法
那么多相乘,讲实话我到现在也没仔细梳理过,所以现在搞一下子。
首先声明一个向量和一个二维矩阵
import torch vec = torch.arange(4) mtx = torch.arange(12).reshape(4,3) print(vec, mtx,sep='\n') 复制代码
输出结果:
>> tensor([0, 1, 2, 3]) tensor([[ 0, 1, 2], [ 3, 4, 5], [ 6, 7, 8], [ 9, 10, 11]]) 复制代码
按位置*
这个*
在pytorch中是按位置相乘,存在广播机制。
import torch vec = torch.arange(4) mtx = torch.arange(12).reshape(4,3) print(vec*vec) print(mtx*mtx) 复制代码
>> tensor([0, 1, 4, 9]) tensor([[ 0, 1, 4], [ 9, 16, 25], [ 36, 49, 64], [ 81, 100, 121]]) 复制代码
但是需要注意的一点是虽然众多地方提到向量默认是列向量,但是在pytorch中一维的张量没有这种说法。 就算你用3×4的张量乘4×1的张量,得出的结果本应该是3×1的张量,但是因为是一维张量,也会变成默认的3(也不是1×3)。
可以执行的状态下print(mtx*vec)
和print(vec*mtx)
的结果是完全一样的。
但是上面的例子中如果执行print(mtx*vec)
或者print(vec*mtx)
会报错。因为默认情况下,一维的张量和矩阵执行*
操作的时候,一维张量中元素的个数必须和二维矩阵列数相同,否则广播功能失效。
当然也可以使用reshap()
为其增加一个维度。但是增加维度之后要遵守一些维度规则。
import torch vec = torch.arange(4).reshape(4,1) # 增加维度 mtx = torch.arange(12).reshape(4,3) print(vec*mtx) print(mtx*vec) 复制代码
>> tensor([[ 0, 0, 0], [ 3, 4, 5], [12, 14, 16], [27, 30, 33]]) tensor([[ 0, 0, 0], [ 3, 4, 5], [12, 14, 16], [27, 30, 33]]) 复制代码
比如上边矩阵是4×3的。
第二行代码你可以使用
vec = torch.arange(4).reshape(4,1)
vec = torch.arange(3).reshape(1,3)
就是说必须在行或者列上保持元素个数的一致。
数乘torch.mul
torch.mul(input, value, out=None) 复制代码
用标量值value
乘以输入input
的每个元素,并返回一个新的结果张量。 就是张量的数乘运算。
import torch vec = torch.arange(4) mtx = torch.arange(12).reshape(3,4) print(torch.mul(vec,2)) print(torch.mul(mtx,2)) 复制代码
>> tensor([0, 2, 4, 6]) tensor([[ 0, 2, 4, 6], [ 8, 10, 12, 14], [16, 18, 20, 22]]) 复制代码
矩阵向量相乘torch.mv
torch.mv(mat, vec, out=None) → Tensor 复制代码
对矩阵mat
和向量vec
进行相乘。 如果mat
是一个n×m张量,vec
是一个m元 1维张量,将会输出一个n 元 1维张量。
必须前边是矩阵后边是向量,维度要符合矩阵乘法。出来的是一维张量。
import torch vec = torch.arange(4) mtx = torch.arange(12).reshape(3,4) print(torch.mv(mtx,vec)) 复制代码
>> tensor([14, 38, 62]) 复制代码
矩阵乘法torch.mm
torch.mm(mat1, mat2, out=None) → Tensor 复制代码
对矩阵mat1
和mat2
进行相乘。 如果mat1
是一个n×m张量,mat2
是一个 m×p张量,将会输出一个 n×p张量out
。
就是我们线代中学的矩阵乘法,维度必须对应正确。
import torch mtx = torch.arange(12) m1 = mtx.reshape(3,4) m2 = mtx.reshape(4,3) print(torch.mm(m1, m2)) 复制代码
>> tensor([[ 42, 48, 54], [114, 136, 158], [186, 224, 262]]) 复制代码
点乘积torch.dot
torch.dot(tensor1, tensor2) → float 复制代码
计算两个张量的点乘积(内积),两个张量都为一维向量。
import torch vec = torch.arange(4) print(torch.dot(vec, vec)) 复制代码
>> tensor(14) 复制代码
黑科技@
还存在一个黑科技@
,也是严格按照第一个参数的列数要等于第二个参数的行数。
import torch vec = torch.arange(4) mtx = torch.arange(12) m1 = mtx.reshape(4,3) m2 = mtx.reshape(3,4) print(vec @ vec) print(vec @ m1) print(m2 @ vec) print(m1 @ m2) 复制代码
>> tensor(14) tensor([42, 48, 54]) tensor([14, 38, 62]) tensor([[ 20, 23, 26, 29], [ 56, 68, 80, 92], [ 92, 113, 134, 155], [128, 158, 188, 218]]) 复制代码
上边的结果可能不够直观,那看看下边:
import torch vec = torch.arange(4) mtx = torch.arange(12) m1 = mtx.reshape(4,3) m2 = mtx.reshape(3,4) print(vec @ vec==torch.dot(vec,vec)) print(vec @ m1) # 本句直接使用torch.mv()无法执行。 print(m2 @ vec==torch.mv(m2,vec)) print(m1 @ m2==torch.mm(m1,m2)) 复制代码
使用一个@就可以替代上边的那三个函数。
对一维张量执行@操作就是dot
对一维和二维张量执行操作就是mv
对二维张量执行@操作就是mm
>> tensor(True) tensor([42, 48, 54]) tensor([True, True, True]) tensor([[True, True, True, True], [True, True, True, True], [True, True, True, True], [True, True, True, True]]) 复制代码
第二个无法替换的怎么办?为了满足强迫症,可以这样:
import torch vec = torch.arange(4) mtx = torch.arange(12).reshape(4,3) print(vec @ mtx) # 本句直接使用torch.mv()无法执行。 print(torch.mm(vec.reshape(1,4),mtx)) print(vec @ mtx==torch.mm(vec.reshape(1,4),mtx)) 复制代码
>> tensor([42, 48, 54]) tensor([[42, 48, 54]]) tensor([[True, True, True]])
作者:LolitaAnn
链接:https://juejin.cn/post/7023668955109654559