I’m using (torch.nn.functional.softplus(x) + 1e-6) to compute the variance that is trained in a policy gradient network. It occasionally throws a negative value which crashes the Normal(mean,variance) distribution creation for the error function. When I try and detect it and then print out the offending input, and run the same function outside of the loop, I get a different (valid) output for the same input.
I’ve been fighting this for 3 days - even using different softplus approximations like a manual squareplus function, and even clamping gradients to small values, and cannot figure out what’s going on. I’m using Pytorch-Lightning for the training module. Any ideas here? What could be going on in the training loop that might corrupt the final model output?
Code:
class GradientPolicy(nn.Module):
def __init__(self, in_features, out_dims, hidden_size=128, xrange=[0.0,1.0],min_variance=1e-6):
super().__init__()
assert(xrange[0] < xrange[1])
assert(min_variance >= 1e-6)
self.min_var = nn.Parameter(torch.zeros(out_dims, requires_grad=False) + min_variance)
self.range_min = nn.Parameter(torch.zeros(out_dims, requires_grad=False) + xrange[0])
self.range_size = nn.Parameter(torch.zeros(out_dims, requires_grad=False) + (xrange[1] - xrange[0]))
self.fc1 = nn.Linear(in_features, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc_mu = nn.Linear(hidden_size, out_dims)
self.fc_var = nn.Linear(hidden_size, out_dims)
def forward(self, x):
x = torch.tensor(x).float().to(device)
x = F.mish(self.fc1(x))
x = F.mish(self.fc2(x))
mu = F.sigmoid(self.fc_mu(x)) * self.range_size + self.range_min
v = F.softplus(self.fc_var(x)) + self.min_var
if (v <= 0.).any().item():
print("softplus input",self.fc_var(x))
print("gp variance: ",v)
return mu, v
And here’s where it prints the detected bad value immediately before the crash:
softplus input tensor([[-3.7529]], device='cuda:0')
gp variance: tensor([[-0.0329]], device='cuda:0')
Of course, when you manually run that input through softplus:
print(F.softplus(torch.tensor(-3.7529)))
You get a positive output: tensor(0.0232)
WTF is going on here?