Speed Confusion

Hi all, I am a new beginner of Pytorch. I found previous disccusion talking about the speed for backward is 1 or 2 times of forward. But my backward() is rather too slow.

import torch
import torch.nn as nn

class Nets(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(26, 15)
        self.hidden2 = nn.Linear(15, 15)
        self.out = nn.Linear(15, 1)

    def forward(self, x):
        # flow to 1st hidden layer
        x = self.hidden1(x)
        x = torch.tanh(x)

        # flow to 2nd hidden layer
        x = self.hidden2(x)
        x = torch.tanh(x)

        # flow to output layer
        x = self.out(x)
        output = torch.sum(x).view(-1)
        return output

inputs = torch.randn(63,26)

%%timeit
net = Nets()
a = net(inputs)
## 320 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

%%timeit
a.backward(retain_graph=True)
## 25.3 s ± 519 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

Overall, is that normal or I could make it a little bit faster? Thanks.

Hi,

I cannot reproduce it.
Using the modified script here:

import torch
import torch.nn as nn
import time

class Nets(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden1 = nn.Linear(26, 15)
        self.hidden2 = nn.Linear(15, 15)
        self.out = nn.Linear(15, 1)

    def forward(self, x):
        # flow to 1st hidden layer
        x = self.hidden1(x)
        x = torch.tanh(x)

        # flow to 2nd hidden layer
        x = self.hidden2(x)
        x = torch.tanh(x)

        # flow to output layer
        x = self.out(x)
        output = torch.sum(x).view(-1)
        return output

inputs = torch.randn(63,26)

net = Nets()
start = time.time()
a = net(inputs)
print("fw: ", time.time() - start)
## 320 µs ± 11.6 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)

start = time.time()
a.backward(retain_graph=True)
print("bw: ", time.time() - start)

I get

fw:  0.0009152889251708984
bw:  0.0009768009185791016

which is what I was expected.
Maybe timeit has some weird side effects?