Hello,
following is my code with network combining CNN and a fully-connected net. This is a physics-informed neural network. So, I defined custom loss functions including a regression loss function and e PDE loss function. I’m getting None value for the model grads after and before backward step. I believe this is due to how I defined the loss function. I would appreciate it if anyone could help me with this issue.
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import warnings
from torch.autograd import Variable
class CNN_FC(nn.Module):
def __init__(self, in_features=2, out_features=3, nf=13,
activation=torch.nn.Tanh, cnn_activation=torch.nn.ReLU):
"""Initialization
Args:
in_channels: int, number of input channels.
neck_channels: int, number of channels in bottleneck layer.
out_channels: int, number of output channels.
final_relu: bool, add relu to the last layer.
"""
super(CNN_FC, self).__init__()
self.nf = nf
self.in_features = in_features
self.out_features = out_features
self.activ = activation()
self.cnn_activ = cnn_activation()
self.conv_in = nn.Conv3d(self.in_features, self.nf, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='valid')
self.conv11 = nn.Conv3d(self.nf, self.nf*2, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.conv12 = nn.Conv3d(self.nf*2, self.nf*3, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.conv13 = nn.Conv3d(self.nf*3, self.nf*6, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.conv14 = nn.Conv3d(self.nf*6, self.nf*13, kernel_size=(2, 2, 1), stride=(1, 1, 1), padding='same')
self.convs = [self.conv_in, self.conv11, self.conv12, self.conv13, self.conv14]
self.convs = nn.ModuleList(self.convs)
self.maxpool11 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
self.maxpool12 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
self.maxpool13 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
self.maxpool14 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(1, 1, 2), padding=0, dilation=1, ceil_mode=False)
self.maxpools = [self.maxpool11, self.maxpool12, self.maxpool13, self.maxpool14]
self.maxpools = nn.ModuleList(self.maxpools)
self.flatten1 = nn.Flatten()
self.flatten = [self.flatten1]
self.flatten = nn.ModuleList(self.flatten)
self.fc0 = nn.Linear(679, nf*64)
self.fc1 = nn.Linear(nf*64 , nf*32)
self.fc2 = nn.Linear(nf*32 , nf*16)
self.fc3 = nn.Linear(nf*16 , nf*8)
self.fc4 = nn.Linear(nf*8 , nf*4)
self.fc5 = nn.Linear(nf*4, out_features)
self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]
self.fc = nn.ModuleList(self.fc)
def forward(self, c, t, y, x):
# first entry of x should be batch size! x = torch.randn(batch_size,numpoints,1)
# c is the output of convolutional prtion
c = self.conv_in(c)
c = self.cnn_activ(c)
c = self.conv11(c)
c = self.cnn_activ(c)
c = self.maxpool11(c)
c = self.conv12(c)
c = self.cnn_activ(c)
c = self.maxpool12(c)
c = self.conv13(c)
c = self.cnn_activ(c)
c = self.maxpool13(c)
c = self.conv14(c)
c = self.cnn_activ(c)
c = self.maxpool14(c)
c = self.flatten1(c)
# print(c.shape)
c = c.unsqueeze(1)
c = c.repeat(1,int(x.shape[1]),1)
x_tmp = torch.cat((c, t, y, x), dim=-1)
x_tmp = self.fc0(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc1(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc2(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc3(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc4(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc5(x_tmp)
return x_tmp
import os
import torch
from torch.utils.data import Dataset, Sampler
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy import ndimage
from scipy.io import savemat, loadmat
import warnings
# in __init__ load lres data (entire data, generated in matlab)
#
class cylinder_dataset(Dataset):
def __init__(self, data_dir="./", data_filename="CFD_DATA_space_avg_before_cyl.mat", center_data_filename="CFD_DATA_centers_before_cyl.mat",
nx=26, ny=26, nt=3, n_samp_pts_per_crop=2028, normalize_output=True):
self.nt = nt
self.nx = nx
self.ny = ny
self.data_dir = data_dir
self.data_filename = data_filename
self.center_data_filename = center_data_filename
self.n_samp_pts_per_crop = n_samp_pts_per_crop
self.normalize_output = normalize_output
################## load low res data (training data)
npdata = loadmat(os.path.join(self.data_dir, self.data_filename))
self.data = np.stack([npdata['u_space_avg'], npdata['v_space_avg']], axis=0) # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
self.data = self.data.astype(np.float32)
#self.data = self.data.transpose(0, 3, 2, 1) # [c, t, y, x] # c is number of channels
nc_data, nt_data, ny_data, nx_data = self.data.shape
# if highres: nc = 2, nt = 1501, ny = 80, nx = 640
# if lowres (by factor of 5): nc=2, nt=1501, ny=16, nx=128
self.nx_start_range = np.arange(0, nx_data-nx+1) # 0 to 560 fo high res, 0 to 112 for low res
self.ny_start_range = np.arange(0, ny_data-ny+1) # returns an array shape (1,) ([0])
self.nt_start_range = np.arange(0, nt_data-nt+1) # 0 to 1498
self.rand_grid = np.stack(np.meshgrid(self.nt_start_range,
self.ny_start_range,
self.nx_start_range, indexing='ij'), axis=-1)
self.rand_start_id = self.rand_grid.reshape([-1, 3])
self.num_samples = self.rand_start_id.shape[0]
################## load center data (fidelity data)
center_npdata = loadmat(os.path.join(self.data_dir, self.center_data_filename))
self.center_data = np.stack([center_npdata['u_space_center'], center_npdata['v_space_center']], axis=0) # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
self.center_data = self.center_data.astype(np.float32)
# compute channel-wise mean and std
self._mean = np.mean(self.data, axis=(1, 2, 3))
self._std = np.std(self.data, axis=(1, 2, 3))
def __getitem__(self, index):
t_id, y_id, x_id = self.rand_start_id[index] # idx is the id of the crop of the data that is passed to Dataloader.
space_time_crop = self.data[:,
t_id:t_id+self.nt,
y_id:y_id+self.ny,
x_id:x_id+self.nx] # [c, t, y, x] c is the channel
space_time_crop = np.transpose(space_time_crop, (0, 2, 3, 1))
center_space_time_crop = self.center_data[:,
t_id:t_id+self.nt,
y_id:y_id+self.ny,
x_id:x_id+self.nx] # [c, t, y, x] c is the channel
center_space_time_crop = np.transpose(center_space_time_crop, (0, 2, 3, 1)) # [c, y, x, t]
center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(3, 1, 2, 0), (-1, 2)) # [t,y,x,c]
lres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
np.linspace(0, self.ny-1, self.ny),
np.linspace(0, self.nx-1, self.nx),
indexing='ij'), axis=-1)
lres_coord = np.reshape(lres_coord, (self.nt*self.nx*self.ny,3))
# lres_coord returns an array that includes coordinates of points such that, [0,:,:,:] is the coordinates of the first slice (t=0) that marches first on x (index 2) then on y (index 1)
hres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
np.linspace(0, self.ny*3-1, self.ny*3),
np.linspace(0, self.nx*3-1, self.nx*3),
indexing='ij'), axis=-1)
hres_coord = np.reshape(hres_coord, (self.nt*self.nx*3*self.ny*3,3))
pde_mesh_coord = np.stack(np.meshgrid(np.random.rand(10)*[self.nt-1],
np.random.rand(300)*[self.ny-1],
np.random.rand(300)*[self.nx-1],
indexing='ij'), axis=-1)
# creates an array including 10 slices in time with 300*300 random points in y and x
pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3) * (np.array([3, 26, 26], dtype=np.int32) - 1)
# return an array with three columns such that first column is 3000 randon number between zero and nt-1, second column is 3000 points between zero and ny, third is between zero and nx
if self.normalize_output:
space_time_crop = self.normalize_grid(space_time_crop)
center_space_time_crop = self.normalize_grid(center_space_time_crop)
center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(3, 1, 2, 0), (-1, 2))
pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3)
lres_coord = lres_coord / (np.array([3, 26, 26], dtype=np.int32) - 1)
hres_coord = hres_coord / (np.array([3, 78, 78], dtype=np.int32) - 1)
return_tensors = [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord]
# cast everything to float32
return_tensors = [t.astype(np.float32) for t in return_tensors]
return tuple(return_tensors)
def __len__(self):
return self.rand_start_id.shape[0]
@property
def channel_mean(self):
"""channel-wise mean of dataset."""
return self._mean
@property
def channel_std(self):
"""channel-wise mean of dataset."""
return self._std
@staticmethod
def _normalize_array(array, mean, std):
"""normalize array (np or torch)."""
if isinstance(array, torch.Tensor):
dev = array.device
std = torch.tensor(std, device=dev)
mean = torch.tensor(mean, device=dev)
return (array - mean) / std
@staticmethod
def _denormalize_array(array, mean, std):
"""normalize array (np or torch)."""
if isinstance(array, torch.Tensor):
dev = array.device
std = torch.tensor(std, device=dev)
mean = torch.tensor(mean, device=dev)
return array * std + mean
def normalize_grid(self, grid):
"""Normalize grid.
Args:
grid: np array or torch tensor of shape [2, ...], 2 are the num. of phys channels.
Returns:
channel normalized grid of same shape as input.
"""
# reshape mean and std to be broadcastable.
g_dim = len(grid.shape)
mean_bc = self.channel_mean[(...,)+(None,)*(g_dim-1)] # unsqueeze from the back
std_bc = self.channel_std[(...,)+(None,)*(g_dim-1)] # unsqueeze from the back
return self._normalize_array(grid, mean_bc, std_bc)
def denormalize_grid(self, grid):
"""Denormalize grid.
Args:
grid: np array or torch tensor of shape [2, ...], 2 are the num. of phys channels.
Returns:
channel denormalized grid of same shape as input.
"""
# reshape mean and std to be broadcastable.
g_dim = len(grid.shape)
mean_bc = self.channel_mean[(...,)+(None,)*(g_dim-1)] # unsqueeze from the back
std_bc = self.channel_std[(...,)+(None,)*(g_dim-1)] # unsqueeze from the back
return self._denormalize_array(grid, mean_bc, std_bc)
import torch.nn.functional as F
#===============================================================================
# fidelity_loss
#===============================================================================
# criterion = F.mse_loss()
criterion = torch.nn.MSELoss()
def fidelity_loss(model, input_image, t, y, x, fid_image_tensor):
output_tensor = model(input_image, t, y, x) # output shape: ([batch, num_points, 3]) take first out is u, then v, then p
fid_tensor_u = fid_image_tensor[:,:,0]
fid_tensor_v = fid_image_tensor[:,:,1]
output_u = output_tensor[:,:,0]
output_v = output_tensor[:,:,1]
# fid_u_loss = torch.mean(torch.square(output_u-fid_tensor_u))
# fid_v_loss = torch.mean(torch.square(output_v-fid_tensor_v))
fid_u_loss = criterion(output_u, fid_tensor_u)
fid_v_loss = criterion(output_v, fid_tensor_v)
return fid_u_loss + fid_v_loss
#===============================================================================
# pde_loss
#===============================================================================
def pde_loss_f(model, input_image, t, y, x):
output_tensor = model(input_image, t, y, x) # [batch, num_points, 3]
U = output_tensor[:,:,0].unsqueeze(-1)
V = output_tensor[:,:,1].unsqueeze(-1)
P = output_tensor[:,:,2].unsqueeze(-1)
U_t = torch.autograd.grad(U.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_y = torch.autograd.grad(U.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_yy = torch.autograd.grad(U_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_x = torch.autograd.grad(U.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_xx = torch.autograd.grad(U_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_t = torch.autograd.grad(V.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_y = torch.autograd.grad(V.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_yy = torch.autograd.grad(V_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_x = torch.autograd.grad(V.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_xx = torch.autograd.grad(V_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
P_y = torch.autograd.grad(P.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
P_x = torch.autograd.grad(P.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
momentum_x = (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
momentum_x_loss = criterion(momentum_x, torch.zeros_like(momentum_x))
# momentum_x_loss = torch.mean(torch.square(
# (U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
# ))
momentum_y = (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
momentum_y_loss = criterion(momentum_y, torch.zeros_like(momentum_y))
# momentum_y_loss = torch.mean(torch.square(
# (V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
# ))
continuity = (U_x) + (V_y)
continuity_loss = criterion(continuity, torch.zeros_like(continuity))
# continuity_loss = torch.mean(torch.square(
# (U_x) + (V_y)
# ))
# f'{nt}*dif(u,t) - {Re_inv}*(({nx})**2*dif(dif(u,x),x)+({ny})**2*dif(dif(u,y),y)) + (u*{nx}*dif(u,x)+v*{ny}*dif(u,y)) + dif(p,x)',
# f'{nt}*dif(v,t) - {Re_inv}*(({nx})**2*dif(dif(v,x),x)+({ny})**2*dif(dif(v,y),y)) + (u*{nx}*dif(v,x)+v*{ny}*dif(v,y)) + dif(p,y)',
# f'{nx} * dif(u, x) + {ny} * dif(v, y)')
return momentum_x_loss + momentum_y_loss + continuity_loss
import argparse
import json
import os
from glob import glob
import numpy as np
from collections import defaultdict
np.set_printoptions(precision=4)
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.tensorboard import SummaryWriter
import pylab
from time import time
import matplotlib.pyplot as plt
tot_loss_list = []
reg_loss_list = []
pde_loss_list = []
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = int(torch.cuda.device_count()) * int(16)
# log and create snapshots
# os.makedirs("./log_dir", exist_ok=True)
# filenames_to_snapshot = glob("*.py") + glob("*.sh")
# utils.snapshot_files(filenames_to_snapshot, "./log_dir")
# logger = utils.get_logger(log_dir="./log_dir")
# tensorboard writer
# writer = SummaryWriter(log_dir=os.path.join("./log_dir", 'tensorboard'))
# random seed for reproducability
torch.manual_seed(1)
np.random.seed(1)
# create dataloaders
trainset = cylinder_dataset(
data_dir="./", data_filename="./CFD_DATA_space_avg_before_cyl.mat", center_data_filename="./CFD_DATA_center_before_cyl.mat",
nx=26, ny=26, nt=3, n_samp_pts_per_crop=2028
)
############# check what this is
train_sampler = RandomSampler(trainset, replacement=True, num_samples=2028)
train_loader = DataLoader(trainset, batch_size=5, shuffle=False, drop_last=True,
sampler=train_sampler, num_workers=2, pin_memory=True)
# train_sampler = RandomSampler(trainset, replacement=True, num_samples=3072)
# train_loader = DataLoader(trainset, batch_size=16, shuffle=False, drop_last=True,
# num_workers=2, pin_memory=True)
# setup model
cnn_fc = CNN_FC(in_features=2, out_features=3, nf=13, activation=torch.nn.SiLU)
def init_weights(m):
if isinstance(m, nn.Conv3d):
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear):
torch.nn.init.xavier_normal_(m.weight)
m.bias.data.fill_(0.0)
cnn_fc.apply(init_weights)
all_model_params = list(cnn_fc.parameters())
optimizer = optim.Adam(all_model_params, lr=1e-2)
start_ep = 0
global_step = np.zeros(1, dtype=np.uint32)
cnn_fc.to(device)
model_param_count = lambda model: sum(x.numel() for x in model.parameters())
print("{} cnn_fc paramerters in total".format(model_param_count(cnn_fc)))
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# training loop
for epoch in range(start_ep + 1, 20 + 1):
# loss = train(cnn_fc, train_loader, epoch, global_step, device, optimizer)
# tot_loss = 0
# count = 0
for batch_idx, data_tensors in enumerate(train_loader):
# send tensors to device
data_tensors = [t.to(device) for t in data_tensors]
# [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord]
input_grid, center_input_grid, center_input_grid_reshaped, lres_coord, pde_mesh_coord, pde_point_coord, hres_coord = data_tensors # input_grid is lowres data, point coord and point values are lowres data in dataset. data_tensor is [space_time_crop, center_space_time_crop, lres_coord, pde_mesh_coord, pde_point_coord]
optimizer.zero_grad()
# input_grid = input_grid.float()
# center_input_grid = center_input_grid.float()
# center_input_grid_reshaped = center_input_grid_reshaped.float()
# lres_coord = lres_coord.float().requires_grad_(True)
pde_mesh_coord = pde_mesh_coord.float().requires_grad_(True)
pde_point_coord = pde_point_coord.float().requires_grad_(True)
t_fid = lres_coord[:,:,0]
t_fid = torch.reshape(t_fid, (t_fid.shape[0], t_fid.shape[1], 1))
y_fid = lres_coord[:,:,1]
y_fid = torch.reshape(y_fid, (y_fid.shape[0], y_fid.shape[1], 1))
x_fid = lres_coord[:,:,2]
x_fid = torch.reshape(x_fid, (x_fid.shape[0], x_fid.shape[1], 1))
t_pde = pde_point_coord[:,:,0]
t_pde = torch.reshape(t_pde, (t_pde.shape[0], t_pde.shape[1], 1))
y_pde = pde_point_coord[:,:,1]
y_pde = torch.reshape(y_pde, (y_pde.shape[0], y_pde.shape[1], 1))
x_pde = pde_point_coord[:,:,2]
x_pde = torch.reshape(x_pde, (x_pde.shape[0], x_pde.shape[1], 1))
t_eval = hres_coord[:,:,0]
t_eval = torch.reshape(t_eval, (t_eval.shape[0], t_eval.shape[1], 1))
y_eval = hres_coord[:,:,1]
y_eval = torch.reshape(y_eval, (y_eval.shape[0], y_eval.shape[1], 1))
x_eval = hres_coord[:,:,2]
x_eval = torch.reshape(x_eval, (x_eval.shape[0], x_eval.shape[1], 1))
# the weights of the loss terms
w_fid = 1.0
w_pde = 0.001
# normalizing the weights such that their sum equals 1
w_sum = w_fid + w_pde
w_fid = float(w_fid/w_sum)
w_pde = float(w_pde/w_sum)
def loss_terms():
fid = fidelity_loss(cnn_fc, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
pde = pde_loss_f(cnn_fc, input_grid, t_pde, y_pde, x_pde)
return fid, pde
def total_loss():
fid, pde = loss_terms()
return w_fid*fid + w_pde*pde
foo = loss_terms()
fid_loss, pde_loss = foo[0], foo[1]
loss = w_fid*fid_loss + w_pde*pde_loss
print(loss.grad)
# loss.register_hook(lambda grad: print(grad))
loss.backward()
# gradient clipping
# torch.nn.utils.clip_grad_value_(cnn_fc.parameters(), clip_value=1.0)
# torch.nn.utils.clip_grad_value_(imnet.module.parameters(), args.clip_grad)
optimizer.step()
# tot_loss += loss.item()
# count += input_grid.size()[0]
reg_loss_list.append((w_fid * fid_loss).detach().cpu())
pde_loss_list.append((w_pde * pde_loss).detach().cpu())
tot_loss_list.append((loss).detach().cpu())
time0 = time()
if batch_idx % 10 == 0:
# logger log
print("Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.4e}\t"
"Loss Reg: {:.4e}\tLoss Pde: {:.4e}\t"
"Ten Iters Time: {:.4e}".format(
epoch, batch_idx * len(input_grid), len(train_loader) * len(input_grid),
100. * batch_idx / len(train_loader), loss.item(),
w_fid * fid_loss, w_pde * pde_loss, time()-time0))
if batch_idx % 100 == 0 and batch_idx > 100:
fig, ax = pylab.subplots()
pylab.plot(tot_loss_list, label='total loss')
pylab.plot(reg_loss_list, label='reg loss')
pylab.plot(pde_loss_list, label='pde loss')
pylab.legend(loc='upper right')
ax.set_yscale('log')
plt.show(block = False)
if batch_idx % 100 == 0 and batch_idx > 100:
pred = cnn_fc(input_grid, t_eval, y_eval, x_eval)
print(pred.shape)
u_eval = pred[:,:,0].detach().cpu()
print(u_eval.shape)
u_eval = torch.reshape(u_eval, (u_eval.shape[0], 3, 78, 78)) # [b, t, y, x]
#u_eval = u_eval.permute(0, 1, 3, 2)
print(u_eval.shape)
fig = plt.figure()
ax1 = fig.add_subplot(121)
ax2 = fig.add_subplot(122)
ax1.imshow(input_grid[2, 0, :, :, 1].detach().cpu())
ax2.imshow(u_eval[2,1,:,:])
plt.show()
global_step += 1
# tot_loss /= count
scheduler.step(loss)
print('\Training is done')
FILE = "cnn_fc.pth"
torch.save(cnn_fc.state_dict(), FILE)
# loaded_cnn_fc = Model(in_features=2, out_features=3, nf=16, activation=torch.nn.Tanh)
# loaded_cnn_fc.load_state_dict(torch.load(FILE))
# loaded_cnn_fc.eval()
It returns:
None
Train Epoch: 1 [0/2025 (0%)] Loss Sum: 1.9692e+00 Loss Reg: 1.9692e+00 Loss Pde: 6.9719e-08 Ten Iters Time: 1.6594e-04
None
None
None
None
None
None
None
None
None
None
Train Epoch: 1 [50/2025 (2%)] Loss Sum: 2.7263e+00 Loss Reg: 2.7263e+00 Loss Pde: 1.3230e-05 Ten Iters Time: 8.6784e-05
None
None
None
I think the issue is here:
# the weights of the loss terms
w_fid = 1.0
w_pde = 0.001
# normalizing the weights such that their sum equals 1
w_sum = w_fid + w_pde
w_fid = float(w_fid/w_sum)
w_pde = float(w_pde/w_sum)
def loss_terms():
fid = fidelity_loss(cnn_fc, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
pde = pde_loss_f(cnn_fc, input_grid, t_pde, y_pde, x_pde)
return fid, pde
def total_loss():
fid, pde = loss_terms()
return w_fid*fid + w_pde*pde
foo = loss_terms()
fid_loss, pde_loss = foo[0], foo[1]
loss = w_fid*fid_loss + w_pde*pde_loss
print(loss.grad)
# loss.register_hook(lambda grad: print(grad))
loss.backward()
but I don’t know how to handle this problem