I’m noticing that if I run model.to(device)
, the tensors used in the forward
are not necessarily transferred over the GPU. Below is a minimal example of this
import torch
from torch.distributions import MultivariateNormal
class SimpleModel(torch.nn.Module):
def __init__(self, q):
super(SimpleModel, self).__init__()
self._zero_mean = torch.zeros(2 * q)
self._eye_covar = torch.eye(2 * q)
self.mvn = MultivariateNormal(self._zero_mean, self._eye_covar)
def reparameterize(self, mu, logv):
eps = torch.randn_like(mu)
z = mu + eps * torch.exp(logv)
logp = self.mvn.log_prob(eps)
return z, logp
def forward(self, x):
return self.reparameterize(x, x)
model = SimpleModel(10).to('cuda')
x = torch.ones(20).to('cuda')
model(x)
This gives the following error
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-15-bd0200007a4a> in <module>
----> 1 model(x)
~/miniconda3/envs/mavi/lib/python3.8/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
548 result = self._slow_forward(*input, **kwargs)
549 else:
--> 550 result = self.forward(*input, **kwargs)
551 for hook in self._forward_hooks.values():
552 hook_result = hook(self, input, result)
<ipython-input-10-7cdded9571ea> in forward(self, x)
14 return z, logp
15 def forward(self, x):
---> 16 return self.reparameterize(x, x)
<ipython-input-10-7cdded9571ea> in reparameterize(self, mu, logv)
11 eps = torch.randn_like(mu)
12 z = mu + eps * torch.exp(logv)
---> 13 logp = self.mvn.log_prob(eps)
14 return z, logp
15 def forward(self, x):
~/miniconda3/envs/mavi/lib/python3.8/site-packages/torch/distributions/multivariate_normal.py in log_prob(self, value)
205 if self._validate_args:
206 self._validate_sample(value)
--> 207 diff = value - self.loc
208 M = _batch_mahalanobis(self._unbroadcasted_scale_tril, diff)
209 half_log_det = self._unbroadcasted_scale_tril.diagonal(dim1=-2, dim2=-1).log().sum(-1)
RuntimeError: expected device cuda:0 but got device cpu
I know that I can manually override the to()
and cuda()
methods, so there is a solution. But I’m curious why the to()
and cuda()
can’t already handle these sorts of edge cases (or if there is something that I am overlooking).