Problems with weight array of FloatTensor type in loss function

I have mostly worked on keras with tf backend and sometimes dabbled with torch7. I was intrigued by the pytorch project and wanted to test it out. So, I was trying to run a simple model on a dataset where I loaded my features into a np.float64 array and the target labels into a np.float64 array. Now, PyTorch automatically converted them both to DoubleTensor and that seems okay to me. However, the loss function expects Double Tensors for the Weight and the Bias but apparently it is getting Float Tensors for both of them. I am not sure how to change my code to ensure that the loss function gets what it is expecting. I am putting my model definition below:

class ColorizerNet(nn.Module):

	def __init__(self):
		super(ColorizerNet, self).__init__()
		self.layer1 = nn.Conv2d(1, 8, 2, 2)
		self.layer2 = nn.Conv2d(8, 16, 2, 2)
		self.layer3 = nn.Conv2d(16, 8, 2, 2)
		self.layer4 = nn.Conv2d(8, 1, 2, 2)


	def forward(self, x):
		x = F.relu(self.layer1(x))
		x = F.relu(self.layer2(x))
		x = F.relu(self.layer3(x))
		x = F.relu(self.layer4(x))
		return x

Additionally, this is how I am training the model (I am basically using the code from the Deep Learning with PyTorch: a 60 minute blitz notebook in the tutorials repo:

for epoch in range(num_epochs): # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(data_loader, 0):
        # get the inputs
        inputs, labels = data
        
        # wrap them in Variable
        inputs, labels = Variable(inputs), Variable(labels)
        
        # zero the parameter gradients
        optimizer.zero_grad()
        
        # forward + backward + optimize
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()        
        optimizer.step()
        
        # print statistics
        running_loss += loss.data[0]

Here is the error I am getting:

TypeError: DoubleSpatialConvolutionMM_updateOutput 
received an invalid combination of arguments -  
got (int, torch.DoubleTensor, torch.DoubleTensor, torch.FloatTensor, torch.FloatTensor, torch.DoubleTensor, torch.DoubleTensor, int, int, int, int, int, int), 
but expected (int state, torch.DoubleTensor input, torch.DoubleTensor output, torch.DoubleTensor weight, [torch.DoubleTensor bias or None], torch.DoubleTensor finput, torch.DoubleTensor fgradInput, int kW, int kH, int dW, int dH, int padW, int padH)

If there is some obvious mistake I am making please let me know. I am happy to provide any other information needed to reproduce the error. Additionally, I have already looked at this note but since the issue is associated with the type of the target, I don’t think it is relevant to my issue as I cannot figure out what to change in order to affect the type of the Weight or Bias.

2 Likes

Your input and target tensors are DoubleTensors, but your model parameters are FloatTensors. You have to convert either your inputs or your parameters.

To convert your inputs to float (recommended):

inputs, labels = data
inputs = inputs.float()
labels = labels.float()
inputs, labels = Variable(inputs), Variable(labels)

To convert your model to double:

model = ColorizerNet()
model.double()

I recommend using floats instead of doubles. It’s the default tensor type in PyTorch. On GPUs, float calculations are much faster than double calculations.

7 Likes

Yes, I had heard about that. Thank you for pointing out the missing link. I think I will change the inputs to float. Just one follow-up question: why does pytorch convert numpy’s float64 to Double Tensors? If Float Tensors are the go-to type for the language, I would have thought the numpy to torch conversion would maintain that; instead it gets converted to Double Tensor. Could you shed some light on this?

Conversion from numpy to torch using the torch.from_numpy method keeps the original data type, while converting using torch.Tensor(my_np_array) uses the type of the default Tensor (you could also select the new type of the convertion using torch.IntTensor(my_np_array) for example.
Also, note that float64 is actually double.

4 Likes

Just keep in mind that creating a torch.FloatTensor out of Numpy’s float64 array will be very slow. It’s better to use torch.from_numpy(arr).float(). The .float() call will be a no-op if the array is already of float32 type.

8 Likes

Thank you guys! This helps a lot and clarifies a lot of doubts I had!

I think there is a little inconsistency, because

torch.FloatTensor(np.array([1, 2], dtype=np.float64)) # works
torch.FloatTensor(np.array([1, 2], dtype=np.float32)) #works
torch.IntTensor(np.array([1, 2], dtype=np.int64)) #fails
torch.IntTensor(np.array([1, 2], dtype=np.int32)) #works

The first statement should give an error, because float can not necessarily hold float64, is that right?