Returning zero for a custom loss function

Following code includes a network that combines a convolutional network and a fully connected net. I have built up a custom dataloader and also custom loss functions. This has two loss functions, one for regression and another for the underlying physics (pde_loss). however, the pde_loss return zero after the first iteration. Also, regression loss remains the same, and the training doesn’t converge.
I would appreciate if anyone could help me with this issue.

"""
Class for Convolutional Network
with ResNet backbone
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import warnings
from torch.autograd import Variable

class ResBlock3D(nn.Module):
    """3D convolutional Residue Block. Maintains same resolution.
    """

    def __init__(self, in_channels, neck_channels, out_channels, final_relu=True):
        """Initialization.

        Args:
          in_channels: int, number of input channels.
          neck_channels: int, number of channels in bottleneck layer.
          out_channels: int, number of output channels.
          final_relu: bool, add relu to the last layer.
        """
        super(ResBlock3D, self).__init__()
        self.in_channels = in_channels
        self.neck_channels = neck_channels
        self.out_channels = out_channels
        self.conv1 = nn.Conv3d(in_channels, neck_channels, kernel_size=1, stride=1)
        self.conv2 = nn.Conv3d(neck_channels, neck_channels, kernel_size=3, stride=1, padding=1)
        self.conv3 = nn.Conv3d(neck_channels, out_channels, kernel_size=1, stride=1)
        self.bn1 = nn.BatchNorm3d(num_features=neck_channels)
        self.bn2 = nn.BatchNorm3d(num_features=neck_channels)
        self.bn3 = nn.BatchNorm3d(num_features=out_channels)

        self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1)
        self.final_relu = final_relu

    def forward(self, x):  # pylint:
        identity = x

        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x += self.shortcut(identity)
        if self.final_relu:
            x = F.relu(x)

        return x


class CNN_FC(nn.Module):
  def __init__(self, in_features=2, out_features=3, nf=16,
              activation=torch.nn.Tanh):

    super(CNN_FC, self).__init__()
    self.nf = nf
    self.in_features = in_features
    self.out_features = out_features
    self.activ = activation()

    self.conv_in = ResBlock3D(self.in_features, self.nf, self.nf)   # ResBlock3D(in=2, neck=16, out=16)
    self.conv11 = ResBlock3D(self.nf, self.nf, self.nf*2)
    self.conv12 = ResBlock3D(self.nf*2, self.nf*2, self.nf*4)
    self.conv13 = ResBlock3D(self.nf*4, self.nf*4, self.nf*8)
    self.conv14 = ResBlock3D(self.nf*8, self.nf*8, self.nf*16)
    self.convs = [self.conv_in, self.conv11, self.conv12, self.conv13, self.conv14]
    self.convs = nn.ModuleList(self.convs)
    

    self.maxpool11 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
    self.maxpool12 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
    self.maxpool13 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    self.maxpool14 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(1, 1, 2), padding=0, dilation=1, ceil_mode=False)
    self.maxpools = [self.maxpool11, self.maxpool12, self.maxpool13, self.maxpool14]
    self.maxpools = nn.ModuleList(self.maxpools)

    self.flatten1 = nn.Flatten()
    self.flatten = [self.flatten1]
    self.flatten = nn.ModuleList(self.flatten)

    self.fc0 = nn.Linear(4, nf*32)
    self.fc1 = nn.Linear(nf*32 , nf*16)
    self.fc2 = nn.Linear(nf*16 , nf*8)
    self.fc3 = nn.Linear(nf*8 , nf*4)
    self.fc4 = nn.Linear(nf*4 , nf*2)
    self.fc5 = nn.Linear(nf*2, out_features)
    self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]
    self.fc = nn.ModuleList(self.fc)
    
  def forward(self, c, t, y, x):
    # first entry of x should be batch size!  x = torch.randn(batch_size,numpoints,1)
    # c is the output of convolutional prtion
    c = self.conv_in(c)
    c = self.conv11(c)
    c = self.maxpool11(c)
    c = self.conv12(c)
    c = self.maxpool12(c)
    c = self.conv13(c)
    c = self.maxpool13(c)
    c = self.conv14(c)
    c = self.maxpool14(c)
    c = self.flatten1(c)

    c = c.unsqueeze(-1)
    c = c.repeat(1,int(x.shape[1]/c.shape[1]),1)
    # c = Variable(c, requires_grad=True)

    x_tmp = torch.cat((t, y, x, c), dim=-1)
    x_tmp = self.fc0(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc1(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc2(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc3(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc4(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc5(x_tmp)

    return x_tmp
import os
import torch
from torch.utils.data import Dataset, Sampler
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy import ndimage
from scipy.io import savemat, loadmat
import warnings

# in __init__ load lres data (entire data, generated in matlab)
#

class cylinder_dataset(Dataset):

  def __init__(self, data_dir="./", data_filename="./CFD_DATA_space_avg.mat", center_data_filename="./CFD_DATA_centers.mat",
                 nx=16, ny=16, nt=3, n_samp_pts_per_crop=3000):

    self.nt = nt
    self.nx = nx
    self.ny = ny
    self.data_dir = data_dir
    self.data_filename = data_filename
    self.center_data_filename = center_data_filename
    self.n_samp_pts_per_crop = n_samp_pts_per_crop

    ################## load low res data (training data)
    npdata = loadmat(os.path.join(self.data_dir, self.data_filename))
    self.data = np.stack([npdata['u_space_avg'], npdata['v_space_avg']], axis=0)    # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
    self.data = self.data.astype(np.float32)
    #self.data = self.data.transpose(0, 3, 2, 1)  # [c, t, y, x] # c is number of channels
    nc_data, nt_data, ny_data, nx_data = self.data.shape
    # if highres: nc = 2, nt = 1501, ny = 80, nx = 640
    # if lowres (by factor of 5): nc=2, nt=1501, ny=16, nx=128


    self.nx_start_range = np.arange(0, nx_data-nx+1)    # 0 to 560 fo high res, 0 to 112 for low res
    self.ny_start_range = np.arange(0, ny_data-ny+1)    # returns an array shape (1,) ([0])
    self.nt_start_range = np.arange(0, nt_data-nt+1)    # 0 to 1498
    self.rand_grid = np.stack(np.meshgrid(self.nt_start_range,
                                          self.ny_start_range,
                                          self.nx_start_range, indexing='ij'), axis=-1)

    self.rand_start_id = self.rand_grid.reshape([-1, 3])
    # creates a indexing array for moving window, that starts from [0,0,0] ([t,y,x]), first moves along x, then moves along t (y has only a single index), such that, index 0 returns [0,0,0], index 1 returns [0,0,1], index 113 (for low res input data) returns [1,0,0], and so on
    self.num_samples = self.rand_start_id.shape[0]

    ################## load center data (fidelity data)
    center_npdata = loadmat(os.path.join(self.data_dir, self.center_data_filename))
    self.center_data = np.stack([center_npdata['u_space_center'], center_npdata['v_space_center']], axis=0)    # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
    self.center_data = self.center_data.astype(np.float32)



  def __getitem__(self, index):
    t_id, y_id, x_id = self.rand_start_id[index]    # idx is the id of the crop of the data that is passed to Dataloader.
    space_time_crop = self.data[:,
                                t_id:t_id+self.nt,
                                y_id:y_id+self.ny,
                                x_id:x_id+self.nx]  # [c, t, y, x] c is the channel

    space_time_crop = np.transpose(space_time_crop, (0, 2, 3, 1))
    #space_time_crop = np.swapaxes(space_time_crop, 1, 3)
    
    center_space_time_crop = self.center_data[:,
                                t_id:t_id+self.nt,
                                y_id:y_id+self.ny,
                                x_id:x_id+self.nx]  # [c, t, y, x] c is the channel

    center_space_time_crop = np.transpose(center_space_time_crop, (0, 2, 3, 1)) # [c, y, x, t]
    center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(1,2,3,0), (-1, 2))   # [t,y,x,c]

    lres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
                                      np.linspace(0, self.ny-1, self.ny),
                                      np.linspace(0, self.nx-1, self.nx),
                                      indexing='ij'), axis=-1)
    lres_coord = np.reshape(lres_coord, (self.nt*self.nx*self.ny,3))
    # lres_coord returns an array that includes coordinates of points such that, [0,:,:,:] is the coordinates of the first slice (t=0) that marches first on x (index 2) then on y (index 1)


    pde_mesh_coord = np.stack(np.meshgrid(np.random.rand(10)*[self.nt-1],
                                      np.random.rand(300)*[self.ny-1],
                                      np.random.rand(300)*[self.nx-1],
                                      indexing='ij'), axis=-1)
    # creates an array including 10 slices in time with 300*300 random points in y and x

    pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3) * (np.array([3, 16, 16], dtype=np.int32) - 1) 

    return_tensors = [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord]


    return tuple(return_tensors)

  def __len__(self):
    return self.rand_start_id.shape[0]
#===============================================================================
# fidelity_loss
#===============================================================================
def fidelity_loss(model, input_image, t, y, x, fid_image_tensor):

    output_tensor = model(input_image, t, y, x)  # output shape: ([batch, num_points, 3])
    
    fid_tensor_u = fid_image_tensor[:,:,0]
    fid_tensor_v = fid_image_tensor[:,:,1]

    output_u = output_tensor[:,:,0]
    output_v = output_tensor[:,:,1]

    fid_u_loss = torch.mean(torch.square(output_u-fid_tensor_u))
    fid_v_loss = torch.mean(torch.square(output_v-fid_tensor_v))

    #  fid_image_tensor: ([batch, 768, 2(channel)])
    return fid_u_loss + fid_v_loss

#===============================================================================
# pde_loss
#===============================================================================
def pde_loss_f(model, input_image, t, y, x):
  output_tensor = model(input_image, t, y, x) # [batch, num_points, 3]

  U = output_tensor[:,:,0].unsqueeze(-1)
  V = output_tensor[:,:,1].unsqueeze(-1)
  P = output_tensor[:,:,2].unsqueeze(-1)

  U_t = torch.autograd.grad(U.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]

  U_y = torch.autograd.grad(U.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
  U_yy = torch.autograd.grad(U_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]

  U_x = torch.autograd.grad(U.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
  U_xx = torch.autograd.grad(U_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]

  V_t = torch.autograd.grad(V.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]

  V_y = torch.autograd.grad(V.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
  V_yy = torch.autograd.grad(V_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]

  V_x = torch.autograd.grad(V.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
  V_xx = torch.autograd.grad(V_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]

  P_y = torch.autograd.grad(P.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
  P_x = torch.autograd.grad(P.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]


  momentum_x_loss = torch.mean(torch.square(
      (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
      ))
  
  momentum_y_loss = torch.mean(torch.square(
      (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
      ))

  continuity_loss = torch.mean(torch.square(
      (U_x) + (V_y)
      ))


  return momentum_x_loss + momentum_y_loss + continuity_loss
########## train #############

import argparse
import json
import os
from glob import glob
import numpy as np
from collections import defaultdict
np.set_printoptions(precision=4)

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.tensorboard import SummaryWriter

####### train function ##########
def train(CNN_FC, train_loader, epoch, global_step, device,
          logger, writer, optimizer):
    """Training function."""
    CNN_FC.train()
    # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    # cnn_fc = CNN_FC() 
    # cnn_fc = cnn_fc.to(device)
    # optimizer = torch.optim.Adam(cnn_fc.parameters(), lr=0.001)
    #list(cnn_fc.parameters())
    
    tot_loss = 0
    count = 0

    for batch_idx, data_tensors in enumerate(train_loader):
        # send tensors to device
        
        data_tensors = [t.to(device) for t in data_tensors]
        # [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord]
        input_grid, center_input_grid, center_input_grid_reshaped, lres_coord, pde_mesh_coord, pde_point_coord = data_tensors   # input_grid is lowres data, point coord and point values are lowres data in dataset. data_tensor is [space_time_crop, center_space_time_crop, lres_coord, pde_mesh_coord, pde_point_coord]
        optimizer.zero_grad()

        input_grid = Variable(input_grid.float(), requires_grad=False)
        center_input_grid = Variable(center_input_grid.float(), requires_grad=False)
        center_input_grid_reshaped = Variable(center_input_grid_reshaped.float(), requires_grad=False)
        lres_coord = Variable(lres_coord.float(), requires_grad=False)
        pde_mesh_coord = Variable(pde_mesh_coord.float(), requires_grad=True)
        pde_point_coord = Variable(pde_point_coord.float(), requires_grad=True)

        t_fid = lres_coord[:,:,0]
        t_fid = torch.reshape(t_fid, (t_fid.shape[0], t_fid.shape[1], 1))

        y_fid = lres_coord[:,:,1]
        y_fid = torch.reshape(y_fid, (y_fid.shape[0], y_fid.shape[1], 1))

        x_fid = lres_coord[:,:,2]
        x_fid = torch.reshape(x_fid, (x_fid.shape[0], x_fid.shape[1], 1))


        t_pde = pde_point_coord[:,:,0]
        t_pde = torch.reshape(t_pde, (t_pde.shape[0], t_pde.shape[1], 1))

        y_pde = pde_point_coord[:,:,1]
        y_pde = torch.reshape(y_pde, (y_pde.shape[0], y_pde.shape[1], 1))

        x_pde = pde_point_coord[:,:,2]
        x_pde = torch.reshape(x_pde, (x_pde.shape[0], x_pde.shape[1], 1))

        # the weights of the loss terms
        w_fid = 1.0
        w_pde = 1.0

        # normalizing the weights such that their sum equals 1
        w_sum = w_fid + w_pde
        w_fid = float(w_fid/w_sum)
        w_pde = float(w_pde/w_sum)


        def loss_terms():
          fid = fidelity_loss(CNN_FC, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
          pde = pde_loss_f(CNN_FC, input_grid, t_pde, y_pde, x_pde)
          return fid, pde

        def total_loss():
          fid, pde = loss_terms()
          return w_fid*fid + w_pde*pde

        foo = loss_terms()
        fid_loss, pde_loss = foo[0], foo[1]
        
        loss = w_fid*fid_loss + w_pde*pde_loss

        loss.backward()

        ################# CNN_FC.module.parameters() change to CNN_FC.parameters()
        # gradient clipping
        torch.nn.utils.clip_grad_value_(CNN_FC.parameters(), 1.)
        # torch.nn.utils.clip_grad_value_(imnet.module.parameters(), args.clip_grad)

        optimizer.step()
        tot_loss += loss.item()
        count += input_grid.size()[0]
        if batch_idx % 10 == 0:
            # logger log
            logger.info(
                "Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.6f}\t"
                "Loss Reg: {:.6f}\tLoss Pde: {:.6f}".format(
                    epoch, batch_idx * len(input_grid), len(train_loader) * len(input_grid),
                    100. * batch_idx / len(train_loader), loss.item(),
                    w_fid * fid_loss, w_pde * pde_loss))
            # tensorboard log
            writer.add_scalar('train/reg_loss_unweighted', fid_loss, global_step=int(global_step))
            writer.add_scalar('train/pde_loss_unweighted', pde_loss, global_step=int(global_step))
            writer.add_scalar('train/sum_loss', loss, global_step=int(global_step))
            writer.add_scalars('train/losses_weighted',
                               {"reg_loss": w_fid * fid_loss,
                                "pde_loss": w_pde * pde_loss,
                                "sum_loss": loss}, global_step=int(global_step))

        global_step += 1
    tot_loss /= count
    return tot_loss

########################
import train_utils as utils
from torch.utils.tensorboard import SummaryWriter
def main():

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    batch_size = int(torch.cuda.device_count()) * int(16)

    # log and create snapshots
    os.makedirs("./log_dir", exist_ok=True)
    filenames_to_snapshot = glob("*.py") + glob("*.sh")
    utils.snapshot_files(filenames_to_snapshot, "./log_dir")
    logger = utils.get_logger(log_dir="./log_dir")

    # tensorboard writer
    writer = SummaryWriter(log_dir=os.path.join("./log_dir", 'tensorboard'))

    # random seed for reproducability
    torch.manual_seed(1)
    np.random.seed(1)

    # create dataloaders
    trainset = cylinder_dataset(
        data_dir="./", data_filename="./CFD_DATA_space_avg.mat", center_data_filename="./CFD_DATA_centers.mat",
                 nx=16, ny=16, nt=3, n_samp_pts_per_crop=3072
    )

    ############# check what this is
    train_sampler = RandomSampler(trainset, replacement=True, num_samples=3072)

    train_loader = DataLoader(trainset, batch_size=16, shuffle=False, drop_last=True,
                              sampler=train_sampler, num_workers=1, pin_memory=True)

    # setup model
    cnn_fc = CNN_FC(in_features=2, out_features=3, nf=16, activation=torch.nn.Tanh)

    all_model_params = list(cnn_fc.parameters())

    optimizer = optim.Adam(all_model_params, lr=1e-2)

    start_ep = 0
    global_step = np.zeros(1, dtype=np.uint32)
    tracked_stats = np.inf

    # if args.resume:
    #     resume_dict = torch.load(args.resume)
    #     start_ep = resume_dict["epoch"]
    #     global_step = resume_dict["global_step"]
    #     tracked_stats = resume_dict["tracked_stats"]
    #     unet.load_state_dict(resume_dict["unet_state_dict"])
    #     imnet.load_state_dict(resume_dict["imnet_state_dict"])
    #     optimizer.load_state_dict(resume_dict["optim_state_dict"])
    #     for state in optimizer.state.values():
    #         for k, v in state.items():
    #             if isinstance(v, torch.Tensor):
    #                 state[k] = v.to(device)

    cnn_fc.to(device)

    model_param_count = lambda model: sum(x.numel() for x in model.parameters())
    logger.info("{} cnn_fc paramerters in total".format(model_param_count(cnn_fc)))

    checkpoint_path = os.path.join("./log_dir", "checkpoint_latest.pth.tar")
   
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

    # training loop
    for epoch in range(start_ep + 1, 100 + 1):
        loss = train(cnn_fc, train_loader, epoch, global_step, device, logger, writer,
                     optimizer)
        
        scheduler.step(loss)

        if loss < tracked_stats:
            tracked_stats = loss
            is_best = True
        else:
            is_best = False
        # "cnn_fc_state_dict": cnn_fc.module.state_dict()
        utils.save_checkpoint({
            "epoch": epoch,
            "cnn_fc_state_dict": cnn_fc.state_dict(),
            "optim_state_dict": optimizer.state_dict(),
            "tracked_stats": tracked_stats,
            "global_step": global_step,
        }, is_best, epoch, checkpoint_path, "_pdenet", logger)

if __name__ == "__main__":
    main()