I am following the example posted here, where the parameters of a sinusoid are learned using an SGD optimizer. I have a custom NN module defined with a function which has one parameter, theta, to be learned. The parameter, theta, generates a new set of NN weights, which I then use to update my NN using the load_state_dict
function. After updating the weights, I compute a loss between y_true and y_pred, and want to compute dloss/dtheta. However, whenever I print the value of theta.grad, it is None. I believe that NN weight update mechanism I am using, load_state_dict
, is causing the issue. How can I setup my modules and optimizer to get the required loss?
In the following Minimal Working Example, ParentModule1
and SubModule1
preserve the computational graph, and the gradient of the loss w.r.t. theta is computed, where as if ParentModule2
and SubModule2
are used, the gradient of the loss w.r.t. theta cannot be computed. The question is, how can I setup my modules such that I can replace the weights of the network (which is a function of theta), and successfully compute dloss/dtheta?
import torch.nn as nn
import torch
import torchvision.models as models
import urllib
from PIL import Image
from torchvision import transforms
import copy
class ParentModule1(nn.Module):
def __init__(self, net, second_module):
super().__init__()
self.net = net
self.second_module = second_module
def forward(self, x, t=None):
a = self.second_module(t)
return self.net(x+a)
class SubModule1(nn.Module):
def __init__(self):
super().__init__()
self.theta = nn.Parameter(torch.randn(1), requires_grad=True)
def forward(self, x):
return self.theta + x
class ParentModule2(nn.Module):
def __init__(self, net, second_module):
super().__init__()
self.net = net
self.second_module = second_module
def forward(self, x, t=None):
a = self.second_module(t)
self.net.load_state_dict(a)
return self.net(x)
class SubModule2(nn.Module):
def __init__(self, state_dict_in):
super().__init__()
self.state = state_dict_in
self.theta = nn.Parameter(torch.randn(1), requires_grad=True)
def forward(self, x):
state_copy = copy.copy(self.state)
for k, v in self.state.items():
state_copy[k] = v + self.theta
return state_copy
if __name__ == '__main__':
net = models.alexnet(pretrained=True)
# when using SubModule1 and ParentModule1, theta.grad is not None
# submod = SubModule1()
# model = ParentModule1(net, submod)
# when using SubModule2 and ParentModule2, theta.grad is None
submod = SubModule2(net.state_dict())
model = ParentModule2(net, submod)
url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
try: urllib.URLopener().retrieve(url, filename)
except: urllib.request.urlretrieve(url, filename)
input_image = Image.open(filename)
preprocess = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
input_tensor = preprocess(input_image)
input_batch = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.second_module.parameters(), lr=1e-3)
t = torch.zeros(1)
y = model(input_batch, t)
loss = nn.functional.cross_entropy(y, torch.LongTensor([3]))
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('gradient', model.second_module.theta.grad)