# Custom loss function detached from the graph

Hello,
following is my code with network combining CNN and a fully-connected net. This is a physics-informed neural network. So, I defined custom loss functions including a regression loss function and e PDE loss function. I’m getting None value for the model grads after and before backward step. I believe this is due to how I defined the loss function. I would appreciate it if anyone could help me with this issue.

``````
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import warnings

class CNN_FC(nn.Module):
def __init__(self, in_features=2, out_features=3, nf=13,
activation=torch.nn.Tanh, cnn_activation=torch.nn.ReLU):
"""Initialization

Args:
in_channels: int, number of input channels.
neck_channels: int, number of channels in bottleneck layer.
out_channels: int, number of output channels.
final_relu: bool, add relu to the last layer.
"""
super(CNN_FC, self).__init__()
self.nf = nf
self.in_features = in_features
self.out_features = out_features
self.activ = activation()
self.cnn_activ = cnn_activation()

self.conv_in = nn.Conv3d(self.in_features, self.nf, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='valid')
self.conv11 = nn.Conv3d(self.nf, self.nf*2, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.conv12 = nn.Conv3d(self.nf*2, self.nf*3, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.conv13 = nn.Conv3d(self.nf*3, self.nf*6, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.conv14 = nn.Conv3d(self.nf*6, self.nf*13, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.convs = [self.conv_in, self.conv11, self.conv12, self.conv13, self.conv14]

self.convs = nn.ModuleList(self.convs)

self.maxpool11 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
self.maxpool12 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
self.maxpool13 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
self.maxpool14 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(1, 1, 2), padding=0, dilation=1, ceil_mode=False)
self.maxpools = [self.maxpool11, self.maxpool12, self.maxpool13, self.maxpool14]
self.maxpools = nn.ModuleList(self.maxpools)

self.flatten1 = nn.Flatten()
self.flatten = [self.flatten1]
self.flatten = nn.ModuleList(self.flatten)

self.fc0 = nn.Linear(679, nf*64)
self.fc1 = nn.Linear(nf*64 , nf*32)
self.fc2 = nn.Linear(nf*32 , nf*16)
self.fc3 = nn.Linear(nf*16 , nf*8)
self.fc4 = nn.Linear(nf*8 , nf*4)
self.fc5 = nn.Linear(nf*4, out_features)
self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]
self.fc = nn.ModuleList(self.fc)

def forward(self, c, t, y, x):

# first entry of x should be batch size!  x = torch.randn(batch_size,numpoints,1)
# c is the output of convolutional prtion
c = self.conv_in(c)
c = self.cnn_activ(c)
c = self.conv11(c)
c = self.cnn_activ(c)
c = self.maxpool11(c)
c = self.conv12(c)
c = self.cnn_activ(c)
c = self.maxpool12(c)
c = self.conv13(c)
c = self.cnn_activ(c)
c = self.maxpool13(c)
c = self.conv14(c)
c = self.cnn_activ(c)
c = self.maxpool14(c)
c = self.flatten1(c)
# print(c.shape)

c = c.unsqueeze(1)
c = c.repeat(1,int(x.shape[1]),1)

x_tmp = torch.cat((c, t, y, x), dim=-1)
x_tmp = self.fc0(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc1(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc2(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc3(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc4(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc5(x_tmp)

return x_tmp

``````
``````import os
import torch
from torch.utils.data import Dataset, Sampler
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy import ndimage
import warnings

# in __init__ load lres data (entire data, generated in matlab)
#

class cylinder_dataset(Dataset):

def __init__(self, data_dir="./", data_filename="CFD_DATA_space_avg_before_cyl.mat", center_data_filename="CFD_DATA_centers_before_cyl.mat",
nx=26, ny=26, nt=3, n_samp_pts_per_crop=2028, normalize_output=True):

self.nt = nt
self.nx = nx
self.ny = ny
self.data_dir = data_dir
self.data_filename = data_filename
self.center_data_filename = center_data_filename
self.n_samp_pts_per_crop = n_samp_pts_per_crop
self.normalize_output = normalize_output

################## load low res data (training data)
self.data = np.stack([npdata['u_space_avg'], npdata['v_space_avg']], axis=0)    # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
self.data = self.data.astype(np.float32)
#self.data = self.data.transpose(0, 3, 2, 1)  # [c, t, y, x] # c is number of channels
nc_data, nt_data, ny_data, nx_data = self.data.shape
# if highres: nc = 2, nt = 1501, ny = 80, nx = 640
# if lowres (by factor of 5): nc=2, nt=1501, ny=16, nx=128

self.nx_start_range = np.arange(0, nx_data-nx+1)    # 0 to 560 fo high res, 0 to 112 for low res
self.ny_start_range = np.arange(0, ny_data-ny+1)    # returns an array shape (1,) ([0])
self.nt_start_range = np.arange(0, nt_data-nt+1)    # 0 to 1498
self.rand_grid = np.stack(np.meshgrid(self.nt_start_range,
self.ny_start_range,
self.nx_start_range, indexing='ij'), axis=-1)

self.rand_start_id = self.rand_grid.reshape([-1, 3])

self.num_samples = self.rand_start_id.shape[0]

################## load center data (fidelity data)
self.center_data = np.stack([center_npdata['u_space_center'], center_npdata['v_space_center']], axis=0)    # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
self.center_data = self.center_data.astype(np.float32)

# compute channel-wise mean and std
self._mean = np.mean(self.data, axis=(1, 2, 3))
self._std = np.std(self.data, axis=(1, 2, 3))

def __getitem__(self, index):
t_id, y_id, x_id = self.rand_start_id[index]    # idx is the id of the crop of the data that is passed to Dataloader.
space_time_crop = self.data[:,
t_id:t_id+self.nt,
y_id:y_id+self.ny,
x_id:x_id+self.nx]  # [c, t, y, x] c is the channel

space_time_crop = np.transpose(space_time_crop, (0, 2, 3, 1))

center_space_time_crop = self.center_data[:,
t_id:t_id+self.nt,
y_id:y_id+self.ny,
x_id:x_id+self.nx]  # [c, t, y, x] c is the channel

center_space_time_crop = np.transpose(center_space_time_crop, (0, 2, 3, 1)) # [c, y, x, t]
center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(3, 1, 2, 0), (-1, 2))   # [t,y,x,c]

lres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
np.linspace(0, self.ny-1, self.ny),
np.linspace(0, self.nx-1, self.nx),
indexing='ij'), axis=-1)
lres_coord = np.reshape(lres_coord, (self.nt*self.nx*self.ny,3))
# lres_coord returns an array that includes coordinates of points such that, [0,:,:,:] is the coordinates of the first slice (t=0) that marches first on x (index 2) then on y (index 1)

hres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
np.linspace(0, self.ny*3-1, self.ny*3),
np.linspace(0, self.nx*3-1, self.nx*3),
indexing='ij'), axis=-1)

hres_coord = np.reshape(hres_coord, (self.nt*self.nx*3*self.ny*3,3))

pde_mesh_coord = np.stack(np.meshgrid(np.random.rand(10)*[self.nt-1],
np.random.rand(300)*[self.ny-1],
np.random.rand(300)*[self.nx-1],
indexing='ij'), axis=-1)
# creates an array including 10 slices in time with 300*300 random points in y and x

pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3) * (np.array([3, 26, 26], dtype=np.int32) - 1)
# return an array with three columns such that first column is 3000 randon number between zero and nt-1, second column is 3000 points between zero and ny, third is between zero and nx

if self.normalize_output:
space_time_crop = self.normalize_grid(space_time_crop)
center_space_time_crop = self.normalize_grid(center_space_time_crop)
center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(3, 1, 2, 0), (-1, 2))
pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3)
lres_coord = lres_coord / (np.array([3, 26, 26], dtype=np.int32) - 1)
hres_coord = hres_coord / (np.array([3, 78, 78], dtype=np.int32) - 1)

