Incorporate Moving Average Into Loss

I have a network that incorporates into the loss function a cosine similarity on the data which works well when allowed to see the entire dataset on each pass, but when doing batches for large datasets the cosine similarity will look different every time. My question is: what would be the best way to do a moving average on the cosine similarity, so that on each pass it is a weighted average of the current and previous cosine similarity matrices?

I have seen resources that say to make a class variable for the loss and store the moving average as a buffer or attribute so its state persists across multiple calls. However it is still not clear to me whether we want retain these tensors or detach them when updating the moving average. If we do retain them, will this mean that each one will be stored in memory?

Would declaring a global variable be a better option for this purpose? I’m not sure how this would work or if it would interfere with backprop.

Here is my attempt at moving average which is very slow:

import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt


class Loss(nn.Module):
    def __init__(self, lambda_orth=10, moving_avg_alpha=0.1):
        super(Loss, self).__init__()
        self.mse = nn.MSELoss(reduction='sum')
        self.lambda_orth = lambda_orth
        self.moving_avg_alpha = moving_avg_alpha  # Smoothing factor for EMA
        self.cos_sim_moving_avg = None  # Initialize the moving average variable

    def update_moving_average(self, current_value):
        """Update the exponential moving average."""
        # current_value = current_value.detach()

        if self.cos_sim_moving_avg is None:  # Initialize if not already done
            self.cos_sim_moving_avg = current_value
        else:  # Update using EMA formula
            self.cos_sim_moving_avg = (
                    self.moving_avg_alpha * current_value
                    + (1 - self.moving_avg_alpha) * self.cos_sim_moving_avg
            )

        return self.cos_sim_moving_avg

    def compute_orthogonality_loss(self, z, show_plot=False):
        s = torch.mm(z.t(), z)

        if show_plot:
            plt.imshow(s.cpu().detach().numpy(),
                aspect='auto',  # Allow rectangular pixels
                interpolation='none')
            plt.show()

        avg_cos_sim = self.update_moving_average(s)

        idx0, idx1 = torch.triu_indices(avg_cos_sim.shape[0], avg_cos_sim.shape[1], offset=1)  # indices of triu w/o diagonal
        cos_sim = avg_cos_sim[idx0, idx1]
        orth_loss = torch.mean(cos_sim.square())

        return orth_loss

    def forward(self, recon_x, x, z, show_plot=False):
        MSE = self.mse(recon_x, x)
        ORTH = self.compute_orthogonality_loss(z, show_plot=show_plot)
        return MSE + ORTH * self.lambda_orth


class OrthAE(nn.Module):
    def __init__(self, input_dim=3, latent_dim=3):
        super(OrthAE, self).__init__()

        # Encoder
        self.encoder = nn.Linear(input_dim, latent_dim, bias=False)
        # Decoder
        self.decoder = nn.Linear(latent_dim, input_dim, bias=False)
        #Loss function
        self.loss_fn = Loss(lambda_orth=2, moving_avg_alpha=0.1)

    def encode(self, x):
        return self.encoder(x)

    def decode(self, z):
        return self.decoder(z)

    def forward(self, x):
        z = self.encode(x)
        return self.decode(z), z

    def train_model(self, data_loader, num_epochs=100, learning_rate=1e-3):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.to(device)
        optimizer = optim.Adam(self.parameters(), lr=learning_rate)

        self.train()
        for epoch in range(num_epochs):
            total_loss = 0
            for batch_idx, data in enumerate(data_loader):
                data = data.to(device)
                optimizer.zero_grad()
                recon_batch, z = self(data)

                #Only Show the plot every 2 epochs
                show_plot = (epoch % 2 == 0 and batch_idx == 0)
                loss = self.loss_fn(recon_batch, data, z, show_plot=show_plot)

                loss.backward(retain_graph = True) #retain_graph = True
                optimizer.step()
                total_loss += loss.item()

            avg_loss = total_loss / len(data_loader.dataset)
            if epoch % 10 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Loss: {avg_loss:.4f}')