Synchronization slow down caused by .item() which is not caused by .data[0]

I noticed that there is a weird slow down of the training phase when I accumulate the losses using .item() instead of .data[0] (note I am testing this code on google colab GPU). The network is a relatively simple CNN:

import torch
import time
from torch.autograd import Variable
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim


#these are the per channel mean and standart deviation of
#CIFAR10 image database. We will use these to normalize each
#channel to unit deviation with mean 0.

mean_CIFAR10=np.array([0.49139968, 0.48215841, 0.44653091])
std_CIFAR10=np.array([0.49139968, 0.48215841, 0.44653091])

#this transformation is used to transform the images to 0 mean and 1 std.
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean_CIFAR10 , std_CIFAR10)])

#load the CIFAR10 training and test sets
training_set_CIFAR10 = datasets.CIFAR10(root = 'cifar10/',
                                  transform = transform,
                                  train = True,
                                  download = True)


test_set_CIFAR10 = datasets.CIFAR10(root = 'cifar10/',
                                  transform = transform,
                                  train = False,
                                  download = True)

print('Number of training examples:', len(training_set_CIFAR10))
print('Number of test examples:', len(test_set_CIFAR10))

#there are ten classes in the CIFAR10 database
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#DataLoaders are used to iterate over the database images in batches rather
#one by one using for loops which is expensive in python since it is interpreted
training_loader_CIFAR10 = torch.utils.data.DataLoader(dataset=training_set_CIFAR10,
                                              batch_size=512,
                                              shuffle=True)

test_loader_CIFAR10 = torch.utils.data.DataLoader(dataset=test_set_CIFAR10,
                                            batch_size=512,
                                            shuffle=False)

#this function is used to test the accuracy of the model     
#over the test set. The network cnn is defined later on in the code.
def test():
    print('Started evaluating test accuracy...')
    cnn.eval()
    #calculate the accuracy of our model over the whole test set in batches
    correct = 0
    for x, y in test_loader_CIFAR10:
        x, y = Variable(x).cuda(), y.cuda()
        h = cnn.forward(x)
        pred = h.data.max(1)[1]
        correct += pred.eq(y).sum()
    return correct/len(test_set_CIFAR10)



#Below we define the convolutional network class.
#See the beginning of the document for the architecture

class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        
        #define the feature extraction layers
        self.conv1 = torch.nn.Conv2d(3,16,kernel_size=3,stride=1,padding=1)   
        self.pool1 = nn.MaxPool2d(2, stride=2)
        
        self.conv2 = torch.nn.Conv2d(16,32,kernel_size=3,stride=1,padding=1)
        self.pool2 = nn.MaxPool2d(2, stride=2)

        self.conv3 = torch.nn.Conv2d(32,64,kernel_size=3,stride=1,padding=1)
        self.pool3 = nn.MaxPool2d(2, stride=2)
        
        self.conv4 = torch.nn.Conv2d(64,128,kernel_size=3,stride=1,padding=1)
        self.pool4 = nn.MaxPool2d(2, stride=2)
        
        #define the categorization layers
        self.full1=nn.Linear(512,512)        
        self.full2=nn.Linear(512, 256)        
        self.full3=nn.Linear(256,10)
        

       
    #define the forward run for the input data x    
    def forward(self, x):
    
        #convolutional feature extraction layers
        x = F.relu(self.conv1(x))   
        x = self.pool1(x)
        x = F.relu(self.conv2(x))   
        x = self.pool2(x)
        x = F.relu(self.conv3(x))   
        x = self.pool3(x)
        x = F.relu(self.conv4(x))   
        x = self.pool4(x)
        
        
        #learning layers
        x = x.view(-1,512)
        x = F.relu(self.full1(x))
        x = self.full2(x)  #no relu here since we use crossentropyloss
      
        
        return x
        

#this is the training function. cnn is the network that is defined later
#optimizer and learning rate lr are modified inside the function

