pytorch hook,pytorch基本操作

  pytorch hook,pytorch基本操作

  本文主要介绍pytorch中的钩子机制register_forward_hook。forward前手动注册钩子,forward执行后自动执行钩子。下面详细介绍,有需要的朋友可以参考一下。

  00-1010 1、hook后台2、源代码读取3、定义一个测试hook的类4、定义hook函数5、注册所需图层的hook6、测试forward()返回的特征是否与hook 6.1测试forward()提供的输入输出特征6.2 hook 6.3记录的输入输出特征减去Hook和forward 7、完整代码。

  

目录

  钩子被称为钩子机制,并不是pytorch首创,在Windows编程中已经被广泛采用,包括进程内钩子和全局钩子。按照我自己的理解,hook的作用是通过系统维护一个链表,让用户截取(获取)处理事件的通信消息。

  Pytorch包含了两个钩子注册函数,forward和backward,用于获得forward和backward中的输入和输出。根据我的不完全理解,目的应该是“不改变网络的定义代码,也不返回前向函数中某一层感兴趣的输出,所以代码太冗杂”。

  

1、hook背景

  register_forward_hook()函数必须在forward()函数被调用之前使用,因为这个函数的源代码注释显示这个函数“由于这是在3360 func3360 `forward `被调用之后才被调用,所以它对forward没有作用”,也就是说这个函数在forward()之后就没有函数了!):

  得到作用:前进过程中各层的输入输出,用来比较hook记录是否正确。

  def寄存器_forward_hook(self,hook):

  r 在模块上注册一个前向挂钩。

  每当:func:`forward 计算出一个输出后,就会调用这个钩子。

  它应该具有以下签名:

  钩子(模块,输入,输出)-无或修改输出

  钩子可以修改输出。它可以就地修改输入,但是

  它不会对forward产生影响,因为这是在之后调用的

  :func:`forward 被调用。

  返回:

  torch . utils . hooks . removable handle

  一个句柄,可用于通过调用移除添加的挂钩

  “handle.remove()”

  手柄=挂钩。RemovableHandle(self。_forward_hooks)

  自我。_forward_hooks[handle.id]=钩子

  返回手柄

  

2、源码阅读

  如果每一层都是随机初始化的,那么将无法检验自己获得的输入输出是否是forward中的输入输出,所以需要将每一层的权重和偏移量设置为可识别的值(比如全部初始化为1)。网络由两层组成(需要导出的线性参数称为一层,

  而ReLU没有需要求导的参数不被称作一层),__init__()中调用initialize函数对所有层进行初始化。

  

注意:在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。

  

  

class TestForHook(nn.Module):

      def __init__(self):

          super().__init__()

          self.linear_1 = nn.Linear(in_features=2, out_features=2)

          self.linear_2 = nn.Linear(in_features=2, out_features=1)

          self.relu = nn.ReLU()

          self.relu6 = nn.ReLU6()

          self.initialize()

      def forward(self, x):

          linear_1 = self.linear_1(x)

          linear_2 = self.linear_2(linear_1)

          relu = self.relu(linear_2)

          relu_6 = self.relu6(relu)

          layers_in = (x, linear_1, linear_2)

          layers_out = (linear_1, linear_2, relu)

          return relu_6, layers_in, layers_out

      def initialize(self):

          """ 定义特殊的初始化,用于验证是不是获取了权重"""

          self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))

          self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))

          self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))

          self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))

          return True

  

  

4、定义hook函数

  hook()函数是register_forward_hook()函数必须提供的参数,好处是用户可以自行决定拦截了中间信息之后要做什么!,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。

  首先定义几个容器用于记录:

  定义用于获取网络各层输入输出tensor的容器:

  

# 并定义module_name用于记录相应的module名字

  module_name = []

  features_in_hook = []

  features_out_hook = []

  hook函数需要三个参数,这三个参数是系统传给hook函数的,自己不能修改这三个参数:

  

  hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字

  

def hook(module, fea_in, fea_out):

      print("hooker working")

      module_name.append(module.__class__)

      features_in_hook.append(fea_in)

      features_out_hook.append(fea_out)

      return None

  

  

