I’ve got a model I’m testing with PyTorch Lightning using different batch sizes and I’m observing different results. I found that making the following modification fixed the issue. Does anyone know why? To be frank I’m not even sure how the original formulation came about – it certainly seems wasteful; however, at first glance it doesn’t appear incorrect.
First the original:
x = data
y = y[:, -1:].squeeze(-1).float()
w = w[:, -1:].squeeze(-1)
y_hat = self(x).squeeze(-1)
if len(y.shape) == 1:
y = y.unsqueeze(0)
if len(w.shape) == 1:
w = w.unsqueeze(0)
if len(y_hat.shape) == 1:
y_hat = y_hat.unsqueeze(0)
return y.cpu(), y_hat.detach().cpu(), w.cpu(),
And the fix:
x = data
y = y[:, -1:].float()
w = w[:, -1:]
y_hat = self(x)
return y.cpu(), y_hat.detach().cpu(), w.cpu(),