Not all model parameters are transfered to GPU

I’m noticing that if I run model.to(device), the tensors used in the forward are not necessarily transferred over the GPU. Below is a minimal example of this

import torch
from torch.distributions import MultivariateNormal

class SimpleModel(torch.nn.Module):
    def __init__(self, q):
        super(SimpleModel, self).__init__()
        self._zero_mean = torch.zeros(2 * q)
        self._eye_covar = torch.eye(2 * q)
        self.mvn = MultivariateNormal(self._zero_mean, self._eye_covar)
    def reparameterize(self, mu, logv):
        eps = torch.randn_like(mu)  
        z = mu + eps * torch.exp(logv)  
        logp = self.mvn.log_prob(eps) 
        return z, logp        
    def forward(self, x):
        return self.reparameterize(x, x)

model = SimpleModel(10).to('cuda')
x = torch.ones(20).to('cuda')
model(x)

This gives the following error

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-15-bd0200007a4a> in <module>
----> 1 model(x)

~/miniconda3/envs/mavi/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    548             result = self._slow_forward(*input, **kwargs)
    549         else:
--> 550             result = self.forward(*input, **kwargs)
    551         for hook in self._forward_hooks.values():
    552             hook_result = hook(self, input, result)

<ipython-input-10-7cdded9571ea> in forward(self, x)
     14         return z, logp
     15     def forward(self, x):
---> 16         return self.reparameterize(x, x)

<ipython-input-10-7cdded9571ea> in reparameterize(self, mu, logv)
     11         eps = torch.randn_like(mu)
     12         z = mu + eps * torch.exp(logv)
---> 13         logp = self.mvn.log_prob(eps)
     14         return z, logp
     15     def forward(self, x):

~/miniconda3/envs/mavi/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py in log_prob(self, value)
    205         if self._validate_args:
    206             self._validate_sample(value)
--> 207         diff = value - self.loc
    208         M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
    209         half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)

RuntimeError: expected device cuda:0 but got device cpu

I know that I can manually override the to() and cuda() methods, so there is a solution. But I’m curious why the to() and cuda() can’t already handle these sorts of edge cases (or if there is something that I am overlooking).

Hi

Yes I don’t think the distribution objects are moved around by the nn.Modules.
But If you want _zero_mean and _eye_covar to be moved around, you need to register them as buffers in the nn.Module.

1 Like