def train(cycles,cost_criterion,cnn,optimizer):
    
    average_cost=0 #cost function for the training
    acc=0 #accuracy over the test set

    
    
    for e in range(cycles): #cycle through the database many times

        print('Cycle: ',e)
        cnn.train()
        loadt=0
        cudat=0
        forwardt=0
        costt=0
        stept=0
        avcostt=0
         
        #following for loop cycles over the training set in batches
        #of batch_number=5 using the training_loader object
        
        s1 = time.clock() 
        t1 = time.clock()
        for i, (x, y) in enumerate(training_loader_CIFAR10 ,0):
            s2 = time.clock() 
            loadt=loadt+s2-s1
            #here x,y will store data from the training set in batches 
            x, y = Variable(x).cuda(), Variable(y).cuda()
            
            s3 = time.clock() 
            cudat=cudat+s3-s2

            h = cnn.forward(x) #calculate hypothesis over the batch
            
            s4 = time.clock() 
            forwardt=forwardt+s4-s3
            
            cost = cost_criterion(h, y) #calculate cost the cost of the results
            #print(type(cost))
            s5 = time.clock() 
            costt=costt+s5-s4
            
            optimizer.zero_grad() #set the gradients to 0
            cost.backward() # calculate derivatives wrt parameters
            optimizer.step() #update parameters
            
            s6 = time.clock() 
            stept=stept+s6-s5

            average_cost+=cost.item(); #add the cost to the costs
            
            s1 = time.clock() 
            avcostt=avcostt+s1-s6
            
        t2 = time.clock()  
        
        print('total time %.2f loading time %.2f, cuda transfer time %.2f, forward time: %.2f, cost time %.2f, step time %.2f, average cost time %.2f'%(t2-t1,loadt,cudat,forwardt,costt,stept,avcostt))           
        average_cost=0
      


cycles = 50 #number of cycles that the training runs over the database
cost_criterion = torch.nn.CrossEntropyLoss() #cost function
cnn = ConvNet().cuda() #build the initial network (in the GPU)
optimizer=optim.Adam(cnn.parameters(), lr= 0.0001)

train(cycles,cost_criterion,cnn,optimizer)
torch.save(cnn.state_dict(), 'cnn_trained')
   

It happens when I try to accumulate losses by

average_cost+=cost.item()

in Pytorch 0.4 (The full code is at the end of the message). The timing is as follows

Cycle:  0
total time 16.31 loading time 10.85, cuda transfer time 0.11, forward time: 0.37, cost time 0.02, step time 0.80, average cost time 4.17
Cycle:  1
total time 16.32 loading time 10.84, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.80, average cost time 4.18
Cycle:  2
total time 16.32 loading time 10.84, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.80, average cost time 4.19
Cycle:  3

where total time is the time it takes the network to optimize through the whole dataset once and average cost time is the time it takes for the operation I mentioned above. If I use .data[0] instead I get

Cycle 0
total time 12.11 loading time 10.80, cuda transfer time 0.11, forward time: 0.38, cost time 0.02, step time 0.80, average cost time 0.01
Cycle:  1
total time 12.05 loading time 10.75, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.80, average cost time 0.01

Am I making a mistake elsewhere that affects this operation?

Even weirder is the following. I considered the same code with a more complicated network (residual network). It has the same behaviour but something funny happens, when I replace .item() with .data[0] the time for accumulating the cost decreases but the time for transfering the tensors to CUDA increases? The code is below


import torch
import time
from torch.autograd import Variable
import torchvision
from torchvision import transforms, datasets
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import torch.nn as nn
import torch.optim as optim


#these are the per channel mean and standart deviation of
#CIFAR10 image database. We will use these to normalize each
#channel to unit deviation with mean 0.

mean_CIFAR10=np.array([0.49139968, 0.48215841, 0.44653091])
std_CIFAR10=np.array([0.49139968, 0.48215841, 0.44653091])


#this transformation is used to transform the images to 0 mean and 1 std.
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize(mean_CIFAR10 , std_CIFAR10)])

#load the CIFAR10 training and test sets
training_set_CIFAR10 = datasets.CIFAR10(root = 'cifar10/',
                                  transform = transform,
                                  train = True,
                                  download = True)


test_set_CIFAR10 = datasets.CIFAR10(root = 'cifar10/',
                                  transform = transform,
                                  train = False,
                                  download = True)

print('Number of training examples:', len(training_set_CIFAR10))
print('Number of test examples:', len(test_set_CIFAR10))

