CUDA out of memory issue for BLSTM model

Hello,
ive been trying to train a BLSTM model, using the following:

from cmath import isnan
from unicodedata import bidirectional
import torch
import torch.nn as nn
import torch.nn.functional as F

R_THRESHOLD = 20

class speech_lstm(nn.Module):
    def __init__(self):
        #first block - BLSTM and FC
        super(speech_lstm, self).__init__()
        self.blstm = nn.LSTM(257*2, 128, 2, batch_first = True, bidirectional = True)
        self.fclstm = nn.Linear(256,128, bias = False)

        #second - speaker adapt
        self.fc_speaker_adapt = nn.Linear(128,128, bias = False)

        #masking
        self.fc_mask = nn.Linear(128 , 257, bias = False)

        #speaker embedding
        self.fc_spk_1 = nn.Linear(128,50, bias = False)
        self.fc_spk_2 = nn.Linear(50,50, bias = False)
        self.fc_spk_3 = nn.Linear(50,128, bias = False)

        self.fc_affine = nn.Linear(128,128) 

        #gate
        self.fc_gate = nn.Linear(128,128, bias = False)
        self.fc_norm = nn.Sequential(nn.Linear(160,70),nn.Linear(70,1))



    def sys_pass(self, x): # changed forward to sys_pass
        Y, R, z = x
        #BLSTM
        lstm_out, self.hidden = self.blstm(torch.cat((Y,R),1).permute(0,2,1))
        lstm_out =  torch.sigmoid(self.fclstm(lstm_out))

        #SPEAKER ADAPT
        speaker_adapt = lstm_out * torch.sigmoid(self.fc_speaker_adapt(z).unsqueeze(1))

        #MASK - OUTPUT
        M_out = torch.sigmoid(self.fc_mask(speaker_adapt)).permute(0,2,1)

        #SPK. EMBEDDING
        speaker_embedding_temp = torch.sigmoid(self.fc_spk_2(torch.sigmoid(self.fc_spk_1(speaker_adapt))))
        speaker_embedding = self.fc_affine(torch.mean(torch.sigmoid(self.fc_spk_3(speaker_embedding_temp)), dim = 1))

        #GATE - OUTPUT
        gate_temp = torch.sigmoid(self.fc_gate(speaker_adapt))
        pre_norm = gate_temp * speaker_embedding.unsqueeze(1)
        z_out = F.normalize(self.fc_norm(pre_norm.permute(0,2,1)).squeeze(-1) + z , dim = 1)

        #RES-MASK - OUTPUT
        R_out = torch.max(R-M_out, torch.zeros(R.shape,device=R.device)) # need something that converges -normalize
        return M_out, R_out, z_out


    def forward(self, Y): # this is meant to create the different cells of our forward feed
        sample_length = Y.shape[-1]
        R = torch.ones(Y.shape, device = Y.device)
        num_of_parts = sample_length // 160
        M = torch.zeros(num_of_parts,3,Y.shape[0],Y.shape[1],Y.shape[2]//4, device = Y.device)
        z = torch.ones((Y.shape[0] ,num_of_parts, 128),device = Y.device)
        for time in range(num_of_parts):
            Y_i = Y[..., 160 *time :160 * (time+1)]
            for id_num in range(3):
                M_i_, R_i_, z_i_ = self.sys_pass((Y_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :]))
                R[..., 160 *time :160 * (time+1)] = R_i_
                z[:, id_num, :] = z_i_
                M[time, id_num, ...] = M_i_
                # M_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :] = self.sys_pass((Y_i, R[..., 160 *time :160 * (time+1)], z[:, id_num, :]))
        return M ,R, z
        
    def init_hidden(self, device):
        return (torch.zeros((4,32,128)).to(device),
                torch.zeros((4,32,128)).to(device))

with every batch, the allocated memory seems to enlarge to a point where training crashes.
ive tried deleting the tensors once I finished using them, but that doesn’t seem to help to much.

is there a way to limit, or preallocate the required memory?
or in general fix this issue?

.detach() seems to have solved the problem!
I used it quite randomaly, both in the train loop and in the model itself, but hey - it works!
making sure to del Tensor after every iteration of the forward loop may have also helped.

Hey, Adam!
Good for you! I also have similar problem.
Can you let me know that the usage of .detach()?
Where should I put it??
Thx!

Im sorry, my solution was misleading.
do not use detach randomally
I used it in order to remove the grad on the train objects (a.k.a - features)
ill keep trying to minimize the computation graphs - which most likely are your problem
good luck - feel free to keep me posted!