I have a neural network that needs to compute results sequentially, it reads in an initial state that’s learnable, and calculate step by step results based on last steps’ result. During training, there are no errors. But as I get to 2nd epoch, the time it takes to run backwards() starts to rise unreasonably. 1st epoch takes about 20s, but the 2nd epoch backward() takes hours. After searching online for a while, I found that it could be caused by not detaching some variables between epochs. I tried to detach the parameters shared between epochs, but don’t seem to have solved the bug. Any help is appreciated, thanks! Attached is the code for the complete model:
"""
Definitions for parametric functions M0, M1. Can use multiple features,
but only the first feature will be used to calculate labels.
"""
class M0_simplex(nn.Module):
def __init__(self,num_nodes):
super(M0_simplex, self).__init__()
# c0*x + c1
self.c0 = nn.Parameter(torch.tensor(0.0, requires_grad=True))
self.c1 = nn.Parameter(torch.tensor(0.0, requires_grad=True))
def reset_parameters(self):
nn.init.constant_(self.c0,-1.0)
nn.init.constant_(self.c1,1.0)
def get_feature_num(self):
return 1
def forward(self,x):
res = self.c0 * x + self.c1
return res
class M1_simplex(nn.Module):
def __init__(self,num_nodes):
super(M1_simplex, self).__init__()
# c0 * (x-x_neighb) + c1
self.c0 = nn.Parameter(torch.tensor(0.0, requires_grad=True))
self.c1 = nn.Parameter(torch.tensor(0.0, requires_grad=True))
def forward(self,x,neighb_x):
if x<0:
return 0
res = self.c0 * (x-neighb_x) + self.c1
return res
def reset_parameters(self):
nn.init.uniform_(self.c0)
nn.init.uniform_(self.c1)
"""
Definitions for architecture
"""
class info_flow(nn.Module):
def __init__(self,adj, M0, M1, num_nodes):
super(info_flow, self).__init__()
#Current implementation for binary labels only. Need to fix
self.adj = torch.tensor(adj)
self.num_nodes = num_nodes
self.M0 = M0(num_nodes)
self.M1 = M1(num_nodes)
self.features = self.M0.get_feature_num()
self.threshold = nn.Parameter(torch.tensor( 0.0, requires_grad=True ))
self.init_state = nn.Parameter(torch.zeros( (num_nodes,self.features), requires_grad=True ))
def fit(self,y,optim,device,epochs=50,summary_writer=None,add_tag='',steps=-1):
""" The function to fit information flow model
Args:
y: The target to be predicted. Needs to be one episode.
optim: The optimizer for model
device: The device for computation
epochs: The epochs to fit
summary_write: TensorboardX summary writer to keep track of acc
add_tag: The tag to use when adding acc/loss info in summary
Returns:
loss: The loss at last epoch
F1: The F1 score for each class
"""
assert y.shape[0] == self.num_nodes
self.train()
if steps == -1:
steps = y.shape[1]
else:
steps = min(steps, y.shape[1])
y = y.to(device)
self.adj = self.adj.to(device)
for epoch in range(epochs):
optim.zero_grad()
labels = self.forward(steps,device)
loss = F.nll_loss( torch.log(labels), y[:,:steps] )
print('Finished loss compute',epoch)
L = torch.argmax(labels,dim=1).detach().numpy()
F1 = f1_score( y[:,:steps].cpu().numpy().flatten(), L.flatten() )
loss.backward()
optim.step()
loss = loss.detach()
if summary_writer!=None: #Write in summary loss and F1 for each class
summary_writer.add_scalar( add_tag+'/loss', loss, epoch)
summary_writer.add_scalar( add_tag+'/F1', F1, epoch)
#Detach from next epoch's gradient computation
self.init_state.detach_()
self.threshold.detach_()
self.eval()
return loss, F1
def forward_step(self,x,device=None):
#x shape: [node_num, features]
if device == None:
res = torch.zeros( x.shape[0],x.shape[1] )
else:
res = torch.zeros( x.shape[0],x.shape[1], device=device )
for node in range(x.shape[0]):
self_feat = x[node,:]
res[node,:] = res[node,:] + self.M0( self_feat )
neighbs = (torch.abs(self.adj[:,node])>delta).nonzero()
for neighb_node in neighbs:
#Iterate over all neighbors to find sum of neighbor influences
w = self.adj[node,neighb_node]
res[node,:] = res[node,:] + self.M1( self_feat, x[neighb_node,:] ) * w
#return updated values at each node
return res
def forward(self,time_steps,device=None,label=True,binary=False):
""" This function produces the output for 'time_steps' steps
Args:
time_steps: How many steps of results to calculate
device: The device on which computations are done
label: A boolean variable for return type. If true, will return
the 2 labels' probability. If false, will return the raw features.
binary: A boolean variable for return type. If true, will return
max of 2 labels' probability.
Returns:
x: The raw features of size [num_node, time_steps, features] or
the label probabilities of size [num_node, classes, time_steps] or
the labels of size [num_node, time_steps]
"""
x = [self.init_state.unsqueeze(1)]
for step in range(time_steps-1):
x.append(self.forward_step(x[step],device).unsqueeze(2))
x = torch.cat( x, dim=1 )
if label:
#Threshold is the point where CSD vs. no CSD probs are 0.5 vs. 0.5
x = 0.5*x[:,:,0]/self.threshold
x[x>.95] = .95
x[x<.15] = .05 #Don't force to be 0 for numerical stability
x = torch.cat( [x.unsqueeze(2), (1-x).unsqueeze(2)], dim=2 ).transpose(2,1)
if binary:
x = torch.argmax( x, dim=1 )
return x
def reset_parameters(self):
nn.init.uniform_(self.threshold)
nn.init.uniform_(self.init_state)
self.M0.reset_parameters()
self.M1.reset_parameters()