Can torch.no_grad() used in training?

I want to add a new layer to the pretrained model, and the pretrained model will not be updated, only the added layer will be trained. So my question is that can I use torch.no_grad() to wrap the forward of the pretrained part? Is this reasonable? Will this reduce memory usage and speed up training?

Thanks!

Yes, this should work as shown in this small code snippet:

class MyModel(nn.Module):
    def __init__(self):
        super(MyModel,self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(4*4*16, 10)

    def forward(self, x):
        with torch.no_grad():
            x = F.relu(self.conv1(x))
            x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return x
    

model = MyModel()
x = torch.randn(1, 3, 4, 4)
out = model(x)
out.mean().backward()

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.abs().sum())

If you remove the torch.no_grad() guard, all layers will get gradients.
Alternatively you could set the requires_grad attribute to False.

2 Likes

Thanks for your reply!

So which way will be faster? torch.no_grad() or requires_grad=False?

In my view, torch.no_grad() will not caculate grad of inputs of layers in the pretrained model, while requires_grad=False do. So torch.no_grad() will be faster? Is that right?

I think neither approach will store the intermediate tensors, but let me know, if you see any differences in profiling. :wink:

I tried both ways but didn’t get obvious difference in running time…

class Net(nn.Module):
    def __init__(self, no_grad=False, requires_grad=True):
        super(Net, self).__init__()
        self.net_1 = Sequential(nn.Linear(1000,1000),
                  nn.Linear(1000,1000),
                  nn.Linear(1000,1000),
                  nn.Linear(1000,1000),
                  nn.Linear(1000,1000),
                  nn.Linear(1000,1000),)

        self.net_2 = nn.Linear(1000,1000)
        if not requires_grad:
            self._freeze_param()
        self.no_grad = no_grad
    
    def _freeze_param(self):
        for k,v in self.named_parameters():
            if k.startswith("net_1"):
                v.requires_grad = False
    
    def forward(self, x):
        if self.no_grad:
            with torch.no_grad():
                x = self.net_1(x)
        else:
            x = self.net_1(x)
        x = self.net_2(x)
        return x

def net_time_test(net, times):
    duration = 0
    start = time()
    for i in range(times):
        out = net(input_)
        start = time()
        out.backward(gradient=torch.randn(64, 1000))
        end = time()
        duration += end - start
    print(duration)

input_ = torch.randn(64, 1000)
net_baseline = Net(no_grad=False, requires_grad=True)
net_no_grad = Net(no_grad=True, requires_grad=True)
net_no_requires_grad = Net(no_grad=False, requires_grad=False)

print(net_time_test(net_baseline, 100))
# 21.8435652256
print(net_time_test(net_no_grad, 100))
# 9.42938923836
print(net_time_test(net_no_requires_grad, 100))
# 7.28591799736

In my experiment, requires_grad=False is faster than torch.no_grad() in training.

4 Likes