pytorch自定义损失函数,pytorch crossentropyloss损失函数

  pytorch自定义损失函数,pytorch crossentropyloss损失函数

  这篇文章主要介绍了框架自定义失败损失函数,自定义失败的方法有很多,本文要介绍的是把失败作为一个框架的模块,下面详细资料需要的小伙伴可以参考一下

  

目录
步骤1:添加自定义的类步骤2:修改使用的失败函数自定义失败的方法有很多,但是在博主查资料的时候发现有挺多写法会有问题,靠谱一点的方法是把失败作为一个框架的模块,

  比如:

  类海关损失(纳.模块): #注意继承nn .组件

  def __init__(self):

  超级(CustomLoss,self).__init__()

  向前定义(自身,x,y):

  # .这里写x与y的处理逻辑,即失败的计算方法

  回波损耗#注意最后只能返回张量值,且带梯度,即loss.requires_grad==True

  示例代码:

  以一个pytorch求解线性回归的代码为例:

  进口火炬

  将torch.nn作为神经网络导入

  将数组作为铭牌导入

  导入操作系统

  OS。KMP附近重复_LIB_OK]=TRUE

  def get_x_y():

  随机种子(0)

  x=np.random.randint(0,50,300)

  y值=2 * 21

  x=np.array(x,dtype=np.float32)

  y=np.array(y_values,dtype=np.float32)

  x=x .形状(-1,1)

  y=y形(-1,1)

  返回x,y

  类线性回归模型(神经网络).模块):

  def __init__(self,input_dim,output_dim):

  super(LinearRegressionModel,self).__init__()

  self.linear=nn .线性(输入尺寸,输出尺寸)#输入的个数,输出的个数

  定义向前(自身,x):

  out=自线性(十)

  退回去

  if __name__==__main__:

  输入尺寸=1

  output_dim=1

  x_train,y_train=get_x_y()

  model=LinearRegressionModel(input _ dim,output_dim)

  纪元=1000 #迭代次数

  优化器=火炬。optim。新币(型号。参数(),lr=0.001)

  模型_损耗=nn .ms loss()#使用均方误差(均方误差)作为失败

  # 开始训练模型

  对于范围内的纪元(纪元):

  纪元=1

  # 注意转行成张量

  输入=torch.from_numpy(x_train)

  标签=火炬。从_ numpy(y _ train)

  # 梯度要清零每一次迭代

  optimizer.zero_grad()

  # 前向传播

  输出:火炬。张量=模型(输入)

  # 计算损失

  损耗=模型损耗(输出,标签)

  # 返向传播

  loss.backward()

  # 更新权重参数

  optimizer.step()

  如果纪元% 50==0:

  打印(纪元{},损失{} 。format(epoch,loss.item()))

  

步骤

  1:添加自定义的类

  我们就用自定义的写法来写与MSE相同的效果,MSE计算公式如下:

  

  添加一个类:

  

class CustomLoss(nn.Module):

      def __init__(self):

          super(CustomLoss, self).__init__()

          self.mse_loss = nn.MSELoss()

      def forward(self, x, y):

          mse_loss = torch.mean(torch.pow((x - y), 2)) # x与y相减后平方,求均值即为MSE

          return mse_loss

  

  

步骤2:修改使用的loss函数

  只需要把原始代码中的:

  

model_loss = nn.MSELoss() # 使用MSE作为loss

  改为:

  

model_loss = CustomLoss()  # 自定义loss

  即可

  完整代码:

  

import torch

  import torch.nn as nn

  import numpy as np

  import os

  os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

  def get_x_y():

      np.random.seed(0)

      x = np.random.randint(0, 50, 300)

      y_values = 2 * x + 21

      x = np.array(x, dtype=np.float32)

      y = np.array(y_values, dtype=np.float32)

      x = x.reshape(-1, 1)

      y = y.reshape(-1, 1)

      return x, y

  class LinearRegressionModel(nn.Module):

      def __init__(self, input_dim, output_dim):

          super(LinearRegressionModel, self).__init__()

          self.linear = nn.Linear(input_dim, output_dim)  # 输入的个数,输出的个数

      def forward(self, x):

          out = self.linear(x)

          return out

  class CustomLoss(nn.Module):

      def __init__(self):

          super(CustomLoss, self).__init__()

          self.mse_loss = nn.MSELoss()

      def forward(self, x, y):

          mse_loss = torch.mean(torch.pow((x - y), 2))

          return mse_loss

  if __name__ == __main__:

      input_dim = 1

      output_dim = 1

      x_train, y_train = get_x_y()

      model = LinearRegressionModel(input_dim, output_dim)

      epochs = 1000  # 迭代次数

      optimizer = torch.optim.SGD(model.parameters(), lr=0.001)

      # model_loss = nn.MSELoss() # 使用MSE作为loss

      model_loss = CustomLoss()  # 自定义loss

      # 开始训练模型

      for epoch in range(epochs):

          epoch += 1

          # 注意转行成tensor

          inputs = torch.from_numpy(x_train)

          labels = torch.from_numpy(y_train)

          # 梯度要清零每一次迭代

          optimizer.zero_grad()

          # 前向传播

          outputs: torch.Tensor = model(inputs)

          # 计算损失

          loss = model_loss(outputs, labels)

          # 返向传播

          loss.backward()

          # 更新权重参数

          optimizer.step()

          if epoch % 50 == 0:

              print(epoch {}, loss {}.format(epoch, loss.item()))

  

  到此这篇关于pytorch自定义loss损失函数的文章就介绍到这了,更多相关pytorch loss损失函数内容请搜索盛行IT软件开发工作室以前的文章或继续浏览下面的相关文章希望大家以后多多支持盛行IT软件开发工作室!

郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。

留言与评论(共有 条评论)
   
验证码: