pytorch 实现 GRL Gradient Reversal Layer
在GRL中,要实现的目标是:在前向传导的时候,运算结果不变化,在梯度传导的时候,传递给前面的叶子节点的梯度变为原来的相反方向。举个例子最好说明了:
import torch from torch.autograd import Function x = torch.tensor([1.,2.,3.],requires_grad=True) y = torch.tensor([4.,5.,6.],requires_grad=True) z = torch.pow(x,2) + torch.pow(y,2) f = z + x + y s =6* f.sum() print(s) s.backward() print(x) print(x.grad)复制代码
这个程序的运行结果是:
tensor(672., grad_fn=<MulBackward0>) tensor([1., 2., 3.], requires_grad=True) tensor([18., 30., 42.])复制代码
这个运算过程对于tensor中的每个维度上的运算为:
那么对于x的导数为:
所以当输入x=[1,2,3]时,对应的梯度为:[18,30,42]
因此这个是正常的梯度求导过程,但是如何进行梯度翻转呢?很简单,看下方的代码:
import torch from torch.autograd import Function x = torch.tensor([1.,2.,3.],requires_grad=True) y = torch.tensor([4.,5.,6.],requires_grad=True) z = torch.pow(x,2) + torch.pow(y,2) f = z + x + y class GRL(Function): def forward(self,input): return input def backward(self,grad_output): grad_input = grad_output.neg() return grad_input Grl = GRL() s =6* f.sum() s = Grl(s) print(s) s.backward() print(x) print(x.grad)复制代码
运行结果为:
tensor(672., grad_fn=<GRL>) tensor([1., 2., 3.], requires_grad=True) tensor([-18., -30., -42.])复制代码
这个程序相对于上一个程序,只是差在加了一个梯度翻转层:
class GRL(Function): def forward(self,input): return input def backward(self,grad_output): grad_input = grad_output.neg() return grad_input复制代码
这个部分的forward没有进行任何操作,backward里面做了.neg()操作,相当于进行了梯度的翻转。在torch.autograd 中的FUnction 的backward部分,在不做任何操作的情况下,这里的grad_output的默认值是1.
作者:Tchunren
链接:https://juejin.cn/post/7031102368079069192