Potential history leak


(Movses) #1

Hello, I’m facing “out of memory” error of my network on GPU and I checked that main problem is in EncoderBlock layer. Every iteration of the first cycle in the layer allocates new memory, so it seems that there is a history leak. As I am a novice in PyTorch, I don’t really understand what is actually wrong. Please, give an advice what part of the code is wrong.


class EncoderBlock(nn.Module):
    def __init__(self, batch_size, d_model, conv_num, n_heads, kern_sz, p=0.1):
        super(EncoderBlock, self).__init__()
        self.convs = nn.ModuleList([DepthwiseSeparableConv(d_model, d_model, kern_sz) 
                                    for _ in range(conv_num)])
        self.self_att = SelfAttention(n_heads, d_model)
        self.W = nn.Linear(d_model, d_model).to(device)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p)

    def forward(self, x, num_rep=1):
        out = self.pos_encoding(x)
        for _ in range(num_rep):
            out = res = out.transpose(1, 2)

            for i, conv in enumerate(self.convs):
                out = norm(out)
                out = conv(out)
                out = self.relu(out)
                out = res + out
                if (i + 1) % 2 == 0:
                    out = self.dropout(out)
                res = out

            res = out = norm(out).transpose(1, 2)
            out = self.self_att(out, out)
            out = res + out
            out = self.dropout(out)
            res = out
            out = norm(out)
            out = self.W(out)
            out = self.relu(out)
            out = res + out
            out = self.dropout(out)
        return out
    
    @staticmethod
    def pos_encoding(x):
        _, max_len, model_dim = x.shape
        encoding = np.array([
            [pos / np.power(10000, 2 * i / model_dim) for i in range(model_dim)]
            if pos != 0 else np.zeros(model_dim) for pos in range(max_len)])

        encoding[1:, 0::2] = np.sin(encoding[1:, 0::2])
        encoding[1:, 1::2] = np.cos(encoding[1:, 1::2])
        return x + torch.from_numpy(encoding).float().to(device)


class ScaledDotProductAttention(nn.Module):
    def __init__(self, size, p=0.1):
        super(ScaledDotProductAttention, self).__init__()

        self.scaling = 1 / (np.sqrt(size))
        self.dropout = nn.Dropout(p)
        
    
    def forward(self, q, k, v, mask=None):
        attention = torch.bmm(q, k.transpose(1, 2)) * self.scaling
        
        if mask is not None:
            attention.data.masked_fill_(mask, -float('inf'))
            
        attention = F.softmax(attention, dim=2)
        return torch.bmm(self.dropout(attention), v)
    
    
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads, h_size, k_size, v_size, p=0.1):
        super(MultiHeadAttention, self).__init__()

        self.n_heads = n_heads
        self.h_size = h_size
        
        self.q_proj = nn.Parameter(torch.empty((n_heads, h_size, k_size), device=device))
        self.k_v_proj = nn.Parameter(torch.empty((2 * n_heads, h_size, v_size)))
        
        self.attention = ScaledDotProductAttention(k_size, p)

        self.out = nn.Linear(n_heads * v_size, h_size).to(device)
        self.layer_norm = nn.LayerNorm(h_size).to(device)

        self.dropout = nn.Dropout(p)
        
        
    def repeat_n_heads(self, input):
        return input.repeat(self.n_heads, 1, 1).view(self.n_heads, -1, self.h_size)
    
    
    def forward(self, q, k, v, mask=None):
        batch_size = q.size(0)

        q_len = q.size(1)
        seq_len = k.size(1)

        residual = q

        q = self.repeat_n_heads(q)
        k = self.repeat_n_heads(k)
        v = self.repeat_n_heads(v)
        k_v = torch.cat([k, v], dim=0)

        q = torch.bmm(q, self.q_proj).view(-1, q_len, self.q_proj.size(2))
        k, v = torch.split(torch.bmm(k_v, self.k_v_proj), self.n_heads, 0)
        k = k.view(-1, seq_len, self.k_v_proj.size(2))
        v = k.view(-1, seq_len, self.k_v_proj.size(2))

        if mask is not None:
            mask = mask.repeat(self.n_heads, 1, 1)

        result = self.attention(q, k, v, mask)
        result = torch.split(result, batch_size, dim=0)
        result = torch.cat(result, dim=-1)

        result = self.out(result)
        result = self.dropout(result)
        return self.layer_norm(result + residual)


class SelfAttention(nn.Module):
    def __init__(self, n_heads, input_dim):
        super(SelfAttention, self).__init__()
        self.d_k = input_dim // n_heads
        self.multihead = MultiHeadAttention(n_heads, input_dim, self.d_k, self.d_k)
        
        
    def forward(self, context, questions, mask=None):
        return self.multihead(context, questions, questions, mask)