After one batch, it is triggered RuntimeError: Function 'MinBackward1' returned nan values in its 0th output.
. I checked the loss and what is in input to the function in the forward pass and everything is not nan.
Here the forward function:
def forward(self, input):
input_vec = input.flatten(start_dim=1)
t_vec = input_vec @ self.A
input = (t_vec - t_vec.min()) * input_vec
input = input.view(-1, self.n_ticks, self.n_classes)
self.input = input / (input.sum(dim=-1, keepdim=True) + 1e-8)
return self.input
I checked the loss itself and what used in input = (t_vec - t_vec.min()) * input_vec
where there is this issue and self.A
. Nothing is assuming nan values.
Can somebody help with this issue?
Thanks,
S