.grad should not equal None here

TL;DR : Followed the FAQ titled [Why are my tensor’s gradients unexpectedly None or not None?], the test_tensor.grad is still None.

Hi, I have a question about the behaviour of autograd. This problem involves using LLaVA 1.5 13b, directly from the official model repository. The full code to reproduce the problem is below. I am trying to calculate gradients for test_tensor. I’ve followed the guide above by:

  1. Making sure that I call retain_grad() on it, and making sure it requires gradient computation
  2. No non-differentiable operations

Thing is, I don’t believe that the model operations are non-differentiable. Is there anything that I am missing here? Thanks in advance :slight_smile:

import torch 
import torch.nn.functional as F 
import torch.nn as nn

from llava.mm_utils import get_model_name_from_path
from llava.model.builder import load_pretrained_model

import numpy as np
import os 

device = ("cuda" if torch.cuda.is_available else "cpu")
scaler = torch.cuda.amp.GradScaler()  # For mixed precision training

# loads llava 1.5 model from repo 
model_id = "liuhaotian/llava-v1.5-13b"
model_path = os.path.expanduser(model_id)
model_name = get_model_name_from_path(model_id)
_, model, _, _ = load_pretrained_model(model_path, None, model_name)
model_path, model_name

# init dummy inputs to represent the image and text tokens:
model.train()
test_id = torch.randint(low=0, high=10, size=(1,58), device=device)
test_tensor = torch.randn(1, 3,336,336, requires_grad=True, device=device).half()
answer_ids = torch.randint(low=0, high=10, size= (1,58), device=device)

print(f"test_tensor is a leaf: {test_tensor.is_leaf}")

test_tensor.requires_grad_(True).retain_grad()

# run inference on model 
output_ids = model(input_ids=test_id, images=test_tensor)
output = output_ids.logits 

dummy_loss = F.cross_entropy(output[:, -1, :], answer_ids[:, 0])

dummy_loss.backward()
print(f"test_tensor grad: {test_tensor.grad}")
print(f"test_tensor grad_fn: {test_tensor.grad_fn}")

This yields the following outputs:

/home/w/weiyong/miniconda3/envs/llava/lib/python3.10/site-packages/huggingface_hub/file_download.py:795: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
You are using a model of type llava to instantiate a model of type llava_llama. This is not supported for all configurations of models and can yield errors.
Loading checkpoint shards: 100%
 3/3 [00:04<00:00,  1.39s/it]
/home/w/weiyong/miniconda3/envs/llava/lib/python3.10/site-packages/torch/_utils.py:831: UserWarning: TypedStorage is deprecated. It will be removed in the future and UntypedStorage will be the only storage class. This should only matter to you if you are using storages directly.  To access UntypedStorage directly, use tensor.untyped_storage() instead of tensor.storage()
  return self.fget.__get__(instance, owner)()
('liuhaotian/llava-v1.5-13b', 'llava-v1.5-13b')
test_tensor is a leaf: False
test_tensor grad: None
test_tensor grad_fn: <ToCopyBackward0 object at 0x7f72debe9b70>

test_tensor is not a leaf tensor since half() is a differentiable operation.
This will also raise the expected warning:

device = "cuda"
test_tensor = torch.randn(1, 3,336,336, requires_grad=True, device=device).half()

print(f"test_tensor is a leaf: {test_tensor.is_leaf}")
# test_tensor is a leaf: False

out = test_tensor * 2
out.mean().backward()

print(test_tensor.grad)
# UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.

suggesting workaround.

Ah yes, I do recall seeing that warning, but I thought a good workaround would simply be to call retain_grad(). It seems like the problem might be with model loading, do you think that’s the case?

No, the model loading is irrelevant and my simple code snippet already shows the error. Either initialize the tensor directly in tbe right dtype or create a new leaf after transforming it.