pytorch 乘法,pytorch 矩阵乘法
Pytorch用于训练,TensorRT用于推理是很多AI应用开发的标准。人们往往更熟悉pytorch的操作符,而不太熟悉TensorRT的操作符。本文介绍了Pytorch中常用的乘法的TensorRT实现,有兴趣的可以看看。
目录
1.乘法运算概述2。乘法运算符的实现2.1矩阵乘法运算符2.2点乘法运算符本文介绍Pytorch中常用乘法的TensorRT实现。Pytorch用于训练,TensorRT用于推理是很多AI应用开发的标准。人们往往更熟悉pytorch的操作符,而不太熟悉TensorRT的操作符。这里通过比较两个框架中常用的乘法运算,我们会有更直观的认识。
1.乘法运算总览
首先,概述pytorch中一些常用的乘法运算:
Torch.mm:用于两个矩阵(不包括向量)相乘,比如维数为(m,n)的矩阵乘以维数为(n,p)的矩阵;Torch.bmm:用于三维向量与batch的相乘,比如维数为(b,m,n)的矩阵乘以维数为(b,n,p)的矩阵;Torch.mul:用于同维矩阵的逐像素相乘,即点乘,如维(m,n)的矩阵与维(m,n)的矩阵的点乘。这种方法支持广播,即矩阵和元素点乘。Torch.mv:用于矩阵和向量相乘,矩阵在前,向量在后,如维数为(m,n)的矩阵乘以维数为(n)的向量,输出维数为(m);Torch.matmul:用于两个张量相乘,或者矩阵和向量相乘,函数有torch.mm,torch.bmm,torch.mv;@:效果相当于torch.matmul;*:效果相当于torch.mul;如上所述,可以得出结论,常用的乘法只有两种:矩阵乘法和点乘,所以分以下两类介绍。
2.乘法算子实现
2.1矩阵乘算子实现
我们先来看看pytorch对矩阵乘法的实现(以下实现在终端):
进口火炬
#火炬
a=torch.randn(66,99)
b=torch.randn(99,88)
c=torch.mm(a,b)
c .形状
火炬. size([66,88])
# torch.bmm
a=torch.randn(3,66,99)
b=torch.randn(3,99,77)
c=torch.bmm(a,b)
c .形状
torch.size([3,66,77])
# torch.mv
a=torch.randn(66,99)
b=torch.randn(99)
c=torch.mv(a,b)
c .形状
torch.size([66])
# torch.matmul
a=torch.randn(32,3,66,99)
b=torch.randn(32,3,99,55)
c=torch.matmul(a,b)
c .形状
torch.size([32,3,66,55])
# @
d=a @ b
d .形状
torch.size([32,3,66,55])
再看TensorRT的实现,上面的乘法可以用addMatrixMultiply方法覆盖,对应torch.matm
ul,先来看该方法的定义:
//!//! \brief Add a MatrixMultiply layer to the network.
//!
//! \param input0 The first input tensor (commonly A).
//! \param op0 The operation to apply to input0.
//! \param input1 The second input tensor (commonly B).
//! \param op1 The operation to apply to input1.
//!
//! \see IMatrixMultiplyLayer
//!
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new matrix multiply layer, or nullptr if it could not be created.
//!
IMatrixMultiplyLayer* addMatrixMultiply(
ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept
{
return mImpl->addMatrixMultiply(input0, op0, input1, op1);
}
可以看到这个方法有四个传参,对应两个张量和其operation
。来看这个算子在 TensorRT 中怎么添加:
// 构造张量 Tensor0nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);
// 构造张量 Tensor1
nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);
// 添加矩阵乘法
nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);
// 获取输出
matmulOutput = Matmul_layer->getOputput(0);
2.2点乘算子实现
再来看看点乘的 pytorch 的实现 (以下实现在终端):
>>> import torch>>> # torch.mul
>>> a = torch.randn(66, 99)
>>> b = torch.randn(66, 99)
>>> c = torch.mul(a, b)
>>> c.shape
torch.size([66, 99])
>>> d = 0.125
>>> e = torch.mul(a, d)
>>> e.shape
torch.size([66, 99])
>>> # *
>>> f = a * b
>>> f.shape
torch.size([66, 99])
来看 TensorRT 的实现,以上乘法都可使用addScale
方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:
//!//! \brief Add a Scale layer to the network.
//!
//! \param input The input tensor to the layer.
//! This tensor is required to have a minimum of 3 dimensions in implicit batch mode
//! and a minimum of 4 dimensions in explicit batch mode.
//! \param mode The scaling mode.
//! \param shift The shift value.
//! \param scale The scale value.
//! \param power The power value.
//!
//! If the weights are available, then the size of weights are dependent on the ScaleMode.
//! For ::kUNIFORM, the number of weights equals 1.
//! For ::kCHANNEL, the number of weights equals the channel dimension.
//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.
//!
//! \see addScaleNd
//! \see IScaleLayer
//! \warning Int32 tensors are not valid input tensors.
//!
//! \return The new Scale layer, or nullptr if it could not be created.
//!
IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept
{
return mImpl->addScale(input, mode, shift, scale, power);
}
可以看到有三个模式:
- kUNIFORM:weights 为一个值,对应张量乘一个元素;
- kCHANNEL:weights 维度和输入张量通道的 c 维度对应,可以做一些以通道为基准的预处理;
- kELEMENTWISE:weights 维度和输入张量的 c、h、w 对应,不考虑 batch,所以是输入的后三维;
再来看这个算子在 TensorRT 中怎么添加:
// 构造张量 inputnvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);
// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISE
scalemode = kUNIFORM;
// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量
nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };
nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };
// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行
// 添加张量乘法
nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);
// 获取输出
scaleOutput = Scale_layer->getOputput(0);
有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。
到此这篇关于Pytorch实现常用乘法算子TensorRT的示例代码的文章就介绍到这了,更多相关Pytorch乘法算子TensorRT内容请搜索盛行IT软件开发工作室以前的文章或继续浏览下面的相关文章希望大家以后多多支持盛行IT软件开发工作室!
郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。