I have a network that incorporates into the loss function a cosine similarity on the data which works well when allowed to see the entire dataset on each pass, but when doing batches for large datasets the cosine similarity will look different every time. My question is: what would be the best way to do a moving average on the cosine similarity, so that on each pass it is a weighted average of the current and previous cosine similarity matrices?
I have seen resources that say to make a class variable for the loss and store the moving average as a buffer or attribute so its state persists across multiple calls. However it is still not clear to me whether we want retain these tensors or detach them when updating the moving average. If we do retain them, will this mean that each one will be stored in memory?
Would declaring a global variable be a better option for this purpose? I’m not sure how this would work or if it would interfere with backprop.
Here is my attempt at moving average which is very slow:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
class Loss(nn.Module):
def __init__(self, lambda_orth=10, moving_avg_alpha=0.1):
super(Loss, self).__init__()
self.mse = nn.MSELoss(reduction='sum')
self.lambda_orth = lambda_orth
self.moving_avg_alpha = moving_avg_alpha # Smoothing factor for EMA
self.cos_sim_moving_avg = None # Initialize the moving average variable
def update_moving_average(self, current_value):
"""Update the exponential moving average."""
# current_value = current_value.detach()
if self.cos_sim_moving_avg is None: # Initialize if not already done
self.cos_sim_moving_avg = current_value
else: # Update using EMA formula
self.cos_sim_moving_avg = (
self.moving_avg_alpha * current_value
+ (1 - self.moving_avg_alpha) * self.cos_sim_moving_avg
)
return self.cos_sim_moving_avg
def compute_orthogonality_loss(self, z, show_plot=False):
s = torch.mm(z.t(), z)
if show_plot:
plt.imshow(s.cpu().detach().numpy(),
aspect='auto', # Allow rectangular pixels
interpolation='none')
plt.show()
avg_cos_sim = self.update_moving_average(s)
idx0, idx1 = torch.triu_indices(avg_cos_sim.shape[0], avg_cos_sim.shape[1], offset=1) # indices of triu w/o diagonal
cos_sim = avg_cos_sim[idx0, idx1]
orth_loss = torch.mean(cos_sim.square())
return orth_loss
def forward(self, recon_x, x, z, show_plot=False):
MSE = self.mse(recon_x, x)
ORTH = self.compute_orthogonality_loss(z, show_plot=show_plot)
return MSE + ORTH * self.lambda_orth
class OrthAE(nn.Module):
def __init__(self, input_dim=3, latent_dim=3):
super(OrthAE, self).__init__()
# Encoder
self.encoder = nn.Linear(input_dim, latent_dim, bias=False)
# Decoder
self.decoder = nn.Linear(latent_dim, input_dim, bias=False)
#Loss function
self.loss_fn = Loss(lambda_orth=2, moving_avg_alpha=0.1)
def encode(self, x):
return self.encoder(x)
def decode(self, z):
return self.decoder(z)
def forward(self, x):
z = self.encode(x)
return self.decode(z), z
def train_model(self, data_loader, num_epochs=100, learning_rate=1e-3):
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.to(device)
optimizer = optim.Adam(self.parameters(), lr=learning_rate)
self.train()
for epoch in range(num_epochs):
total_loss = 0
for batch_idx, data in enumerate(data_loader):
data = data.to(device)
optimizer.zero_grad()
recon_batch, z = self(data)
#Only Show the plot every 2 epochs
show_plot = (epoch % 2 == 0 and batch_idx == 0)
loss = self.loss_fn(recon_batch, data, z, show_plot=show_plot)
loss.backward(retain_graph = True) #retain_graph = True
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(data_loader.dataset)
if epoch % 10 == 0:
print(f'Epoch [{epoch}/{num_epochs}] Loss: {avg_loss:.4f}')