I am trying to reproduct AvatarCLIP in pytorch.
And I got the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [1, 6890]], which is output 0 of SelectBackward, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
It would be very appreciate if you help me.
def get_pose(self, text_feature: Tensor) -> Tensor:
latent_code = nn.Parameter(torch.randn(32))
cls = getattr(torch.optim, self.optim_name)
optimizer = cls([latent_code], **self.optim_cfg)
for i in tqdm(range(self.num_iteration)):
new_latent_code = latent_code.to(self.device).unsqueeze(0)
new_pose = self.vp.decode(new_latent_code)['pose_body']
new_pose = new_pose.contiguous().view(-1)
clip_feature = self.get_pose_feature(new_pose).squeeze(0)
loss = 1 - F.cosine_similarity(clip_feature, text_feature, dim=-1)
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return pose_padding(new_pose.detach())
def get_topk_poses(self, text: str) -> Tensor:
text_feature = self.get_text_feature(text)
poses = [self.get_pose(text_feature) for _ in range(self.topk)]
poses = self.sort_poses_by_score(text, poses)
poses = torch.stack(poses, dim=0)
return poses
def get_pose(self, text_feature: Tensor) -> Tensor:
pose = nn.Parameter(torch.randn(63))
cls = getattr(torch.optim, self.optim_name)
optimizer = cls([pose], **self.optim_cfg)
for i in tqdm(range(self.num_iteration)):
new_pose = pose.to(self.device)
clip_feature = self.get_pose_feature(new_pose).squeeze(0)
loss = 1 - F.cosine_similarity(clip_feature, text_feature, dim=-1)
loss = loss.mean()
optimizer.zero_grad()
loss.backward()
optimizer.step()
return pose_padding(pose.data).to(self.device)
def get_topk_poses(self, text: str) -> Tensor:
text_feature = self.get_text_feature(text)
poses = [self.get_pose(text_feature) for _ in range(self.topk)]
poses = self.sort_poses_by_score(text, poses)
poses = torch.stack(poses, dim=0)
return poses