Memory leak with multiprocessing

Hi, I am running into a memory leak when I try to run model inference in parallel using pythons multiprocessing library. I am using a recurrent model of which I got the code a few years ago, I am trying to speed up the inference using parallel processing.
Here is what I do:
I have a dataset which is just a list of numpy arrays. This dataset I pass to the model doing

    worker_with_model = partial(GetModelPrediction, model=model)

    t0 = time.time()
    with multiprocessing.Pool(processes=num_workers, maxtasksperchild=1) as pool:
        predictions = pool.map(worker_with_model, dataset, chunksize=1)
def GetModelPrediction(file, model):
    pred = GetPrediction(model, torch.from_numpy(file["EP"]))
    return pred
def GetPrediction(model, EP):
    if model:
        if (type(model).__name__ == "FRAE"): # this is always true in my current case
            with torch.no_grad():
                pred = model.evaluate(EP)

        else:
            pred = model(EP)

    pred = pred.detach().numpy()
    pred = postProcess(pred)
    pred = pred.T
    return pred

My model looks like this:

class FRAE(nn.Module): # feedback recurrent autoencoder
    def __init__(self, recurrent_timesteps=1, neurons_per_layer=[22, 30, 14, 5], hidden_size=22, quant=None):
        super().__init__()
        
        self.hidden_size = hidden_size * recurrent_timesteps
        self.recurrent_timesteps = recurrent_timesteps
        self.neurons_per_layer = neurons_per_layer
        self.quant = quant
        self.new_hidden = self.initHidden()
        
        # build layers according to list neurons_per_layer
        # encoder
        layers_enc = []
        for i in range(len(neurons_per_layer) - 1):

            # first layer takes input and hidden vector
            if(i == 0):
                layers_enc.append(nn.Linear(neurons_per_layer[i] + self.hidden_size, neurons_per_layer[i + 1]))
            else:
                layers_enc.append(nn.Linear(neurons_per_layer[i], neurons_per_layer[i + 1]))
            layers_enc.append(nn.SiLU())
            
        # decoder
        neurons_per_layer_rev = neurons_per_layer[::-1]
        layers_dec = []
        for i in range(len(neurons_per_layer) - 1):
            # first layer takes input and hidden vector
            if(i==0):
                layers_dec.append(nn.Linear(neurons_per_layer_rev[i] + self.hidden_size, neurons_per_layer_rev[i + 1]))
                layers_dec.append(nn.SiLU())
            else:
                layers_dec.append(nn.Linear(neurons_per_layer_rev[i], neurons_per_layer_rev[i + 1]))
                if(i == len(neurons_per_layer) - 2):
                    layers_dec.append(nn.Tanh())
                else:
                    layers_dec.append(nn.SiLU())
        
        self.encoder = nn.Sequential(*layers_enc)
        self.decoder = nn.Sequential(*layers_dec)   
    
    def forward(self, x, hidden):
        combined_enc = torch.cat((x, hidden))
        encoded = self.encoder(combined_enc)

        if self.quant:
            encoded = encoded.detach().numpy()
            encoded = np.reshape(encoded, (1, self.neurons_per_layer[-1]))
            encoded = self.quant.encode(encoded)
            encoded = self.quant.decode(encoded)
            encoded = torch.from_numpy(encoded)
            encoded = torch.reshape(encoded, (-1,))
        
        combined_dec = torch.cat((encoded, hidden))
        decoded = self.decoder(combined_dec)

        new_hidden = self.updateHidden(decoded, hidden)
        return decoded, new_hidden
    
    def getFlattenedWeights(self):
        # returns flattened tensor of all weights of the model
        if self.quant:
            weights_m = torch.cat(list(torch.flatten(p.data) for p in self.parameters())).double()
            weights_q = torch.from_numpy(self.quant.getFlattenedWeights()).detach().clone()
            return torch.cat((weights_m, weights_q))
        return torch.cat(list(torch.flatten(p.data) for p in self.parameters())).double()

    def setWeights(self, weights):
        # sets the weights according to the input tensor
        # assumes weights tensor and model has same number of weights
        # indices are there to keep track which part of the weights tensor belongs to which param
        if self.quant:
            number_weights_m = torch.numel(torch.cat(list(torch.flatten(p.data) for p in self.parameters())))
            weights_m = weights[:number_weights_m]
            weights_q = weights[number_weights_m:]
            self.quant.setWeights(weights_q.detach().clone().numpy())
        else:
            weights_m = weights
        i_start = 0
        i_stop = 0
        for param in self.parameters():
            i_stop += param.numel()
            param_size = param.size()
            param_data = torch.reshape(weights_m[i_start : i_stop], param_size)
            param.data = nn.parameter.Parameter(param_data)
            i_start += param.numel()
    
    def initHidden(self):
        return torch.zeros(self.hidden_size)

    def updateHidden(self, last_decoded, hidden):
        if self.recurrent_timesteps == 1:
            new_hidden = last_decoded
        else:
            hid_size = self.hidden_size // self.recurrent_timesteps
            new_hidden = torch.zeros_like(hidden)
            new_hidden[hid_size:] = hidden[:self.hidden_size - hid_size]
            new_hidden[:hid_size] = last_decoded

        return new_hidden

    def evaluate(self, inputs):
        recon = torch.zeros(inputs.size())
        new_hidden = self.initHidden()
        for i in range(inputs.size()[0]):
            recon_i, new_hidden = self.forward(inputs[i], new_hidden)
            recon[i] = recon_i
        return recon
    
    def encode(self, inputs):
        latent_dim = self.neurons_per_layer[-1]
        encoded = torch.zeros(inputs.size(dim=0), latent_dim)
        hidden = self.initHidden()
        for i in range(inputs.size()[0]):
            combined = torch.cat((inputs[i], hidden))
            encoded_i = self.encoder(combined)
            encoded[i] = encoded_i
        return encoded

I see a steady increase in runtime for the prediction part of my code and eventually the code crashes due to an OOM error. I am using python 3.8 and pytorch 2.4.1. I work with linux.
I tried calling gc.collect(), torch.cuda.empty_cache() but to no avail.

The model I am using is really small, maybe 5000 parameters, so it is faster to run it on CPU. Is it not possible use multiprocessing to predict the individual samples in parallel? It would come quite in handy to easily speed up the code I have.

Why does this leak occur? Does anyone know a solution?

Anyone got an idea? I am open for any suggestion.