Hello all,
I have been training an attention-based model and it is painstakingly slow, i.e. it is taking over 20 minutes for a single epoch and the training samples are only 52 on an RTX 2070
Here is my entire code below, including the training, Is there any way I could make it run faster( maybe optimizations in the attention class)
import torch
import torch.nn as nn
import math
from CNNAE import Encoder
class Attention(nn.Module):
def __init__(self, qdim, kdim, vdim):
super(Attention, self).__init__()
self.scale = 1./math.sqrt(qdim)
def forward(self, q, k, v):
q = q.unsqueeze(1)
k = k.transpose(0, 1).transpose(1, 2)
energy = torch.bmm(q, k)
energy = nn.functional.softmax(energy.mul_(self.scale), dim=2)
v = v.transpose(0, 1)
combination = torch.bmm(energy, v)
return energy, combination
class LSTMclassifier(nn.Module):
def __init__(self):
super(LSTMclassifier, self).__init__()
self.rnn = torch.nn.LSTM(input_size=1, hidden_size=64, num_layers=200, bidirectional=True)
encoder_layer = nn.TransformerEncoderLayer(d_model=128, dim_feedforward=128, nhead=8)
self.encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
self.autoencoder = Encoder().cuda()
self.autoencoder.load_state_dict(torch.load('/path/encoder.pth'))
self.rnn_final = torch.nn.LSTM(input_size=128, hidden_size=64, num_layers=200, bidirectional=True)
self.attention = Attention(qdim=64, kdim=128, vdim=128)
self.classifier = nn.Linear(128, 4)
self.min_loss = 0
def forward(self, x):
h0 = c0 = torch.zeros((200*2, 1, 64)).cuda()
out, hidden = self.rnn(x, (h0, c0))
out = self.encoder(out)
out_final, hidden_final = self.rnn_final(out, (hidden[0], hidden[1]))
hidden_final = hidden_final[1]
hidden_final = torch.cat((hidden_final[-1], hidden_final[-2]), dim=1)
energy, output = self.attention(hidden_final, out_final, out_final)
return self.classifier(output)
def train_model(self,model, epochs, train, test):
model.cuda()
loss_func = nn.CrossEntropyLoss().cuda()
optim = torch.optim.SGD(model.parameters(), lr=1e-3)
for i in range(0, epochs+1):
print(i)
val_accuracy = 0
model.train()
for (x, y) in train:
optim.zero_grad()
x = x.cuda()
encoded, out = self.autoencoder(x)
encoded = encoded.reshape((1, 1, 313))
encoded = encoded.permute(2, 1, 0)
y = y.cuda()
out = model(encoded)
out = out.reshape((1, 4))
loss = loss_func(out, y)
loss.backward()
optim.step()
print('---testing---')
with torch.no_grad():
model.eval()
for (x,y) in test:
x = x.cuda()
encoded, out = self.autoencoder(x)
encoded = encoded.reshape((1, 1, 313))
encoded = encoded.permute(2, 1, 0)
out = model(encoded)
out = out.reshape((1, 4))
if torch.argmax(out).item() == y.item():
val_accuracy += 1
if(val_accuracy) > self.min_loss:
print('saving')
print(val_accuracy)
torch.save(model.state_dict(), 'path/attention_9.pth')
self.min_loss = val_accuracy
Thanks In Advance