I had a look at your code and it seems your error calculation does overflow.
In this method you are calculating the error:
def error_criterion(outputs,labels):
max_vals, max_indices = torch.max(outputs,1)
train_error = (max_indices != labels).sum().data[0]/max_indices.size()[0]
return train_error
The comparision (max_indices != labels)
returns a torch.ByteTensor
, which can overflow using your batch size of 10000
.
Adding a .float
to this line (max_indices != labels).float().sum()...
will give a train error of ~0.622
and a test error of ~0.640
.
Did you not get an error, since I got a RuntimeError
when trying to run your code:
RuntimeError: value cannot be converted to type uint8_t without overflow: 8821