Multihead Attention throwing unknown CUDA Errors

So this works fine on CPU and yes I have read related Stack Overflow and PyTorch Discuss posts for common CUDA errors. No, my input does not have more classes than expected. No, my tensors are not mismatching. I am running out of things to try.

ERROR:

<ipython-input-12-0a931e7c38e6> in forward(self, query, key, value, attention_mask)
     74 #         print(self.weights_query)
     75         # print(self.weights_query(query))
---> 76         query_score = self.weights_query(query).view(batch_size, -1, self.number_of_heads, self.dimension_query).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
     77         key_score = self.weights_key(key).view(batch_size, -1, self.number_of_heads, self.dimension_key).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
     78         value_score = self.weights_value(value).view(batch_size, -1, self.number_of_heads, self.dimension_value).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

/usr/local/lib/python3.7/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    487             result = self._slow_forward(*input, **kwargs)
    488         else:
--> 489             result = self.forward(*input, **kwargs)
    490         for hook in self._forward_hooks.values():
    491             hook_result = hook(self, input, result)

/usr/local/lib/python3.7/site-packages/torch/nn/modules/linear.py in forward(self, input)
     65     @weak_script_method
     66     def forward(self, input):
---> 67         return F.linear(input, self.weight, self.bias)
     68 
     69     def extra_repr(self):

/usr/local/lib/python3.7/site-packages/torch/nn/functional.py in linear(input, weight, bias)
   1352         ret = torch.addmm(torch.jit._unwrap_optional(bias), input, weight.t())
   1353     else:
-> 1354         output = input.matmul(weight.t())
   1355         if bias is not None:
   1356             output += torch.jit._unwrap_optional(bias)

RuntimeError: CUDA error: device-side assert triggered

Relevant Code:

def gaussian_error_linear_unit_activation(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class TokenPositionSegmentEmbedding(nn.Module):
    def __init__(self,
        vocabulary_size:int,
        embedding_size:int,
        maximum_length:int,
    ):
        super(TokenPositionSegmentEmbedding, self).__init__()
        # embedding for the tokens
        self.token_embedding    = nn.Embedding(vocabulary_size, embedding_size)
        # embedding for corresponding position
        self.position_embedding = nn.Embedding(maximum_length, embedding_size)
        self.norm = nn.LayerNorm(embedding_size)

    def forward(self, tokens):
        sequence_length = tokens.size(1)
        position = torch.arange(sequence_length, dtype=torch.long).to(tokens.device)
        # (sequence_length,) -> (batch_size, sequence_length)
        position = position.unsqueeze(0).expand_as(tokens)
        embedding = (
            self.token_embedding(tokens) + \
            self.position_embedding(position)
        )
        return self.norm(embedding)

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dimension_key:int):
        super(ScaledDotProductAttention, self).__init__()
        # dimension of key is the same as query
        self.dimension_key = dimension_key

    def forward(self, query, key, value, attention_mask):
        # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores = torch.matmul(query, key.transpose(-1, -2)) / np.sqrt(self.dimension_key)

        # Fills elements of self tensor with value where mask is one.
        scores.masked_fill_(attention_mask, -1e9)

        attention = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attention, value)
        return context, attention

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        embedding_size:int,
        dimension_query:int,
        dimension_key:int,
        dimension_value:int,
        number_of_heads:int,
        batch_size: int,
    ):
        assert dimension_query == dimension_key, 'query and key do not share the same dimension!'
        super(MultiHeadAttention, self).__init__()
        self.embedding_size = embedding_size
        self.dimension_query = dimension_query
        self.dimension_key = dimension_key
        self.dimension_value = dimension_value
        self.number_of_heads = number_of_heads
        self.batch_size = batch_size
        self.weights_query = nn.Linear(embedding_size, number_of_heads * dimension_query)
        self.weights_key   = nn.Linear(embedding_size, number_of_heads * dimension_key)
        self.weights_value = nn.Linear(embedding_size, number_of_heads * dimension_value)
        self.attention = ScaledDotProductAttention(dimension_key=dimension_key)

    def forward(self, query, key, value, attention_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = query, query.size(0)
        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)
