Out of memory error in pytorch

I am getting MPS out of memory error while using pytorch in the following code. Could any one please help me in resolving the error ?

Here is the code:

# Import necessary PyTorch libraries
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchvision
import torchvision.transforms as transforms

# Set the device to MPS if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Additional libraries for visualization and utilities
import matplotlib.pyplot as plt
import numpy as np

# Import the adapted Echo noise functions
from echo import echo_sample, echo_loss

from torchvision import datasets
from torch.utils.data import DataLoader, random_split

# Define transformations: Resize if needed and normalize the data
transform = transforms.Compose([
    # transforms.Resize((28, 28)), # Uncomment if resizing is needed
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# Load the MNIST dataset
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)

# Splitting dataset into training and validation sets
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoader instances
train_loader = DataLoader(train_dataset, batch_size=100, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=100, shuffle=False)

print("Data loaders created for training and validation.")

class Encoder(nn.Module):
    def __init__(self, input_shape, latent_dims):
        super(Encoder, self).__init__()
        self.input_shape = input_shape
        self.latent_dims = latent_dims
        
        self.conv1 = nn.Conv2d(input_shape[0], latent_dims[0], kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv2d(latent_dims[0], latent_dims[1], kernel_size=5, stride=2, padding=2)
        self.conv3 = nn.Conv2d(latent_dims[1], latent_dims[2], kernel_size=5, stride=1, padding=2)
        self.conv4 = nn.Conv2d(latent_dims[2], latent_dims[3], kernel_size=5, stride=2, padding=2)
        self.conv5 = nn.Conv2d(latent_dims[3], latent_dims[4], kernel_size=7, stride=1, padding=0)
        
        self.fc_mean = nn.Linear(latent_dims[4] * 1 * 1, latent_dims[5])
        self.fc_log_var = nn.Linear(latent_dims[4] * 1 * 1, latent_dims[5])
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        
        x = x.view(x.size(0), -1)
        
        f_x = torch.tanh(self.fc_mean(x))
        log_var = self.fc_log_var(x)
        
        return f_x, log_varca

class Decoder(nn.Module):
    def __init__(self, latent_dims, output_shape):
        super(Decoder, self).__init__()
        self.latent_dims = latent_dims
        self.output_shape = output_shape
        
        self.conv1 = nn.Conv2d(1, latent_dims[0], kernel_size=3, stride=1, padding=1)
        self.conv2 = nn.Conv2d(latent_dims[0], latent_dims[1], kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(latent_dims[1], latent_dims[2], kernel_size=3, stride=1, padding=1)
        self.conv4 = nn.Conv2d(latent_dims[2], latent_dims[3], kernel_size=3, stride=2, padding=1)
        self.conv5 = nn.Conv2d(latent_dims[3], latent_dims[4], kernel_size=3, stride=1, padding=1)
        
        self.deconv1 = nn.ConvTranspose2d(latent_dims[4], latent_dims[3], kernel_size=3, stride=1, padding=1)
        self.deconv2 = nn.ConvTranspose2d(latent_dims[3], latent_dims[2], kernel_size=4, stride=2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(latent_dims[2], latent_dims[1], kernel_size=3, stride=1, padding=1)
        self.deconv4 = nn.ConvTranspose2d(latent_dims[1], latent_dims[0], kernel_size=4, stride=2, padding=1)
        self.deconv5 = nn.ConvTranspose2d(latent_dims[0], output_shape[0], kernel_size=3, stride=1, padding=1)
        
    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.relu(self.conv2(x))
        x = torch.relu(self.conv3(x))
        x = torch.relu(self.conv4(x))
        x = torch.relu(self.conv5(x))
        
        x = torch.relu(self.deconv1(x))
        x = torch.relu(self.deconv2(x))
        x = torch.relu(self.deconv3(x))
        x = torch.relu(self.deconv4(x))
        x = torch.sigmoid(self.deconv5(x))
        
        return x

class EchoModel(nn.Module):
    def __init__(self, input_shape, latent_dims, output_shape, T=1000, batch_size=100):
        super(EchoModel, self).__init__()
        self.input_shape = input_shape
        self.latent_dims = latent_dims
        self.output_shape = output_shape
        self.T = T
        self.batch_size = batch_size  # Add batch_size as an attribute of EchoModel
        
        self.encoder = Encoder(input_shape, latent_dims)
        self.decoder = Decoder(latent_dims, output_shape)
        
        # Define the noise schedule
        self.alpha = self.create_noise_schedule(T)
        
    def create_noise_schedule(self, T):
        alpha = torch.linspace(0.9999, 1e-5, T)
        return alpha
    
    def forward(self, x):
        f_x, log_var = self.encoder(x)
        
        # Convert log-variance to diagonal elements of S(x)
        diagonal_sx = torch.exp(log_var)
        
        # Create the full square matrix representation of S(x)
        sx_matrix = torch.diag_embed(diagonal_sx)

        print(f"Shape of fx: {f_x.shape}")
        print(f"Shape of Sx: {sx_matrix.shape}")
        # Generate the noise variable z using echo_sample
        z = echo_sample([f_x, sx_matrix], d_max = 5, batch_size=self.batch_size)  # Pass batch_size explicitly
        print(f"Shape of output: {z.shape}")
        # Perform the reconstruction process using Algorithm 2
        reconstructed_x = self.reconstruct(x, z, f_x, sx_matrix)
        
        return reconstructed_x, f_x, sx_matrix
    
    def reconstruct(self, x_t, z, f_x, sx_matrix):
        x_s = x_t
        for s in range(self.T-1, 0, -1):
            sqrt_alpha_s = torch.sqrt(self.alpha[s])
            sqrt_one_minus_alpha_s = torch.sqrt(1 - self.alpha[s])
            
            # Estimate the original image using the decoder
            x_0_hat = self.decoder(x_s)
            
            # Calculate the estimated noise using Eq. (3)
            z_hat = (x_s - sqrt_alpha_s * x_0_hat) / sqrt_one_minus_alpha_s
            
            # Calculate D(x_0_hat, s) and D(x_0_hat, s-1) using Eq. (5) and (6)
            D_x_0_hat_s = sqrt_alpha_s * x_0_hat + sqrt_one_minus_alpha_s * z_hat
            D_x_0_hat_s_minus_1 = torch.sqrt(self.alpha[s-1]) * x_0_hat + torch.sqrt(1 - self.alpha[s-1]) * z_hat
            
            # Update x_s using Eq. (7)
            x_s = x_s - D_x_0_hat_s + D_x_0_hat_s_minus_1
        
        return x_s

def train(model, optimizer, train_loader, device, num_epochs, loss_weights):
    model.train()
    
    for epoch in range(num_epochs):
        epoch_loss = 0.0
        
        for batch_idx, (data, _) in enumerate(train_loader):
            data = data.to(device)
            
            optimizer.zero_grad()
            
            reconstructed_x, f_x, sx_matrix = model(data)
            
            # Calculate the reconstruction loss (L1 or L2 norm)
            reconstruction_loss = nn.functional.l1_loss(reconstructed_x, data)
            
            # Calculate the mutual information penalty using echo_loss
            mi_penalty = echo_loss([f_x, sx_matrix])
            
            # Calculate the total loss as a weighted sum of reconstruction loss and MI penalty
            total_loss = loss_weights['reconstruction'] * reconstruction_loss + loss_weights['mi_penalty'] * mi_penalty
            
            total_loss.backward()
            optimizer.step()
            
            epoch_loss += total_loss.item()
        
        # Print the average loss for the epoch
        avg_loss = epoch_loss / len(train_loader)
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}")
    
    return model

# Define the input shape, latent dimensions, and output shape
input_shape = (1, 28, 28)  # Example shape for MNIST
latent_dims = [32, 32, 64, 64, 256, 32]  # Latent dimensions from echo.json
output_shape = (1, 28, 28)  # Example shape for MNIST

# Create an instance of the EchoModel
model = EchoModel(input_shape, latent_dims, output_shape).to(device)

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Define the number of epochs and loss weights
num_epochs = 100
loss_weights = {'reconstruction': 1.0, 'mi_penalty': 0.0}  # Adjust the weights as needed

# Train the model
trained_model = train(model, optimizer, train_loader, device, num_epochs, loss_weights)

Here is the code of the echo.py file that is imported in the notebook:

import numpy as np
import torch
import torch.nn.functional as F


def random_indices(n, d):
    # Generates a 1D tensor of random integers from 0 to n-1, with total n*d elements.
    return torch.randint(low=0, high=n, size=(n * d,), dtype=torch.int32)


def gather_nd_reshape(t, indices):
    # Get the shape of the input tensor t
    t_shape = t.shape

    # Convert the indices tensor to long type
    indices = indices.long()

    # Calculate the flat indices based on the provided indices tensor
    strides = torch.tensor(
        [np.prod(t_shape[i + 1 :]) for i in range(len(t_shape) - 1)] + [1]
    )
    flat_indices = (indices * strides[-indices.shape[-1] :]).sum(dim=-1)

    # Gather the values from the flattened tensor
    gathered = t.view(-1)[flat_indices]

    return gathered


def indices_without_replacement(batch_size, d_max=-1, replace=False, pop=True):
    if d_max < 0:
        d_max = batch_size + d_max

    inds = torch.empty((0, d_max, 2), dtype=torch.long)
    for i in range(batch_size):
        batch_range = torch.arange(batch_size)
        if pop:
            batch_range = batch_range[batch_range != i]
        shuffled_indices = torch.randperm(batch_range.size(0))[:d_max]
        dmax_range = torch.arange(d_max)

        dmax_enumerated = torch.stack(
            (dmax_range, batch_range[shuffled_indices]), dim=1
        )
        inds = torch.cat((inds, dmax_enumerated.unsqueeze(0)), dim=0)


def permute_neighbor_indices(batch_size, d_max=-1, replace=False, pop=True):
    if d_max < 0:
        d_max = batch_size + d_max

    inds = []
    for i in range(batch_size):
        if pop:
            # Exclude the current sample if pop is True
            sub_batch = torch.cat((torch.arange(i), torch.arange(i + 1, batch_size)))
        else:
            sub_batch = torch.arange(batch_size)

        if replace:
            # Select d_max elements with replacement
            selected_indices = torch.multinomial(
                torch.ones(sub_batch.shape), num_samples=d_max, replacement=True
            )
            selected_indices = sub_batch[selected_indices]
        else:
            # Shuffle the sub_batch and select the first d_max elements
            selected_indices = sub_batch[torch.randperm(sub_batch.shape[0])[:d_max]]

        # Pair each selected index with its position in the batch
        dmax_range = torch.arange(d_max)
        dmax_enumerated = torch.stack((dmax_range, selected_indices), dim=1)
        inds.append(dmax_enumerated)

    # Stack the indices from all batches into a single tensor
    inds_tensor = torch.stack(inds, dim=0)
    return inds_tensor


def echo_sample(
    inputs,
    clip=None,
    d_max=100,
    batch_size=100,
    multiplicative=False,
    echo_mc=False,
    replace=False,
    fx_clip=None,
    plus_sx=True,
    calc_log=True,
    set_batch=True,
    return_noise=False,
):
    if isinstance(inputs, list):
        fx = inputs[0]
        sx = inputs[1]
    else:
        fx = inputs
        sx = None

    if clip is None:
        max_fx = fx_clip if fx_clip is not None else 1.0
        clip = (2 ** (-23) / max_fx) ** (1.0 / d_max)

    clip = torch.tensor(clip, dtype=fx.dtype, device=fx.device)

    if fx_clip is not None:
        fx = torch.clamp(fx, -fx_clip, fx_clip)

    if sx is not None:
        if not calc_log:
            sx = clip * sx
            sx = torch.where(
                torch.abs(sx) < torch.finfo(sx.dtype).eps,
                torch.sign(sx) * torch.finfo(sx.dtype).eps,
                sx,
            )
        else:
            sx = torch.log(clip) + (-sx if not plus_sx else sx)
    else:
        sx = torch.zeros_like(fx)

    if echo_mc:
        fx = fx - fx.mean(dim=0, keepdim=True)

    if replace:
        sx = sx.view(sx.size(0), -1) if len(sx.shape) > 2 else sx
        fx = fx.view(fx.size(0), -1) if len(fx.shape) > 2 else fx

        inds = torch.randint(
            0, batch_size, (batch_size * d_max,), dtype=torch.long, device=fx.device
        )
        inds = inds.view(-1, 1)

        select_sx = gather_nd_reshape(sx, inds).view(batch_size, d_max, -1)
        select_fx = gather_nd_reshape(fx, inds).view(batch_size, d_max)

        if len(sx.shape) > 2:
            select_sx = select_sx.unsqueeze(2).unsqueeze(2)
            sx = sx.unsqueeze(1).unsqueeze(1)

        if len(fx.shape) > 2:
            select_fx = select_fx.unsqueeze(2).unsqueeze(2)
            fx = fx.unsqueeze(1).unsqueeze(1)
    else:
        repeat_fx = torch.ones_like(fx.unsqueeze(0)) * torch.ones_like(fx.unsqueeze(1))
        stack_fx = fx * repeat_fx

        repeat_sx = torch.ones_like(sx.unsqueeze(0)) * torch.ones_like(sx.unsqueeze(1))
        stack_sx = sx * repeat_sx

        if not set_batch:
            inds = indices_without_replacement(batch_size, d_max)
        else:
            inds = permute_neighbor_indices(batch_size, d_max, replace=replace)

        select_sx = gather_nd_reshape(stack_sx, inds).view(batch_size, d_max, -1)
        select_fx = gather_nd_reshape(stack_fx, inds).view(batch_size, d_max, -1)

    if calc_log:
        sx_echoes = torch.cumsum(select_sx, dim=1)
    else:
        sx_echoes = torch.cumprod(select_sx, dim=1)

    sx_echoes = torch.exp(sx_echoes) if calc_log else sx_echoes

    fx_sx_echoes = select_fx * sx_echoes
    noise = torch.sum(fx_sx_echoes, dim=1)

    if sx is not None:
        sx = sx if not calc_log else torch.exp(sx)

    noise_expanded = noise.view(batch_size, -1)
    noise_expanded = noise_expanded.expand(batch_size, sx.size(1))

    if multiplicative:
        output = torch.exp(
            fx + torch.matmul(sx, noise_expanded.view(batch_size, -1, 1)).squeeze(-1)
        )
    else:
        output = fx + torch.matmul(sx, noise_expanded.view(batch_size, -1, 1)).squeeze(
            -1
        )

    return output if not return_noise else noise


def echo_loss(inputs, clip=0.8359, calc_log=True, plus_sx=True, multiplicative=False):
    if isinstance(inputs, list):
        z_mean = inputs[0]
        z_scale = inputs[-1]
    else:
        z_scale = inputs

    clip_tensor = torch.tensor(clip, dtype=z_scale.dtype, device=z_scale.device)

    if not calc_log:
        z_scale_clipped = clip_tensor * z_scale
        z_scale_clipped = torch.where(
            torch.abs(z_scale_clipped) < 1e-7,
            torch.sign(z_scale_clipped) * 1e-7,
            z_scale_clipped,
        )
        mi = -torch.log(torch.abs(z_scale_clipped) + 1e-7)
    else:
        if plus_sx:
            mi = -(torch.log(clip_tensor) + z_scale)
        else:
            mi = -(torch.log(clip_tensor) - z_scale)

    print("MI matrix: ", mi)

    averaged_loss = calculate_avg_log_det(mi)

    return averaged_loss


def calculate_avg_log_det(sx_matrices, eps=1e-6):
    # Add a small diagonal value to ensure positive definiteness
    sx_matrices = sx_matrices + eps * torch.eye(
        sx_matrices.size(-1), dtype=sx_matrices.dtype, device=sx_matrices.device
    )

    # Calculate the negative log-determinant for each Sx matrix in the batch
    log_det_sx = -torch.logdet(sx_matrices)

    # Replace any infinite or nan values with zero
    log_det_sx = torch.where(
        torch.isinf(log_det_sx) | torch.isnan(log_det_sx),
        torch.zeros_like(log_det_sx),
        log_det_sx,
    )

    # Average the negative log-determinants across the batch
    avg_log_det_sx = torch.mean(log_det_sx)

    return avg_log_det_sx