ive been trying to train a BLSTM model, using the following:
from cmath import isnan from unicodedata import bidirectional import torch import torch.nn as nn import torch.nn.functional as F R_THRESHOLD = 20 class speech_lstm(nn.Module): def __init__(self): #first block - BLSTM and FC super(speech_lstm, self).__init__() self.blstm = nn.LSTM(257*2, 128, 2, batch_first = True, bidirectional = True) self.fclstm = nn.Linear(256,128, bias = False) #second - speaker adapt self.fc_speaker_adapt = nn.Linear(128,128, bias = False) #masking self.fc_mask = nn.Linear(128 , 257, bias = False) #speaker embedding self.fc_spk_1 = nn.Linear(128,50, bias = False) self.fc_spk_2 = nn.Linear(50,50, bias = False) self.fc_spk_3 = nn.Linear(50,128, bias = False) self.fc_affine = nn.Linear(128,128) #gate self.fc_gate = nn.Linear(128,128, bias = False) self.fc_norm = nn.Sequential(nn.Linear(160,70),nn.Linear(70,1)) def sys_pass(self, x): # changed forward to sys_pass Y, R, z = x #BLSTM lstm_out, self.hidden = self.blstm(torch.cat((Y,R),1).permute(0,2,1)) lstm_out = torch.sigmoid(self.fclstm(lstm_out)) #SPEAKER ADAPT speaker_adapt = lstm_out * torch.sigmoid(self.fc_speaker_adapt(z).unsqueeze(1)) #MASK - OUTPUT M_out = torch.sigmoid(self.fc_mask(speaker_adapt)).permute(0,2,1) #SPK. EMBEDDING speaker_embedding_temp = torch.sigmoid(self.fc_spk_2(torch.sigmoid(self.fc_spk_1(speaker_adapt)))) speaker_embedding = self.fc_affine(torch.mean(torch.sigmoid(self.fc_spk_3(speaker_embedding_temp)), dim = 1)) #GATE - OUTPUT gate_temp = torch.sigmoid(self.fc_gate(speaker_adapt)) pre_norm = gate_temp * speaker_embedding.unsqueeze(1) z_out = F.normalize(self.fc_norm(pre_norm.permute(0,2,1)).squeeze(-1) + z , dim = 1) #RES-MASK - OUTPUT R_out = torch.max(R-M_out, torch.zeros(R.shape,device=R.device)) # need something that converges -normalize return M_out, R_out, z_out def forward(self, Y): # this is meant to create the different cells of our forward feed sample_length = Y.shape[-1] R = torch.ones(Y.shape, device = Y.device) num_of_parts = sample_length // 160 M = torch.zeros(num_of_parts,3,Y.shape,Y.shape,Y.shape//4, device = Y.device) z = torch.ones((Y.shape ,num_of_parts, 128),device = Y.device) for time in range(num_of_parts): Y_i = Y[..., 160 *time :160 * (time+1)] for id_num in range(3): M_i_, R_i_, z_i_ = self.sys_pass((Y_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :])) R[..., 160 *time :160 * (time+1)] = R_i_ z[:, id_num, :] = z_i_ M[time, id_num, ...] = M_i_ # M_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :] = self.sys_pass((Y_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :])) return M ,R, z def init_hidden(self, device): return (torch.zeros((4,32,128)).to(device), torch.zeros((4,32,128)).to(device))
with every batch, the allocated memory seems to enlarge to a point where training crashes.
ive tried deleting the tensors once I finished using them, but that doesn’t seem to help to much.
is there a way to limit, or preallocate the required memory?
or in general fix this issue?