Got TypeError when running the blitz tutorial

At the very last part of the 60 minutes blitz tutorial, we are doing the training on the GPU. I do it like this:

net.cuda()
for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data

        # wrap them in Variable
        inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.data[0]
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

Which results in the following error:

TypeError: add_ received an invalid combination of arguments - got (int, torch.cuda.FloatTensor), but expected one of:
 * (float value)
 * (torch.FloatTensor other)
 * (torch.SparseFloatTensor other)
 * (float value, torch.FloatTensor other)
      didn't match because some of the arguments have invalid types: (int, !torch.cuda.FloatTensor!)
 * (float value, torch.SparseFloatTensor other)
      didn't match because some of the arguments have invalid types: (int, !torch.cuda.FloatTensor!)

Anyone know how to resolve the error?

use_cuda = torch.cuda.is_available()
if use_cuda:
data, target = Variable(data.cuda(async=True, volatile=True)), Variable(target.cuda(async=True)) # On GPU
else:
data, target = Variable(data), Variable(target)
# You will get RuntimeError: expected CPU tensor (got CUDA tensor) if you dont do this

Do I need to import something for volatile ? I got this error:

TypeError: _cuda() got an unexpected keyword argument 'volatile'

Just remove it, no need for it.
See full example here:
https://github.com/QuantScientist/Deep-Learning-Boot-Camp/blob/master/day%2002%20PyTORCH%20and%20PyCUDA/PyTorch/21-PyTorch-CIFAR-10-Custom-data-loader-from-scratch.ipynb

Unfortunately still got the same error :frowning:

Here’s the updated code:

if torch.cuda.is_available():
    print("Using the GPU")
    net = Net().cuda() # On GPU

for epoch in range(2):  # loop over the dataset multiple times

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        # get the inputs
        inputs, labels = data
        
        # wrap them in Variable
        use_cuda = torch.cuda.is_available()
        if use_cuda:
            inputs, labels = Variable(inputs.cuda()), Variable(labels.cuda())
            print('yes cuda')
        else:
            inputs, labels = Variable(inputs), Variable(labels)

        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        
        if use_cuda:
            loss = criterion(outputs, labels).cuda()            
        else:
            loss = criterion(outputs, labels)
        #loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # print statistics
        running_loss += loss.data[0]
        if i % 2000 == 1999:    # print every 2000 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

I’m wondering if anyone can run the example code just fine? http://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html#training-on-gpu

Can you do this:

   print (type(data))
   print (type(target))                
   print ("Label:" + str(target))  

It should be torch.LongTensor of size BATCH_SIZE
and [torch.FloatTensor of size 4x3x32x32]

Which line throws the error?

Here’s the output to the code you requested

<class 'torch.autograd.variable.Variable'>
<class 'torch.autograd.variable.Variable'>
Label:Variable containing:
 5
 6
 6
 6
[torch.cuda.LongTensor of size 4 (GPU 0)]

The error is thrown on the optimizer.step() line, here’s the full error code:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-67-d9149eef4ecc> in <module>()
     30         #loss = criterion(outputs, labels)
     31         loss.backward()
---> 32         optimizer.step()
     33 
     34         # print statistics

~/anaconda2/envs/pytorch/lib/python3.5/site-packages/torch/optim/sgd.py in step(self, closure)
     90                     else:
     91                         buf = param_state['momentum_buffer']
---> 92                         buf.mul_(momentum).add_(1 - dampening, d_p)
     93                     if nesterov:
     94                         d_p = d_p.add(momentum, buf)

TypeError: add_ received an invalid combination of arguments - got (int, torch.cuda.FloatTensor), but expected one of:
 * (float value)
 * (torch.FloatTensor other)
 * (torch.SparseFloatTensor other)
 * (float value, torch.FloatTensor other)
      didn't match because some of the arguments have invalid types: (int, !torch.cuda.FloatTensor!)
 * (float value, torch.SparseFloatTensor other)
      didn't match because some of the arguments have invalid types: (int, !torch.cuda.FloatTensor!)

The output seems to be VARIABLE instead of a TENSOR (e.g. [torch.FloatTensor])

Please upload a full Jupyter / py file so that I can run on my computer.

I came across the same problem. This is because
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

So you need to define optimizer again after net.cuda(), i.e.
net.cuda()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

Then it will be solved.