#there are ten classes in the CIFAR10 database
classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

#DataLoaders are used to iterate over the database images in batches rather
#one by one using for loops which is expensive in python since it is interpreted
training_loader_CIFAR10 = torch.utils.data.DataLoader(dataset=training_set_CIFAR10,
                                              batch_size=512,
                                              shuffle=True)

test_loader_CIFAR10 = torch.utils.data.DataLoader(dataset=test_set_CIFAR10,
                                            batch_size=512,
                                            shuffle=False)

#this function is used to test the accuracy of the model     
#over the test set. The network cnn is defined later on in the code.
def test():
    print('Started evaluating test accuracy...')
    cnn.eval()
    #calculate the accuracy of our model over the whole test set in batches
    correct = 0
    for x, y in test_loader_CIFAR10:
        x, y = Variable(x).cuda(), y.cuda()
        h = cnn.forward(x)
        pred = h.data.max(1)[1]
        correct += pred.eq(y).sum()
    return correct/len(test_set_CIFAR10)


 
#These are the two types of the basic blocks in a residual network. The residual network
#in this code is built by concatenating several such blocks together.
#Basic blocks are of the form x -> D(x) + F(x), where D(x) is x downsampled
#to the same dimensions as F(x) by a single convolution and F(x) is collection of 
#successive operations involving several convolutions and batchnorms.
class BasicResBlock1(nn.Module):
    def __init__(self, input, output, downsample, stride=1):
       super(BasicResBlock1, self).__init__()
       
       self.conv1 = torch.nn.Conv2d(input,output,kernel_size=3,stride=stride,padding=1, bias=False)
       self.batchNorm1 = torch.nn.BatchNorm2d(output)
       self.conv2 = torch.nn.Conv2d(output,output,kernel_size=3,padding=1, stride=1, bias=False)
       self.downsample=downsample
       
       #applied to the residual to downsample
       
      
       
        
    def forward(self,x1):       
      
       residual = self.downsample(x1)
     
  
       x2 = self.conv1(x1)
       x2 = self.batchNorm1(x2)
       x2 = F.relu(x2,inplace=True) 
       x2 = self.conv2(x2)
       
      
       x2+= residual

      
     
       return x2
       
class BasicResBlock2(nn.Module):
    def __init__(self, input, output):
       super(BasicResBlock2, self).__init__()
       
       self.conv1 = torch.nn.Conv2d(input,output,kernel_size=3,stride=1,padding=1, bias=False)
       self.batchNorm1 = torch.nn.BatchNorm2d(input)
       self.conv2 = torch.nn.Conv2d(output,output,kernel_size=3,padding=1, stride=1, bias=False)   
       self.batchNorm2 = torch.nn.BatchNorm2d(output) 
       self.batchNorm3 = torch.nn.BatchNorm2d(output) 
        
    def forward(self,x1):       
       
        
       residual = x1
        
       
       x2 = self.batchNorm1(x1)
       x2 = F.relu(x2,inplace=True)  
       x2 = self.conv1(x1);
        
       x2 = self.batchNorm2(x2)
       x2 = F.relu(x2,inplace=True)  
       x2 = self.conv2(x2)
       

       x2+= residual
        
       x2 = self.batchNorm3(x2)  
       x2 = F.relu(x2, inplace=True)
      
     
       return x2       
  

