RuntimeError: log_vml_cpu not implemented for 'Long'

Hello,
I am working on the Dsprites Dataset and have created a Causal Variational Auto Encoder. I am trying to answer counterfactual queries like “given this image of a heart with this orientation, position, and scale, what would it have looked like if it were a square?”

While building the Structural Causal Model and conditioning on it I am getting the runtime error of l og_vml_cpu not implemented for 'Long' while running the Inference. This looks like some GPU to CPU issue of Pyro or Pytorch but I am not sure. Here is the error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-35-a6f1c970e088> in <module>()
     13 #posterior = MCMC(kernel, num_samples=1000, warmup_steps=50)
     14 
---> 15 posterior = pyro.infer.Importance(conditioned_model, num_samples = 1).run(vae, mu, sigma)
     16 #posterior.run(vae, mu, sigma)
     17 

3 frames
/usr/local/lib/python3.6/dist-packages/pyro/infer/abstract_infer.py in run(self, *args, **kwargs)
    222         self._reset()
    223         with poutine.block():
--> 224             for i, vals in enumerate(self._traces(*args, **kwargs)):
    225                 if len(vals) == 2:
    226                     chain_id = 0

/usr/local/lib/python3.6/dist-packages/pyro/infer/importance.py in _traces(self, *args, **kwargs)
     48             model_trace = poutine.trace(
     49                 poutine.replay(self.model, trace=guide_trace)).get_trace(*args, **kwargs)
---> 50             log_weight = model_trace.log_prob_sum() - guide_trace.log_prob_sum()
     51             yield (model_trace, log_weight)
     52 

/usr/local/lib/python3.6/dist-packages/pyro/poutine/trace_struct.py in log_prob_sum(self, site_filter)
    189                 else:
    190                     try:
--> 191                         log_p = site["fn"].log_prob(site["value"], *site["args"], **site["kwargs"])
    192                     except ValueError:
    193                         _, exc_value, traceback = sys.exc_info()

/usr/local/lib/python3.6/dist-packages/pyro/distributions/delta.py in log_prob(self, x)
     58     def log_prob(self, x):
     59         v = self.v.expand(self.shape())
---> 60         log_prob = (x == v).type(x.dtype).log()
     61         log_prob = sum_rightmost(log_prob, self.event_dim)
     62         return log_prob + self.log_density

RuntimeError: log_vml_cpu not implemented for 'Long' 

Here is the code of my SCM:

from pyro.infer.importance import Importance
from pyro.infer.mcmc import MCMC
from pyro.infer.mcmc.nuts import HMC


intervened_model = pyro.do(SCM, data={"Y_shape": torch.tensor(1)})
conditioned_model = pyro.condition(intervened_model, data={
                                       "X": recon_x1, 
                                       "Y_shape": torch.tensor(0),
                                       "Z":z1})

#kernel = HMC(conditioned_model, step_size=0.8, num_steps=4)
#posterior = MCMC(kernel, num_samples=1000, warmup_steps=50)

posterior = pyro.infer.Importance(conditioned_model, num_samples = 1).run(vae, mu, sigma)
#posterior.run(vae, mu, sigma)

marginal = posterior.EmpiricalMarginal(posterior, )

print(type(posterior))
print(posterior)

result = []
for i in range(10):
  trace = posterior()
  x = trace.nodes['Nx']['value']
  y = trace.nodes['Ny']['value']
  z = trace.nodes['Nz']['value']
  con_obj = pyro.condition(intervened_model, data = {"Nx": x,"Ny": y, "Nz": z})
#   result.append(con_obj()[2])
  
# recon_x2,y2,z2 = con_obj(vae, mu, sigma)
# print(y2)
# recon_check(recon_x1.reshape(-1, 64, 64)[0], recon_x2.reshape(-1, 64, 64)[0])

Please let me know how to debug this or what the issue is. Highly appreciated

Could you post the type() of the data you are passing, i.e. recon_x1, vae, mu, etc.?
It seems the internal cast to x.dtype is creating the error, but I’m not sure, which input value corresponds to x.

Thanks! Could you post the recon_x1.type() etc.?
This yields more information as the built-in Python type method.

PS: It’s better to post code snippets than images. :wink:

1 Like

Hey, I think I have resolved that particular error for now. Thank you for your support, much appreciated!

The issue was in dist.Delta(Nx < p.cpu()). @neerajprad helped on the pyro forum with this. https://forum.pyro.ai/t/runtimeerror-log-vml-cpu-not-implemented-for-long/1206/7?u=viralpandey

1 Like