Hi,
Would someone know how to help with this issue? Thanks!
If in the following code snipped, I change the device from CPU to cuda it throws the following error:
import torch
from diffusers import AutoencoderKL
device = "cuda"
vae_2 = AutoencoderKL.from_pretrained("stabilityai/sd-vae-ft-mse", torch_dtype=torch.float32,)
vae_2.train()
classes = 2
vae_2.to(device)
latents = torch.randn(1, 4, 64, 64).to(torch.float32).to(device)
decoded_mask = vae_2.decode(latents, return_dict=False)[0]
decoded_mask.requires_grad_(True)
target = torch.randn(decoded_mask.shape).to(device)
loss = torch.nn.MSELoss()(decoded_mask, target)
loss.backward()
Error:
│ │
│ 25 │
│ 26 print(loss) │
│ 27 │
│ ❱ 28 loss.backward() │
│ 29 │
│ 30 print(vae_2.decoder.conv_out.weight.grad) │
│ 31 print(vae_2.decoder.conv_out.bias.grad) │
│ │
│ /home/nkondapa/anaconda3/envs/neurips2023_env/lib/python3.8/site-packages/to │
│ rch/_tensor.py:487 in backward │
│ │
│ 484 │ │ │ │ create_graph=create_graph, │
│ 485 │ │ │ │ inputs=inputs, │
│ 486 │ │ │ ) │
│ ❱ 487 │ │ torch.autograd.backward( │
│ 488 │ │ │ self, gradient, retain_graph, create_graph, inputs=inputs │
│ 489 │ │ ) │
│ 490 │
│ │
│ /home/nkondapa/anaconda3/envs/neurips2023_env/lib/python3.8/site-packages/to │
│ rch/autograd/__init__.py:200 in backward │
│ │
│ 197 │ # The reason we repeat same the comment below is that │
│ 198 │ # some Python versions print out the first line of a multi-line fu │
│ 199 │ # calls in the traceback and some print out the last line │
│ ❱ 200 │ Variable._execution_engine.run_backward( # Calls into the C++ eng │
│ 201 │ │ tensors, grad_tensors_, retain_graph, create_graph, inputs, │
│ 202 │ │ allow_unreachable=True, accumulate_grad=True) # Calls into th │
│ 203 │
╰──────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA error: invalid argument
CUDA kernel errors might be asynchronously reported at some other API call, so
the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.