RuntimeError: Tensor Device Mismatch in Custom PyTorch Model with Dynamic Relative Position Bias

Hi…

I am facing the following error:
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

After I checked everything, it turns out that the problem is here:
relative_position_bias.relative_attention_bias.weight: cpu
Everything else is correctly on the GPU.

I tried hard to fix the problem but to no avail. I’ve identified the parts of code where I’m moving things to the GPU.

RelativePositionBias:

class RelativePositionBias(nn.Module):
    """
    Translate relative position to a bucket number for relative attention.

    The relative position is defined as memory_position - query_position, i.e.
    the distance in tokens from the attending position to the attended-to
    position. If bidirectional=False, then positive relative positions are
    invalid.

    We use smaller buckets for small absolute relative_position and larger buckets
    for larger absolute relative_positions. All relative positions >=max_distance
    map to the same bucket. All relative positions <=-max_distance map to the
    same bucket. This should allow for more graceful generalization to longer
    sequences than the model has been trained on.

    Args:
        bidirectional (bool): Whether the attention is bidirectional.
        num_buckets (int): Number of buckets.
        max_distance (int): Maximum distance for relative positions.
        num_heads (int): Number of attention heads.

    # REFRANCE: https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
    """
    def __init__(self, config):
        super(RelativePositionBias, self).__init__()
        self.bidirectional = config.bidirectional
        self.num_buckets = config.num_buckets
        self.max_distance = config.max_distance
        self.num_heads = config.num_heads
        self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads)

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        """
        Translate relative position to a bucket number.

        Args:
            relative_position (torch.Tensor): Relative position tensor.
            bidirectional (bool): Whether the attention is bidirectional.
            num_buckets (int): Number of buckets.
            max_distance (int): Maximum distance for relative positions.

        Returns:
            torch.Tensor: Bucket number tensor.
        """
        ret = 0 * relative_position  # Initialized to zero to handle both positive and negative positions
        if bidirectional:
            num_buckets //= 2  # Halve the buckets for bidirectional case
            ret += (relative_position < 0).long() * num_buckets
            relative_position = relative_position.abs()
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact

        # Compute val_if_large with safe clamping within [0, num_buckets - 1]
        val_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact) /
            torch.log(torch.tensor(max_distance / max_exact, dtype=torch.float)) *
            (num_buckets - max_exact)
        ).long()
        val_if_large = torch.minimum(val_if_large, torch.tensor(num_buckets - 1, dtype=torch.long))

        # Combine small and large relative positions
        ret += torch.where(is_small, relative_position, val_if_large)

        return ret

    def compute_bias(self, qlen, klen):
        """
        Compute binned relative position bias.

        Args:
            qlen (int): Length of the query sequence.
            klen (int): Length of the key sequence.

        Returns:
            torch.Tensor: Relative position bias tensor.
        """
        device = self.relative_attention_bias.weight.device
        context_position = torch.arange(qlen, dtype=torch.long, device=device)[:, None]
        memory_position = torch.arange(klen, dtype=torch.long, device=device)[None, :]
        relative_position = memory_position - context_position

        rp_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=self.bidirectional,
            num_buckets=self.num_buckets,
            max_distance=self.max_distance
        )

        values = self.relative_attention_bias(rp_bucket)
        values = values.permute([2, 0, 1]).unsqueeze(0)

        return values

 
    def forward(self, qlen, klen):
        """
        Forward pass.

        Args:
            qlen (int): Length of the query sequence.
            klen (int): Length of the key sequence.

        Returns:
            torch.Tensor: Relative position bias tensor.
        """

        return self.compute_bias(qlen, klen)

T5 model:

