hey guys, i’m facing a huge issue of running out of memory on my backward calls. The thing is, I’m already training a single sample at a time. I’m not sure if operations like torch.cat
is causing some issue. At the same time, I can’t seem to figure out where possible memory leaks are happening.
I’m using the torch_geometric
package for some graph neural network learning, and combining this with a custom algorithm. Here’s a general overview of what my code does
class network(torch.nn.Module):
def __init__(self, in_channels, out_channels, h_size, e_size, n_embed, **kwargs):
super(network, self).__init__(**kwargs)
# Feature sizes
self.in_channels = in_channels
self.out_channels = out_channels
self.h_size = h_size
self.n_embed = n_embed
# Embeddings
self.embed = torch.nn.Embedding(n_embed, out_channels)
# MPNN
self.wt = torch.nn.Sequential(torch.nn.Linear(e_size, h_size * 2),
torch.nn.ReLU(),
torch.nn.Linear(h_size * 2, in_channels * out_channels))
self.mpnn = torch_geometric.nn.NNConv(in_channels, out_channels, self.wt)
self.m_bn = torch.nn.BatchNorm1d(out_channels)
self.m_relu = torch.nn.ReLU()
self.gru = torch.nn.GRU(out_channels, out_channels)
self.g_bn = torch.nn.BatchNorm1d(out_channels)
# Weight translation layer
self.weight_transform = torch.nn.Sequential(torch.nn.Linear(out_channels * 2, h_size * 4),
torch.nn.BatchNorm1d(h_size * 4),
torch.nn.ReLU(),
torch.nn.Linear(h_size * 4, h_size * 2),
torch.nn.BatchNorm1d(h_size * 2),
torch.nn.ReLU(),
torch.nn.Linear(h_size * 2, 1),
torch.nn.BatchNorm1d(1))
def reset_parameters(self):
self.mpnn.reset_parameters()
self.gru.reset_parameters()
def forward(self, data):
n_nodes = data.d_feat.size(0)
x = torch.cat((data.p_feat, data.d_feat), dim=0)
r_edge_index = torch.cat((data.real_edges_original,
data.real_edges_original.index_select(1, torch.LongTensor([1, 0]).to(
data.pax_feat.device))), dim=0).t()
r_edge_attr = torch.cat((data.real_edge_attr, data.real_edge_attr), dim=0).unsqueeze(1)
# Message passing NN
for i in range(3):
m = self.mpnn(x, edge_index=r_edge_index, edge_attr=r_edge_attr)
m = self.m_bn(m)
m = self.m_relu(m)
if i == 0:
h = m.clone().unsqueeze(0)
out, h = self.gru(m.unsqueeze(0), h)
out = out.squeeze(0)
out = self.g_bn(out)
# Get embeddings
embeds = self.embed(data.p_embed)
# Stack matrices
real_p = out[:data.p_feat.size(0), :]
real_d = out[data.p_feat.size(0):, :]
all_feat = torch.cat((real_p, embeds, real_d), dim=0)
# Construct weight matrix
node_pairs = torch.cat((all_feat[data.assign_edge_index[0, :]], all_feat[data.assign_edge_index[1, :]]), dim=1)
wts = self.weight_transform(node_pairs)
wt_mat = torch.zeros((n_nodes * 2, n_nodes * 2), device=data.p_feat.device)
wt_mat[data.assign_edge_index[0, :], data.assign_edge_index[1, :]] = wts.view(-1)
wt_mat = wt_mat[:n_nodes, n_nodes:].contiguous()
# LBP
b_a, b_b = self.sim_lbp(n_nodes, wt_mat, 2)
return b_b
The sim_lbp
function just has a copy of b_a
and b_b
for updates and does max
and sum
operations only.
I’m kinda stumped at this stage as I’m using torch 1.4.0, on CUDA 10.1, on a NVIDIA V100 16Gb GPU.