pytorch中的hook机制是什么


本篇内容介绍了“pytorch中的hook机制是什么”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!Hook被成为钩子机制,这不是pytorch的首创,在Windows的编程中已经被普遍采用,包括进程内钩子和全局钩子。按照自己的理解,hook的作用是通过系统来维护一个链表,使得用户拦截(获取)通信消息,用于处理事件。pytorch中包含forwardbackward两个钩子注册函数,用于获取forward和backward中输入和输出,按照自己不全面的理解,应该目的是“不改变网络的定义代码,也不需要在forward函数中return某个感兴趣层的输出,这样代码太冗杂了”。register_forward_hook()函数必须在forward()函数调用之前被使用,因为这个函数源码注释显示这个函数“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是这个函数在forward()之后就没有作用了!!!):作用:获取forward过程中每层的输入和输出,用于对比hook是不是正确记录。如果随机的初始化每个层,那么就无法测试出自己获取的输入输出是不是forward中的输入输出了,所以需要将每一层的权重和偏置设置为可识别的值(比如全部初始化为1)。网络包含两层(Linear有需要求导的参数被称为一个层,而ReLU没有需要求导的参数不被称作一层),__init__()中调用initialize函数对所有层进行初始化。注意:在forward()函数返回各个层的输出,但是ReLU6没有返回,因为后续测试的时候不对这一层进行注册hook。hook()函数是register_forward_hook()函数必须提供的参数,好处是“用户可以自行决定拦截了中间信息之后要做什么!”,比如自己想单纯的记录网络的输入输出(也可以进行修改等更加复杂的操作)。首先定义几个容器用于记录:定义用于获取网络各层输入输出tensor的容器:hook函数负责将获取的输入输出添加到feature列表中;并提供相应的module名字注册钩子必须在免费云主机域名forward()函数被执行之前,也就是定义网络进行计算之前就要注册,下面的代码对网络除去ReLU6以外的层都进行了注册(也可以选定某些层进行注册):注册钩子可以对某些层单独进行:由于前面的forward()函数返回了需要记录的特征,这里可以直接测试:得到下面的输出是理所当然的:*****forward return features*****
(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=), tensor([[3.4000],
[3.4000]], grad_fn=))
(tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=), tensor([[3.4000],
[3.4000]], grad_fn=), tensor([[3.4000],
[3.4000]], grad_fn=))
*****forward return features*****
hook通过list结构进行记录,所以可以直接print测试features_in是不是存储了输入:得到和forward一样的结果:*****hook record features*****
[(tensor([[0.1000, 0.1000],
[0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=),), (tensor([[3.4000],
[3.4000]], grad_fn=),)]
[tensor([[1.2000, 1.2000],
[1.2000, 1.2000]], grad_fn=), tensor([[3.4000],
[3.4000]], grad_fn=), tensor([[3.4000],
[3.4000]], grad_fn=)]
[,
,
]
*****hook record features*****
如果害怕会有小数点后面的数值不一致,或者数据类型的不匹配,可以对hook记录的特征和forward记录的特征做减法:测试forward返回的feautes_in是不是和hook记录的一致:得到的全部都是0,说明hook没问题:定义用于获取网络各层输入输出tensor的容器,并定义module_name用于记录相应的module名字hook函数负责将获取的输入输出添加到feature列表中,并提供相应的module名字定义全部是1的输入:注册钩子可以对某些层单独进行:测试网络输出:out, features_in_forward, features_out_forward = net(x)
print(“*”*5+”forward return features”+”*”*5)
print(features_in_forward)
print(features_out_forward)
print(“*”*5+”forward return features”+”*”*5)测试features_in是不是存储了输入:测试forward返回的feautes_in是不是和hook记录的一致:print(“sub result”)
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
print(forward_return-hook_record[0])“pytorch中的hook机制是什么”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注百云主机网站,小编将为大家输出更多高质量的实用文章!

相关推荐: HTML5中如何用路径描画线条

这篇文章主要介绍“HTML5中如何用路径描画线条”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“HTML5中如何用路径描画线条”文章能帮助大家解决问题。 对于HTML5 Canvas,我们可以使用“路径”来描画任何图形。…

免责声明:本站发布的图片视频文字,以转载和分享为主,文章观点不代表本站立场,本站不承担相关法律责任;如果涉及侵权请联系邮箱:360163164@qq.com举报,并提供相关证据,经查实将立刻删除涉嫌侵权内容。

(0)
打赏 微信扫一扫 微信扫一扫
上一篇 01/09 12:12
下一篇 01/09 12:42

相关推荐