Hi, I’m implementing the memory mechanism from the Compressive Transformer paper (https://arxiv.org/abs/1911.05507). There are two loss functions: the main cross-entropy of the language model and a reconstruction loss (MSE) for the difference between the attention on the real hidden states versus the compressed hidden states (see page 4). The latter is used to optimize how the compressed memory is constructed: it enforces a sort of scale-invariance between attention on original uncompressed memory and the smaller, compressed memory.
I’m constructing the reconstruction loss as follows:
compressor_attn, _ = checkpoint(self.attn, q=h, k=original_memory, v=original_memory) compressed_attn, _ = checkpoint(self.attn, q=h, k=compressed_memory, v=compressed_memory)
self.attn is an nn.Module containing all the attention logic + parameters. And where the
compressed_memory is created using an nn.Module that downsamples the size of the the original context memory -> e.g. from length 1000 to length 100, see below for an example:
class MemoryCompressor(nn.Module): def __init__(self, input_dim, latent_dim): super(MemoryCompressor, self).__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.encoder = nn.Linear(input_dim, latent_dim) self.apply(self.init_weights) def init_weights(self, module): if isinstance(module, (nn.Linear)): module.weight.data.normal_(mean=0.0, std=0.1 / np.sqrt(self.input_dim)) def forward(self, input): return self.encoder(input).transpose(0, 2) # Example of memory being compressed: original_memory= ... compressed_memory = MemoryCompressor(1000, 100)(original_memory)
Then the reconstruction loss is (at the end):
recons_loss = F.mse_loss(compressed_attn, compressor_attn)
My problem is that right now the backwards pass goes from
compressor_attn -> the parameters of
self.attn. The desired behaviour I want is that only the parameters from
MemoryCompressor that are used to build
compressed_memory have gradients that affect
recons_loss; and vica versa, the actual attention parameters only have gradients that affect the main crossentropy loss of the language model.
Why? Because I want the attention part of the model to focus purely on getting good prediction; I don’t want it distracted by trying to prioritize reconstruction loss. So the only bit I want focused on the reconstruction loss is the
MemoryCompressor, so I need to selectively construct the gradients somehow so only the gradients from
MemoryCompressor flow through to
recons_loss (even if we use attention parameters as an intermediate step).
Current State: I’ve figured out that I need to somehow zero-out the gradients of the bits I don’t want, but I’m unsure how to do it with this particular control flow. Am going to try and figure this out myself today - think it’s just my inexperience! - but any pointers would be appreciated (happy to clarify if anything is unclear).