Hello.
I am implementing a custom loss function including two parts, clustering and regularization. I created a simple dataset with two clusters in order to see how this loss function works.
I use kmeans_mod that make only one single iteration. Here is the plot of loss values for each parts, clustering and regularization. As you see, the clutering_loss value is increasing and regularization_loss value is decreasing and unstable! I have changed many parameters like the architecture of the network, learning rate, number of epochs, batch size and so on. But unfortunately I have not got a good result.
Moreover, to check if clustering part is working correctly, I set up a program that I apply kmeans_mod on dataset and compute clustering_loss for several iterations. I get a plot whish shows clustering_loss value is decreasing, but when I add regularization part to it and I am using this as a custom loss function of a network, it does not work!!
Would you please help me to fix this problem? If you need more information, please let me know.
class Autoencoder(nn.Module):
def init(self, n1=10, n2=8, n3=6, n4=4, n5=2):
super(Autoencoder, self).init()
self.nl = nn.ReLU()
self.n1 = n1
# encoder
self.enc1 = nn.Linear(n1, n2)
self.enc2 = nn.Linear(n2, n3)
self.enc3 = nn.Linear(n3, n4)
self.enc4 = nn.Linear(n4, n5)
# decoder
self.dec1 = nn.Linear(n5, n4)
self.dec2 = nn.Linear(n4, n3)
self.dec3 = nn.Linear(n3, n2)
self.dec4 = nn.Linear(n2, n1)
def forward(self, x):
x = self.nl(self.enc1(x))
x = self.nl(self.enc2(x))
x = self.nl(self.enc3(x))
y = self.nl(self.enc4(x))
x = self.nl(self.dec1(y))
x = self.nl(self.dec2(x))
x = self.nl(self.dec3(x))
x = self.nl(self.dec4(x))
return y, x
Training Part
def train_fn(model, device, train_loader, optimizer, epoch, log_interval, num_cluster, centroids):
batch_size = train_loader.batch_size
model.train()
loss_ = 0
for batch_idx, (data, _) in enumerate(train_loader):
data = data.to(device)
optimizer.zero_grad()
enc, dec = model(data)
cluster_ids_x, centroids = kmeans_mod(X=dec, num_clusters=num_cluster, centroid=centroids)
loss1 = loss_fn_clustering(dec, cluster_ids_x, centroids) # Clustering Part
loss1.requires_grad = True
near_ids = nearest(data) ## Getting the indices of nearest neighbor of each data point
loss2 = loss_fn_reg(data, dec, near_ids) # Regularization Part
loss = loss1+loss2 ## custom loss function
loss.backward()
optimizer.step()
loss_ += loss.item() * len(data)
# loss_ += loss.item()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
return centroids, cluster_ids_x, loss1, loss2, loss_/train_loader.dataset.__len__()
def main():
# Training settings
model = Autoencoder()
num_cluster = 2
batch_size = 128
num_epochs = 100
learning_rate = 0.0001
log_interval = 2
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
### Training set and Validation set ###
train_dataset = Twoclusters(1000, 1500, [15, 19])
val_dataset = Twoclusters(100, 150, [7, 4])
train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)
model.to(device) # load the neural network on to the device
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
loss_value_train = np.zeros((num_epochs))
loss_value_1 = np.zeros((num_epochs))
loss_value_2 = np.zeros((num_epochs))
loss_value_val = np.zeros((num_epochs))
for epoch in range(num_epochs):
if epoch == 0:
c, label,loss_1, loss_2, training_loss = train_fn(model, device, train_loader, optimizer, epoch, log_interval, num_cluster,
centroids=None)
c, validation_loss, decoded_data, label = val_fn(model, device, val_loader, num_cluster, centroids=None)
else:
c, label,loss_1,loss_2, training_loss = train_fn(model, device, train_loader, optimizer, epoch, log_interval, num_cluster,
centroids=c)
c, validation_loss, decoded_data, label = val_fn(model, device, val_loader, num_cluster, centroids=c)
loss_value_train[epoch] = training_loss
loss_value_1[epoch] = loss_1
loss_value_2[epoch] = loss_2
loss_value_val[epoch] = validation_loss
plt.figure(figsize=(10, 7))
plt.plot(loss_value_train, label='Training loss')
plt.plot(loss_value_val, label='Validation loss')
plt.plot(loss_value_1, label='loss1: Clustering')
plt.plot(loss_value_2, label='loss2: Regularization')
plt.legend()
plt.show()
if name == “main”:
main()