#Below we define the residual network class
class ResNet(nn.Module):
    def __init__(self,width, number_of_blocks):
        super(ResNet, self).__init__()
        
        #these are the inital layers applied before basic blocks
        
        self.conv1 = torch.nn.Conv2d(3,width,kernel_size=3,stride=1,padding=1, bias=False)         
        self.batchNorm1 = torch.nn.BatchNorm2d(width) 
        self.relu1 = nn.ReLU(inplace=True)
        

        #resLayer1 is the basic block for the residual network that is formed by
        #concatenating several basic blocks of increasing dimensions together.
        self.downsample1=torch.nn.Conv2d(width,2*width,kernel_size=1,stride=1,bias=False)  
        self.downsample2=torch.nn.Conv2d(2*width,4*width,kernel_size=1,stride=2,bias=False)
        self.downsample3=torch.nn.Conv2d(4*width,8*width,kernel_size=1,stride=2,bias=False)
        
        self.resLayer1=[]
        self.resLayer1.append(BasicResBlock1(width,2*width,self.downsample1,1))
        for x in range (0, number_of_blocks[0]) :      #stage1
            self.resLayer1.append(BasicResBlock2(2*width,2*width))
        self.resLayer1=nn.Sequential(*self.resLayer1)
        
        self.resLayer2=[]
        self.resLayer2.append(BasicResBlock1(2*width,4*width,self.downsample2,2)) #stage2
        for x in range (0, number_of_blocks[1]) :
            self.resLayer2.append(BasicResBlock2(4*width,4*width))
        self.resLayer2=nn.Sequential(*self.resLayer2)
        
        self.resLayer3=[]
        self.resLayer3.append(BasicResBlock1(4*width,8*width,self.downsample3,2)) #stage3
        for x in range (0, number_of_blocks[2]) :
            self.resLayer3.append(BasicResBlock2(8*width,8*width))
        self.resLayer3=nn.Sequential(*self.resLayer3)   

        
        self.avgpool1 = torch.nn.AvgPool2d(8,stride=1)
        
        #define the final linear classifier layer
        self.full1=nn.Linear(8*width,10)
        
        #weight initializations
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                torch.nn.init.kaiming_normal(m.weight, mode='fan_out')
                
            elif isinstance(m, nn.BatchNorm2d):
                torch.nn.init.constant(m.weight, 1)
                torch.nn.init.constant(m.bias, 0)

            elif isinstance(m, nn.Linear):   
                torch.nn.init.kaiming_normal(m.weight, mode='fan_out')
                torch.nn.init.constant(m.bias, 0)
        
    #define the forward run for the input data x    
    def forward(self, x):
    
        #initial layers before basic blocks    
        x = self.conv1(x)  
        x = self.batchNorm1(x)
        x = self.relu1(x)
  
        
        
        #residual layers and then average pooling
        x = self.resLayer1(x);
        x = self.resLayer2(x);
        x = self.resLayer3(x);
        #x = self.resLayer4(x);
     
        x = self.avgpool1(x)
     
    
        
        #linear classifier layer (since we
        #use CrossEntropyLoss for the loss function
        #which already has logsoftmax incorporated inside
        #we dont have any activation function here.)
        x = x.view(x.size(0), -1)
    
        
        x = self.full1(x)
        return x 

#this is the training function. cnn is the network that is defined later
#optimizer and learning rate lr are modified inside the function

def train(cycles,cost_criterion,cnn,optimizer):
    
    average_cost=0 #cost function for the training
    acc=0 #accuracy over the test set

    
    
    for e in range(cycles): #cycle through the database many times

        print('Cycle: ',e)
        cnn.train()
        loadt=0
        cudat=0
        forwardt=0
        costt=0
        stept=0
        avcostt=0
         
        #following for loop cycles over the training set in batches
        #of batch_number=5 using the training_loader object
        
        s1 = time.clock() 
        t1 = time.clock()
        for i, (x, y) in enumerate(training_loader_CIFAR10 ,0):
            s2 = time.clock() 
            loadt=loadt+s2-s1
            #here x,y will store data from the training set in batches 
            x, y = Variable(x).cuda(), Variable(y).cuda()
            
            s3 = time.clock() 
            cudat=cudat+s3-s2

            h = cnn.forward(x) #calculate hypothesis over the batch
            
            s4 = time.clock() 
            forwardt=forwardt+s4-s3
            
            cost = cost_criterion(h, y) #calculate cost the cost of the results
            #print(type(cost))
            s5 = time.clock() 
            costt=costt+s5-s4
            
            optimizer.zero_grad() #set the gradients to 0
            cost.backward() # calculate derivatives wrt parameters
            optimizer.step() #update parameters
            
            s6 = time.clock() 
            stept=stept+s6-s5

            average_cost+=cost.data[0]; #add the cost to the costs
            
            s1 = time.clock() 
            avcostt=avcostt+s1-s6
            
        t2 = time.clock()  
        
        print('total time %.2f loading time %.2f, cuda transfer time %.2f, forward time: %.2f, cost time %.2f, step time %.2f, average cost time %.2f'%(t2-t1,loadt,cudat,forwardt,costt,stept,avcostt))           
        average_cost=0
      


