Optmizing RNNs Run time

Hello all,
I have been training an attention-based model and it is painstakingly slow, i.e. it is taking over 20 minutes for a single epoch and the training samples are only 52 on an RTX 2070

Here is my entire code below, including the training, Is there any way I could make it run faster( maybe optimizations in the attention class)

import torch
import torch.nn as nn
import math
from CNNAE import Encoder
class Attention(nn.Module):
    def __init__(self, qdim, kdim, vdim):
        super(Attention, self).__init__()
        self.scale = 1./math.sqrt(qdim)

    def forward(self, q, k, v):
        q = q.unsqueeze(1)
        k = k.transpose(0, 1).transpose(1, 2)
        energy = torch.bmm(q, k)
        energy = nn.functional.softmax(energy.mul_(self.scale), dim=2)
        v = v.transpose(0, 1)
        combination = torch.bmm(energy, v)
        return energy, combination

class LSTMclassifier(nn.Module):
    def __init__(self):
        super(LSTMclassifier, self).__init__()
        self.rnn = torch.nn.LSTM(input_size=1, hidden_size=64, num_layers=200, bidirectional=True)
        encoder_layer = nn.TransformerEncoderLayer(d_model=128, dim_feedforward=128, nhead=8)
        self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
        self.autoencoder = Encoder().cuda()
        self.rnn_final = torch.nn.LSTM(input_size=128, hidden_size=64, num_layers=200, bidirectional=True)
        self.attention = Attention(qdim=64, kdim=128, vdim=128)
        self.classifier = nn.Linear(128, 4)
        self.min_loss = 0

    def forward(self, x):
         h0 = c0 = torch.zeros((200*2, 1, 64)).cuda()
         out, hidden = self.rnn(x, (h0, c0))
         out = self.encoder(out)

         out_final, hidden_final = self.rnn_final(out, (hidden[0], hidden[1]))
         hidden_final = hidden_final[1]
         hidden_final = torch.cat((hidden_final[-1], hidden_final[-2]), dim=1)

         energy, output = self.attention(hidden_final, out_final, out_final)
         return self.classifier(output)

    def train_model(self,model, epochs, train, test):
        loss_func = nn.CrossEntropyLoss().cuda()
        optim = torch.optim.SGD(model.parameters(), lr=1e-3)
        for i in range(0, epochs+1):
            val_accuracy = 0
            for (x, y) in train:
                x = x.cuda()
                encoded, out = self.autoencoder(x)
                encoded = encoded.reshape((1, 1, 313))
                encoded = encoded.permute(2, 1, 0)
                y = y.cuda()
                out = model(encoded)
                out = out.reshape((1, 4))
                loss = loss_func(out, y)

            with torch.no_grad():
                for (x,y) in test:
                    x = x.cuda()
                    encoded, out = self.autoencoder(x)
                    encoded = encoded.reshape((1, 1, 313))
                    encoded = encoded.permute(2, 1, 0)
                    out = model(encoded)
                    out = out.reshape((1, 4))
                    if torch.argmax(out).item() == y.item():
                        val_accuracy += 1
                if(val_accuracy) > self.min_loss:
                    torch.save(model.state_dict(), 'path/attention_9.pth')
                    self.min_loss = val_accuracy

Thanks In Advance

Try using a profiler, like the one embedded in pycharm, or single step through your code and time each call.

Two things to take notice in general:
Have you run backward propagation through unwanted tensors?
Have you vectorized your computation?

class AttentionBlock(NeuralNetworkModule):
    def __init__(self, in_channels, key_size, value_size, device):
        super(AttentionBlock, self).__init__()
        self.linear_query = nn.Linear(in_channels, key_size).to(device)
        self.linear_keys = nn.Linear(in_channels, key_size).to(device)
        self.linear_values = nn.Linear(in_channels, value_size).to(device)
        self.sqrt_key_size = m.sqrt(key_size)


    def forward(self, input: t.Tensor, time_steps: Union[t.Tensor, None]):
        # input is dim (N, T, in_channels)
        # time steps is dim (N, T)

        length = input.shape[1]

        if time_steps is None:
            mask = t.ones([length, length], dtype=t.uint8, device=input.device)\
            # upper diagnoal
            mask = time_steps.unsqueeze(dim=2) < time_steps.unsqueeze(dim=1)

        keys = self.linear_keys(input)  # shape: (N, T, key_size)
        query = self.linear_query(input)  # shape: (N, T, key_size)
        values = self.linear_values(input)  # shape: (N, T, value_size)
        raw = t.bmm(query, t.transpose(keys, 1, 2))  # shape: (N, T, T)
        tmp = raw.clone()
        tmp.masked_fill_(mask, -float('inf'))
        rel = t.softmax(tmp / self.sqrt_key_size, dim=1)  # shape: (N, T, T)
        tmp = t.bmm(rel, values)  # shape: (N, T, value_size)

        # shapes: (N, T, in_channels + value_size), (N, T, T), (N, T, T)
        return t.cat((input, tmp), dim=2), rel.detach(), raw.detach()

This is my implementation of a “modified” attention layer, set time_steps to None and it would be equivalent to a normal attention layer.

@iffiX, thanks a lot, this did speed it up considerably.