Is model global in PyTorch?

Hi. I would like to ask if i do:

def change_params(net)
      for param in net.parameters():
             do something

do the net parameters get updated globally? or they will only be updated locally inside the function? Meaning that if i use the model outside the function, am i using the updated one or the original one?

It depends on what you’ll be doing. For instance, the following snippet works well:

import torch
model = torch.nn.Linear(2, 3)
def change_params(net):
    for param in net.parameters():
        #do something
        if param.numel() == 3: # Disabling grad for the biases
            param.requires_grad_(False)


for p in model.parameters(): print(p) # Before

change_params(model)

for p in model.parameters(): print(p) # After

You can’t do anything with the model’s parameters, for instance, the following snippet throws an error:

import torch
model = torch.nn.Linear(2, 3)
def change_params(net):
    for param in net.parameters():
        #do something
        param += torch.ones_like(param)

for p in model.parameters(): print(p) # Before

change_params(model)

for p in model.parameters(): print(p) # After

Moreover, it might just ignore what you want to do, for instance:

import torch
model = torch.nn.Linear(2, 3)
def change_params(net):
   for param in net.parameters():
       #do something
       param = param.double()


for p in model.parameters(): print(p.type())

change_params(model)

for p in model.parameters(): print(p.type())

So, it depends on what you want to do.

Hi @LeviViana
I don’t think that is necessarily the case. Have a look at the snippet of the code below. I have multiplied the gradients by 2 and PyTorch does not throw an error, and the result is correct (x2):

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def change_grads(net):
    for param in net.parameters():
        param.grad*=2

def gradSum(net):
    count = 0
    for param in net.parameters():
        grad = param.grad.detach()
        count += torch.sum((grad)).item()
    return count

# Specify the Mean and standard deviation of all the pixels in the MNIST dataset. They are precomputed 
mean_gray = 0.1307
stddev_gray = 0.3081

transformation = transforms.Compose([transforms.ToTensor(),transforms.Normalize((mean_gray,), (stddev_gray,))])

#Load our dataset
train_dataset = datasets.MNIST(root = './data', train = True, transform = transformation,download = True)

train_loader = torch.utils.data.DataLoader(dataset = train_dataset, 
                                           batch_size = 100,
                                           shuffle = True)

class CNN(nn.Module):
    def __init__(self):
        super(CNN,self).__init__()
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=3,stride=1, padding=1)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=2)
        self.cnn2 = nn.Conv2d(in_channels=8, out_channels=32, kernel_size=5, stride=1, padding=2)
        self.fc1 = nn.Linear(in_features=1568, out_features=600)
        self.fc2 = nn.Linear(in_features=600, out_features=10)
    def forward(self,x):
        out = self.cnn1(x)
        out = self.relu(out)
        out = self.maxpool(out)
        out = self.cnn2(out)
        out = self.relu(out)
        out = self.maxpool(out)
        out = out.view(-1,1568) 
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out

model = CNN().cuda()
criterion = nn.CrossEntropyLoss()   
optimizer = torch.optim.Adam(model.parameters(), lr=0.001, amsgrad=True)

for i, (input, target) in enumerate(train_loader):
    target = target.cuda(async=True)
    input = input.cuda()
    input_var = torch.autograd.Variable(input)
    target_var = torch.autograd.Variable(target)

    # compute output
    output = model(input_var)
    loss = criterion(output, target_var)
    model.zero_grad()
    loss.backward()
    # Print sum of gradients before
    print(gradSum(model))
    change_grads(model)
    # print sum of gradients after
    print(gradSum(model))
    break

This gives:
20.637027357704937
41.274054715409875