cycles = 50 #number of cycles that the training runs over the database
cost_criterion = torch.nn.CrossEntropyLoss() #cost function
cnn = ResNet(16,[1, 1, 1]).cuda() #build the initial network (in the GPU)
optimizer=optim.Adam(cnn.parameters(), lr= 0.0001)

train(cycles,cost_criterion,cnn,optimizer)
torch.save(cnn.state_dict(), 'cnn_trained')
   

In this case if I use .item(0) I get

Cycle:  1
total time 51.80 loading time 10.91, cuda transfer time 0.10, forward time: 9.27, cost time 0.02, step time 2.63, average cost time 28.87

where as if use. data[0] I get

Cycle:  1
total time 41.51 loading time 10.99, cuda transfer time 18.51, forward time: 9.34, cost time 0.02, step time 2.65, average cost time 0.01

I am completely confused. I checked the type of the cost thinking maybe it is also a GPU tensor and somehow I mess somethings up but it is in the CPU.

Thats kind of a lot of code to read through. eg hard to confirm visually things like it is in the CPU. Question: are the costs you are getting similar? Or is eg .data[0] case giving weird, uninitialized values? (or, eg, always the same value?) what happens if you print both, and try printing first .data[0] and then .item(), and also try the other way around?

I think your timing might give weird results, because your synchronization points are different in both implementations.
Calling .item() on a tensor gives you a standard Python number, which is pushed to the CPU.
This line of code would add a synchronization to wait for the GPU to finish calculating.
I’m not completely sure, if that’s also the case calling .data[0] without e.g. a print statement or if the operation can run asynchronously.

If you would like to time certain operations, you should call torch.cuda.synchronize() before stopping the timer.
This could make your overall script slower, but will yield valid timing results.

1 Like

I see so basically the code can pass to the next line of calculation before GPU finishes calculating and I should put sync where ever I either copy something to the GPU or do some calculations in the GPU. I will update the code as you suggested and try again.

Okay I will also try this along with the suggestion by ptrblck.

Yes, the GPU operation can be performed in the background while your python script continues its execution.
Once you get to a point where you push your GPU op result to CPU or print it, the script has to wait for the GPU so a synch point will be added automatically.
Timing is therefore a bit complicated, because it’s often not showing the true GPU op times.

So here is what happens when I sync (with the residual network)

total time 52.23 loading time 11.10, cuda transfer time 0.10, forward time: 12.89, cost time 0.03, step time 28.09, average cost time with data 0.02, average cost time with item 0.01

whereas previously it was (with .item())
total time 51.80 loading time 10.91, cuda transfer time 0.10, forward time: 9.27, cost time 0.02, step time 2.63, average cost time 28.87

and with .data[0]

total time 41.51 loading time 10.99, cuda transfer time 18.51, forward time: 9.34, cost time 0.02, step time 2.65, average cost time 0.01

So now there is no difference between data[0] and item(0) and step time (time it takes for .backward + .step operations) increases. Now most of the time is spend on backward and forward calculations as expected.

I assume when one uses item instead of data, some of the sync happens in the next batch of data which seems to affect the gpu loading time (presumably because this is where sync happens when one uses data instead of item). This resolves the question mostly thanks. Still not sure why .data[0] decreases the overall time spent though. going to do somemore tests to see whether if it really decreases or my timers miss a sync step somewhere.

This is why I’d like to see a comparison of the results of both operations :slight_smile:

I am doing those right now will be back in a minute

I tried what you suggest with the simpler convolutional network since that was where the total time difference appeared. Both data[0] and item() accumulate the losses correctly no matter what order they are in.

This is when .data[0] is before .item()

Cycle:0
total time 16.52 loading time 11.07, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.01
cost with data[0] is 3.51052 cost with .item() is 3.51052

Cycle:  1
total time 16.56 loading time 11.10, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.01
cost with data[0] is 2.21125 cost with .item() is 2.21125

