Transformer training is slow and uses a lot of memory

Hi everyone,

I have a problem with my transformer architecture (for 2D input), since it runs out of memory on the GPU already for system sizes ~6x6=36 pixels. I am using a batch size of 100, and when I decrease it the computation time is incredibly long. Does someone have experience with similar problems?

I am using the pytorch TransformerEncoder layer + caching of the layer outputs:

minibatch = self.minibatch
repeat = int(np.ceil(batch / minibatch))
amp = []
phase = []
new_cache = []
for i in range(repeat):
    src_i = src[:, i * minibatch:(i + 1) * minibatch]
    src_i = self.encoder(src_i) * math.sqrt(self.embedding_size)  # (seq, batch, embedding)
    src_i = self.pos_encoder(src_i, self.system_size)  # (seq, batch, embedding)  
    if compute_cache:
        if cache != None: c = cache[i]
                else: c = None
                output_i, c = self.next_with_cache(src_i, self.src_mask, c)
                if c != None: new_cache.append(c)
                psi_output = output_i
                output_i = self.transformer_encoder(src_i, self.src_mask)
                psi_output = output_i[self.seq_prefix_len - 1:]  # only use the physical degrees of freedom
        amp_i = F.log_softmax(self.amp_head(psi_output), dim=-1)  # (seq, batch, phys_dim)

with the function

def next_with_cache(self,tgt,mask,cache=None,idx=-1):
        output = tgt
        new_token_cache = []
        #go through each layer and apply self attention only to the last input
        for i,layer in enumerate(self.transformer_encoder.layers):
            #have to merge the functions into one
            src = tgt[idx:, :, :]

            # self attention part
            src2 = layer.self_attn(
                src, #only do attention with the last elem of the sequence
            #straight from torch transformer encoder code
            src = src + layer.dropout1(src2)
            src = layer.norm1(src)
            src2 = layer.linear2(layer.dropout(layer.activation(layer.linear1(src))))
            src = src + layer.dropout2(src2)
            src = layer.norm2(src)
           output = src
            if cache is not None:
                #layers after layer 1 need to use a cache of the previous layer's output on each input
                output =[cache[i], output], dim=0)

        #update cache with new output
        if cache is not None:
            new_cache =[cache, torch.stack(new_token_cache, dim=0)], dim=1)
            new_cache = torch.stack(new_token_cache, dim=0)
        return output, new_cache

You don’t mention the “size” of your transformer, e.g., the dimension sizes an the number of encoder blocks. The number of trainable parameters of a transformer grows very quickly depending on those numbers.

Hi, I typically use a embedding size of 32 and 1-4 layers. Then usually each training step takes ~2 seconds, but almost the same time on a GPU, which is strange…