def PointNet_Net(self, fitted_params, images, landmarks, image_masks, occ_dict, savefolder, model_name='full', shape_gt=None):
"""
Train PointNetMesh to adjust the coarse mesh vertices based on input features.
"""
# Ensure models and optimizer are initialized
if not hasattr(self, 'pointnet'):
self.pointnet = PointNetMesh(input_dim=fitted_params['vertices'].shape[2], num_points=fitted_params['vertices'].shape[1]).to('cuda:0')
if not hasattr(self, 'optimizer_pointnet'):
self.optimizer_pointnet = torch.optim.Adam(list(self.pointnet.parameters()), lr=1e-3)
for iter in range(2): # Set the desired number of iterations
cam = fitted_params['cam']
albedo = fitted_params['albedo'].expand(self.num_samples, -1, -1, -1)
light = fitted_params['light'].expand(self.num_samples, -1, -1)
pose = fitted_params['full_pose'].expand(self.num_samples, -1, -1)
images = images.clone().expand(self.num_samples, -1, -1, -1)
shape_in = fitted_params['shape']
batch_size = shape_in.shape[0]
images_occ = images * image_masks
gt_landmarks = landmarks
# Expand shape_gt if available
if shape_gt is not None:
shape_gt = shape_gt.expand(self.num_samples, -1, -1)
else:
shape_gt = shape_in
# Obtain features from PointNet
global_features, point_features, displacement = self.pointnet(shape_in)
# Update vertices with displacement
vertices = shape_in + displacement
shape_out = vertices.expand(self.num_samples, -1, -1)
print(shape_out.shape,'verticess_sfhhtuttiiiepepepepep')
# Compute losses
landmarks2d, landmarks3d = self.flame.get_landmarks(shape_out, pose)
trans_verts = utils.batch_orth_proj(shape_out, cam)
trans_verts[..., 1:] = -trans_verts[..., 1:]
lmk_loss = utils.l2_distance(landmarks2d[:, :, :2], gt_landmarks[:, :, :2], conf=occ_dict['landmarks_valid'])
ops = self.render(shape_out, trans_verts, albedo, light)
pho_loss = (image_masks[0:1] * (ops['images'] - images[0:1]).abs()).mean()
# Train PointNetMesh
self.pointnet.train()
# Zero gradients
self.optimizer_pointnet.zero_grad()
# Compute total loss and backpropagate
loss_network = pho_loss * self.w_pho + lmk_loss * self.w_lmk
loss_network.backward()
self.optimizer_pointnet.step()
# Logging
print(f"Iteration for pointnet {iter}, Pho Loss: {pho_loss.item()}, Lmk Loss: {lmk_loss.item()}, Total Loss: {loss_network.item()}")
# Update fitted_params with adjusted shape
fitted_params['shape'] = vertices.detach()
fitted_params['point_features'] = point_features
def GCNN_Net(self, fitted_params, images, landmarks, image_masks, occ_dict, savefolder, model_name='full',
shape_gt=None, image_features=None, hidden_layers=None, output_dim=None):
"""
Train GCNN to refine mesh vertices based on fused features and adjacency matrix.
"""
# Compute adjacency matrix and refine once
faces = fitted_params['faces']
edges = get_mesh_edges(faces)
adj_matrix = edges_to_adjacency(edges, fitted_params['vertices'].shape[1]).to('cuda:0')
visibility_ind = torch.nonzero(fitted_params['vis_mask'][0], as_tuple=False).squeeze()
keep_indices = torch.tensor([i for i in range(adj_matrix.size(0)) if i in visibility_ind]).to('cuda:0')
refined_adj_matrix = adj_matrix[keep_indices][:, keep_indices].to('cuda:0')
if not hasattr(self, 'gcnn'):
self.gcnn = Net(
A=refined_adj_matrix,
nfeat=fitted_params['point_features'].shape[-1] + image_features.shape[1],
nhid=int(hidden_layers), # Convert to int
nout=int(output_dim) # Convert to int
).to('cuda:0')
if not hasattr(self, 'optimizer_gcnn'):
self.optimizer_gcnn = torch.optim.Adam(self.gcnn.parameters(), lr=1e-3)
for iter in range(2): # Set the desired number of iterations
cam = fitted_params['cam']
albedo = fitted_params['albedo'].expand(self.num_samples, -1, -1, -1)
light = fitted_params['light'].expand(self.num_samples, -1, -1)
pose = fitted_params['full_pose'].expand(self.num_samples, -1, -1)
images = images.clone().expand(self.num_samples, -1, -1, -1)
vertices = fitted_params['vertices']
gt_landmarks = landmarks
# Fuse features
proj_vertices = utils.batch_orth_proj(vertices, cam)
proj_vertices[..., 1:] = -proj_vertices[..., 1:]
proj_vertices_norm = proj_vertices[:, :, :-1].unsqueeze(2)
extracted_features = F.grid_sample(
image_features, proj_vertices_norm, mode='bilinear', align_corners=True
).squeeze(3).permute(0, 2, 1) # Shape: [B, N, C]
fused_features = torch.cat([fitted_params['point_features'], extracted_features], dim=-1) # Detach to prevent reusing old graph
# Process visible vertices
vis_fused_features = fused_features[:, fitted_params['vis_mask'].squeeze(0), :].to('cuda:0')
vis_vertices = vertices[:, fitted_params['vis_mask'].squeeze(0), :].to('cuda:0')
# Forward pass through GCNN
residual_vertices = self.gcnn(vis_fused_features[0])
refined_vertices = residual_vertices + vis_vertices
# Update vertices
vertices_copy = vertices.clone()
vertices_copy[:, fitted_params['vis_mask'].squeeze(0), :] = refined_vertices.view(vertices.shape[0], -1, 3)
shape_out_gcnn = vertices_copy.expand(self.num_samples, -1, -1) # Detach to prevent reuse
# Compute losses
landmarks2d, landmarks3d = self.flame.get_landmarks(shape_out_gcnn, pose)
trans_verts = utils.batch_orth_proj(shape_out_gcnn, cam)
trans_verts[..., 1:] = -trans_verts[..., 1:]
lmk_loss = utils.l2_distance(landmarks2d[:, :, :2], gt_landmarks[:, :, :2], conf=occ_dict['landmarks_valid'])
ops = self.render(shape_out_gcnn, trans_verts, albedo, light)
pho_loss = (image_masks[0:1] * (ops['images'] - images[0:1]).abs()).mean()
# Train GCNN
self.gcnn.train()
self.optimizer_gcnn.zero_grad()
loss_gcnn = pho_loss * self.w_pho + lmk_loss * self.w_lmk
loss_gcnn.backward() # Remove retain_graph=True
self.optimizer_gcnn.step()
# Logging
print(f"Iteration {iter}, Pho Loss: {pho_loss.item()}, Lmk Loss: {lmk_loss.item()}, Total Loss: {loss_gcnn.item()}")
# Update fitted_params with adjusted shape
fitted_params['shape'] = vertices_copy.detach()
these are my models, and i am running them back to back using
self.PointNet_Net(fitted_params, images, landmarks, image_masks, occ_dict, savefolder, model_name='full', shape_gt=None)
self.GCNN_Net(fitted_params, images, landmarks, image_masks, occ_dict, savefolder, model_name='full',shape_gt=None, image_features=image_features, hidden_layers=8., output_dim=3.)
I am getting the following error,
._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.