Hello,
ive been trying to train a BLSTM model, using the following:
from cmath import isnan
from unicodedata import bidirectional
import torch
import torch.nn as nn
import torch.nn.functional as F
R_THRESHOLD = 20
class speech_lstm(nn.Module):
def __init__(self):
#first block - BLSTM and FC
super(speech_lstm, self).__init__()
self.blstm = nn.LSTM(257*2, 128, 2, batch_first = True, bidirectional = True)
self.fclstm = nn.Linear(256,128, bias = False)
#second - speaker adapt
self.fc_speaker_adapt = nn.Linear(128,128, bias = False)
#masking
self.fc_mask = nn.Linear(128 , 257, bias = False)
#speaker embedding
self.fc_spk_1 = nn.Linear(128,50, bias = False)
self.fc_spk_2 = nn.Linear(50,50, bias = False)
self.fc_spk_3 = nn.Linear(50,128, bias = False)
self.fc_affine = nn.Linear(128,128)
#gate
self.fc_gate = nn.Linear(128,128, bias = False)
self.fc_norm = nn.Sequential(nn.Linear(160,70),nn.Linear(70,1))
def sys_pass(self, x): # changed forward to sys_pass
Y, R, z = x
#BLSTM
lstm_out, self.hidden = self.blstm(torch.cat((Y,R),1).permute(0,2,1))
lstm_out = torch.sigmoid(self.fclstm(lstm_out))
#SPEAKER ADAPT
speaker_adapt = lstm_out * torch.sigmoid(self.fc_speaker_adapt(z).unsqueeze(1))
#MASK - OUTPUT
M_out = torch.sigmoid(self.fc_mask(speaker_adapt)).permute(0,2,1)
#SPK. EMBEDDING
speaker_embedding_temp = torch.sigmoid(self.fc_spk_2(torch.sigmoid(self.fc_spk_1(speaker_adapt))))
speaker_embedding = self.fc_affine(torch.mean(torch.sigmoid(self.fc_spk_3(speaker_embedding_temp)), dim = 1))
#GATE - OUTPUT
gate_temp = torch.sigmoid(self.fc_gate(speaker_adapt))
pre_norm = gate_temp * speaker_embedding.unsqueeze(1)
z_out = F.normalize(self.fc_norm(pre_norm.permute(0,2,1)).squeeze(-1) + z , dim = 1)
#RES-MASK - OUTPUT
R_out = torch.max(R-M_out, torch.zeros(R.shape,device=R.device)) # need something that converges -normalize
return M_out, R_out, z_out
def forward(self, Y): # this is meant to create the different cells of our forward feed
sample_length = Y.shape[-1]
R = torch.ones(Y.shape, device = Y.device)
num_of_parts = sample_length // 160
M = torch.zeros(num_of_parts,3,Y.shape[0],Y.shape[1],Y.shape[2]//4, device = Y.device)
z = torch.ones((Y.shape[0] ,num_of_parts, 128),device = Y.device)
for time in range(num_of_parts):
Y_i = Y[..., 160 *time :160 * (time+1)]
for id_num in range(3):
M_i_, R_i_, z_i_ = self.sys_pass((Y_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :]))
R[..., 160 *time :160 * (time+1)] = R_i_
z[:, id_num, :] = z_i_
M[time, id_num, ...] = M_i_
# M_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :] = self.sys_pass((Y_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :]))
return M ,R, z
def init_hidden(self, device):
return (torch.zeros((4,32,128)).to(device),
torch.zeros((4,32,128)).to(device))
with every batch, the allocated memory seems to enlarge to a point where training crashes.
ive tried deleting the tensors once I finished using them, but that doesn’t seem to help to much.
is there a way to limit, or preallocate the required memory?
or in general fix this issue?