train_set=Dataset(args.data_dir,parttern='train')
test_set=Dataset(args.data_dir,parttern='test')
train_loader=Dataloader(dataset=train_set, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=False)
test_loader=Dataloader(dataset=test_set, batch_size=args.batch_size, shuffle=True, num_workers=0, pin_memory=True, drop_last=True)
for i in tqdm(range(start_epoch+1,args.max_epoch+1)):
model.train()
#need to re-assign here
train_iter=iter(train_loader)
test_iter=iter(test_loader)
for origin,mask,inpaint in train_iter:
origin=origin.to(device)
mask=mask.to(device)
inpaint=inpaint.to(device)
result=model(origin, mask)
loss_dict=inpaint_crit(origin, mask, result, inpaint)
loss=0.0
for key,value in loss_dict.items():
loss+=loss_dict[key]
writer.add_scalar('loss_{:s}'.format(key),value,i)
optimizer.zero_grad()
loss.backward()
optimizer.step()