Hi, Thank you very much for the reply.
I ran it a few times and did not observe a memory increase. It has been stable at around 9GB out of 11GB memory. The only thing is that it just suddenly runs out of memory halfway.
Please see below for the code snippet
This is the train script.
"""define some global variables"""
args = parser.parse_args()
args.train_folder = datetime.strftime(datetime.now(), '%Y%m%d-%H%M%S')
torch.manual_seed(20190331)
np.random.seed(20190331)
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = Resnet50(embedding_dim = args.emb_dim, pretrained=False).to(device)
optimizer = optim.Adam(model.parameters(), lr = args.learning_rate)
scheduler = lr_scheduler.StepLR(optimizer, step_size = 50, gamma = 0.1)
"""triplet loss fn is here because it will cause out of memory by building graphs repeatedly during loops"""
#p=2 for eucliedean loss, reduction = mean to calculate avg triplet loss
triplet_loss_fn = torch.nn.TripletMarginLoss(margin=args.margin, p=2, reduction = 'mean')
def main():
#if reload model to continue training
if args.model_weight_path:
#load checkpoint to get model, optimizer and start_epoch from checkpoint
checkpoint = torch.load(args.model_weight_path)
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
args.start_epoch = checkpoint['epoch']
else:
args.start_epoch = 0
for epoch in tqdm(range(args.start_epoch, args.num_epochs + args.start_epoch), desc = 'epochs'):
print(80 * '=')
print('Epoch [{}/{}]'.format(epoch, args.num_epochs + args.start_epoch - 1))
#get data loaders
data_loaders, data_size = get_dataloader(args.train_root_dir, args.valid_root_dir,
args.train_csv_name, args.valid_csv_name,
args.num_train_triplets,
args.batch_size, args.num_workers,
args.image_size, args.cropped_image_size)
print('data loaded for epoch {}'.format(epoch))
train_phase(model, optimizer, scheduler, epoch, data_loaders, data_size)
valid_phase(model, epoch, data_loaders, data_size)
def train_phase(model, optimizer, scheduler, epoch, dataloaders, data_size):
"""train phase"""
phase = 'train'
print('in phase train for epoch {}'.format(epoch))
triplet_loss_sum = 0.0
scheduler.step()
model.train()
for batch_idx, batch_sample in tqdm(enumerate(dataloaders['train']), desc ='batches'):
batch_size = batch_sample['anc_img'].shape[0] #current batch size
"""plot some anchor images in the first batch to analyse """
# for i in range(batch_size):
# print('anc_img_path: ', batch_sample['anc_img_path'][i])
# print('mean values: ', torch.mean(batch_sample['anc_img'][i]))
# print('std values: ', torch.std(batch_sample['anc_img'][i]))
# plt.imshow(np.transpose(batch_sample['anc_img'][i].numpy(), (1,2,0)))
# plt.savefig('img{}.jpg'.format(i))
# sys.exit()
anc_img = batch_sample['anc_img'].to(device)
pos_img = batch_sample['pos_img'].to(device)
neg_img = batch_sample['neg_img'].to(device)
# pos_class = batch_sample['pos_class']
# neg_class = batch_sample['neg_class']
# anc_img_path = batch_sample['anc_img_path']
# pos_img_path = batch_sample['pos_img_path']
# neg_img_path = batch_sample['neg_img_path']
# for i in range(batch_size):
# print('triplets are \n{}, \n{}, \n{}'.format(anc_img_path[i], pos_img_path[i], neg_img_path[i]))
# print('classes are {} and {}'.format(pos_class[i], neg_class[i]))
with torch.set_grad_enabled(True):
# anc_embed, pos_embed and neg_embed are encoding(embedding) of image
anc_embed, pos_embed, neg_embed = model(anc_img), model(pos_img), model(neg_img)
#calculates the averaged triplet loss per batch
batch_avg_triplet_loss = triplet_loss_fn(anc_embed, pos_embed, neg_embed)
if batch_avg_triplet_loss.item() == 0:
continue
optimizer.zero_grad()
batch_avg_triplet_loss.backward()
optimizer.step()
triplet_loss_sum += batch_avg_triplet_loss.item() * batch_size
print('loss for epoch {}, batch {}: {}'.format(epoch, batch_idx, batch_avg_triplet_loss.item()))
torch.cuda.empty_cache()
epoch_avg_triplet_loss = triplet_loss_sum / data_size[phase]
print('{} set - Average Triplet Loss for epoch {}: {}'.format(phase, epoch, epoch_avg_triplet_loss))
#make folder directory if it doesnt exist
if not os.path.isdir(args.train_folder):
os.mkdir(args.train_folder)
with open('{}/args.txt'.format(args.train_folder), 'a+') as f:
for key, value in iteritems(vars(args)):
f.write('%s: %s\n' % (key, str(value)))
with open('{}/train_loss.txt'.format(args.train_folder), 'a+') as f:
f.write('{}: {}\n'.format(str(epoch), str(round(epoch_avg_triplet_loss, 4))))
f.close()
torch.save({'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict()},
'{}/checkpoint_epoch_{}.pth'.format(args.train_folder, epoch))
def valid_phase(model, epoch, dataloaders, data_size):
print('in phase valid for epoch {}'.format(epoch))
model.eval()
full_emb = []
full_class = []
full_img_path = []
valid_size = data_size['valid']
matches = 0
for batch_idx, batch_sample in enumerate(dataloaders['valid']):
img = batch_sample['img'].to(device)
img_class = batch_sample['img_class'].to(device)
img_path = batch_sample['img_path']
with torch.set_grad_enabled(False):
img_embed = model(img)
full_emb.append(img_embed)
full_class.append(img_class)
full_img_path += img_path
torch.cuda.empty_cache()
full_emb = torch.cat(full_emb, 0)
full_class = torch.cat(full_class, 0)
pairwise_distance = pdist(full_emb, full_emb)
pairwise_distance[range(valid_size), range(valid_size)] = 1000000 * torch.diag(pairwise_distance)
closest_img_idx = torch.argmin(pairwise_distance, dim=1)
for idx in range(valid_size):
person_class = full_class[idx].item()
closest_img_class = closest_img_idx[idx].item()
print('full_emb: ', full_emb[idx,:10])
print('full_img_path: ', full_img_path[idx])
print('own class: ', person_class)
print('closest img class: ', closest_img_class)
if int(person_class) == int(closest_img_class):
matches += 1
accuracy = float(matches / valid_size)
print('accuracy for epoch {} = {}/{} = {}'.format(epoch, matches, valid_size, accuracy))
with open('{}/accuracy.txt'.format(args.train_folder), 'a+') as f:
f.write('{}: {}\n'.format(str(epoch), str(round(accuracy, 4))))
f.close()
if __name__ == '__main__':
main()
This is the model
class Resnet50(nn.Module):
def __init__(self, embedding_dim = 512, pretrained = False):
super(Resnet50, self).__init__()
self.embedding_dim = embedding_dim
self.resnet50 = models.resnet50(pretrained=pretrained)
self.linear = nn.Linear(self.resnet50.fc.in_features, embedding_dim)
self.resnet50.fc = self.linear
self.init_weights()
def init_weights(self):
self.linear.weight.data.normal_(0.0, 0.02)
self.linear.bias.data.fill_(0)
def l2_norm(self, input):
input_size = input.size()
buffer = torch.pow(input, 2)
normp = torch.sum(buffer, 1).add_(1e-10)
norm = torch.sqrt(normp)
_output = torch.div(input, norm.view(-1, 1).expand_as(input))
output = _output.view(input_size)
return output
def forward(self, images):
embed = self.resnet50(images)
norm_emb = self.l2_norm(embed)
return norm_emb
This is the loss
class TripletLoss(Function):
def __init__(self, margin):
super(TripletLoss, self).__init__()
self.margin = margin
self.pdist = PairwiseDistance(p=2)
def forward(self, anchor, positive, negative):
pos_dist = self.pdist.forward(anchor, positive)
neg_dist = self.pdist.forward(anchor, negative)
hinge_dist = torch.clamp(self.margin + pos_dist - neg_dist, min = 0.0)
loss = torch.mean(hinge_dist)
return loss
def pdist(emb1, emb2, eps=1e-6):
"""Compute the matrix of all squared pairwise distances.
refer to https://stackoverflow.com/questions/51986758/calculating-euclidian-norm-in-pytorch-trouble-understanding-an-implementation/52032471
Arguments
---------
emb1 : torch.Tensor or Variable
The first sample, should be of shape ``(n_1, d)``.
emb2 : torch.Tensor or Variable
The second sample, should be of shape ``(n_2, d)``.
norm : float
The l_p norm to be used.
Returns
-------
torch.Tensor or Variable
Matrix of shape (n_1, n_2). The [i, j]-th entry is equal to
``|| emb1[i, :] - emb2[j, :] ||_p``."""
n_1, n_2 = emb1.size(0), emb2.size(0)
norms_1 = torch.sum(emb1**2, dim=1, keepdim=True)
norms_2 = torch.sum(emb2**2, dim=1, keepdim=True)
norms = (norms_1.expand(n_1, n_2) +
norms_2.transpose(0, 1).expand(n_1, n_2))
distances_squared = norms - 2 * emb1.mm(emb2.t())
return torch.sqrt(eps + torch.abs(distances_squared))