Zero gradient in firt layer

Hey guys,

I have a very simple model but I’m struggling in understanding why I obtain zero gradients for the first layer and very lowee gradients for the last one. My code is:

import os

import torch

from torch import nn

from torch import optim

from torch.utils.data import DataLoader

from torchvision import datasets, transforms

device = ‘cuda’ if torch.cuda.is_available() else ‘cpu’

training_set = datasets.MNIST(

root="MNIST",

train=True,

download=True,

transform=transforms.ToTensor()

)

test_set = datasets.MNIST(

root="MNIST",

train=False,

download=True,

transform=transforms.ToTensor()

)

train_dataloader = DataLoader(training_set, batch_size=64, shuffle=True)

test_dataloader = DataLoader(test_set, batch_size=64, shuffle=True)

model = nn.Sequential(

nn.Linear(28*28, 50),

nn.ReLU(),

nn.Linear(50, 10)

)

optimizer = optim.SGD(model.parameters(), lr=0.1)

criterion = nn.CrossEntropyLoss()

input, output = next(iter(train_dataloader))

input = input.view(input.shape[0], input.shape[2]*input.shape[3])

optimizer.zero_grad()

logits = model(input)

loss = criterion(logits, output)

loss.backward()

optimizer.step()

Someone can spot the mistake? Thank you

I’m not sure how you’ve checked the gradient in the first layer, but running your code for a single step I get:

print((model[0].weight.grad))
> tensor([[0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          ...,
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.],
          [0., 0., 0.,  ..., 0., 0., 0.]])

While this looks like an all-zero gradient, note that the output is truncated.
Visualizing an example gradient of the first layer via:

f, axarr = plt.subplots(1, 2)
axarr[0].imshow(model[0].weight.grad[0].view(28, 28))
axarr[1].imshow(model[0].weight.grad[0].view(28, 28) == 0)

gives:
image

So I would assume that the raw inputs containing zeros (the borders of the MNIST images) are creating the zero gradient.

Hey, thanks @ptrblck for your answer. It is very helpful, never thought of plotting the gradients :smile: