Simple Linear Regression: why backward() is so slow?

I try to solve a simple Multivariate Linear Regression problem (argmin_w (Y-Xw)^2) using gradient descent.

The code using .backward() is 5 times lower than not using .backward():

import torch
import time

INPUT_SIZE = 256
HIDDEN_SIZE = 2048
BASIS_SIZE = [INPUT_SIZE, HIDDEN_SIZE]
batch_size = 100

X = torch.nn.Linear(BASIS_SIZE[1],BASIS_SIZE[0]).cuda()
target = torch.randn([batch_size, INPUT_SIZE]).cuda()
eta=0.01

w = torch.cuda.FloatTensor(batch_size,HIDDEN_SIZE).fill_(0)
w.requires_grad = True
start_time = time.time()
for k in range(2000):
    Res = X(w)-target
    loss = 0.5*(Res**2).sum()
    loss.backward()
    w.data = w.data.add(-eta*w.grad)
    X.weight.grad.zero_()
    w.grad.zero_()
print('With Backward -- Runtime: {}, loss: {}'.format(time.time()-start_time,loss))

w = torch.cuda.FloatTensor(batch_size,HIDDEN_SIZE).fill_(0)
w.requires_grad = True
start_time = time.time()
for k in range(2000):
    Res =X(w)-target
    loss = 0.5*(Res**2).sum()
    w.data = w.data.add(-eta*X.weight.data.t().mm(Res.T).T)
    
print('Without Backward -- Runtime: {}, loss: {}'.format(time.time()-start_time,loss))
With Backward -- Runtime: 1.8691670894622803, loss: 2.3017024993896484
Without Backward -- Runtime: 0.4285166263580322, loss: 2.3017029762268066

I think there are a few things going on here. A large portion of the timing difference is due to the fact that the “backward” version is run first. The first invocation of GPU kernels is usually much slower as libraries (such as cuBLAS in this case) need to be loaded, a kernel for the specific shape needs to be selected, etc…
Additionally, I don’t think the “manual” gradient computation and update matches the backward call. For example, the backward pass will also calculate gradients for X’s params (which is not done in the manual version), and also calculate a gradient for the bias in the linear function. There are probably still additional inconsistencies/mismatches but accounting for these seems to close the gap somewhat:

import torch
import time

INPUT_SIZE = 256
HIDDEN_SIZE = 2048
BASIS_SIZE = [INPUT_SIZE, HIDDEN_SIZE]
batch_size = 100

X = torch.nn.Linear(BASIS_SIZE[1],BASIS_SIZE[0], bias=False).cuda()
target = torch.randn([batch_size, INPUT_SIZE]).cuda()
eta=0.01

w = torch.cuda.FloatTensor(batch_size,HIDDEN_SIZE).fill_(0)
X.weight.requires_grad = False
w.requires_grad = True


for k in range(2000):
    if k == 100:
        torch.cuda.synchronize()
        start_time = time.time()
    Res = X(w)-target
    loss = 0.5*(Res**2).sum()
    loss.backward()
    w.data = w.data.add(-eta*w.grad)
    #X.weight.grad.zero_()
    w.grad.zero_()
torch.cuda.synchronize()
print(X.weight.grad)
print('With Backward -- Runtime: {}, loss: {}'.format(time.time()-start_time,loss))

w = torch.cuda.FloatTensor(batch_size,HIDDEN_SIZE).fill_(0)
w.requires_grad = True

for k in range(2000):
    if k == 100:
        torch.cuda.synchronize()
        start_time = time.time()
    Res =X(w)-target
    w.data = w.data.add(-eta*X.weight.data.t().mm(Res.T).T)
torch.cuda.synchronize()
print('Without Backward -- Runtime: {}, loss: {}'.format(time.time()-start_time,loss))
None
With Backward -- Runtime: 0.5746731758117676, loss: 2.399951934814453                                                 
Without Backward -- Runtime: 0.24897408485412598, loss: 2.399951934814453                                             

Thanks for answering!

I did notice the first run is slower because of loading libraries. However, it’s only less than 0.1s difference. It’s very strange that when I run your code, I couldn’t reproduce the result. Even if I increase the numbers of iteration, the gap between the time difference doesn’t seem to get smaller:

import torch
import time

INPUT_SIZE = 256
HIDDEN_SIZE = 2048
BASIS_SIZE = [INPUT_SIZE, HIDDEN_SIZE]
batch_size = 100

X = torch.nn.Linear(BASIS_SIZE[1],BASIS_SIZE[0], bias=False).cuda()
target = torch.randn([batch_size, INPUT_SIZE]).cuda()
eta=0.01

w = torch.cuda.FloatTensor(batch_size,HIDDEN_SIZE).fill_(0)
X.weight.requires_grad = False
w.requires_grad = True


for k in range(6000):
    if k == 100:
        torch.cuda.synchronize()
        start_time = time.time()
    Res = X(w)-target
    loss = 0.5*(Res**2).sum()
    loss.backward()
    w.data = w.data.add(-eta*w.grad)
    #X.weight.grad.zero_()
    w.grad.zero_()
torch.cuda.synchronize()
print(X.weight.grad)
print('With Backward -- Runtime: {}, loss: {}'.format(time.time()-start_time,loss))

w = torch.cuda.FloatTensor(batch_size,HIDDEN_SIZE).fill_(0)
w.requires_grad = True

for k in range(6000):
    if k == 100:
        torch.cuda.synchronize()
        start_time = time.time()
    Res =X(w)-target
    w.data = w.data.add(-eta*X.weight.data.t().mm(Res.T).T)
torch.cuda.synchronize()
print('Without Backward -- Runtime: {}, loss: {}'.format(time.time()-start_time,loss))
None
With Backward -- Runtime: 4.442049503326416, loss: 9.112495717999991e-06
Without Backward -- Runtime: 0.7066187858581543, loss: 9.112495717999991e-06