Hi everyone,
I am trying to implement a function that takes a 4-D Tensor
as input and raises its absolute value to a power alpha
preserving the sign. The trick is that alpha
is a trainable parameter but when the gradient is computed w.r.t. alpha
, I get nan
which is probably due to overflow.
I have tried using autograd.Function
and nn.Module
but no luck in fixing the issue so far. Here is my code with autograd.Function
:
a = torch.tensor(2., requires_grad=True)
class Power(torch.autograd.Function):
@staticmethod
def forward(ctx, x, alpha):
result = x.sign() * torch.abs(x) ** alpha
ctx.save_for_backward(result)
return result
@staticmethod
def backward(ctx, grad_output):
result, = ctx.saved_tensors
return result*torch.log(result.abs()+1e-6), None
power = Power.apply
x = torch.randn(N, D_in, device='cpu', dtype=torch.float)
out = torch.sum(power(x, a))
out.backward()
out.grad # none here
and using nn.Module
:
class Power(nn.Module):
def __init__(self, alpha=2.):
super(Power, self).__init__()
self.alpha = nn.Parameter(torch.tensor(alpha))
def forward(self, x):
return x.sign()*torch.abs(x)**self.alpha
Is there any way to properly fix the function to avoid overflow / nan
's? Thanks in advance.