import torch
from torch import nn
import torch.nn.functional as F
net = nn.Sequential(
nn.Linear(30, 10),
nn.ReLU(),
nn.Linear(10, 30),
).cuda()
def latent_recursion(
x: torch.Tensor,
y_latent: torch.Tensor,
z_latent: torch.Tensor,
n_latent_reasoning_steps: int = 3,
net: nn.Module = net
):
x_dim = x.shape[-1]
y_latent_dim = y_latent.shape[-1]
z_latent_dim = z_latent.shape[-1]
input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
for _ in range(n_latent_reasoning_steps):
output_tensor = net(input_tensor)
input_tensor = output_tensor + input_tensor
y = output_tensor[:, x_dim:x_dim+y_latent_dim]
z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
return y, z
def deep_recursion(
x: torch.Tensor,
y_latent: torch.Tensor,
z_latent: torch.Tensor,
t_recursion_steps: int = 2,
net: nn.Module = net
):
# Don't modify y_latent and z_latent in place within no_grad
for _ in range(t_recursion_steps - 1):
with torch.no_grad():
y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach())
y_latent = y_latent_new
z_latent = z_latent_new
y_latent = y_latent.requires_grad_(True)
z_latent = z_latent.requires_grad_(True)
y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
return y_latent, z_latent
x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
scaler = torch.amp.GradScaler()
with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
y_latent, z_latent = deep_recursion(x, y_latent, z_latent)
example_class = torch.randint(0, 10, (1,)).cuda()
loss = F.cross_entropy(y_latent, example_class)
scaler.scale(loss).backward()
# loss.backward()
print(net[0].weight.grad)
Output: None
import torch
from torch import nn
import torch.nn.functional as F
net = nn.Sequential(
nn.Linear(30, 10),
nn.ReLU(),
nn.Linear(10, 30),
).cuda()
def latent_recursion(
x: torch.Tensor,
y_latent: torch.Tensor,
z_latent: torch.Tensor,
n_latent_reasoning_steps: int = 3,
net: nn.Module = net
):
x_dim = x.shape[-1]
y_latent_dim = y_latent.shape[-1]
z_latent_dim = z_latent.shape[-1]
input_tensor = torch.cat([x, y_latent, z_latent], dim=-1)
for _ in range(n_latent_reasoning_steps):
output_tensor = net(input_tensor)
input_tensor = output_tensor + input_tensor
y = output_tensor[:, x_dim:x_dim+y_latent_dim]
z = output_tensor[:, x_dim+y_latent_dim:x_dim+y_latent_dim+z_latent_dim]
return y, z
def deep_recursion(
x: torch.Tensor,
y_latent: torch.Tensor,
z_latent: torch.Tensor,
t_recursion_steps: int = 2,
net: nn.Module = net
):
# Don't modify y_latent and z_latent in place within no_grad
for _ in range(t_recursion_steps - 1):
with torch.no_grad():
y_latent_new, z_latent_new = latent_recursion(x, y_latent.detach(), z_latent.detach())
y_latent = y_latent_new
z_latent = z_latent_new
y_latent = y_latent.requires_grad_(True)
z_latent = z_latent.requires_grad_(True)
y_latent, z_latent = latent_recursion(x, y_latent, z_latent)
return y_latent, z_latent
x = torch.randn(1,10).cuda()
y_latent = torch.randn(1,10).cuda()
z_latent = torch.randn(1,10).cuda()
scaler = torch.amp.GradScaler()
# with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
y_latent, z_latent = deep_recursion(x, y_latent, z_latent)
example_class = torch.randint(0, 10, (1,)).cuda()
loss = F.cross_entropy(y_latent, example_class)
# scaler.scale(loss).backward()
loss.backward()
print(net[0].weight.grad)
Output: tensor([[-3.9366e-01, -4.4403e-01, 2.9051e-01, 4.8551e-01, -1.5548e-01, -1.0718e-01, -2.3545e-01, 1.6691e-01, 3.5807e-01, 4.8783e-02, -1.7829e-02, 2.6798e-01, 1.1731e-01, 1.4514e-01, -1.2646e-01, -1.9012e-02, 1.9878e-01, -1.5690e-02, -3.9625e-02, -1.4597e-01, 1.9006e-01, -1.1753e-01, 3.0696e-02, -2.9123e-02, 2.5830e-01, 6.5135e-03, -7.6369e-04, 2.8767e-02, 4.2457e-02, 1.0483e-01],…
I am confused why autocast is creating this gradient mismatch. Any advice or is this a context bug?