In my model, I have a critic and I use the Wasserstein distance with gradient penalty. Here is the loss function in the model class:
def wgan_gp_reg(self, x_real, x_fake, center=1., lambda_gp=10.0):
batch_size = x_real.shape[0]
eps = torch.rand(batch_size, 1, 1, device=self.device, dtype=x_real.dtype)
eps = eps.expand_as(x_real)
#eps = torch.randn_like(x_real).to(self.device)
x_interp = torch.autograd.Variable((eps * x_real + (1 - eps) * x_fake), requires_grad=True)
d_out = self.discriminator(x_interp)
gradients = torch.autograd.grad(inputs = x_interp,
outputs = d_out,
grad_outputs = torch.ones_like(d_out, device=self.device),
create_graph = True,
retain_graph = True,
)[0]
gradients = gradients.view(gradients.size(), -1)
gradient_penalty = (((gradients + 1e-16).norm(2, dim=1) - center) ** 2).mean() * lambda_gp
return gradient_penalty
I also use torch.autocast
and torch.cuda.amp.GradScaler
together during training
def train(epoch):
modelstate.model.train()
total_loss = 0
total_batches = 0
total_points = 0
if torch.cuda.is_available():
scaler = torch.cuda.amp.GradScaler()
for i, (u, y) in enumerate(loader_train):
u = u.to(device)
y = y.to(device)
modelstate.optimizer.zero_grad()
if torch.cuda.is_available():
with torch.autocast(device_type='cuda', dtype=torch.float32) and torch.backends.cudnn.flags(enabled=False):
loss_ = modelstate.model(u, y)
scaled_grad_params = torch.autograd.grad(outputs=scaler.scale(loss_),
inputs=modelstate.model.parameters(),
create_graph=True,
#retain_graph=True,
allow_unused=True #Whether to allow differentiation of unused parameters.
)
inv_scale = 1./scaler.get_scale()
grad_params = [ p * inv_scale if p is not None and not torch.isnan(p).any() else torch.tensor(0, device=device, dtype=torch.float32) for p in scaled_grad_params ]
with torch.autocast(device_type='cuda', dtype=torch.float32):
#grad_norm = torch.tensor(0, device=grad_params[0].device, dtype=grad_params[0].dtype)
grad_norm = 0
for grad in grad_params:
grad_norm += grad.pow(2).sum()
grad_norm = grad_norm**0.5
# Compute the L2 Norm as penalty and add that to loss
loss_ = loss_ + grad_norm
However, I got this error
→ 2827 gradient_penalty = self.wgan_gp_reg(input_feature, fake_input_feature)
2828 d_loss = -torch.mean(disc_real) + torch.mean(disc_fake) + gradient_penalty
2829 total_loss += d_loss/tmp/ipykernel_59758/571975632.py in wgan_gp_reg(self, x_real, x_fake, center, lambda_gp)
2691 d_out = self.discriminator(x_interp)
2692
→ 2693 gradients = torch.autograd.grad(inputs = x_interp,
2694 outputs = d_out,
2695 grad_outputs = torch.ones_like(d_out, device=self.device),~/anaconda3/lib/python3.9/site-packages/torch/autograd/init.py in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched)
298 return _vmap_internals.vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs)
299 else:
→ 300 return Variable.execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
301 t_outputs, grad_outputs, retain_graph, create_graph, t_inputs,
302 allow_unused, accumulate_grad=False) # Calls into the C++ engine to run the backward passRuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I tried the same critic with the given wasserstein gradient penalty module when I did not use torch.autocast
without any error. I am wondering whether this error related to this and how I can integrate both together?