Cycle:  2
total time 16.61 loading time 11.17, cuda transfer time 0.11, forward time: 1.39, cost time 0.02, step time 3.89, average cost time with .data[0] 0.01, average cost time with .item() 0.01
cost with data[0] is 2.01861 cost with .item() is 2.01861

This is when .item[0] is before .data[0]

Cycle:0
total time 16.52 loading time 11.08, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.89, average cost time with .data[0] 0.01, average cost time with .item() 0.01
cost with data[0] is 3.51901 cost with .item() is 3.51901

Cycle:  1
total time 16.51 loading time 11.06, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.01
cost with data[0] is 2.23629 cost with .item() is 2.23629

Cycle: 3 
total time 16.47 loading time 11.03, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.01
cost with data[0] is 2.05838 cost with .item() is 2.05838

This is when I only use .data[0]

Cycle:  0
total time 16.55 loading time 11.10, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.90, average cost time with data 0.01
cost with data[0] is 3.47288

Cycle:  1
total time 16.54 loading time 11.10, cuda transfer time 0.11, forward time: 1.40, cost time 0.02, step time 3.90, average cost time with data 0.01
cost with data[0] is 2.18770

Note that in the previous attempts when I did not use sync with only .data[0] it was

Cycle 0
total time 12.11 loading time 10.80, cuda transfer time 0.11, forward time: 0.38, cost time 0.02, step time 0.80, average cost time 0.01
Cycle:  1
total time 12.05 loading time 10.75, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.80, average cost time 0.01

Does that mean when you use .data[0] there is some synchronization that actually never happens unless you specifically ask for manual sync. This behaviour only happens in smaller networks so this might be also related to the amount of mem available on GPU

for the first two sections, are those both using sync, or both not using sync?

These are all using sync after each operation involving GPU

Can you rerun the first two sections with no syncs please?

(edit: because the concern I have is that maybe you are getting the result earlier in the .data case because it’s not waiting for the data to arrive first, and just giving you some garbage result :stuck_out_tongue: )

The results are below. It seems almost if .item() causes for some synchronization that is not needed when .data[0] is used. Most of the difference is definitely related to the synchronization of calculations that happen at the update step.

When I put sync before .data[0] the total time is about 12 whereas if I put it at the end of the cycle it is 16. If I remove all syncs inside the for loop over the database but put it after the whole loop just before when the final measurement is made the time is again 12. What does this suggest? That there is some unecessary sync done by the .item() that is not done by .data[0] ?

This is when I remove sync everywhere .item() is before .data[0]

Cycle:  0
total time 16.41 loading time 10.95, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.02, average cost time with .item() 4.19
cost with data[0] is 3.33067 cost with .item() is 3.33067

Cycle:  1
total time 16.56 loading time 11.09, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.79, average cost time with .data[0] 0.01, average cost time with .item() 4.19
cost with data[0] is 2.20410 cost with .item() is 2.20410

Cycle:  2
total time 16.61 loading time 11.12, cuda transfer time 0.11, forward time: 0.37, cost time 0.02, step time 0.80, average cost time with .data[0] 0.01, average cost time with .item() 4.18
cost with data[0] is 2.05243 cost with .item() is 2.05243

And when .data[0] is before .item()

Cycle:  0
total time 16.26 loading time 10.79, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 4.18
cost with data[0] is 3.50083 cost with .item() is 3.50083

Cycle:  1
total time 16.31 loading time 10.85, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 4.18
cost with data[0] is 2.17939 cost with .item() is 2.17939

Cycle:  2
total time 16.19 loading time 10.73, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 4.19
cost with data[0] is 2.01196 cost with .item() is 2.01196

When I comment out .item() (i,e only .data[0])

Cycle:  0
total time 11.96 loading time 10.69, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 3.57116 cost with .item() is 0.00000

Cycle:  1
total time 11.92 loading time 10.65, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 2.27203 cost with .item() is 0.00000

Cycle:  2
total time 11.95 loading time 10.68, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 2.08692 cost with .item() is 0.00000

Cycle:  3
total time 11.88 loading time 10.61, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.77, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 1.92515 cost with .item() is 0.00000

When I comment out .item() (i,e only .data[0]) and put torch.cuda.synchronize() at the end of the cycle

