Custom loss function detached from the graph

Hello,
following is my code with network combining CNN and a fully-connected net. This is a physics-informed neural network. So, I defined custom loss functions including a regression loss function and e PDE loss function. I’m getting None value for the model grads after and before backward step. I believe this is due to how I defined the loss function. I would appreciate it if anyone could help me with this issue.


import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import warnings
from torch.autograd import Variable


class CNN_FC(nn.Module):
  def __init__(self, in_features=2, out_features=3, nf=13,
              activation=torch.nn.Tanh, cnn_activation=torch.nn.ReLU):
    """Initialization

        Args:
          in_channels: int, number of input channels.
          neck_channels: int, number of channels in bottleneck layer.
          out_channels: int, number of output channels.
          final_relu: bool, add relu to the last layer.
    """
    super(CNN_FC, self).__init__()
    self.nf = nf
    self.in_features = in_features
    self.out_features = out_features
    self.activ = activation()
    self.cnn_activ = cnn_activation()


    self.conv_in = nn.Conv3d(self.in_features, self.nf, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='valid')
    self.conv11 = nn.Conv3d(self.nf, self.nf*2, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
    self.conv12 = nn.Conv3d(self.nf*2, self.nf*3, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
    self.conv13 = nn.Conv3d(self.nf*3, self.nf*6, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
    self.conv14 = nn.Conv3d(self.nf*6, self.nf*13, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
    self.convs = [self.conv_in, self.conv11, self.conv12, self.conv13, self.conv14]

    self.convs = nn.ModuleList(self.convs)
    

    self.maxpool11 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
    self.maxpool12 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
    self.maxpool13 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
    self.maxpool14 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(1, 1, 2), padding=0, dilation=1, ceil_mode=False)
    self.maxpools = [self.maxpool11, self.maxpool12, self.maxpool13, self.maxpool14]
    self.maxpools = nn.ModuleList(self.maxpools)

    self.flatten1 = nn.Flatten()
    self.flatten = [self.flatten1]
    self.flatten = nn.ModuleList(self.flatten)

    self.fc0 = nn.Linear(679, nf*64)
    self.fc1 = nn.Linear(nf*64 , nf*32)
    self.fc2 = nn.Linear(nf*32 , nf*16)
    self.fc3 = nn.Linear(nf*16 , nf*8)
    self.fc4 = nn.Linear(nf*8 , nf*4)
    self.fc5 = nn.Linear(nf*4, out_features)
    self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]
    self.fc = nn.ModuleList(self.fc)

  
    
  def forward(self, c, t, y, x):

    # first entry of x should be batch size!  x = torch.randn(batch_size,numpoints,1)
    # c is the output of convolutional prtion
    c = self.conv_in(c)
    c = self.cnn_activ(c)
    c = self.conv11(c)
    c = self.cnn_activ(c)
    c = self.maxpool11(c)
    c = self.conv12(c)
    c = self.cnn_activ(c)
    c = self.maxpool12(c)
    c = self.conv13(c)
    c = self.cnn_activ(c)
    c = self.maxpool13(c)
    c = self.conv14(c)
    c = self.cnn_activ(c)
    c = self.maxpool14(c)
    c = self.flatten1(c)
    # print(c.shape)

    c = c.unsqueeze(1)
    c = c.repeat(1,int(x.shape[1]),1)

    x_tmp = torch.cat((c, t, y, x), dim=-1)
    x_tmp = self.fc0(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc1(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc2(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc3(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc4(x_tmp)
    x_tmp = self.activ(x_tmp)
    x_tmp = self.fc5(x_tmp)

    return x_tmp

import os
import torch
from torch.utils.data import Dataset, Sampler
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy import ndimage
from scipy.io import savemat, loadmat
import warnings

# in __init__ load lres data (entire data, generated in matlab)
#

class cylinder_dataset(Dataset):

  def __init__(self, data_dir="./", data_filename="CFD_DATA_space_avg_before_cyl.mat", center_data_filename="CFD_DATA_centers_before_cyl.mat",
                 nx=26, ny=26, nt=3, n_samp_pts_per_crop=2028, normalize_output=True):

    self.nt = nt
    self.nx = nx
    self.ny = ny
    self.data_dir = data_dir
    self.data_filename = data_filename
    self.center_data_filename = center_data_filename
    self.n_samp_pts_per_crop = n_samp_pts_per_crop
    self.normalize_output = normalize_output

    ################## load low res data (training data)
    npdata = loadmat(os.path.join(self.data_dir, self.data_filename))
    self.data = np.stack([npdata['u_space_avg'], npdata['v_space_avg']], axis=0)    # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
    self.data = self.data.astype(np.float32)
    #self.data = self.data.transpose(0, 3, 2, 1)  # [c, t, y, x] # c is number of channels
    nc_data, nt_data, ny_data, nx_data = self.data.shape
    # if highres: nc = 2, nt = 1501, ny = 80, nx = 640
    # if lowres (by factor of 5): nc=2, nt=1501, ny=16, nx=128


    self.nx_start_range = np.arange(0, nx_data-nx+1)    # 0 to 560 fo high res, 0 to 112 for low res
    self.ny_start_range = np.arange(0, ny_data-ny+1)    # returns an array shape (1,) ([0])
    self.nt_start_range = np.arange(0, nt_data-nt+1)    # 0 to 1498
    self.rand_grid = np.stack(np.meshgrid(self.nt_start_range,
                                          self.ny_start_range,
                                          self.nx_start_range, indexing='ij'), axis=-1)

    self.rand_start_id = self.rand_grid.reshape([-1, 3])
    
    self.num_samples = self.rand_start_id.shape[0]

    ################## load center data (fidelity data)
    center_npdata = loadmat(os.path.join(self.data_dir, self.center_data_filename))
    self.center_data = np.stack([center_npdata['u_space_center'], center_npdata['v_space_center']], axis=0)    # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
    self.center_data = self.center_data.astype(np.float32)

    # compute channel-wise mean and std
    self._mean = np.mean(self.data, axis=(1, 2, 3))
    self._std = np.std(self.data, axis=(1, 2, 3))

  def __getitem__(self, index):
    t_id, y_id, x_id = self.rand_start_id[index]    # idx is the id of the crop of the data that is passed to Dataloader.
    space_time_crop = self.data[:,
                                t_id:t_id+self.nt,
                                y_id:y_id+self.ny,
                                x_id:x_id+self.nx]  # [c, t, y, x] c is the channel

    space_time_crop = np.transpose(space_time_crop, (0, 2, 3, 1))
    
    
    center_space_time_crop = self.center_data[:,
                                t_id:t_id+self.nt,
                                y_id:y_id+self.ny,
                                x_id:x_id+self.nx]  # [c, t, y, x] c is the channel

    center_space_time_crop = np.transpose(center_space_time_crop, (0, 2, 3, 1)) # [c, y, x, t]
    center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(3, 1, 2, 0), (-1, 2))   # [t,y,x,c]

    lres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
                                      np.linspace(0, self.ny-1, self.ny),
                                      np.linspace(0, self.nx-1, self.nx),
                                      indexing='ij'), axis=-1)
    lres_coord = np.reshape(lres_coord, (self.nt*self.nx*self.ny,3))
    # lres_coord returns an array that includes coordinates of points such that, [0,:,:,:] is the coordinates of the first slice (t=0) that marches first on x (index 2) then on y (index 1)

    hres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
                                      np.linspace(0, self.ny*3-1, self.ny*3),
                                      np.linspace(0, self.nx*3-1, self.nx*3),
                                      indexing='ij'), axis=-1)
    
    hres_coord = np.reshape(hres_coord, (self.nt*self.nx*3*self.ny*3,3))

    pde_mesh_coord = np.stack(np.meshgrid(np.random.rand(10)*[self.nt-1],
                                      np.random.rand(300)*[self.ny-1],
                                      np.random.rand(300)*[self.nx-1],
                                      indexing='ij'), axis=-1)
    # creates an array including 10 slices in time with 300*300 random points in y and x

    pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3) * (np.array([3, 26, 26], dtype=np.int32) - 1) 
    # return an array with three columns such that first column is 3000 randon number between zero and nt-1, second column is 3000 points between zero and ny, third is between zero and nx

    if self.normalize_output:
        space_time_crop = self.normalize_grid(space_time_crop)
        center_space_time_crop = self.normalize_grid(center_space_time_crop)
        center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(3, 1, 2, 0), (-1, 2))
        pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3)
        lres_coord = lres_coord / (np.array([3, 26, 26], dtype=np.int32) - 1)
        hres_coord = hres_coord / (np.array([3, 78, 78], dtype=np.int32) - 1)

    return_tensors = [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord]

    # cast everything to float32
    return_tensors = [t.astype(np.float32) for t in return_tensors]

    return tuple(return_tensors)

  def __len__(self):
    return self.rand_start_id.shape[0]


  @property
  def channel_mean(self):
      """channel-wise mean of dataset."""
      return self._mean

  @property
  def channel_std(self):
      """channel-wise mean of dataset."""
      return self._std

  @staticmethod
  def _normalize_array(array, mean, std):
      """normalize array (np or torch)."""
      if isinstance(array, torch.Tensor):
          dev = array.device
          std = torch.tensor(std, device=dev)
          mean = torch.tensor(mean, device=dev)
      return (array - mean) / std

  @staticmethod
  def _denormalize_array(array, mean, std):
      """normalize array (np or torch)."""
      if isinstance(array, torch.Tensor):
          dev = array.device
          std = torch.tensor(std, device=dev)
          mean = torch.tensor(mean, device=dev)
      return array * std + mean

  def normalize_grid(self, grid):
      """Normalize grid.
      Args:
        grid: np array or torch tensor of shape [2, ...], 2 are the num. of phys channels.
      Returns:
        channel normalized grid of same shape as input.
      """
      # reshape mean and std to be broadcastable.
      g_dim = len(grid.shape)
      mean_bc = self.channel_mean[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
      std_bc = self.channel_std[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
      return self._normalize_array(grid, mean_bc, std_bc)


  def denormalize_grid(self, grid):
      """Denormalize grid.
      Args:
        grid: np array or torch tensor of shape [2, ...], 2 are the num. of phys channels.
      Returns:
        channel denormalized grid of same shape as input.
      """
      # reshape mean and std to be broadcastable.
      g_dim = len(grid.shape)
      mean_bc = self.channel_mean[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
      std_bc = self.channel_std[(...,)+(None,)*(g_dim-1)]  # unsqueeze from the back
      return self._denormalize_array(grid, mean_bc, std_bc)

import torch.nn.functional as F
#===============================================================================
# fidelity_loss
#===============================================================================
# criterion = F.mse_loss()
criterion = torch.nn.MSELoss()
def fidelity_loss(model, input_image, t, y, x, fid_image_tensor):

    output_tensor = model(input_image, t, y, x)  # output shape: ([batch, num_points, 3]) take first out is u, then v, then p
    
    
    fid_tensor_u = fid_image_tensor[:,:,0]
    fid_tensor_v = fid_image_tensor[:,:,1]

    output_u = output_tensor[:,:,0]
    output_v = output_tensor[:,:,1]


    # fid_u_loss = torch.mean(torch.square(output_u-fid_tensor_u))
    # fid_v_loss = torch.mean(torch.square(output_v-fid_tensor_v))

    fid_u_loss = criterion(output_u, fid_tensor_u)
    fid_v_loss = criterion(output_v, fid_tensor_v)

    return fid_u_loss + fid_v_loss

#===============================================================================
# pde_loss
#===============================================================================
def pde_loss_f(model, input_image, t, y, x):
  output_tensor = model(input_image, t, y, x) # [batch, num_points, 3]

  U = output_tensor[:,:,0].unsqueeze(-1)
  V = output_tensor[:,:,1].unsqueeze(-1)
  P = output_tensor[:,:,2].unsqueeze(-1)

  U_t = torch.autograd.grad(U.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]

  U_y = torch.autograd.grad(U.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
  U_yy = torch.autograd.grad(U_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]

  U_x = torch.autograd.grad(U.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
  U_xx = torch.autograd.grad(U_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]

  V_t = torch.autograd.grad(V.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]

  V_y = torch.autograd.grad(V.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
  V_yy = torch.autograd.grad(V_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]

  V_x = torch.autograd.grad(V.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
  V_xx = torch.autograd.grad(V_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]

  P_y = torch.autograd.grad(P.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
  P_x = torch.autograd.grad(P.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]


  momentum_x = (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
  momentum_x_loss = criterion(momentum_x, torch.zeros_like(momentum_x))
  # momentum_x_loss = torch.mean(torch.square(
  #     (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
  #     ))
  
  momentum_y = (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
  momentum_y_loss = criterion(momentum_y, torch.zeros_like(momentum_y))
  # momentum_y_loss = torch.mean(torch.square(
  #     (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
  #     ))

  continuity = (U_x) + (V_y)
  continuity_loss = criterion(continuity, torch.zeros_like(continuity))
  # continuity_loss = torch.mean(torch.square(
  #     (U_x) + (V_y)
  #     ))

  # f'{nt}*dif(u,t)  -  {Re_inv}*(({nx})**2*dif(dif(u,x),x)+({ny})**2*dif(dif(u,y),y))  +  (u*{nx}*dif(u,x)+v*{ny}*dif(u,y))  +  dif(p,x)',
  # f'{nt}*dif(v,t)  -  {Re_inv}*(({nx})**2*dif(dif(v,x),x)+({ny})**2*dif(dif(v,y),y))  +  (u*{nx}*dif(v,x)+v*{ny}*dif(v,y))  +  dif(p,y)',
  # f'{nx} * dif(u, x) + {ny} * dif(v, y)')

  return momentum_x_loss + momentum_y_loss + continuity_loss
import argparse
import json
import os
from glob import glob
import numpy as np
from collections import defaultdict
np.set_printoptions(precision=4)

import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.tensorboard import SummaryWriter
import pylab
from time import time
import matplotlib.pyplot as plt

tot_loss_list = []
reg_loss_list = []
pde_loss_list = []

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

batch_size = int(torch.cuda.device_count()) * int(16)

# log and create snapshots
# os.makedirs("./log_dir", exist_ok=True)
# filenames_to_snapshot = glob("*.py") + glob("*.sh")
# utils.snapshot_files(filenames_to_snapshot, "./log_dir")
# logger = utils.get_logger(log_dir="./log_dir")

# tensorboard writer
# writer = SummaryWriter(log_dir=os.path.join("./log_dir", 'tensorboard'))

# random seed for reproducability
torch.manual_seed(1)
np.random.seed(1)

# create dataloaders
trainset = cylinder_dataset(
    data_dir="./", data_filename="./CFD_DATA_space_avg_before_cyl.mat", center_data_filename="./CFD_DATA_center_before_cyl.mat",
              nx=26, ny=26, nt=3, n_samp_pts_per_crop=2028
)

############# check what this is
train_sampler = RandomSampler(trainset, replacement=True, num_samples=2028)

train_loader = DataLoader(trainset, batch_size=5, shuffle=False, drop_last=True,
                          sampler=train_sampler, num_workers=2, pin_memory=True)

# train_sampler = RandomSampler(trainset, replacement=True, num_samples=3072)

# train_loader = DataLoader(trainset, batch_size=16, shuffle=False, drop_last=True,
#                          num_workers=2, pin_memory=True)


# setup model
cnn_fc = CNN_FC(in_features=2, out_features=3, nf=13, activation=torch.nn.SiLU)

def init_weights(m):
  if isinstance(m, nn.Conv3d):
    torch.nn.init.xavier_uniform_(m.weight)
  if isinstance(m, nn.Linear):
    torch.nn.init.xavier_normal_(m.weight)
    m.bias.data.fill_(0.0)

cnn_fc.apply(init_weights)

all_model_params = list(cnn_fc.parameters())

optimizer = optim.Adam(all_model_params, lr=1e-2)

start_ep = 0
global_step = np.zeros(1, dtype=np.uint32)

cnn_fc.to(device)

model_param_count = lambda model: sum(x.numel() for x in model.parameters())
print("{} cnn_fc paramerters in total".format(model_param_count(cnn_fc)))

scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

# training loop
for epoch in range(start_ep + 1, 20 + 1):
    # loss = train(cnn_fc, train_loader, epoch, global_step, device, optimizer)
    # tot_loss = 0
    # count = 0
    for batch_idx, data_tensors in enumerate(train_loader):
        # send tensors to device
        
        data_tensors = [t.to(device) for t in data_tensors]
        # [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord]
        input_grid, center_input_grid, center_input_grid_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord = data_tensors   # input_grid is lowres data, point coord and point values are lowres data in dataset. data_tensor is [space_time_crop, center_space_time_crop, lres_coord, pde_mesh_coord, pde_point_coord]
        optimizer.zero_grad()

        # input_grid = input_grid.float()
        # center_input_grid = center_input_grid.float()
        # center_input_grid_reshaped = center_input_grid_reshaped.float()
        # lres_coord = lres_coord.float().requires_grad_(True)
        pde_mesh_coord = pde_mesh_coord.float().requires_grad_(True)
        pde_point_coord = pde_point_coord.float().requires_grad_(True)

        t_fid = lres_coord[:,:,0]
        t_fid = torch.reshape(t_fid, (t_fid.shape[0], t_fid.shape[1], 1))

        y_fid = lres_coord[:,:,1]
        y_fid = torch.reshape(y_fid, (y_fid.shape[0], y_fid.shape[1], 1))

        x_fid = lres_coord[:,:,2]
        x_fid = torch.reshape(x_fid, (x_fid.shape[0], x_fid.shape[1], 1))


        t_pde = pde_point_coord[:,:,0]
        t_pde = torch.reshape(t_pde, (t_pde.shape[0], t_pde.shape[1], 1))

        y_pde = pde_point_coord[:,:,1]
        y_pde = torch.reshape(y_pde, (y_pde.shape[0], y_pde.shape[1], 1))

        x_pde = pde_point_coord[:,:,2]
        x_pde = torch.reshape(x_pde, (x_pde.shape[0], x_pde.shape[1], 1))

        
        t_eval = hres_coord[:,:,0]
        t_eval = torch.reshape(t_eval, (t_eval.shape[0], t_eval.shape[1], 1))

        y_eval = hres_coord[:,:,1]
        y_eval = torch.reshape(y_eval, (y_eval.shape[0], y_eval.shape[1], 1))

        x_eval = hres_coord[:,:,2]
        x_eval = torch.reshape(x_eval, (x_eval.shape[0], x_eval.shape[1], 1))


        # the weights of the loss terms
        w_fid = 1.0
        w_pde = 0.001

        # normalizing the weights such that their sum equals 1
        w_sum = w_fid + w_pde
        w_fid = float(w_fid/w_sum)
        w_pde = float(w_pde/w_sum)


        def loss_terms():
          fid = fidelity_loss(cnn_fc, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
          pde = pde_loss_f(cnn_fc, input_grid, t_pde, y_pde, x_pde)
          return fid, pde

        def total_loss():
          fid, pde = loss_terms()
          return w_fid*fid + w_pde*pde

        foo = loss_terms()
        fid_loss, pde_loss = foo[0], foo[1]
        
        loss = w_fid*fid_loss + w_pde*pde_loss
        print(loss.grad)
        # loss.register_hook(lambda grad: print(grad))
        loss.backward()


        # gradient clipping
        # torch.nn.utils.clip_grad_value_(cnn_fc.parameters(), clip_value=1.0)
        # torch.nn.utils.clip_grad_value_(imnet.module.parameters(), args.clip_grad)

        optimizer.step()
        # tot_loss += loss.item()
        # count += input_grid.size()[0]

        reg_loss_list.append((w_fid * fid_loss).detach().cpu())
        pde_loss_list.append((w_pde * pde_loss).detach().cpu())
        tot_loss_list.append((loss).detach().cpu())
        time0 = time()

        if batch_idx % 10 == 0:
            # logger log
            print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.4e}\t"
                "Loss Reg: {:.4e}\tLoss Pde: {:.4e}\t"
                "Ten Iters Time: {:.4e}".format(
                    epoch, batch_idx * len(input_grid), len(train_loader) * len(input_grid),
                    100. * batch_idx / len(train_loader), loss.item(),
                    w_fid * fid_loss, w_pde * pde_loss, time()-time0))
            
        if batch_idx % 100 == 0 and batch_idx > 100:
            fig, ax = pylab.subplots()
            pylab.plot(tot_loss_list, label='total loss')
            pylab.plot(reg_loss_list, label='reg loss')
            pylab.plot(pde_loss_list, label='pde loss')
            pylab.legend(loc='upper right')
            ax.set_yscale('log')
            plt.show(block = False)

        if batch_idx % 100 == 0 and batch_idx > 100:
            pred = cnn_fc(input_grid, t_eval, y_eval, x_eval)
            print(pred.shape)
            u_eval = pred[:,:,0].detach().cpu()
            print(u_eval.shape)
            u_eval = torch.reshape(u_eval, (u_eval.shape[0], 3, 78, 78))  # [b, t, y, x]
            #u_eval = u_eval.permute(0, 1, 3, 2)
            print(u_eval.shape)
            fig = plt.figure()
            ax1 = fig.add_subplot(121)
            ax2 = fig.add_subplot(122)
            ax1.imshow(input_grid[2, 0, :, :, 1].detach().cpu())
            ax2.imshow(u_eval[2,1,:,:])
            plt.show()

        global_step += 1
    # tot_loss /= count

    
    scheduler.step(loss)

print('\Training is done')
FILE = "cnn_fc.pth"
torch.save(cnn_fc.state_dict(), FILE)

# loaded_cnn_fc = Model(in_features=2, out_features=3, nf=16, activation=torch.nn.Tanh)
# loaded_cnn_fc.load_state_dict(torch.load(FILE))
# loaded_cnn_fc.eval()

It returns:

None
Train Epoch: 1 [0/2025 (0%)]	Loss Sum: 1.9692e+00	Loss Reg: 1.9692e+00	Loss Pde: 6.9719e-08	Ten Iters Time: 1.6594e-04
None
None
None
None
None
None
None
None
None
None
Train Epoch: 1 [50/2025 (2%)]	Loss Sum: 2.7263e+00	Loss Reg: 2.7263e+00	Loss Pde: 1.3230e-05	Ten Iters Time: 8.6784e-05
None
None
None

I think the issue is here:

        # the weights of the loss terms
        w_fid = 1.0
        w_pde = 0.001

        # normalizing the weights such that their sum equals 1
        w_sum = w_fid + w_pde
        w_fid = float(w_fid/w_sum)
        w_pde = float(w_pde/w_sum)


        def loss_terms():
          fid = fidelity_loss(cnn_fc, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
          pde = pde_loss_f(cnn_fc, input_grid, t_pde, y_pde, x_pde)
          return fid, pde

        def total_loss():
          fid, pde = loss_terms()
          return w_fid*fid + w_pde*pde

        foo = loss_terms()
        fid_loss, pde_loss = foo[0], foo[1]
        
        loss = w_fid*fid_loss + w_pde*pde_loss
        print(loss.grad)
        # loss.register_hook(lambda grad: print(grad))
        loss.backward()

but I don’t know how to handle this problem

I don’t see any obvious issues in your code so could you post a minimal, executable code snippet which would reproduce this issue, please?