Hello,
I’m trying to run a customized version of the seq2seq tutorial. The relevant snippets are included in the following gist:
decoder.py
mport torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from ilmt.utils.gpu_profile import counter
class DecoderRNN(nn.Module):
def __init__(self, hidden_size, output_size, num_layers):
super(DecoderRNN, self).__init__()
self.num_layers = num_layers
This file has been truncated. show original
encoder.py
import torch
import torch.nn as nn
from torch.autograd import Variable
from ilmt.utils.gpu_profile import counter
class EncoderRNN(nn.Module):
def __init__(self, input_size, hidden_size, num_layers):
super(EncoderRNN, self).__init__()
self.num_layers = num_layers
self.hidden_size = hidden_size
This file has been truncated. show original
run.log
epoch: 0%| | 0/10 [00:00<?, ?it/s/home/jerin/.local/lib/python3.5/site-packages/nltk/app/__init__.py:44: UserWarning: nltk.app.wordfreq not loaded (requires the matplotlib library).
warnings.warn("nltk.app.wordfreq not loaded "
defaultdict(<class 'list'>,
{torch.Size([3, 383, 500]): [4],
torch.Size([383, 10]): [4],
torch.Size([383, 10, 500]): [4],
torch.Size([1500]): [24],
torch.Size([1500, 500]): [24],
torch.Size([3085]): [2],
torch.Size([3085, 500]): [4],
This file has been truncated. show original
I’ve tracked it down to the nn.gru calls in encoder and decoder, where a torch.cuda.FloatTensor of size 4509000 keeps getting added everytime forward is called, eventually running out of memory. Far as I understand, the trouble happens after the GRU call. I’m not very familiar with internals, so:
{torch.Size([3, 383, 500]): [4, 2, 2, 2, 2],
torch.Size([383, 10, 500]): [4, 2, 2, 2, 2],
torch.Size([1500]): [24, 24, 24, 24, 24],
torch.Size([1500, 500]): [24, 24, 24, 24, 24],
torch.Size([3085]): [2, 2, 2, 2, 2],
torch.Size([3085, 500]): [4, 4, 4, 4, 4],
torch.Size([3258]): [2, 2, 2, 2, 2],
torch.Size([3258, 500]): [4, 4, 4, 4, 4],
**torch.Size([4509000]): [1, 2, 3, 4, 5],**
torch.Size([153200000]): [1, 1, 1, 1, 1]})
What could I do further to isolate where this is happening and fix it?