Hello everybody! This neural network training code snippet uses the fftn function from the Jax library so that I can use mixed precision, as I cannot use mixed precision using PyTorch’s fftn function as there is restriction for that (input size must be power of 2, and in my case it is not). The same structure is used for the validation part and for the test.

I ask: because I don’t use PyTorch’s fftn is there a problem in my code regarding gradient calculations? If so, is this something serious that affects my network performance?

Thanks in advance!

```
for idx_epoch, epoch in enumerate(t):
print()
list_train_loss = []
list_train_loss_l1 = []
list_train_loss_l1_rel = []
list_train_l1_max = []
list_train_loss_fft_l1 = []
list_train_loss_fft_l1_rel = []
list_train_fft_l1_max = []
list_validation_loss = []
list_validation_loss_l1 = []
list_validation_loss_l1_rel = []
list_validation_l1_max = []
list_validation_loss_fft_l1 = []
list_validation_loss_fft_l1_rel = []
list_validation_fft_l1_max = []
list_test_loss = []
list_test_loss_l1 = []
list_test_loss_l1_rel = []
list_test_l1_max = []
list_test_loss_fft_l1 = []
list_test_loss_fft_l1_rel = []
list_test_fft_l1_max = []
#l1_distribution_validation = []
#l1_distribution_test = []
#model.train()
net.train()
scaler = torch.cuda.amp.GradScaler()
import jax
#os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION']='.50'
import time
for seq in seq_train:
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=True):
X, y = seq
X = X.cuda()
y = y.cuda()
optimizer.zero_grad()
#output = model(X)
output = net(X)
#ffty = torch.fft.rfftn(y.float(), dim=(-1, -2, -3))
#fftout = torch.fft.rfftn(output.float(), dim=(-1, -2, -3))
#ffty = torch.fft.rfftn(y, dim=(-1, -2, -3))
#fftout = torch.fft.rfftn(output, dim=(-1, -2, -3))
ffty = jax.numpy.fft.fftn(y.detach().cpu().numpy(), axes=(-1,-2,-3))
ffty = torch.from_numpy(np.array(ffty))
fftout = jax.numpy.fft.fftn(output.detach().cpu().numpy(), axes=(-1,-2,-3))
fftout = torch.from_numpy(np.array(fftout))
#ffty = jax.numpy.fft.rfft(y.detach().cpu().numpy(), 1)
#ffty = torch.from_numpy(np.array(ffty))
#fftout = jax.numpy.fft.rfft(output.detach().cpu().numpy(), 1)
#fftout = torch.from_numpy(np.array(fftout))
#ffty = torch.fft.rfft(y, 1)
#fftout = torch.fft.rfft(output, 1)
fft_res = (fftout - ffty)
fft_res_abs = torch.abs(fft_res)
loss_fft_l1 = torch.mean(fft_res_abs)
loss_fft_l1_rel = torch.mean(fft_res_abs/(torch.abs(ffty) + 0.01))
fft_l1_max = torch.max(fft_res_abs)
res = (output - y)
res_abs = torch.abs(res)
loss_l1 = torch.mean(res_abs)
loss_l1_rel = torch.mean(res_abs/(torch.abs(y)+0.01))
l1_max = torch.max(res_abs)
loss = loss_l1_rel + loss_fft_l1_rel
#loss.backward()
scaler.scale(loss).backward()
#optimizer.step()
scaler.step(optimizer)
scaler.update()
list_train_loss.append(loss.cpu().item())
list_train_loss_l1.append(loss_l1.cpu().item())
list_train_loss_l1_rel.append(loss_l1_rel.cpu().item())
list_train_l1_max.append(l1_max.cpu().item())
list_train_loss_fft_l1.append(loss_fft_l1.cpu().item())
list_train_loss_fft_l1_rel.append(loss_fft_l1_rel.cpu().item())
list_train_fft_l1_max.append(fft_l1_max.cpu().item())
```