Sorry for the late response but I was trying different things trying to fix my error but are still having troubles. I’m just going to post my entire code. If the loss function itself is not the issue it might be something else. Thank you so much for trying to help.
Also I did try it with changing the bits to make it similar to how you tested it, however the model parameters do not seem to change if I set img_inter to True for its requires_grad.
import torch
import gc
with torch.cuda.device('cuda:3'):
torch.cuda.empty_cache()
import os
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torchvision import transforms
from PIL import Image, ImageOps
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.models as models
from torch.autograd import Variable
class AgeDBDataset(Dataset):
def __init__(self, image, label, transform=None):
self.image = image
self.label = label
self.transform = transform
def __len__(self):
return(len(self.image))
def __getitem__(self, index):
if self.transform:
tran_image = self.transform(image[index])
else:
tran_image = self.image[index]
return(tran_image, self.label[index])
def DataOrganizer(directory):
y = []
dataset = os.scandir(directory)
for title in dataset:
label = title.name.split('_')
y.append(label[2])
y = np.array(y)
y = y.astype('float32')
y = torch.from_numpy(y)
X = torch.zeros((len(y),3,224,224))
count = 0
for img in os.listdir(directory):
trans = transforms.Compose([transforms.ToTensor(),
transforms.Resize((224,224))])
pic = Image.open(directory+img)
#pic = ImageOps.grayscale(pic)
pic = trans(pic)
if pic.size(0) == 1:
pic = torch.cat((pic,pic,pic),0)
X[count] = (pic)
count = count + 1
return(X, y)
class Regress_Loss(torch.nn.Module):
def __init__(self):
super(Regress_Loss,self).__init__()
def contrastive_loss(self,img_inter,labels_test,temp,batch_size,device):
img_sim = -torch.sqrt(torch.sum(torch.square((img_inter.unsqueeze(2)-img_inter.squeeze()).squeeze()),2))
dist = (labels_test.unsqueeze(2)-labels_test.squeeze()).squeeze()
A = img_sim
m = A.size(0)
s0,s1 = A.stride()
out = A.as_strided((m-1, m), (s0+s1, s1)).unfold(0, m-1, 1)
img_sim_num = out
processed_dist = dist.abs().flatten()
processed_sim = (-torch.exp(img_sim/temp)).flatten()
processed_sim_num = (-torch.exp(img_sim_num/temp)).flatten()
(sorted_dist, indices_d) = processed_dist.sort()
total_den = torch.sum(processed_sim)
index = 0
in_index = 0
subtractor = 0
denom = torch.zeros(dist.size(0)*dist.size(0)-dist.size(0)).to(device)
for distances in sorted_dist:
if (distances != 0 or index > np.sqrt(list(sorted_dist.size())[0])-1):
subtractor = subtractor + processed_sim[indices_d[index]]
inter = total_den - subtractor + processed_sim[indices_d[index]]
denom[in_index] = inter
in_index = in_index + 1
index = index + 1
loss = -torch.div(torch.sum(torch.log(torch.div(processed_sim_num,denom))),
torch.mul(batch_size,batch_size-1))
return loss
def get_activation(name):
def hook(model, input, output):
activation[name] = output.detach()
return hook
temp = 2
batch_size = 256
linearepoch = 100
encoderepoch = 400
device = "cuda:3" if torch.cuda.is_available() else "cpu"
transform_0 = transforms.Compose([transforms.RandomResizedCrop((224,224)),
transforms.RandomVerticalFlip(0.5),
transforms.RandomHorizontalFlip(0.5)])
transform_1 = transforms.Compose([transforms.ColorJitter()])
images, labels = DataOrganizer(r"/home/parkn1/Documents/AgeDB/")
images_0 = transform_0(images)
images_1 = transform_1(images)
aug_img = torch.cat((images_0, images_1),0)
aug_lab = torch.cat((labels,labels), 0)
dataset = AgeDBDataset(aug_img, aug_lab, transform=None)
train_set, validation_set, test_set = torch.utils.data.random_split(dataset, [250*2,125*2,125*2])
train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle = True)
validation_loader = DataLoader(dataset=validation_set, batch_size=batch_size, shuffle = True)
test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle = True)
resnet18 = models.resnet18(weights=True)
resnet18.to(device)
loss_class = Regress_Loss().to(device)
lr = 0.1
optimizer = torch.optim.SGD(params = resnet18.parameters(), lr = lr)
lambda1 = lambda epoch: epoch / 10
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, encoderepoch)
for epoch in range(encoderepoch):
for (inputs, labels) in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
resnet18.train()
layer = resnet18._modules.get('avgpool')
activation = {}
resnet18.avgpool.register_forward_hook(get_activation('avgpool'))
output = resnet18(inputs.squeeze())
img_inter = activation['avgpool'].squeeze(dim = 3).squeeze(dim = 2).unsqueeze(dim = 0)
print(img_inter.size())
print(labels.unsqueeze(0).size())
loss = loss_class.contrastive_loss(img_inter,
labels.unsqueeze(0)
,temp,batch_size,device)
optimizer.zero_grad()
a = list(resnet18.parameters())[0].clone()
loss.backward()
optimizer.step()
b = list(resnet18.parameters())[0].clone()
print(torch.equal(a.data,b.data))
#validation where it should go
scheduler.step()
print(loss)