return_tensors = [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord]

# cast everything to float32
return_tensors = [t.astype(np.float32) for t in return_tensors]

return tuple(return_tensors)

def __len__(self):
return self.rand_start_id.shape[0]

@property
def channel_mean(self):
"""channel-wise mean of dataset."""
return self._mean

@property
def channel_std(self):
"""channel-wise mean of dataset."""
return self._std

@staticmethod
def _normalize_array(array, mean, std):
"""normalize array (np or torch)."""
if isinstance(array, torch.Tensor):
dev = array.device
std = torch.tensor(std, device=dev)
mean = torch.tensor(mean, device=dev)
return (array - mean) / std

@staticmethod
def _denormalize_array(array, mean, std):
"""normalize array (np or torch)."""
if isinstance(array, torch.Tensor):
dev = array.device
std = torch.tensor(std, device=dev)
mean = torch.tensor(mean, device=dev)
return array * std + mean

def normalize_grid(self, grid):
"""Normalize grid.
Args:
grid: np array or torch tensor of shape [2, ...], 2 are the num. of phys channels.
Returns:
channel normalized grid of same shape as input.
"""
# reshape mean and std to be broadcastable.
g_dim = len(grid.shape)
mean_bc = self.channel_mean[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
std_bc = self.channel_std[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
return self._normalize_array(grid, mean_bc, std_bc)

def denormalize_grid(self, grid):
"""Denormalize grid.
Args:
grid: np array or torch tensor of shape [2, ...], 2 are the num. of phys channels.
Returns:
channel denormalized grid of same shape as input.
"""
# reshape mean and std to be broadcastable.
g_dim = len(grid.shape)
mean_bc = self.channel_mean[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
std_bc = self.channel_std[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
return self._denormalize_array(grid, mean_bc, std_bc)

``````
``````import torch.nn.functional as F
#===============================================================================
# fidelity_loss
#===============================================================================
# criterion = F.mse_loss()
criterion = torch.nn.MSELoss()
def fidelity_loss(model, input_image, t, y, x, fid_image_tensor):

output_tensor = model(input_image, t, y, x)  # output shape: ([batch, num_points, 3]) take first out is u, then v, then p

fid_tensor_u = fid_image_tensor[:,:,0]
fid_tensor_v = fid_image_tensor[:,:,1]

output_u = output_tensor[:,:,0]
output_v = output_tensor[:,:,1]

# fid_u_loss = torch.mean(torch.square(output_u-fid_tensor_u))
# fid_v_loss = torch.mean(torch.square(output_v-fid_tensor_v))

fid_u_loss = criterion(output_u, fid_tensor_u)
fid_v_loss = criterion(output_v, fid_tensor_v)

return fid_u_loss + fid_v_loss

#===============================================================================
# pde_loss
#===============================================================================
def pde_loss_f(model, input_image, t, y, x):
output_tensor = model(input_image, t, y, x) # [batch, num_points, 3]

U = output_tensor[:,:,0].unsqueeze(-1)
V = output_tensor[:,:,1].unsqueeze(-1)
P = output_tensor[:,:,2].unsqueeze(-1)

momentum_x = (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
momentum_x_loss = criterion(momentum_x, torch.zeros_like(momentum_x))
# momentum_x_loss = torch.mean(torch.square(
#     (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
#     ))

momentum_y = (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
momentum_y_loss = criterion(momentum_y, torch.zeros_like(momentum_y))
# momentum_y_loss = torch.mean(torch.square(
#     (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
#     ))

continuity = (U_x) + (V_y)
continuity_loss = criterion(continuity, torch.zeros_like(continuity))
# continuity_loss = torch.mean(torch.square(
#     (U_x) + (V_y)
#     ))

# f'{nt}*dif(u,t)  -  {Re_inv}*(({nx})**2*dif(dif(u,x),x)+({ny})**2*dif(dif(u,y),y))  +  (u*{nx}*dif(u,x)+v*{ny}*dif(u,y))  +  dif(p,x)',
# f'{nt}*dif(v,t)  -  {Re_inv}*(({nx})**2*dif(dif(v,x),x)+({ny})**2*dif(dif(v,y),y))  +  (u*{nx}*dif(v,x)+v*{ny}*dif(v,y))  +  dif(p,y)',
# f'{nx} * dif(u, x) + {ny} * dif(v, y)')

return momentum_x_loss + momentum_y_loss + continuity_loss
``````
``````import argparse
import json
import os
from glob import glob
import numpy as np
from collections import defaultdict
np.set_printoptions(precision=4)

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.tensorboard import SummaryWriter
import pylab
from time import time
import matplotlib.pyplot as plt

tot_loss_list = []
reg_loss_list = []
pde_loss_list = []

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = int(torch.cuda.device_count()) * int(16)

# log and create snapshots
# os.makedirs("./log_dir", exist_ok=True)
# filenames_to_snapshot = glob("*.py") + glob("*.sh")
# utils.snapshot_files(filenames_to_snapshot, "./log_dir")
# logger = utils.get_logger(log_dir="./log_dir")

# tensorboard writer
# writer = SummaryWriter(log_dir=os.path.join("./log_dir", 'tensorboard'))

# random seed for reproducability
torch.manual_seed(1)
np.random.seed(1)

trainset = cylinder_dataset(
data_dir="./", data_filename="./CFD_DATA_space_avg_before_cyl.mat", center_data_filename="./CFD_DATA_center_before_cyl.mat",
nx=26, ny=26, nt=3, n_samp_pts_per_crop=2028
)

############# check what this is
train_sampler = RandomSampler(trainset, replacement=True, num_samples=2028)

sampler=train_sampler, num_workers=2, pin_memory=True)

# train_sampler = RandomSampler(trainset, replacement=True, num_samples=3072)

#                          num_workers=2, pin_memory=True)

# setup model
cnn_fc = CNN_FC(in_features=2, out_features=3, nf=13, activation=torch.nn.SiLU)

def init_weights(m):
if isinstance(m, nn.Conv3d):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight)
m.bias.data.fill_(0.0)

cnn_fc.apply(init_weights)

all_model_params = list(cnn_fc.parameters())

start_ep = 0
global_step = np.zeros(1, dtype=np.uint32)

cnn_fc.to(device)

model_param_count = lambda model: sum(x.numel() for x in model.parameters())
print("{} cnn_fc paramerters in total".format(model_param_count(cnn_fc)))

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# training loop
for epoch in range(start_ep + 1, 20 + 1):
# loss = train(cnn_fc, train_loader, epoch, global_step, device, optimizer)
# tot_loss = 0
# count = 0
# send tensors to device

data_tensors = [t.to(device) for t in data_tensors]
# [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord]
input_grid, center_input_grid, center_input_grid_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord = data_tensors   # input_grid is lowres data, point coord and point values are lowres data in dataset. data_tensor is [space_time_crop, center_space_time_crop, lres_coord, pde_mesh_coord, pde_point_coord]

# input_grid = input_grid.float()
# center_input_grid = center_input_grid.float()
# center_input_grid_reshaped = center_input_grid_reshaped.float()

t_fid = lres_coord[:,:,0]
t_fid = torch.reshape(t_fid, (t_fid.shape[0], t_fid.shape[1], 1))

y_fid = lres_coord[:,:,1]
y_fid = torch.reshape(y_fid, (y_fid.shape[0], y_fid.shape[1], 1))

x_fid = lres_coord[:,:,2]
x_fid = torch.reshape(x_fid, (x_fid.shape[0], x_fid.shape[1], 1))

t_pde = pde_point_coord[:,:,0]
t_pde = torch.reshape(t_pde, (t_pde.shape[0], t_pde.shape[1], 1))

y_pde = pde_point_coord[:,:,1]
y_pde = torch.reshape(y_pde, (y_pde.shape[0], y_pde.shape[1], 1))

x_pde = pde_point_coord[:,:,2]
x_pde = torch.reshape(x_pde, (x_pde.shape[0], x_pde.shape[1], 1))

t_eval = hres_coord[:,:,0]
t_eval = torch.reshape(t_eval, (t_eval.shape[0], t_eval.shape[1], 1))

y_eval = hres_coord[:,:,1]
y_eval = torch.reshape(y_eval, (y_eval.shape[0], y_eval.shape[1], 1))

x_eval = hres_coord[:,:,2]
x_eval = torch.reshape(x_eval, (x_eval.shape[0], x_eval.shape[1], 1))

# the weights of the loss terms
w_fid = 1.0
w_pde = 0.001

# normalizing the weights such that their sum equals 1
w_sum = w_fid + w_pde
w_fid = float(w_fid/w_sum)
w_pde = float(w_pde/w_sum)

def loss_terms():
fid = fidelity_loss(cnn_fc, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
pde = pde_loss_f(cnn_fc, input_grid, t_pde, y_pde, x_pde)
return fid, pde

def total_loss():
fid, pde = loss_terms()
return w_fid*fid + w_pde*pde

foo = loss_terms()
fid_loss, pde_loss = foo[0], foo[1]

loss = w_fid*fid_loss + w_pde*pde_loss
loss.backward()

optimizer.step()
# tot_loss += loss.item()
# count += input_grid.size()[0]

reg_loss_list.append((w_fid * fid_loss).detach().cpu())
pde_loss_list.append((w_pde * pde_loss).detach().cpu())
tot_loss_list.append((loss).detach().cpu())
time0 = time()

if batch_idx % 10 == 0:
# logger log
print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.4e}\t"
"Loss Reg: {:.4e}\tLoss Pde: {:.4e}\t"
"Ten Iters Time: {:.4e}".format(
epoch, batch_idx * len(input_grid), len(train_loader) * len(input_grid),
100. * batch_idx / len(train_loader), loss.item(),
w_fid * fid_loss, w_pde * pde_loss, time()-time0))

if batch_idx % 100 == 0 and batch_idx > 100:
fig, ax = pylab.subplots()
pylab.plot(tot_loss_list, label='total loss')
pylab.plot(reg_loss_list, label='reg loss')
pylab.plot(pde_loss_list, label='pde loss')
pylab.legend(loc='upper right')
ax.set_yscale('log')
plt.show(block = False)

if batch_idx % 100 == 0 and batch_idx > 100:
pred = cnn_fc(input_grid, t_eval, y_eval, x_eval)
print(pred.shape)
u_eval = pred[:,:,0].detach().cpu()
print(u_eval.shape)
u_eval = torch.reshape(u_eval, (u_eval.shape[0], 3, 78, 78))  # [b, t, y, x]
#u_eval = u_eval.permute(0, 1, 3, 2)
print(u_eval.shape)
fig = plt.figure()
ax1.imshow(input_grid[2, 0, :, :, 1].detach().cpu())
ax2.imshow(u_eval[2,1,:,:])
plt.show()

global_step += 1
# tot_loss /= count

scheduler.step(loss)

print('\Training is done')
FILE = "cnn_fc.pth"
torch.save(cnn_fc.state_dict(), FILE)

# loaded_cnn_fc = Model(in_features=2, out_features=3, nf=16, activation=torch.nn.Tanh)

``````

It returns:

``````None
Train Epoch: 1 [0/2025 (0%)]	Loss Sum: 1.9692e+00	Loss Reg: 1.9692e+00	Loss Pde: 6.9719e-08	Ten Iters Time: 1.6594e-04
None
None
None
None
None
None
None
None
None
None
Train Epoch: 1 [50/2025 (2%)]	Loss Sum: 2.7263e+00	Loss Reg: 2.7263e+00	Loss Pde: 1.3230e-05	Ten Iters Time: 8.6784e-05
None
None
None
``````

I think the issue is here:

``````        # the weights of the loss terms
w_fid = 1.0
w_pde = 0.001

# normalizing the weights such that their sum equals 1
w_sum = w_fid + w_pde
w_fid = float(w_fid/w_sum)
w_pde = float(w_pde/w_sum)

def loss_terms():
fid = fidelity_loss(cnn_fc, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
pde = pde_loss_f(cnn_fc, input_grid, t_pde, y_pde, x_pde)
return fid, pde

def total_loss():
fid, pde = loss_terms()
return w_fid*fid + w_pde*pde

foo = loss_terms()
fid_loss, pde_loss = foo[0], foo[1]

loss = w_fid*fid_loss + w_pde*pde_loss