#         print(query.shape, query)
#         print(self.embedding_size, self.number_of_heads * self.dimension_query)
#         print(self.weights_query)
        # print(self.weights_query(query))
        query_score = self.weights_query(query).view(batch_size, -1, self.number_of_heads, self.dimension_query).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        key_score = self.weights_key(key).view(batch_size, -1, self.number_of_heads, self.dimension_key).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        value_score = self.weights_value(value).view(batch_size, -1, self.number_of_heads, self.dimension_value).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attention_mask = attention_mask.unsqueeze(1).repeat(1, self.number_of_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attention = self.attention(query_score, key_score, value_score, attention_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.number_of_heads * self.dimension_value) # context: [batch_size x len_q x n_heads * d_v]
        output = nn.Linear(self.number_of_heads * self.dimension_value, self.embedding_size)(context)
        return nn.LayerNorm(self.embedding_size)(output + residual), attention # output: [batch_size x len_q x d_model]

Rerun the code with CUDA_LAUNCH_BLOCKING=1 python script.py args and check the stacktrace to isolate the failing operation.

It is the same errors. That isn’t the issue.

os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
# definition is below
bert = BERT(
    vocabulary_size                = 10,
    embedding_size                 = 64,
    number_of_classes              = 10,
    nontext_input_dimensions       = 10,
    dimension_query                = 64,
    dimension_key                  = 64,
    dimension_value                = 64,
    number_of_heads                = 12,
    batch_size                     = 1,
    number_of_layers               = 6,
    learning_rate                  = 0.001,
    maximum_length                 = 0,
    calculate_maximum_length       = True,
    document_corpus                = {
        'a': np.zeros(10),
        'b': np.zeros(23),
        'c': np.zeros(15),
    }
)
ids = torch.LongTensor(np.zeros((batch_size, 64))).to(device)
pos = torch.LongTensor(np.zeros((batch_size, 1))).to(device)
bert(ids, pos)

Raises:

<ipython-input-4-1ef0f6bf9ef2> in forward(self, tokens)
     25         embedding = (
     26             self.token_embedding(tokens) + \
---> 27             self.position_embedding(position)
     28         ).to(tokens.device)
     29         return self.norm(embedding)

RuntimeError: CUDA error: device-side assert triggered

please help

Big chunky code:

def gaussian_error_linear_unit_activation(x):
    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

class TokenPositionSegmentEmbedding(nn.Module):
    def __init__(self,
        vocabulary_size:int,
        embedding_size:int,
        maximum_length:int,
    ):
        super(TokenPositionSegmentEmbedding, self).__init__()
        # embedding for the tokens
        self.token_embedding    = nn.Embedding(vocabulary_size, embedding_size).to(device)
        # embedding for corresponding position
        self.position_embedding = nn.Embedding(maximum_length, embedding_size).to(device)
        self.norm = nn.LayerNorm(embedding_size).to(device)

    def forward(self, tokens):
        sequence_length = tokens.size(1)
        
        position = torch.arange(sequence_length, dtype=torch.long).to(tokens.device)
        # (sequence_length,) -> (batch_size, sequence_length)
        
        position = position.unsqueeze(0).expand_as(tokens).to(tokens.device)
        print(tokens.device, position.device)
        embedding = (
            self.token_embedding(tokens) + \
            self.position_embedding(position)
        ).to(tokens.device)
        return self.norm(embedding)

class ScaledDotProductAttention(nn.Module):
    def __init__(self, dimension_key:int):
        super(ScaledDotProductAttention, self).__init__()
        # dimension of key is the same as query
        self.dimension_key = dimension_key

    def forward(self, query, key, value, attention_mask):
        # scores : [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        scores = torch.matmul(query, key.transpose(-1, -2)) / np.sqrt(self.dimension_key)

        # Fills elements of self tensor with value where mask is one.
        scores.masked_fill_(attention_mask.byte(), -1e9)

        attention = nn.Softmax(dim=-1)(scores)
        context = torch.matmul(attention, value)
        return context, attention

class MultiHeadAttention(nn.Module):
    def __init__(
        self,
        embedding_size:int,
        dimension_query:int,
        dimension_key:int,
        dimension_value:int,
        number_of_heads:int,
        batch_size: int,
    ):
        assert dimension_query == dimension_key, 'query and key do not share the same dimension!'
        super(MultiHeadAttention, self).__init__()
        self.embedding_size = embedding_size
        self.dimension_query = dimension_query
        self.dimension_key = dimension_key
        self.dimension_value = dimension_value
        self.number_of_heads = number_of_heads
        self.batch_size = batch_size
        self.weights_query = nn.Linear(embedding_size, number_of_heads * dimension_query).to(device)
        self.weights_key   = nn.Linear(embedding_size, number_of_heads * dimension_key).to(device)
        self.weights_value = nn.Linear(embedding_size, number_of_heads * dimension_value).to(device)
        self.attention = ScaledDotProductAttention(dimension_key=dimension_key).to(device)
        
        self.linear = nn.Linear(number_of_heads * dimension_value, embedding_size).to(device)
        self.norm = nn.LayerNorm(embedding_size).to(device)

    def forward(self, query, key, value, attention_mask):
        # q: [batch_size x len_q x d_model], k: [batch_size x len_k x d_model], v: [batch_size x len_k x d_model]
        residual, batch_size = query, query.size(0)
        query_score = self.weights_query(query).view(batch_size, -1, self.number_of_heads, self.dimension_query).transpose(1,2)  # q_s: [batch_size x n_heads x len_q x d_k]
        key_score = self.weights_key(key).view(batch_size, -1, self.number_of_heads, self.dimension_key).transpose(1,2)  # k_s: [batch_size x n_heads x len_k x d_k]
        value_score = self.weights_value(value).view(batch_size, -1, self.number_of_heads, self.dimension_value).transpose(1,2)  # v_s: [batch_size x n_heads x len_k x d_v]

        attention_mask = attention_mask.unsqueeze(1).repeat(1, self.number_of_heads, 1, 1) # attn_mask : [batch_size x n_heads x len_q x len_k]

        # context: [batch_size x n_heads x len_q x d_v], attn: [batch_size x n_heads x len_q(=len_k) x len_k(=len_q)]
        context, attention = self.attention(query_score, key_score, value_score, attention_mask)
        context = context.transpose(1, 2).contiguous().view(batch_size, -1, self.number_of_heads * self.dimension_value) # context: [batch_size x len_q x n_heads * d_v]
        output = self.linear(context)
        return self.norm(output + residual), attention # output: [batch_size x len_q x d_model]

class GaussianErrorLinearUnit(nn.Module):
    def __init__(self, embedding_size:int, output_size:int):
        super(GaussianErrorLinearUnit, self).__init__()
        self.linear = nn.Linear(embedding_size, output_size).to(device)

    def activation(self, x):
        return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

    def forward(self, x):
        return self.activation(self.linear(x))

class PositionWiseFeedForwardNetwork(nn.Module):
    def __init__(self, embedding_size:int, squeeze_size:int):
        super(PositionWiseFeedForwardNetwork, self).__init__()
        self.gaussian_error_linear_unit = GaussianErrorLinearUnit(embedding_size, squeeze_size).to(device)
        self.fully_connected_unit = nn.Linear(squeeze_size, embedding_size).to(device)

    def forward(self, x):
        # (batch_size, len_seq, d_model) -> (batch_size, len_seq, d_ff) -> (batch_size, len_seq, d_model)
        return self.fully_connected_unit(self.gaussian_error_linear_unit(x))

class EncoderLayer(nn.Module):
    def __init__(
        self,
        embedding_size:int,
        dimension_query:int,
        dimension_key:int,
        dimension_value:int,
        number_of_heads:int,
        batch_size:int,
    ):
        super(EncoderLayer, self).__init__()
        self.encoder_self_attention = MultiHeadAttention(
            embedding_size,
            dimension_query,
            dimension_key,
            dimension_value,
            number_of_heads,
            batch_size
        )
        self.position_wise_feed_forward_network = PositionWiseFeedForwardNetwork(
            embedding_size,
            embedding_size * 4  #4 * embedding_size, Feed Forward dimension
        )

    def forward(self, encoder_inputs, encoder_self_attention_mask):
        encoder_outputs, attention = self.encoder_self_attention(
            encoder_inputs.float(),
            encoder_inputs.float(),
            encoder_inputs.float(),
            encoder_self_attention_mask.float()
        ) # encoder_inputs to same query, key, value

        # enc_outputs: [batch_size x len_q x d_model]
        encoder_outputs = self.position_wise_feed_forward_network(encoder_outputs)
        return encoder_outputs, attention
    
    
class BERT(nn.Module):
    def __init__(
        self,
        vocabulary_size:int,
        embedding_size:int,
        number_of_classes:int,
        nontext_input_dimensions:int,
        dimension_query:int,
        dimension_key:int,
        dimension_value:int,
        number_of_heads:int,
        batch_size:int,
        number_of_layers:int,
        learning_rate:float,
        maximum_length:int = 0,
        calculate_maximum_length: bool = False,
        document_corpus:dict = None
    ):
        super(BERT, self).__init__()
        # store model parameters
        self.vocabulary_size = vocabulary_size
        self.embedding_size = embedding_size
        
        
        if calculate_maximum_length:
            if document_corpus is not None:
                max_len = max(list(map(len, list(document_corpus.values()))))
                max_len += 2
            else:
                # NOTE: we add 2 here for ["CLS"] + tokens + ["SEP"]
                max_len = maximum_length + 2
            self.maximum_length = max_len
        else:
            self.maximum_length = maximum_length
        
        self.number_of_classes = number_of_classes
        self.nontext_input_dimensions = nontext_input_dimensions
        self.dimension_query = dimension_query
        self.dimension_key = dimension_key
        self.dimension_value = dimension_value
        self.number_of_heads = number_of_heads
        self.batch_size = batch_size
        self.number_of_layers = number_of_layers
        self.learning_rate = learning_rate

        # define model
        self.embedding = TokenPositionSegmentEmbedding(
            vocabulary_size,
            embedding_size,
            maximum_length            
        )
        self.encoder_layers = nn.ModuleList([
            EncoderLayer(
                embedding_size,
                dimension_query,
                dimension_key,
                dimension_value,
                number_of_heads,
                batch_size
            )
            for _ in range(number_of_layers)
        ])
        self.fully_connected_unit = nn.Linear(embedding_size, embedding_size)
        self.tanh = nn.Tanh()
        self.gaussian_error_linear_unit = GaussianErrorLinearUnit(embedding_size, embedding_size)
        self.norm = nn.LayerNorm(embedding_size)
        
        self.nontext_fully_connected_unit = nn.Linear(nontext_input_dimensions, embedding_size)
        self.nontext_fully_connected_unit_norm = nn.LayerNorm(embedding_size)        
        self.classifier = nn.Linear(embedding_size, number_of_classes)

        # decoder is shared with embedding layer
        embed_weight = self.embedding.token_embedding.weight
        n_vocab, n_dim = embed_weight.size()
        self.decoder = nn.Linear(n_dim, n_vocab, bias=False)
        self.decoder.weight = embed_weight
        self.decoder_bias = nn.Parameter(torch.zeros(n_vocab))

    def get_padded_attention_mask(self, sequence_query, sequence_key):
        batch_size, length_query = sequence_query.size()
        batch_size, length_key = sequence_key.size()
        # eq(zero) is PAD token
        padded_attention_mask = sequence_key.data.eq(0).unsqueeze(1)  # batch_size x 1 x len_k(=len_q), one is masking
        return padded_attention_mask.expand(batch_size, length_query, length_key)  # batch_size x len_q x len_k

    def forward(self, input_ids, masked_positions, nontext=None):
        output = self.embedding(input_ids)
        encoder_self_attention_mask = self.get_padded_attention_mask(input_ids, input_ids)
        print(encoder_self_attention_mask)
        for encoder_layer in self.encoder_layers:
            output, encoder_self_attention = encoder_layer(output, encoder_self_attention_mask)
        # output : [batch_size, len, d_model], attn : [batch_size, n_heads, d_mode, d_model]
        # it will be decided by first token(CLS)
        h_pooled = self.tanh(self.fully_connected_unit(output[:, 0])) # [batch_size, d_model]
        
        if nontext is None:
            blank = [0 for i in range(self.nontext_input_dimensions)]            
            nontext = torch.LongTensor([blank for i in range(input_ids.shape[0])])
            nontext = nontext.to(input_ids.device)
            
        if nontext.type() == 'torch.LongTensor':
            nontext = nontext.float()
        
        nontext_output = self.nontext_fully_connected_unit(nontext)
        nontext_output = self.nontext_fully_connected_unit_norm(nontext_output)
        nt_pooled = self.tanh(nontext_output)
        logits_clsf = self.classifier(h_pooled + nt_pooled) # [batch_size, 2]

        masked_positions = masked_positions[:, :, None].expand(-1, -1, output.size(-1)) # [batch_size, max_pred, d_model]
        # get masked position from final output of transformer.
        h_masked = torch.gather(output, 1, masked_positions) # masking position [batch_size, max_pred, d_model]
        h_masked = self.norm(self.gaussian_error_linear_unit(h_masked))
        logits_lm = self.decoder(h_masked) + self.decoder_bias # [batch_size, max_pred, n_vocab]

        return logits_lm, logits_clsf

Using CUDA_LAUNCH_BLOCKING=1 isn’t a fix, but should point to the failing operation in the stacktrace.

Based on your latest update, chck the embedding usage, i.e. its setup and inputs.