5、对需要的层注册hook

  注册钩子必须在forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):

  注册钩子可以对某些层单独进行:

  

net = TestForHook()

  net_chilren = net.children()

  for child in net_chilren:

      if not isinstance(child, nn.ReLU6):

          child.register_forward_hook(hook=hook)

  

  

6、测试forward()返回的特征和hook记录的是否一致

  

  

6.1 测试forward()提供的输入输出特征

  由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:

  

out, features_in_forward, features_out_forward = net(x)

  print("*"*5+"forward return features"+"*"*5)

  print(features_in_forward)

  print(features_out_forward)

  print("*"*5+"forward return features"+"*"*5)

  得到下面的输出是理所当然的:

  

*****forward return features*****
(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****

  

  

  

6.2 hook记录的输入特征和输出特征

  hook通过list结构进行记录,所以可以直接print

  测试features_in是不是存储了输入:

  

print("*"*5+"hook record features"+"*"*5)

  print(features_in_hook)

  print(features_out_hook)

  print(module_name)

  print("*"*5+"hook record features"+"*"*5)

  得到和forward一样的结果:

  

*****hook record features*****
[(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
[3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.linear.Linear'>,
<class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****

  

  

  

6.3 把hook记录的和forward做减法

  如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:

  测试forward返回的feautes_in是不是和hook记录的一致:

  

print("sub result")

  for forward_return, hook_record in zip(features_in_forward, features_in_hook):

      print(forward_return-hook_record[0])

  得到的全部都是0,说明hook没问题:

  

sub result

  tensor([[0., 0.],

          [0., 0.]])

  tensor([[0., 0.],

          [0., 0.]], grad_fn=<SubBackward0>)

  tensor([[0.],

          [0.]], grad_fn=<SubBackward0>)

  

  

7、完整代码

  

import torch

  import torch.nn as nn

  class TestForHook(nn.Module):

      def __init__(self):

          super().__init__()

          self.linear_1 = nn.Linear(in_features=2, out_features=2)

          self.linear_2 = nn.Linear(in_features=2, out_features=1)

          self.relu = nn.ReLU()

          self.relu6 = nn.ReLU6()

          self.initialize()

      def forward(self, x):

          linear_1 = self.linear_1(x)

          linear_2 = self.linear_2(linear_1)

          relu = self.relu(linear_2)

          relu_6 = self.relu6(relu)

          layers_in = (x, linear_1, linear_2)

          layers_out = (linear_1, linear_2, relu)

          return relu_6, layers_in, layers_out

      def initialize(self):

          """ 定义特殊的初始化,用于验证是不是获取了权重"""

          self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))

          self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))

          self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))

          self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))

          return True

  定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字

  

module_name = []

  features_in_hook = []

  features_out_hook = []

  hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字

  

def hook(module, fea_in, fea_out):

      print("hooker working")

      module_name.append(module.__class__)

      features_in_hook.append(fea_in)

      features_out_hook.append(fea_out)

      return None

  定义全部是1的输入:

  

x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])

  注册钩子可以对某些层单独进行:

  

net = TestForHook()

  net_chilren = net.children()

  for child in net_chilren:

      if not isinstance(child, nn.ReLU6):

          child.register_forward_hook(hook=hook)

  测试网络输出:

  

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

  

  测试features_in是不是存储了输入:

  

print("*"*5+"hook record features"+"*"*5)

  print(features_in_hook)

  print(features_out_hook)

  print(module_name)

  print("*"*5+"hook record features"+"*"*5)

  测试forward返回的feautes_in是不是和hook记录的一致:

  

print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
print(forward_return-hook_record[0])

  

  到此这篇关于pytorch中的hook机制register_forward_hook的文章就介绍到这了,更多相关pytorch中的hook机制内容请搜索盛行IT软件开发工作室以前的文章或继续浏览下面的相关文章希望大家以后多多支持盛行IT软件开发工作室!

郑重声明:本文由网友发布,不代表盛行IT的观点,版权归原作者所有,仅为传播更多信息之目的,如有侵权请联系,我们将第一时间修改或删除,多谢。

留言与评论(共有 条评论)
   
验证码: