I am training a small network as followed:
class Model(nn.Module):
def __init__(self,numUsers, numKW):
super(Model, self).__init__()
self.user1 = nn.Linear(numKW, 256)
self.user2 = nn.Linear(256,32)
self.item1 = nn.Linear(numUsers,256)
self.item2 = nn.Linear(256,32)
self.cos = nn.CosineSimilarity(dim=1, eps=1e-6)
def forward(self, userBatch, itemBatch):
output1 = F.relu(self.user1(userBatch))
output1 = F.relu(self.user2(output1))
output2 = F.relu(self.item1(itemBatch))
output2 = F.relu(self.item2(output2))
output = self.cos(output1,output2)
output = torch.clamp(output, 1e-6, 1)
return output
and my training loop goes as followed:
def train(userData,kwData,clickData, model, optimizer):
numOfBatches = len(userData) // batchSize + 1
for j in xrange(epochs):
for i in xrange(numOfBatches):
print i
user_batch_temp = userData[i*batchSize:batchSize*(i+1)]
kw_batch_temp = kwData[i*batchSize:batchSize*(i+1)]
click_batch = clickData[i*batchSize:batchSize*(i+1)]
user_batch, kw_batch = prepare_batch(user_batch_temp, kw_batch_temp)
user_batch = torch.tensor(user_batch, dtype=torch.float64,device=device)
kw_batch = torch.tensor(kw_batch, dtype=torch.float64,device=device)
user_batch = user_batch.float()
kw_batch = kw_batch.float()
optimizer.zero_grad()
output = model(user_batch, kw_batch)
click_batch = torch.tensor(click_batch, dtype=torch.float64,device=device).detach()
click_batch = click_batch.float()
loss = click_batch*torch.log(output) + (1-click_batch)*torch.log(1-output)
loss = -torch.sum(loss)
loss.backward()
optimizer.step()
In the model, numUsers is nearly 1.3 Million and numKw is nearly 127,000. In first epoch, first minibatch itself, i get cuda out of memory error at optimizer.step()
I tried reducing minibatch size to even 1 but still GPU memory explodes.
EDIT:
So my network works so that there are two different paths in the graphs which merge at output=self.cos(output1,output2). I tried commenting out one of the paths and running other, and both paths run fine individually(with other output initialised as randn matrix), using upto 7.5 GB GPU memory. Only when i try to join the two paths, the graph explodes. Could anyone help in this direction?