Cycle 0
total time 16.23 loading time 10.77, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.78, average cost time with .data[0] 0.01, average cost time with .item() 4.18
cost with data[0] is 3.46457 cost with .item() is 0.00000

Cycle:  1
total time 16.06 loading time 10.62, cuda transfer time 0.11, forward time: 0.35, cost time 0.02, step time 0.77, average cost time with .data[0] 0.01, average cost time with .item() 4.20
cost with data[0] is 2.20793 cost with .item() is 0.00000

When I comment out .item() (i,e only .data[0]) and put torch.cuda.synchronize() BEFORE accumulation of cost with .data[0]

Cycle:  0
total time 16.53 loading time 11.06, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 4.97, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 3.53346 cost with .item() is 0.00000

Cycle:  1
total time 16.47 loading time 11.00, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 4.96, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 2.23358 cost with .item() is 0.00000

Cycle:  2
total time 16.43 loading time 10.97, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 4.95, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 2.03888 cost with .item() is 0.00000

Cycle:  3
total time 16.42 loading time 10.95, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 4.97, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 1.92610 cost with .item() is 0.00000

When I comment out .item() (i,e only .data[0]) and put torch.cuda.synchronize() BEFORE accumulation of cost with .data[0] and BEFORE the update phase (i.e step time)

Cycle:  0
total time 16.11 loading time 10.69, cuda transfer time 0.11, forward time: 0.35, cost time 1.05, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 3.52324 cost with .item() is 0.00000

Cycle:  1
total time 16.10 loading time 10.67, cuda transfer time 0.11, forward time: 0.36, cost time 1.05, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 2.19925 cost with .item() is 0.00000

Cycle:  2
total time 16.13 loading time 10.70, cuda transfer time 0.11, forward time: 0.36, cost time 1.05, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 1.97843 cost with .item() is 0.00000

Cycle:  3
total time 16.10 loading time 10.68, cuda transfer time 0.11, forward time: 0.36, cost time 1.04, step time 3.90, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 1.86069 cost with .item() is 0.00000

Finally when I remove all the syncs inside the loop that iterate over the database but put it just before the timer that is activated once the loop is finished:


Cycle:  0
total time 11.96 loading time 10.67, cuda transfer time 0.11, forward time: 0.35, cost time 0.02, step time 0.77, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 3.41078 cost with .item() is 0.00000

Cycle:  1
total time 12.03 loading time 10.73, cuda transfer time 0.11, forward time: 0.35, cost time 0.02, step time 0.77, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 2.16014 cost with .item() is 0.00000

Cycle:  2
total time 12.03 loading time 10.73, cuda transfer time 0.11, forward time: 0.36, cost time 0.02, step time 0.77, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 1.98470 cost with .item() is 0.00000

Cycle:  3
total time 12.00 loading time 10.71, cuda transfer time 0.11, forward time: 0.35, cost time 0.02, step time 0.77, average cost time with .data[0] 0.01, average cost time with .item() 0.00
cost with data[0] is 1.84018 cost with .item() is 0.00000

p.s: In the light of these findings I changed the title of the topic to be less misleading

Basically what happens in pseudoish code is:

time 1

for items in dataset 
    do some stuff

    update parameters

    accumulate costs

end of for 

time2

If I use .data[0] to accumulate costs, the total time is 11-12 and if I use .item() it is 16. Moreover putting sync just before time2 still gives total time 12 but putting sync anywhere in the for loop increases the total time to 16.

Hmmm. Thats kind of mysterious. You are printing the actual concrete cost, and that is being displayed, and that’s included in the timing of .data and .item, and yet they give different times?

Because, in order to print the value, the data needs to have been transferred from the gpu: the sync must have already happened.

cost is a zero dimensional tensor right? like, if you print cost.size(), it’s []?

Yes it is []

I am going to try to replicate this with a simpler code so that people can read it and see if there is something I am missing.

1 Like

So here is another test. I increased the number of workers for dataloading from 1 to 2. If I use .item() the total time is 3.77 where as if I use .data[0] it is total time 2.71. Previously the time difference was about 6 but now it is about 1. It seems the delay somehow is caused by the dataloading stage which can be reduced by increasing number of workers…

1 Like