I tried to train my encoder with PatchGAN, I will calculate losses for outputs from every layer of the encoder. So this result in the memory leak? my training scripts is below
def main(args):
# loading the pretrained model
model_path = os.path.join("models", args.model_name)
encoder_path = os.path.join(model_path, "encoder.pth")
src_model = networks.ResnetEncoder(18, False)
loaded_dict_src = torch.load(encoder_path, map_location=device)
depth_decoder_path = os.path.join(model_path, "depth.pth")
feed_height = loaded_dict_src['height']
feed_width = loaded_dict_src['width']
# feed_height = 256
# feed_width = 512
filtered_dict_src = {k: v for k, v in loaded_dict_src.items() if k in src_model.state_dict()}
src_model.load_state_dict(filtered_dict_src)
src_model.to(device)
# load target model
target_model = networks.ResnetEncoder(18, False)
target_model.load_state_dict(filtered_dict_src)
target_model.to(device)
# load discriminators
discriminator = networks.Discriminator(input_nc=3).to(device)
discriminators = []
# load two datasets
day_path = args.day_path
night_path = args.night_path
half_batch = args.batch_size // 2
src_dataset = datasets.OxfordRAWDataset(day_path, feed_width, feed_height, transform=None)
src_dataloader = DataLoader(src_dataset, batch_size=half_batch, shuffle=True)
tgt_dataset = datasets.OxfordRAWDataset(night_path,feed_width, feed_height, transform=None)
tgt_dataloader = DataLoader(tgt_dataset, batch_size=half_batch, shuffle=True)
len_data_loader = min(len(src_dataloader), len(tgt_dataloader))
# define optimizers and criterion
# discriminator_optim = torch.optim.Adam(discriminator.parameters())
discriminator_optims = []
target_optim = torch.optim.Adam(target_model.parameters())
criterion = nn.BCEWithLogitsLoss()
target_model.train()
# discriminator.train()
# train
for epoch in range(1, args.epoch):
data_zip = enumerate(zip(src_dataloader, tgt_dataloader))
for step, (img_src, img_tgt) in data_zip:
#####################
# train discriminator
#####################
img_src = img_src.cuda()
img_tgt = img_tgt.cuda()
# extract features
src_features = src_model(img_src)
tgt_features = target_model(img_tgt)
len_features = len(src_features)
loss_disc = 0
# for every layer
for i, (src_feature, tgt_feature) in enumerate(zip(src_features, tgt_features)):
discriminators.append(networks.Discriminator(input_nc=src_feature.size()[1]).to(device))
# concat features, prepare real and fake labels
discriminator_x = torch.cat([src_feature, tgt_feature])
discriminator_y = torch.cat([torch.ones(img_src.size(0), device=device),
torch.zeros(img_tgt.size(0), device=device)])
discriminator_optims.append(torch.optim.Adam(discriminators[i].parameters()))
discriminators[i].train()
# predict on discrminator
preds = discriminators[i](discriminator_x)
preds_mean = torch.mean(preds, axis=[2,3]).squeeze()
# compute loss
loss_disc += criterion(preds_mean, discriminator_y)
# zero gradient
for i in range(len_features):
discriminator_optims[i].zero_grad()
loss_disc /= len_features
loss_disc.backward()
# optimize discriminator
for i in range(len_features):
discriminator_optims[i].step()
# total_loss += loss.item()
######################
# train target encoder
######################
# zero gradient
for i in range(len_features):
discriminator_optims[i].zero_grad()
target_optim.zero_grad()
# extract the features
tgt_features = target_model(img_tgt)
discriminator_y = torch.ones(img_tgt.size(0), device=device)
# prepare fake labels
loss_tgt = 0
for i, tgt_feature in enumerate(tgt_features):
# predicts on discriminator
preds = discriminators[i](tgt_feature)
# compute loss for target encoder
preds_mean = torch.mean(preds, axis=[2,3]).squeeze()
loss_tgt += criterion(preds_mean, discriminator_y)
loss_tgt /= len_features
loss_tgt.backward()
target_optim.step()
#################
# print step info
#################
if ((step+1) % 20 == 0):
print("Epoch [{}/{}] Step [{}/{}]: d_loss={:.5f}, g_loss={:.5f}"
.format(epoch+1,
args.epoch,
step+1,
len_data_loader,
loss_disc.item(),
loss_tgt.item()))
torch.cuda.empty_cache()
torch.save(target_model.state_dict(), 'models/adda_encoder.pt')
torch.save(discriminator.state_dict(), 'models/discriminator.pt')