Hello,
My network uses a 2 value tensor to represent complex numbers. In an intermediate state, cast the numbers to a complex tensor for calculating an FFT and then moves them back to a 2 value tensor to continue processing. I’m getting a warning saying casting a complex numbers to float discards the imaginary part during backpropagation. I’m pretty sure I’m not doing any cast. Should I worry some information is getting lost, or can I safely ignore the warning?
The full trace of the warning is
---------------------------------------------------------------------------
UserWarning Traceback (most recent call last)
Input In [9], in <cell line: 8>()
5 model = UNet(len(TRs)*2, 3)
6 loss_dc = DataConsistencyLoss(TRs, TEs, kmask, N, FOV, device=device)
----> 8 trained_model, loss_evol, loss_gt, loss_dc = train_model(model, dataloader, loss_dc, torch.nn.MSELoss(), alpha=1.0, lr=0.001, epochs=500)
Input In [8], in train_model(nn, data, loss_dc, loss_gt, alpha, lr, epochs)
17 loss = lgt + alpha*ldc
18 optimizer.zero_grad()
---> 19 loss.backward(retain_graph=True)
21 optimizer.step()
22 tepoch.set_postfix(Loss=loss.item(), Loss_gt=lgt.item(), Loss_dc=ldc.item())
File ~/miniconda3/envs/mrf/lib/python3.8/site-packages/torch/tensor.py:221, in Tensor.backward(self, gradient, retain_graph, create_graph)
213 if type(self) is not Tensor and has_torch_function(relevant_args):
214 return handle_torch_function(
215 Tensor.backward,
216 relevant_args,
(...)
219 retain_graph=retain_graph,
220 create_graph=create_graph)
--> 221 torch.autograd.backward(self, gradient, retain_graph, create_graph)
File ~/miniconda3/envs/mrf/lib/python3.8/site-packages/torch/autograd/__init__.py:130, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
127 if retain_graph is None:
128 retain_graph = create_graph
--> 130 Variable._execution_engine.run_backward(
131 tensors, grad_tensors_, retain_graph, create_graph,
132 allow_unreachable=True)
UserWarning: Casting complex values to real discards the imaginary part (Triggered internally at /pytorch/aten/src/ATen/native/Copy.cpp:162.)
If it helps, I am using an old version of pytorch (1.7.1) because I can’t update cuda drivers in my universities shared cluster.