How To Zero Grad for selective inputs (Compressive Transformer Memory)

Hi, I’m implementing the memory mechanism from the Compressive Transformer paper ( There are two loss functions: the main cross-entropy of the language model and a reconstruction loss (MSE) for the difference between the attention on the real hidden states versus the compressed hidden states (see page 4). The latter is used to optimize how the compressed memory is constructed: it enforces a sort of scale-invariance between attention on original uncompressed memory and the smaller, compressed memory.

I’m constructing the reconstruction loss as follows:

compressor_attn, _ = checkpoint(self.attn, q=h, k=original_memory, v=original_memory)
compressed_attn, _ = checkpoint(self.attn, q=h, k=compressed_memory, v=compressed_memory)

Where self.attn is an nn.Module containing all the attention logic + parameters. And where the compressed_memory is created using an nn.Module that downsamples the size of the the original context memory -> e.g. from length 1000 to length 100, see below for an example:

class MemoryCompressor(nn.Module):
    def __init__(self, input_dim, latent_dim):
        super(MemoryCompressor, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.encoder = nn.Linear(input_dim, latent_dim)
    def init_weights(self, module):
        if isinstance(module, (nn.Linear)):
  , std=0.1 / np.sqrt(self.input_dim))
    def forward(self, input):
        return self.encoder(input).transpose(0, 2)

# Example of memory being compressed:
original_memory= ...
compressed_memory = MemoryCompressor(1000, 100)(original_memory)

Then the reconstruction loss is (at the end):

recons_loss = F.mse_loss(compressed_attn, compressor_attn)

My problem is that right now the backwards pass goes from recons_loss -> compressor_attn -> the parameters of self.attn. The desired behaviour I want is that only the parameters from MemoryCompressor that are used to build compressed_memory have gradients that affect recons_loss; and vica versa, the actual attention parameters only have gradients that affect the main crossentropy loss of the language model.

Why? Because I want the attention part of the model to focus purely on getting good prediction; I don’t want it distracted by trying to prioritize reconstruction loss. So the only bit I want focused on the reconstruction loss is the MemoryCompressor, so I need to selectively construct the gradients somehow so only the gradients from MemoryCompressor flow through to recons_loss (even if we use attention parameters as an intermediate step).

Current State: I’ve figured out that I need to somehow zero-out the gradients of the bits I don’t want, but I’m unsure how to do it with this particular control flow. Am going to try and figure this out myself today - think it’s just my inexperience! - but any pointers would be appreciated (happy to clarify if anything is unclear).

I think the key here is the “stop_gradient” pseudocode in Algorithm 2. This is just h = hidden.detach() and old_mem = old_mem.detach() or somesuch.

Best regards