Julia:Zygote 上自定义后向传播

Zygote 是 Julia 上一个实现自动微分、自动求导的包,其中 @adjoint 宏是 Zygote 接口的一个重要组成部分。使用 @adjoint 可以自定义函数的后向传播。

Pullbacks

要理解 @adjoint 首先要先理解更为底层的函数 pullbackgradient 实际上就是 pullback 的语法糖(syntactic sugar)。

julia> y, back = Zygote.pullback(sin, 0.5)
(0.479425538604203, Zygote.var"#41#42"{Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}}(Zygote.ZBack{ChainRules.var"#sin_pullback#1430"{Float64}}(ChainRules.var"#sin_pullback#1430"{Float64}(0.8775825618903728))))

julia> y
0.479425538604203

pullback 输入两个参数 sin0.5 分别代表要求导的函数和要求导的值,会得到两个输出:给定函数的结果 sin(0.5) 以及一个 pullback,也就是上面代码中的 back 变量。back 对函数 sin 进行梯度计算,~~接受的是一个派生,并且产生新的一个变量。~~从数学上讲,就是 vector-Jacobian 积的实现。其中 y=f(x)y=f(x) 和梯度 lx\frac{\partial{l}}{\partial{x}} 写为 xˉ\bar{x},pullback By\mathcal{B}_y 如下计算:

xˉ=lx=lyyx=By(yˉ)\bar{x}=\frac{\partial l}{\partial x}=\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}=\mathcal{B}_{y}(\bar{y})

更为具体的讲,以上面的代码为例子,函数 y=sin(x)y=\sin(x). yx=cos(x)\frac{\partial y}{\partial x}=\cos (x),所以 pullback 就为 yˉcos(x)\bar{y}\cos(x),其中 yˉ=ly\bar{y}=\frac{\partial l}{\partial y}。换句话说,pullback(sin, x)dsin(x) = (sin(x), ȳ -> (ȳ * cos(x),)) 等价。

gradient 中函数 l=f(x)l=f(x) 并且假设 lˉ=ll=1\bar{l}=\frac{\partial l}{\partial l}=1,并且将其输入到 pullback 中。在 sin 的例子中,

julia> dsin(x) = (sin, ȳ -> (ȳ * cos(x),))
dsin (generic function with 1 method)

julia> function gradsin(x)
           _, back = dsin(x)
           back(1)
       end
gradsin (generic function with 1 method)

julia> gradsin(0.5)
(0.8775825618903728,)

julia> cos(0.5)
0.8775825618903728
                
julia> back(1)
(0.8775825618903728,)

个人理解,为什么前面要加一项 ly\frac{\partial l}{\partial y},这是为了实现链式法则。比如假设最终的损失是 ll,函数 y(x)y(x),要得到损失函数 ll 对参数 xx 的微分 lx\frac{\partial l}{\partial x},根据链式法则就是损失函数对函数 yy 的微分乘以函数对参数 xx 的微分,即 lyyx\frac{\partial l}{\partial y} \frac{\partial y}{\partial x}。函数 yypullback 就是损失函数对函数 yy 的微分(用 yˉ\bar{y} 表示)乘以函数对 xx 的微分。

对于上面的例子,pullback 函数返回的第一个结果为:假设函数 y=sin(x)y=\sin(x) 就是损失函数 ll 时,x=0.5x=0.5 时的结果,即 cos(0.5)\cos(0.5),并且返回的 back 就是一个关于 ly\frac{\partial l}{\partial y} 的函数,可以看成是 B(ly)=lycos(0.5)\mathcal{B}(\frac{\partial l}{\partial y})=\frac{\partial l}{\partial y}\cos(0.5)

假如 l=0.5y=0.5sin(x)l=0.5y=0.5\sin(x),我们可以得到 ly=0.5\frac{\partial l}{\partial y}=0.5,那么 lx=B(ly)=B(0.5)\frac{\partial l}{\partial x}=\mathcal{B}(\frac{\partial l}{\partial y})=\mathcal{B}(0.5)


参考:

[1] Custom Adjoints • Zygote

智能之路 文章被收录于专栏

包括机器学习、神经网络、深度学习、强化学习各种方面的文章

全部评论

相关推荐

牛客101244697号:这个衣服和发型不去投偶像练习生?
点赞 评论 收藏
分享
评论
点赞
收藏
分享
牛客网
牛客企业服务