No grad & Autocast not working together

import torch
from torch import nn
import torch.nn.functional as F

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    for _ in range(t_recursion_steps - 1):
        with torch.no_grad():
            y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach())
        y_latent = y_latent_new
        z_latent = z_latent_new
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)
    y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
scaler = torch.amp.GradScaler()
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
    y_latent, z_latent = deep_recursion(x, y_latent, z_latent)
    example_class = torch.randint(0, 10, (1,)).cuda()

    loss = F.cross_entropy(y_latent, example_class)
    scaler.scale(loss).backward()
    # loss.backward()
    print(net[0].weight.grad)

Output: None


import torch
from torch import nn
import torch.nn.functional as F

net = nn.Sequential(
    nn.Linear(30, 10),
    nn.ReLU(),
    nn.Linear(10, 30),
).cuda()

def latent_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    n_latent_reasoning_steps: int = 3,
    net: nn.Module = net
):
    x_dim = x.shape[-1]
    y_latent_dim = y_latent.shape[-1]
    z_latent_dim = z_latent.shape[-1]
    input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
    for _ in range(n_latent_reasoning_steps):
        output_tensor = net(input_tensor)
        input_tensor = output_tensor + input_tensor
    y = output_tensor[:, x_dim:x_dim+y_latent_dim]
    z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
    return y, z

def deep_recursion(
    x: torch.Tensor, 
    y_latent: torch.Tensor, 
    z_latent: torch.Tensor,
    t_recursion_steps: int = 2,
    net: nn.Module = net
):
    # Don't modify y_latent and z_latent in place within no_grad
    for _ in range(t_recursion_steps - 1):
        with torch.no_grad():
            y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach())
        y_latent = y_latent_new
        z_latent = z_latent_new
    y_latent = y_latent.requires_grad_(True)
    z_latent = z_latent.requires_grad_(True)
    y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
    return y_latent, z_latent

x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
scaler = torch.amp.GradScaler()
# with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
y_latent, z_latent = deep_recursion(x, y_latent, z_latent)
example_class = torch.randint(0, 10, (1,)).cuda()

loss = F.cross_entropy(y_latent, example_class)
# scaler.scale(loss).backward()
loss.backward()
print(net[0].weight.grad)

Output: tensor([[-3.9366e-01, -4.4403e-01, 2.9051e-01, 4.8551e-01, -1.5548e-01, -1.0718e-01, -2.3545e-01, 1.6691e-01, 3.5807e-01, 4.8783e-02, -1.7829e-02, 2.6798e-01, 1.1731e-01, 1.4514e-01, -1.2646e-01, -1.9012e-02, 1.9878e-01, -1.5690e-02, -3.9625e-02, -1.4597e-01, 1.9006e-01, -1.1753e-01, 3.0696e-02, -2.9123e-02, 2.5830e-01, 6.5135e-03, -7.6369e-04, 2.8767e-02, 4.2457e-02, 1.0483e-01],…

I am confused why autocast is creating this gradient mismatch. Any advice or is this a context bug?

This post might be related.