class T5Model(nn.Module):
    """
    """
    def __init__(self, config):
        """
        """
        super(T5Model, self).__init__()

        self.num_blocks: int = config.num_blocks
        self.vocab_size: int = config.vocab_size
        self.hidden_size: int = config.hidden_size

        self.embed_layer: Embeddings = Embeddings(config)

        ########################################################
        ####### Here I create an instance of RelativePositionBias #########
        ########################################################

        self.relative_position_bias = RelativePositionBias(config)
        self.biases = self.relative_position_bias(config.max_token_len, config.max_token_len).to(self.embed_layer.token_embeddings.weight.device)

        ########################################################
        ########################################################
        ########################################################

        self.encoder: nn.ModuleList = nn.ModuleList([Encoder(config) for _ in range(self.num_blocks)])
        self.decoder: nn.ModuleList = nn.ModuleList([Decoder(config) for _ in range(self.num_blocks)])

        # The output of the final decoder block is fed into a dense layer with a softmax output, whose weights are shared with the input embedding matrix.
        self.prediction_layer: nn.Linear = nn.Linear(self.hidden_size, self.vocab_size)
        self.prediction_layer.weight = self.embed_layer.token_embeddings.weight
        self.softmax: nn.LogSoftmax = nn.LogSoftmax(dim=-1)

    def forward(self, input_ids: torch.Tensor, labels: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        """
        print(self.embed_layer.token_embeddings.weight.device)
        print(self.biases.device)
        x_enc: torch.Tensor  = self.embed_layer(input_ids)
        x_dec: torch.Tensor  = self.embed_layer(labels)

        for encoder_layer in self.encoder:
            x_enc: torch.Tensor = encoder_layer(x_enc, self.biases, mask)

        for decoder_layer in self.decoder:
            x_dec = decoder_layer(x_dec, x_enc, self.biases, mask)

        x_logits: torch.Tensor = self.prediction_layer(x_dec)

        return self.softmax(x_logits)

    ############################################################
    # I have done this to ensure the biases tensor is correctly moved to the target device.
    ############################################################
    ############################################################
    def to(self, *args, **kwargs):
        super(T5Model, self).to(*args, **kwargs)
        
        self.biases = self.biases.to(*args, **kwargs)
        self.relative_position_bias.to(*args, **kwargs)
        
        return self
    ############################################################
    ############################################################
    ############################################################

Run function:

def run(config):
    """
    Main function to run T5Model training.

    Args:
        config (Config): Configuration object containing parameters.
    """
    # Set random seeds
    set_seeds(config)

    print("Loading Train Dataset...")

    # Load training dataset
    train_dataset = CustomTextDataset(config.data_dir, tokenizer= config.tokenizer_path, max_token_len = config.max_token_len)

    # Load test dataset if provided
    test_dataset = CustomTextDataset(config.data_dir, tokenizer= config.tokenizer_path, max_token_len = config.max_token_len) if config.test_dataset is not None else None

    # Setup cuda device for T5 training
    cuda_condition: bool = torch.cuda.is_available() and config.with_cuda
    device: torch.device = torch.device("cuda:0" if cuda_condition else "cpu")

    ############ Here I create the model and move it to GPU #########
    ########################################################
    ########################################################
    # Initialize T5Model
    t5 = T5Model(config).to(device)
    ########################################################
    ########################################################
    ########################################################

    # Distributed GPU training if CUDA can detect more than 1 GPU
    if config.with_cuda and torch.cuda.device_count() > 1:
        print("Using %d GPUs for T5Model" % torch.cuda.device_count())
        t5: nn.DataParallel = nn.DataParallel(t5, device_ids=config.cuda_devices)

    # Initialize optimizer and scheduler
    optim = Adam(t5.parameters(), lr=config.lr, betas=config.betas, weight_decay=config.weight_decay)
    optim_schedule = ScheduledOptim(config, optim)

    # Create data loaders
    batch_size = config.batch_size
    train_data_loader = DataLoader(train_dataset, batch_size = batch_size, worker_init_fn=np.random.seed(config.seed), shuffle = True)
    test_data_loader = DataLoader(test_dataset, batch_size= batch_size, worker_init_fn=np.random.seed(config.seed)) if test_dataset is not None else None

    # Initialize t5 trainer
    trainer = T5Trainer(config, t5, optim_schedule, device, train_data_loader, test_data_loader)

    # Training loop
    for epoch in range(config.epochs):
        # Train the model
        trainer.train(epoch)

        # Save the model
        trainer.save(epoch)

        # Test the model if test data is available
        if test_data_loader is not None:
            trainer.test(epoch)

self.relative_attention_bias = nn.Embedding(self.num_buckets, self.num_heads) seems to be properly registered, but I don’t fully understand the forward pass in this __init__ method:

self.biases = self.relative_position_bias(config.max_token_len, config.max_token_len).to(self.embed_layer.token_embeddings.weight.device)

since the device attribute won’t be set at this point and the layer would still be on the CPU.
Also, self.biases is not a leaf variable and will thus also stay on the CPU, which is most likely why you are trying to override the to() method.

1 Like

Thank you…

You’re right, as for this statement .to(self.embed_layer.token_embeddings.weight.device) on this line: self.biases = self.relative_position_bias(config.max_token_len, config.max_token_len).to(self.embed_layer.token_embeddings.weight.device) it’s just a typo.

For the biases tensor (Relative position bias tensor) it is a common tensor across all layers (I do this according to the T5 paper which says this). It is true that placing it within the forward function of the model will solve the problem, but that would require recalculating the same tensor repeatedly, which is not necessary, it is the same every time. Specifically I’m using it here:

att_scores: torch.Tensor = (torch.matmul(query, key.transpose(1, 2)) + relative_biases) / self.head_dim ** 0.5

So can I calculate it only once within the __init__ function and only one time or is there another solution? I could solve the problem by passing device to __init__ but I wonder if there is a more elegant solution.

I don’t fully understand where the typo is in this line of code.

You are currently creating this tensor by calling the forward method of self.relative_position_bias. If this tensor is static, detach it and create a new leaf-variable instead.

Thank you, I got it. I just was a little confused.