Can torch.no_grad() used in training?

I want to add a new layer to the pretrained model, and the pretrained model will not be updated, only the added layer will be trained. So my question is that can I use torch.no_grad() to wrap the forward of the pretrained part? Is this reasonable? Will this reduce memory usage and speed up training?


Yes, this should work as shown in this small code snippet:

class MyModel(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        self.fc1 = nn.Linear(4*4*16, 10)

    def forward(self, x):
        with torch.no_grad():
            x = F.relu(self.conv1(x))
            x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return x

model = MyModel()
x = torch.randn(1, 3, 4, 4)
out = model(x)

for name, param in model.named_parameters():
    if param.grad is not None:
        print(name, param.grad.abs().sum())

If you remove the torch.no_grad() guard, all layers will get gradients.
Alternatively you could set the requires_grad attribute to False.


Thanks for your reply!

So which way will be faster? torch.no_grad() or requires_grad=False?

In my view, torch.no_grad() will not caculate grad of inputs of layers in the pretrained model, while requires_grad=False do. So torch.no_grad() will be faster? Is that right?

I think neither approach will store the intermediate tensors, but let me know, if you see any differences in profiling. :wink:

I tried both ways but didn’t get obvious difference in running time…

class Net(nn.Module):
    def __init__(self, no_grad=False, requires_grad=True):
        super(Net, self).__init__()
        self.net_1 = Sequential(nn.Linear(1000,1000),

        self.net_2 = nn.Linear(1000,1000)
        if not requires_grad:
        self.no_grad = no_grad
    def _freeze_param(self):
        for k,v in self.named_parameters():
            if k.startswith("net_1"):
                v.requires_grad = False
    def forward(self, x):
        if self.no_grad:
            with torch.no_grad():
                x = self.net_1(x)
            x = self.net_1(x)
        x = self.net_2(x)
        return x

def net_time_test(net, times):
    duration = 0
    start = time()
    for i in range(times):
        out = net(input_)
        start = time()
        out.backward(gradient=torch.randn(64, 1000))
        end = time()
        duration += end - start

input_ = torch.randn(64, 1000)
net_baseline = Net(no_grad=False, requires_grad=True)
net_no_grad = Net(no_grad=True, requires_grad=True)
net_no_requires_grad = Net(no_grad=False, requires_grad=False)

print(net_time_test(net_baseline, 100))
# 21.8435652256
print(net_time_test(net_no_grad, 100))
# 9.42938923836
print(net_time_test(net_no_requires_grad, 100))
# 7.28591799736

In my experiment, requires_grad=False is faster than torch.no_grad() in training.