class Net(nn.Module):
def __init__(self,network):
super(Net,self).__init__()
self.features0 = nn.Sequential(
network.model.conv1,
network.model.bn1,
network.model.relu,
network.model.maxpool
)
self.features1 = nn.Sequential(network.model.layer1)
self.features2 = nn.Sequential(network.model.layer2)
self.features3 = nn.Sequential(network.model.layer3)
self.features4 = nn.Sequential(network.model.layer4)
def forward(self, x):
x= self.features0(x)
layer1 = self.features1(x)
layer2 = self.features2(layer1)
layer3 = self.features3(layer2)
layer4 = self.features4(layer3)
a = torch.cat((F.interpolate(layer4,[8,8],mode='bilinear',align_corners=False),layer3),dim=1)
b= torch.cat((F.interpolate(a,[16,16],mode='bilinear',align_corners=False),layer2),dim=1)
c = torch.cat((F.interpolate(b,[32,32],mode='bilinear',align_corners=False),layer1),dim=1)
d = F.interpolate(c,[128,128],mode='bilinear',align_corners=False)
return d
def FE(d,landmarks,batch_size):
visual_feature = F.grid_sample(d,landmarks.unsqueeze(2),mode='bilinear',padding_mode='zeros',align_corners=False)[:,:,:,0]
visual_feature = visual_feature.permute(0,2,1)
init_landmark = landmarks[:,None,:,:]-landmarks[:,:,None,:]
shape_feature = init_landmark.reshape(batch_size, landmarks.shape[1],-1)
return torch.cat([visual_feature,shape_feature],-1)
class GIN(nn.Module):
def __init__(self,in_ch,out_ch,num_landmark=3):
super(GIN, self).__init__()
self.adj = torch.nn.Parameter((torch.ones((num_landmark, num_landmark), requires_grad=True) / num_landmark), requires_grad=True)
self.linear1 = nn.Linear(in_ch,in_ch)
self.relu1 = nn.LeakyReLU(inplace=True)
self.linear2 = nn.Linear(in_ch, out_ch)
def forward(self, node_feat):
nd = node_feat
message = torch.matmul(self.adj, nd)+nd
x = self.linear1(message)
x= self.relu1(x)
x= self.linear2(x)
return x
class GCN(nn.Module):
def __init__(self,in_ch=966,out_ch=9,batch_size=8):
super(GCN,self).__init__()
self.net = Net(Network())
self.gin1 = GIN(in_ch,in_ch)
self.gin2 = GIN(in_ch,int(in_ch/2))
self.gin3 = GIN(int(in_ch/2),int(in_ch/2))
self.lin = nn.Linear(in_ch+int(in_ch/2)+int(in_ch/2),out_ch)
def forward(self, image, init_land):
prediction = self.net(image)
predictions = FE(prediction, init_land, batch_size)
layer0=self.gin1(predictions)
layer1=self.gin2(layer0)
layer2=self.gin3(layer1)
layer0_sum = torch.sum(layer0, dim=1)
layer1_sum = torch.sum(layer1, dim=1)
layer2_sum = torch.sum(layer2, dim=1)
cat = torch.cat([layer0_sum, layer1_sum, layer2_sum], -1)
out = self.lin(cat)
return out, prediction
class GIN_local(nn.Module):
def __init__(self, steps=3, h_dim=966, num_landmarks=3, batch_size = 8):
super(GIN_local,self).__init__()
self.steps = steps
for step in range(self.steps):
if step == 0:
self.gnns = nn.ModuleList(
[GIN(h_dim, 2)])
else:
self.gnns.append(GIN(h_dim, 2))
def forward(self,d,x):
for step in range(self.steps):
shift=self.gnns[step](FE(d, x, batch_size))
updated_landmarks = x + shift
y = updated_landmarks
return y
The model that I made above
def make_perspective_transform(global_t, init_land):
global_t = global_t.view(batch_size,3,3)
m = nn.ConstantPad1d((0, 1),1)
init_land_1 = m(init_land)
init_land_2 = init_land_1.permute(0,2,1)
p_land = torch.matmul(global_t,init_land_2.cuda()) # [ rx ry r]t
p_land = p_land.permute(0,2,1)
p_land[:, :, 0] = p_land[:, :, 0] / p_land[:, :, 2]
p_land[:, :, 1] = p_land[:, :, 1] / p_land[:, :, 2]
pt_land = p_land[:,:,:2]
return pt_land
#gcn_in_ch = 966
gcn = GCN().cuda()
gin_local = GIN_local().cuda()
torch.autograd.set_detect_anomaly(True)
criterion = nn.MSELoss()
optimizer_global = optim.AdamW(gcn.parameters(), lr=3e-4,weight_decay=0.00001)
optimizer_local = optim.AdamW(gin_local.parameters(),lr=3e-4,weight_decay=0.00001)
#scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=optimizer_global, patience=3,threshold=e-10,factor=0.9)
loss_min = np.inf
num_epochs = 200
start_time = time.time()
loss_train_list =[]
loss_valid_list=[]
train_init_land = torch.tensor([]).type(torch.cuda.FloatTensor)
for epoch in range(1,num_epochs+1):
loss_train = 0
loss_valid = 0
running_loss = 0
Network().train()
gcn.train()
gin_local.train()
for step in range(1,len(train_loader)+1):
images, landmarks, init_land = next(iter(train_loader))
images = images.cuda()
landmarks = landmarks.reshape(landmarks.shape[0],-1).cuda()
init_land = init_land.cuda()
if step == 1:
train_init_land = torch.cat([train_init_land,init_land])
global_t, prediction = gcn(images,init_land) #(batch_size,9) / extracted feature map
pt_land = make_perspective_transform(global_t, init_land)
loss_global = criterion(pt_land.reshape(pt_land.shape[0],-1), landmarks)
loss_global.backward()
optimizer_global.step()
outputs = gin_local(prediction, pt_land)
outputs = outputs.reshape(outputs.shape[0],-1)
# clear all the gradients before calculating them
optimizer_local.zero_grad()
# find the loss for the current step
loss_local = criterion(outputs, landmarks)
# calculate the gradients
loss_local.backward()
# update the parameters
optimizer_local.step()
loss_train_step = loss_global + loss_local
loss_train += loss_train_step.item()
running_loss = loss_train/step
print_overwrite(step, len(train_loader), running_loss, 'train')
Network().eval()
gcn.eval()
net.eval()
gin_local.eval()
with torch.no_grad():
for step in range(1,len(valid_loader)+1):
images, landmarks= next(iter(valid_loader))
images = images.cuda()
landmarks = landmarks.reshape(landmarks.shape[0],-1).cuda()
train_init_land = train_init_land.cuda()
global_t, prediction = gcn(images, train_init_land) #(batch_size,9) / extracted feature map
pt_land = make_perspective_transform(global_t, train_init_land)
loss_global = criterion(pt_land.reshape(pt_land.shape[0],-1), landmarks)
outputs = gin_local(prediction, pt_land)
outputs = outputs.reshape(outputs.shape[0],-1)
# find the loss for the current step
loss_local = criterion(outputs, landmarks)
loss_valid_step = loss_global + loss_local
loss_valid += loss_valid_step.item()
running_loss = loss_valid/step
print_overwrite(step, len(valid_loader), running_loss, 'valid')
loss_train /= len(train_loader)
loss_valid /= len(valid_loader)
#scheduler.step(loss_valid)
print('\n--------------------------------------------------')
print('Epoch: {} Train Loss: {:.4f} Valid Loss: {:.4f}'.format(epoch, loss_train, loss_valid))
print('--------------------------------------------------')
loss_train_list.append(loss_train)
loss_valid_list.append(loss_valid)
plt.plot(loss_train_list,label='train loss')
plt.plot(loss_valid_list,label='valid loss')
plt.legend(loc='upper right')
plt.title('epoch: %d '%(epoch))
plt.pause(.0001)
if loss_valid < loss_min:
loss_min = loss_valid
torch.save(network.state_dict(), '/content/drive/MyDrive/IXIDB_axial_model_save/210511_gnn_param.pth')
print("\nMinimum Validation Loss of {:.4f} at epoch {}/{}".format(loss_min, epoch, num_epochs))
print('Model Saved\n')
print('Training Complete')
print("Total Elapsed Time : {} s".format(time.time()-start_time))
This is my training cell but I got RunTimeError on loss_global.backward()
I made code based on this, https://arxiv.org/pdf/2004.08190.pdf
please let me know how to solve it