(Performance question) How to not pay cost of initializing a scalar Variable to 0?

Hey all –

I’ve looked around as much as I can but can’t seem to find an explanation on how to best do this.

I have a loss function I implemented, and upon doing a little bit of profiling, I’ve realized that the vast majority of the runtime is spent initializing a scalar Variable (the loss) which is then summed from a couple different components:

It looks something like this:

class PixelwiseContrastiveLoss(torch.nn.Module):

    def __init__(self):
        super(PixelwiseContrastiveLoss, self).__init__()

    def forward(self, #lots of inputs):
        start = time.time()
        loss = Variable(torch.cuda.FloatTensor([0]))
        print time.time() - start, " for init 0 loss"

        start = time.time()
        loss += # LOTS of torch math operations
        print time.time() - start, " for first component of loss"
  
        start = time.time()
        loss +=# LOTS of torch math operations
        print time.time() - start, " for second component of loss"

        return loss

And my poor man’s profiler consistently gives these approximate times:

0.0570869445801  for init 0 loss
0.0010199546814  for first component of loss
0.0110859870911  for second component of loss
---
0.0719039440155 total getting loss

No matter what I try, I still get about 0.07 seconds calling this loss function. I’ve tried every variant I can think of: having a Variable that lives as a class variable and I re-zero on each forward, etc.

What’s interesting to me is that if I never even assign a variable loss, and just return the sum of two big torch math operations, the total time is still about 0.07 seconds. But breaking out the way I have above, the vast majority of this appears to just be initializing.

So, question is: what can I do to avoid this time cost which appears to just be initializing memory?

I’m probably just misunderstanding something about proper use of autograd…

Thanks!

P.S. Note also that when the object was created I call .cuda() with pixelwise_contrastive_loss = PixelwiseContrastiveLoss().cuda(), so I’m under impression the object should already live on GPU

After lots of trial and error, I think my print statements are measuring artifacts…

Even if I massively simplify the loss function to something like loss = img_b_pred[0][0][0] - img_a_pred[0][0][0], then even though that loss is computed in .001, the overall training step time (about 0.26 seconds total for forward + backward) does not appreciably decrease from the scenario presented in the original question. More computation time is shifted to what I measure backward() taking up.

I would still be interested to know more about why the above scenario measures such a large time for initializing the loss to 0 though, if people are knowledgeable and willing to share! Perhaps it happens to kick off something behind the scenes with respect to the graph?

Hi,

I think the problem with your timings is that all the cuda api is asynchronous and so your measurements are not correct. If you want to do that, you need to use torch.cuda.synchronize() before and after timings.

For the initialization, you can simply replace loss = Variable(torch.cuda.FloatTensor([0])) by loss = 0. :slight_smile:

Hi @albanD, thank you for clearing this up for me!