Caching parameters, and randomly using one of them for computing gradients

Hi, I am working on a problem where I need to cache the model parameters (weights) for the last k iterations. In the next iteration, my model needs to use the parameters (randomly picked from the cached values) to compute the gradients.

I tried the following.

model = torch.nn.Sequential(
    torch.nn.Linear(1000, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, 10),
)
queue.push(model.parameters)
delayed_params = queue.pop()

However, I am unable to make the model use delayed_params for computing the gradients. Is there any way to solve this?

Here’s a hacky class defn I put up together to solve a similar issue https://gist.github.com/SsnL/f2f56534aefb22d8612dbd7a5da28ed8

2 Likes

Thanks a lot for the link. I had trouble understanding some lines in your class definition. Specifically, I did not understand why these lines had to be written in forward_with_weights

        for (m, n), w in zip(self._module_names, old_ws):
            super(nn.Module, m).__setattr__(n, w)

From my understanding, you are using the new weights and computing the output, but why do we need to set those weights back to the old ones? And in backward, how will the new weights be used if you set them back to the old ones?

The output is computed using new weights before the lines you referenced here https://gist.github.com/SsnL/f2f56534aefb22d8612dbd7a5da28ed8#file-reparam_module-py-L51. I am setting back to old ones because I didn’t want the module weight to be changed before and after forward_with_weights

I scrapped working on this back when I posted it, and only restarted a week ago. I found your ReparamModule really helpful. However, I would like to understand one thing. Using your class def I wrote a simple module as follows.

@add_metaclass(PatchModules)
class ReparamModule(nn.Module):
    def __init__(self):
        super(ReparamModule, self).__init__()
        self.fc1 = nn.Linear(10, 1)
    
    def forward(self, x):
        x = F.relu(self.fc1(x))
        return x
   #other functions 

rpm = ReparamModule()

Also, I did not assign the old weights back as you did in

since I want the gradients to be computed with respect to the newly assigned parameters new_ws.

I find that after forward_with_weights, rpm.fc1.weight, rpm.fc1.bias have the assigned weights, but if I call rpm.get_weights() I get the old weights. I see that the parameters are also unchanged. Similarly, I had to explicitly use rpm..fc1.weight.grad instead of looping over rpm.parameters() and use .grad.data. Why is this so? How are the parameters and weight, bias decoupled? I attached the code for your reference.
https://gist.github.com/JayanthRR/07aa7a8d45a027954a8d0d2bb1fbaec7