Torch.matmul CUDA OOM

I have been trying to run my transformer codebase on top of a single CPU.

But I hit a wall when the code is trying to run matrix multiplication as part of self-attention :

def forward(self, hidden_states, attention_mask):
    mixed_query_layer   = self.query(hidden_states)
    mixed_key_layer     = self.key(hidden_states)
    mixed_value_layer   = self.value(hidden_states)

    query_layer         = self.transpose_for_scores(mixed_query_layer)
    key_layer           = self.transpose_for_scores(mixed_key_layer)
    value_layer         = self.transpose_for_scores(mixed_value_layer)
    print (query_layer.shape, key_layer.transpose(-1,-2).shape ) 
    # torch.Size([1, 8, 10381, 16]) torch.Size([1, 8, 16, 10381])
    attention_scores    = torch.matmul(query_layer, key_layer.transpose(-1, -2))

It always ‘bombed’ out on the last operation ( torch.matmul ) with the following error :

RuntimeError: CUDA out of memory. Tried to allocate 8.38 GiB (GPU 0; 15.78 GiB total capacity; 8.75 GiB already allocated; 5.47 GiB free; 8.76 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I’m surprised that matmul operation takes so much memory so such as small matrix multiplication.

Did anyone experience the same thing ? Any workaround, guys ?

Appreciate any help.