Exclude indices from being updated in optimizer.step()


my approach differs from usual neural networks because what I want to do is back-propagate to the input of the network and not to the weights. Let’s say I have an input tensor of shape 512x4094 where 512 is the number of samples. Then I run my model on this 512x4094 tensor, get some output, calculate a loss and back-propagate it to the input. The network weights are not updated.
The 512 samples are independent from each other! The gradient of one sample should not affect other samples at all.

This works fine with different optimizers like Adam oder Adagrad. However, I have noticed that when I use Adam and I split my input into batches, e.g. 2 batches of size 256, then the output is different. When I use Adagrad the difference is almost 0, so it does not make a difference if I use 1,2,… batches. This is what I expect since the samples are independent from each other.

Example (all samples at the same time):

# samples is tensor of shape 512x4096
adam = torch.optim.Adam(params=samples)
for epoch in range(100):
    output = my_model(samples)
    loss = lossfct(output, target)

Example (2 batches):

# samples is tensor of shape 512x4096
adam = torch.optim.Adam(params=samples)
for epoch in range(100):
    for start,end in [(0,256), (256,512)]:
        output = my_model(samples[start:end, :])
        loss = lossfct(output, target[start:end, :])

        # Now all gradients of the current batch are set correctly and the other batch has grad = 0

        # This should only update samples[start:end] and not the other samples
        # How can I avoid that the momentum gets applied? I need to disable updates for this batch completely

My explanation for this behavior is that Adam uses a kind of momentum. So when I feed batch 2 in my model and calculate the loss, batch 2 will have some gradient and batch 1 will have grad = 0 everywhere. However, when I then call adam.step() also the samples from batch 1 will be modified because they have momentum. This is what I want to avoid.

I want a way to run Adam really only on one batch and ignore everything else. What I could do is use an own optimizer for each batch. But is there an easier solution?


Looking at Pytorch’s Adam implementation, I see this:

for group in self.param_groups:
    for p in group['params']:
        if p.grad is None:

Unfortunately, this will not look at the gradient of each slice of the parameter, but at the gradient of the parameter as a whole. So I would need to splice the samples and initialize Adam like this:

adam = torch.optim.Adam(params=[samples[i] for i in range(512)])

This unfortunately leads to a new problem. The slices are no leafs anymore and when I try to use this instead, the original samplesis not updated anymore, but just the slices:

adam = torch.optim.Adam(params=[samples[i].detach().requires_grad_(True) for i in range(512)])

Any idea how to solve the issue elegantly?