RuntimeError: Expected all tensors to be on the same device

Hello, I know it’s because the tensors are on the different device, but I don’t know why. Here is my code:

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist


# Xavier initializer
def xavier(shape):
    return torch.normal(mean=0.0, std=torch.sqrt(torch.tensor(2.0 / sum(shape))), size=shape)

class BayesianDenseLayer(nn.Module):
    """A fully-connected Bayesian neural network layer

    Parameters
    ----------
    d_in : int
        Dimensionality of the input (# input features)
    d_out : int
        Output dimensionality (# units in the layer)
    name : str
        Name for the layer

    Attributes
    ----------
    losses : torch.Tensor
        Sum of the Kullback–Leibler divergences between
        the posterior distributions and their priors

    Methods
    -------
    forward : torch.Tensor
        Perform the forward pass of the data through
        the layer
    """

    def __init__(self, d_in, d_out, name):

        super(BayesianDenseLayer, self).__init__()
        self.d_in = d_in
        self.d_out = d_out

        self.w_loc = nn.Parameter(xavier([d_in, d_out]))
        self.w_std = nn.Parameter(xavier([d_in, d_out]) - 6.0)
        self.b_loc = nn.Parameter(xavier([1, d_out]))
        self.b_std = nn.Parameter(xavier([1, d_out]) - 6.0)

    def forward(self, x, sampling=True):
        """Perform the forward pass"""

        if sampling:
            # Flipout-estimated weight samples
            s = torch.bernoulli(torch.full(x.shape, 0.5)) * 2 - 1
            r = torch.bernoulli(torch.full((x.shape[0], self.d_out), 0.5)) * 2 - 1
            w_samples = F.softplus(self.w_std) * torch.randn(self.d_in, self.d_out)
            w_perturbations = r * torch.matmul(x * s, w_samples)
            w_outputs = torch.matmul(x, self.w_loc) + w_perturbations

            # Flipout-estimated bias samples
            r = torch.bernoulli(torch.full((x.shape[0], self.d_out), 0.5)) * 2 - 1
            b_samples = F.softplus(self.b_std) * torch.randn(self.d_out)
            b_outputs = self.b_loc + r * b_samples

            return w_outputs + b_outputs

        else:
            return torch.matmul(x, self.w_loc) + self.b_loc

    @property
    def losses(self):
        """Sum of the KL divergences between priors + posteriors"""
        weight = dist.Normal(self.w_loc, F.softplus(self.w_std))
        bias = dist.Normal(self.b_loc, F.softplus(self.b_std))
        prior = dist.Normal(0, 1)
        return (dist.kl_divergence(weight, prior).sum() +
                dist.kl_divergence(bias, prior).sum())



class BayesianDenseNetwork(nn.Module):
    """A multilayer fully-connected Bayesian neural network

    Parameters
    ----------
    dims : List[int]
        List of units in each layer
    name : str
        Name for the network

    Attributes
    ----------
    losses : torch.Tensor
        Sum of the Kullback–Leibler divergences between
        the posterior distributions and their priors,
        over all layers in the network

    Methods
    -------
    forward : torch.Tensor
        Perform the forward pass of the data through
        the network
    """

    def __init__(self, dims, name):

        super(BayesianDenseNetwork, self).__init__()
        self.steps = nn.ModuleList()
        self.acts = []

        for i in range(len(dims) - 1):
            layer_name = name + '_Layer_' + str(i)
            self.steps.append(BayesianDenseLayer(dims[i], dims[i + 1], name=layer_name))
            self.acts.append(F.relu)

        self.acts[-1] = lambda x: x

    def forward(self, x, sampling=True):
        """Perform the forward pass"""

        for i in range(len(self.steps)):
            x = self.steps[i](x, sampling=sampling)
            x = self.acts[i](x)

        return x

    @property
    def losses(self):
        """Sum of the KL divergences between priors + posteriors"""
        return torch.sum(torch.stack([s.losses for s in self.steps]))

dims = [10, 20, 30, 40]
network = BayesianDenseNetwork(dims=dims, name="network")
x = torch.randn(5, 10)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
network.to(device)
x = x.to(device)
output = network(x)
print(f'output: {output.shape}')
print(network.losses)

the problem is from:

for i in range(len(self.steps)):
            x = self.steps[i](x, sampling=sampling)
            x = self.acts[i](x)

        return x

this for loop. But I can’t figure out why it causes different devices. Could you please help me figure out?

the error output is:

File "/var/lib/condor/execute/slot1/dir_1700200/Process/bnn.py", line 217, in <module>
    output = network(x)
  File "/var/lib/condor/execute/slot1/dir_1700200/miniconda3/envs/studio/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/var/lib/condor/execute/slot1/dir_1700200/Process/bnn.py", line 120, in forward
    x = self.steps[i](x, sampling=sampling)
  File "/var/lib/condor/execute/slot1/dir_1700200/miniconda3/envs/studio/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/var/lib/condor/execute/slot1/dir_1700200/Process/bnn.py", line 54, in forward
    w_samples = F.softplus(self.w_std) * torch.randn(self.d_in, self.d_out)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

You are creating new tensors in the forward which will be created on the CPU by default.
Use the .device attribute of any parameter and move the newly created tensors to this device:

    def forward(self, x, sampling=True):
        """Perform the forward pass"""

        if sampling:
            device = next(self.parameters()).device
            # Flipout-estimated weight samples
            s = torch.bernoulli(torch.full(x.shape, 0.5, device=device)) * 2 - 1
            r = torch.bernoulli(torch.full((x.shape[0], self.d_out), 0.5, device=device)) * 2 - 1
            w_samples = F.softplus(self.w_std) * torch.randn(self.d_in, self.d_out, device=device)
            w_perturbations = r * torch.matmul(x * s, w_samples)
            w_outputs = torch.matmul(x, self.w_loc) + w_perturbations

            # Flipout-estimated bias samples
            r = torch.bernoulli(torch.full((x.shape[0], self.d_out), 0.5, device=device)) * 2 - 1
            b_samples = F.softplus(self.b_std) * torch.randn(self.d_out, device=device)
            b_outputs = self.b_loc + r * b_samples

            return w_outputs + b_outputs

        else:
            return torch.matmul(x, self.w_loc) + self.b_loc
1 Like