Training function to generate NN weights

I am following the example posted here, where the parameters of a sinusoid are learned using an SGD optimizer. I have a custom NN module defined with a function which has one parameter, theta, to be learned. The parameter, theta, generates a new set of NN weights, which I then use to update my NN using the load_state_dict function. After updating the weights, I compute a loss between y_true and y_pred, and want to compute dloss/dtheta. However, whenever I print the value of theta.grad, it is None. I believe that NN weight update mechanism I am using, load_state_dict, is causing the issue. How can I setup my modules and optimizer to get the required loss?

In the following Minimal Working Example, ParentModule1 and SubModule1 preserve the computational graph, and the gradient of the loss w.r.t. theta is computed, where as if ParentModule2 and SubModule2 are used, the gradient of the loss w.r.t. theta cannot be computed. The question is, how can I setup my modules such that I can replace the weights of the network (which is a function of theta), and successfully compute dloss/dtheta?

import torch.nn as nn
import torch
import torchvision.models as models
import urllib
from PIL import Image
from torchvision import transforms
import copy

class ParentModule1(nn.Module):
    def __init__(self, net, second_module):
        super().__init__()
        self.net = net
        self.second_module = second_module
    def forward(self, x, t=None):
        a = self.second_module(t)
        return self.net(x+a)

class SubModule1(nn.Module):
    def __init__(self):
        super().__init__()
        self.theta = nn.Parameter(torch.randn(1), requires_grad=True)
    def forward(self, x):
        return self.theta + x

class ParentModule2(nn.Module):
    def __init__(self, net, second_module):
        super().__init__()
        self.net = net
        self.second_module = second_module
    def forward(self, x, t=None):
        a = self.second_module(t)
        self.net.load_state_dict(a)
        return self.net(x)

class SubModule2(nn.Module):
    def __init__(self, state_dict_in):
        super().__init__()
        self.state = state_dict_in
        self.theta = nn.Parameter(torch.randn(1), requires_grad=True)
    def forward(self, x):
        state_copy = copy.copy(self.state)
        for k, v in self.state.items():
            state_copy[k] = v + self.theta
        return state_copy

if __name__ == '__main__':
    net = models.alexnet(pretrained=True)
    # when using SubModule1 and ParentModule1, theta.grad is not None
    # submod = SubModule1()
    # model = ParentModule1(net, submod)
    # when using SubModule2 and ParentModule2, theta.grad is None
    submod = SubModule2(net.state_dict())
    model = ParentModule2(net, submod)

    url, filename = ("https://github.com/pytorch/hub/raw/master/images/dog.jpg", "dog.jpg")
    try: urllib.URLopener().retrieve(url, filename)
    except: urllib.request.urlretrieve(url, filename)

    input_image = Image.open(filename)
    preprocess = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0)  # create a mini-batch as expected by the model

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.second_module.parameters(), lr=1e-3)
    t = torch.zeros(1)
    y = model(input_batch, t)
    loss = nn.functional.cross_entropy(y, torch.LongTensor([3]))

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    print('gradient', model.second_module.theta.grad)