Following code includes a network that combines a convolutional network and a fully connected net. I have built up a custom dataloader and also custom loss functions. This has two loss functions, one for regression and another for the underlying physics (pde_loss). however, the pde_loss return zero after the first iteration. Also, regression loss remains the same, and the training doesn’t converge.
I would appreciate if anyone could help me with this issue.
"""
Class for Convolutional Network
with ResNet backbone
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
import math
import warnings
from torch.autograd import Variable
class ResBlock3D(nn.Module):
"""3D convolutional Residue Block. Maintains same resolution.
"""
def __init__(self, in_channels, neck_channels, out_channels, final_relu=True):
"""Initialization.
Args:
in_channels: int, number of input channels.
neck_channels: int, number of channels in bottleneck layer.
out_channels: int, number of output channels.
final_relu: bool, add relu to the last layer.
"""
super(ResBlock3D, self).__init__()
self.in_channels = in_channels
self.neck_channels = neck_channels
self.out_channels = out_channels
self.conv1 = nn.Conv3d(in_channels, neck_channels, kernel_size=1, stride=1)
self.conv2 = nn.Conv3d(neck_channels, neck_channels, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv3d(neck_channels, out_channels, kernel_size=1, stride=1)
self.bn1 = nn.BatchNorm3d(num_features=neck_channels)
self.bn2 = nn.BatchNorm3d(num_features=neck_channels)
self.bn3 = nn.BatchNorm3d(num_features=out_channels)
self.shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=1, stride=1)
self.final_relu = final_relu
def forward(self, x): # pylint:
identity = x
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x += self.shortcut(identity)
if self.final_relu:
x = F.relu(x)
return x
class CNN_FC(nn.Module):
def __init__(self, in_features=2, out_features=3, nf=16,
activation=torch.nn.Tanh):
super(CNN_FC, self).__init__()
self.nf = nf
self.in_features = in_features
self.out_features = out_features
self.activ = activation()
self.conv_in = ResBlock3D(self.in_features, self.nf, self.nf) # ResBlock3D(in=2, neck=16, out=16)
self.conv11 = ResBlock3D(self.nf, self.nf, self.nf*2)
self.conv12 = ResBlock3D(self.nf*2, self.nf*2, self.nf*4)
self.conv13 = ResBlock3D(self.nf*4, self.nf*4, self.nf*8)
self.conv14 = ResBlock3D(self.nf*8, self.nf*8, self.nf*16)
self.convs = [self.conv_in, self.conv11, self.conv12, self.conv13, self.conv14]
self.convs = nn.ModuleList(self.convs)
self.maxpool11 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
self.maxpool12 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 1), padding=0, dilation=1, ceil_mode=False)
self.maxpool13 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(2, 2, 2), padding=0, dilation=1, ceil_mode=False)
self.maxpool14 = nn.MaxPool3d(kernel_size=(2, 2, 1), stride=(1, 1, 2), padding=0, dilation=1, ceil_mode=False)
self.maxpools = [self.maxpool11, self.maxpool12, self.maxpool13, self.maxpool14]
self.maxpools = nn.ModuleList(self.maxpools)
self.flatten1 = nn.Flatten()
self.flatten = [self.flatten1]
self.flatten = nn.ModuleList(self.flatten)
self.fc0 = nn.Linear(4, nf*32)
self.fc1 = nn.Linear(nf*32 , nf*16)
self.fc2 = nn.Linear(nf*16 , nf*8)
self.fc3 = nn.Linear(nf*8 , nf*4)
self.fc4 = nn.Linear(nf*4 , nf*2)
self.fc5 = nn.Linear(nf*2, out_features)
self.fc = [self.fc0, self.fc1, self.fc2, self.fc3, self.fc4, self.fc5]
self.fc = nn.ModuleList(self.fc)
def forward(self, c, t, y, x):
# first entry of x should be batch size! x = torch.randn(batch_size,numpoints,1)
# c is the output of convolutional prtion
c = self.conv_in(c)
c = self.conv11(c)
c = self.maxpool11(c)
c = self.conv12(c)
c = self.maxpool12(c)
c = self.conv13(c)
c = self.maxpool13(c)
c = self.conv14(c)
c = self.maxpool14(c)
c = self.flatten1(c)
c = c.unsqueeze(-1)
c = c.repeat(1,int(x.shape[1]/c.shape[1]),1)
# c = Variable(c, requires_grad=True)
x_tmp = torch.cat((t, y, x, c), dim=-1)
x_tmp = self.fc0(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc1(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc2(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc3(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc4(x_tmp)
x_tmp = self.activ(x_tmp)
x_tmp = self.fc5(x_tmp)
return x_tmp
import os
import torch
from torch.utils.data import Dataset, Sampler
import numpy as np
from scipy.interpolate import RegularGridInterpolator
from scipy import ndimage
from scipy.io import savemat, loadmat
import warnings
# in __init__ load lres data (entire data, generated in matlab)
#
class cylinder_dataset(Dataset):
def __init__(self, data_dir="./", data_filename="./CFD_DATA_space_avg.mat", center_data_filename="./CFD_DATA_centers.mat",
nx=16, ny=16, nt=3, n_samp_pts_per_crop=3000):
self.nt = nt
self.nx = nx
self.ny = ny
self.data_dir = data_dir
self.data_filename = data_filename
self.center_data_filename = center_data_filename
self.n_samp_pts_per_crop = n_samp_pts_per_crop
################## load low res data (training data)
npdata = loadmat(os.path.join(self.data_dir, self.data_filename))
self.data = np.stack([npdata['u_space_avg'], npdata['v_space_avg']], axis=0) # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
self.data = self.data.astype(np.float32)
#self.data = self.data.transpose(0, 3, 2, 1) # [c, t, y, x] # c is number of channels
nc_data, nt_data, ny_data, nx_data = self.data.shape
# if highres: nc = 2, nt = 1501, ny = 80, nx = 640
# if lowres (by factor of 5): nc=2, nt=1501, ny=16, nx=128
self.nx_start_range = np.arange(0, nx_data-nx+1) # 0 to 560 fo high res, 0 to 112 for low res
self.ny_start_range = np.arange(0, ny_data-ny+1) # returns an array shape (1,) ([0])
self.nt_start_range = np.arange(0, nt_data-nt+1) # 0 to 1498
self.rand_grid = np.stack(np.meshgrid(self.nt_start_range,
self.ny_start_range,
self.nx_start_range, indexing='ij'), axis=-1)
self.rand_start_id = self.rand_grid.reshape([-1, 3])
# creates a indexing array for moving window, that starts from [0,0,0] ([t,y,x]), first moves along x, then moves along t (y has only a single index), such that, index 0 returns [0,0,0], index 1 returns [0,0,1], index 113 (for low res input data) returns [1,0,0], and so on
self.num_samples = self.rand_start_id.shape[0]
################## load center data (fidelity data)
center_npdata = loadmat(os.path.join(self.data_dir, self.center_data_filename))
self.center_data = np.stack([center_npdata['u_space_center'], center_npdata['v_space_center']], axis=0) # it seems like each of p, b, u, and w are already [200, 512, 128] shaped arrays
self.center_data = self.center_data.astype(np.float32)
def __getitem__(self, index):
t_id, y_id, x_id = self.rand_start_id[index] # idx is the id of the crop of the data that is passed to Dataloader.
space_time_crop = self.data[:,
t_id:t_id+self.nt,
y_id:y_id+self.ny,
x_id:x_id+self.nx] # [c, t, y, x] c is the channel
space_time_crop = np.transpose(space_time_crop, (0, 2, 3, 1))
#space_time_crop = np.swapaxes(space_time_crop, 1, 3)
center_space_time_crop = self.center_data[:,
t_id:t_id+self.nt,
y_id:y_id+self.ny,
x_id:x_id+self.nx] # [c, t, y, x] c is the channel
center_space_time_crop = np.transpose(center_space_time_crop, (0, 2, 3, 1)) # [c, y, x, t]
center_space_time_crop_reshaped = np.reshape(center_space_time_crop.transpose(1,2,3,0), (-1, 2)) # [t,y,x,c]
lres_coord = np.stack(np.meshgrid(np.linspace(0, self.nt-1, self.nt),
np.linspace(0, self.ny-1, self.ny),
np.linspace(0, self.nx-1, self.nx),
indexing='ij'), axis=-1)
lres_coord = np.reshape(lres_coord, (self.nt*self.nx*self.ny,3))
# lres_coord returns an array that includes coordinates of points such that, [0,:,:,:] is the coordinates of the first slice (t=0) that marches first on x (index 2) then on y (index 1)
pde_mesh_coord = np.stack(np.meshgrid(np.random.rand(10)*[self.nt-1],
np.random.rand(300)*[self.ny-1],
np.random.rand(300)*[self.nx-1],
indexing='ij'), axis=-1)
# creates an array including 10 slices in time with 300*300 random points in y and x
pde_point_coord = np.random.rand(self.n_samp_pts_per_crop, 3) * (np.array([3, 16, 16], dtype=np.int32) - 1)
return_tensors = [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord]
return tuple(return_tensors)
def __len__(self):
return self.rand_start_id.shape[0]
#===============================================================================
# fidelity_loss
#===============================================================================
def fidelity_loss(model, input_image, t, y, x, fid_image_tensor):
output_tensor = model(input_image, t, y, x) # output shape: ([batch, num_points, 3])
fid_tensor_u = fid_image_tensor[:,:,0]
fid_tensor_v = fid_image_tensor[:,:,1]
output_u = output_tensor[:,:,0]
output_v = output_tensor[:,:,1]
fid_u_loss = torch.mean(torch.square(output_u-fid_tensor_u))
fid_v_loss = torch.mean(torch.square(output_v-fid_tensor_v))
# fid_image_tensor: ([batch, 768, 2(channel)])
return fid_u_loss + fid_v_loss
#===============================================================================
# pde_loss
#===============================================================================
def pde_loss_f(model, input_image, t, y, x):
output_tensor = model(input_image, t, y, x) # [batch, num_points, 3]
U = output_tensor[:,:,0].unsqueeze(-1)
V = output_tensor[:,:,1].unsqueeze(-1)
P = output_tensor[:,:,2].unsqueeze(-1)
U_t = torch.autograd.grad(U.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_y = torch.autograd.grad(U.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_yy = torch.autograd.grad(U_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_x = torch.autograd.grad(U.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
U_xx = torch.autograd.grad(U_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_t = torch.autograd.grad(V.sum(), t, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_y = torch.autograd.grad(V.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_yy = torch.autograd.grad(V_y.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_x = torch.autograd.grad(V.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
V_xx = torch.autograd.grad(V_x.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
P_y = torch.autograd.grad(P.sum(), y, create_graph=True, retain_graph=True, allow_unused=True)[0]
P_x = torch.autograd.grad(P.sum(), x, create_graph=True, retain_graph=True, allow_unused=True)[0]
momentum_x_loss = torch.mean(torch.square(
(U_t) - ((1/160)*(U_xx + U_yy)) + (U*U_x + V*U_y) + (P_x)
))
momentum_y_loss = torch.mean(torch.square(
(V_t) - ((1/160)*(V_xx + V_yy)) + (U*V_x + V*V_y) + (P_y)
))
continuity_loss = torch.mean(torch.square(
(U_x) + (V_y)
))
return momentum_x_loss + momentum_y_loss + continuity_loss
########## train #############
import argparse
import json
import os
from glob import glob
import numpy as np
from collections import defaultdict
np.set_printoptions(precision=4)
import torch
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader, RandomSampler
from torch.utils.tensorboard import SummaryWriter
####### train function ##########
def train(CNN_FC, train_loader, epoch, global_step, device,
logger, writer, optimizer):
"""Training function."""
CNN_FC.train()
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# cnn_fc = CNN_FC()
# cnn_fc = cnn_fc.to(device)
# optimizer = torch.optim.Adam(cnn_fc.parameters(), lr=0.001)
#list(cnn_fc.parameters())
tot_loss = 0
count = 0
for batch_idx, data_tensors in enumerate(train_loader):
# send tensors to device
data_tensors = [t.to(device) for t in data_tensors]
# [space_time_crop, center_space_time_crop, center_space_time_crop_reshaped, lres_coord, pde_mesh_coord, pde_point_coord]
input_grid, center_input_grid, center_input_grid_reshaped, lres_coord, pde_mesh_coord, pde_point_coord = data_tensors # input_grid is lowres data, point coord and point values are lowres data in dataset. data_tensor is [space_time_crop, center_space_time_crop, lres_coord, pde_mesh_coord, pde_point_coord]
optimizer.zero_grad()
input_grid = Variable(input_grid.float(), requires_grad=False)
center_input_grid = Variable(center_input_grid.float(), requires_grad=False)
center_input_grid_reshaped = Variable(center_input_grid_reshaped.float(), requires_grad=False)
lres_coord = Variable(lres_coord.float(), requires_grad=False)
pde_mesh_coord = Variable(pde_mesh_coord.float(), requires_grad=True)
pde_point_coord = Variable(pde_point_coord.float(), requires_grad=True)
t_fid = lres_coord[:,:,0]
t_fid = torch.reshape(t_fid, (t_fid.shape[0], t_fid.shape[1], 1))
y_fid = lres_coord[:,:,1]
y_fid = torch.reshape(y_fid, (y_fid.shape[0], y_fid.shape[1], 1))
x_fid = lres_coord[:,:,2]
x_fid = torch.reshape(x_fid, (x_fid.shape[0], x_fid.shape[1], 1))
t_pde = pde_point_coord[:,:,0]
t_pde = torch.reshape(t_pde, (t_pde.shape[0], t_pde.shape[1], 1))
y_pde = pde_point_coord[:,:,1]
y_pde = torch.reshape(y_pde, (y_pde.shape[0], y_pde.shape[1], 1))
x_pde = pde_point_coord[:,:,2]
x_pde = torch.reshape(x_pde, (x_pde.shape[0], x_pde.shape[1], 1))
# the weights of the loss terms
w_fid = 1.0
w_pde = 1.0
# normalizing the weights such that their sum equals 1
w_sum = w_fid + w_pde
w_fid = float(w_fid/w_sum)
w_pde = float(w_pde/w_sum)
def loss_terms():
fid = fidelity_loss(CNN_FC, input_grid, t_fid, y_fid, x_fid, center_input_grid_reshaped)
pde = pde_loss_f(CNN_FC, input_grid, t_pde, y_pde, x_pde)
return fid, pde
def total_loss():
fid, pde = loss_terms()
return w_fid*fid + w_pde*pde
foo = loss_terms()
fid_loss, pde_loss = foo[0], foo[1]
loss = w_fid*fid_loss + w_pde*pde_loss
loss.backward()
################# CNN_FC.module.parameters() change to CNN_FC.parameters()
# gradient clipping
torch.nn.utils.clip_grad_value_(CNN_FC.parameters(), 1.)
# torch.nn.utils.clip_grad_value_(imnet.module.parameters(), args.clip_grad)
optimizer.step()
tot_loss += loss.item()
count += input_grid.size()[0]
if batch_idx % 10 == 0:
# logger log
logger.info(
"Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss Sum: {:.6f}\t"
"Loss Reg: {:.6f}\tLoss Pde: {:.6f}".format(
epoch, batch_idx * len(input_grid), len(train_loader) * len(input_grid),
100. * batch_idx / len(train_loader), loss.item(),
w_fid * fid_loss, w_pde * pde_loss))
# tensorboard log
writer.add_scalar('train/reg_loss_unweighted', fid_loss, global_step=int(global_step))
writer.add_scalar('train/pde_loss_unweighted', pde_loss, global_step=int(global_step))
writer.add_scalar('train/sum_loss', loss, global_step=int(global_step))
writer.add_scalars('train/losses_weighted',
{"reg_loss": w_fid * fid_loss,
"pde_loss": w_pde * pde_loss,
"sum_loss": loss}, global_step=int(global_step))
global_step += 1
tot_loss /= count
return tot_loss
########################
import train_utils as utils
from torch.utils.tensorboard import SummaryWriter
def main():
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = int(torch.cuda.device_count()) * int(16)
# log and create snapshots
os.makedirs("./log_dir", exist_ok=True)
filenames_to_snapshot = glob("*.py") + glob("*.sh")
utils.snapshot_files(filenames_to_snapshot, "./log_dir")
logger = utils.get_logger(log_dir="./log_dir")
# tensorboard writer
writer = SummaryWriter(log_dir=os.path.join("./log_dir", 'tensorboard'))
# random seed for reproducability
torch.manual_seed(1)
np.random.seed(1)
# create dataloaders
trainset = cylinder_dataset(
data_dir="./", data_filename="./CFD_DATA_space_avg.mat", center_data_filename="./CFD_DATA_centers.mat",
nx=16, ny=16, nt=3, n_samp_pts_per_crop=3072
)
############# check what this is
train_sampler = RandomSampler(trainset, replacement=True, num_samples=3072)
train_loader = DataLoader(trainset, batch_size=16, shuffle=False, drop_last=True,
sampler=train_sampler, num_workers=1, pin_memory=True)
# setup model
cnn_fc = CNN_FC(in_features=2, out_features=3, nf=16, activation=torch.nn.Tanh)
all_model_params = list(cnn_fc.parameters())
optimizer = optim.Adam(all_model_params, lr=1e-2)
start_ep = 0
global_step = np.zeros(1, dtype=np.uint32)
tracked_stats = np.inf
# if args.resume:
# resume_dict = torch.load(args.resume)
# start_ep = resume_dict["epoch"]
# global_step = resume_dict["global_step"]
# tracked_stats = resume_dict["tracked_stats"]
# unet.load_state_dict(resume_dict["unet_state_dict"])
# imnet.load_state_dict(resume_dict["imnet_state_dict"])
# optimizer.load_state_dict(resume_dict["optim_state_dict"])
# for state in optimizer.state.values():
# for k, v in state.items():
# if isinstance(v, torch.Tensor):
# state[k] = v.to(device)
cnn_fc.to(device)
model_param_count = lambda model: sum(x.numel() for x in model.parameters())
logger.info("{} cnn_fc paramerters in total".format(model_param_count(cnn_fc)))
checkpoint_path = os.path.join("./log_dir", "checkpoint_latest.pth.tar")
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')
# training loop
for epoch in range(start_ep + 1, 100 + 1):
loss = train(cnn_fc, train_loader, epoch, global_step, device, logger, writer,
optimizer)
scheduler.step(loss)
if loss < tracked_stats:
tracked_stats = loss
is_best = True
else:
is_best = False
# "cnn_fc_state_dict": cnn_fc.module.state_dict()
utils.save_checkpoint({
"epoch": epoch,
"cnn_fc_state_dict": cnn_fc.state_dict(),
"optim_state_dict": optimizer.state_dict(),
"tracked_stats": tracked_stats,
"global_step": global_step,
}, is_best, epoch, checkpoint_path, "_pdenet", logger)
if __name__ == "__main__":
main()