Hi,
I’m a newcomer to pytorch.
I am trying to replicate a scientific ML model in pytorch which basically has two sub-networks with an inner product in the end. The TL;DR version is two feed-forward networks with ReLU activations in all but the last layer, an inner product and standard MSE loss.
It usually requires large paired input-output samples. I am attaching my code below, with two objectives:
- Speeding the overall training somehow… A jax version of the same code is quite fast but terrible to read for a newcomer.
- Getting to know why
num_workers>1
case fails.
import torch
from torch import nn
from torch.utils import data
import pickle, random
import numpy as np
from tqdm.notebook import tqdm, trange
# Branch Net
class BranchNet(nn.Module):
def __init__(self, num_branch_inputs, width, depth):
super(BranchNet, self).__init__()
self.num_branch_inputs = num_branch_inputs
self.width = width
self.depth = depth
self.branch = nn.Sequential(
nn.Linear(num_branch_inputs, width),
nn.ReLU(),
nn.Linear(width, width),
nn.ReLU(),
nn.Linear(width, width)
)
def _weight_init(self):
for layer in self.branch:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0.)
def forward(self, x):
return self.branch(x)
## Trunk Net
class TrunkNet(nn.Module):
def __init__(self, num_trunk_inputs, width, depth):
super(TrunkNet, self).__init__()
self.num_branch_inputs = num_trunk_inputs
self.width = width
self.depth = depth
self.trunk = nn.Sequential(
nn.Linear(num_trunk_inputs, width),
nn.ReLU(),
nn.Linear(width, width),
nn.ReLU(),
nn.Linear(width, width)
)
def _weight_init(self):
for layer in self.trunk:
if isinstance(layer, nn.Linear):
nn.init.xavier_normal_(layer.weight)
if layer.bias is not None:
nn.init.constant_(layer.bias, 0.)
def forward(self, x):
return self.trunk(x)
## Total
class DeepONet:
def __init__(
self, branch_width, branch_depth, trunk_width, trunk_depth, num_branch_inputs, num_trunk_inputs
):
self._dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
super(DeepONet, self).__init__()
self.branch_net = BranchNet(num_branch_inputs, branch_width, branch_depth)
self.trunk_net = TrunkNet(num_trunk_inputs, trunk_width, trunk_depth)
self.branch_net._weight_init()
self.trunk_net._weight_init()
self.branch_net.to(self._dev)
self.trunk_net.to(self._dev)
self.loss_log = []
def get_loss(self, inputs, outputs):
# inputs = inputs.to(self._dev)
# outputs = outputs.to(self._dev)
u, y = inputs
pred_branch = self.branch_net(u).to(self._dev)
pred_trunk = self.trunk_net(y).to(self._dev)
pred = torch.sum(pred_branch * pred_trunk) # output should be batch_size x 1 ?
loss = torch.nn.functional.mse_loss(pred.flatten(), outputs.flatten())
return loss
def train(self, num_epochs, lr, u_train, y_train, s_train):
dataset = DataGenerator(u_train, y_train, s_train, self._dev)
dataloader = data.DataLoader(dataset, batch_size=10000, shuffle=False, num_workers=8) ### THE CODE WORKS if `num_workers=0`
optimizer = torch.optim.Adam([*self.branch_net.parameters(), *self.trunk_net.parameters()], lr=lr)
for epoch in tqdm(range(num_epochs)):
for i, (inputs, outputs) in tqdm(enumerate(dataloader)):
loss = self.get_loss(inputs, outputs)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if epoch % 100 == 0:
print(f'Epoch {epoch}: {loss.item()}')
self.loss_log.append(loss.item())
return self.loss_log
def save(self, path):
torch.save(self.branch_net.state_dict(), path+"branch_net.pth")
torch.save(self.trunk_net.state_dict(), path+"trunk_net.pth")
with open(path+"loss_log.pkl", "wb") as f:
pickle.dump(self.loss_log, f)
def load(self, path):
self.branch_net.load_state_dict(torch.load(path+"branch_net.pth"))
self.trunk_net.load_state_dict(torch.load(path+"trunk_net.pth"))
with open(path+"loss_log.pkl", "rb") as f:
self.loss_log = pickle.load(f)
self.branch_net.eval()
self.trunk_net.eval()
def predict(self, inputs, outputs):
u, y = inputs
pred_branch = self.branch_net(u)
pred_trunk = self.trunk_net(y)
pred = torch.sum(pred_branch * pred_trunk) # output should be batch_size x 1 ?
loss = torch.nn.functional.mse_loss(pred.flatten(), outputs.flatten())
return loss.item()
## Data Generator
class DataGenerator(data.Dataset):
def __init__(self, u, y, s, dev):
'Initialization'
self.u = u
self.y = y
self.s = s
self._dev = dev
self.N = u.shape[0]
def __getitem__(self, index):
'Generate one batch of data'
inputs = (self.u[index, :].to(self._dev), self.y[index, :].to(self._dev))
outputs = self.s[index, :].to(self._dev)
return inputs, outputs
def __len__(self):
'Denotes the number of batches per epoch'
return self.N
## Load the data
# =========== CONSIDER FOR THE TIME BEING RANDOM DATA ===========
u_train = torch.rand(676000, 26)
y_train = torch.rand(676000, 2)
s_train = torch.rand(676000, 1)
# u_train = torch.from_numpy(np.load('u_train.npz')['u_train'])
# y_train = torch.from_numpy(np.load('y_train.npz')['y_train'])
# s_train = torch.from_numpy(np.load('s_train.npz')['s_train'])
deep_o_net = DeepONet(branch_width=10, branch_depth=4, trunk_width=10, trunk_depth=4, num_branch_inputs=u_train.shape[1], num_trunk_inputs=y_train.shape[1])
# ************************ Can I somehow speed up the training ? **********************************
deep_o_net.train(num_epochs=1000, lr=0.0001, u_train=u_train, y_train=y_train, s_train=s_train) # extremely slow