I am trying to replicate the tutorial model found at What is torch.nn really? — PyTorch Tutorials 2.1.0+cu121 documentation, but in a functional style rather than OO style. When I train the model, it works sometimes, but sometimes the model gets into a state where it will return all 0s for the output for some training and validation cases. When it gets into this state, it fails to learn properly.
Sometimes I’ll get lucky with the random starting weights and training data shuffle and it will train flawlessly without any problems.
My guess is that there is some bug in my model that is causing values to reach extreme numbers (0, inf) or something, but I don’t know what it might be. I believe I have replicated the tutorial model, other than syntax/style, but maybe I’m missing something.
The gist of my model is the following (reformatted a bit to facilitate troubleshooting):
from torch.nn.functional import cross_entropy, conv2d, relu, adaptive_avg_pool2d
from functools import partial
from torch.nn import Parameter
from torch.nn.init import uniform_
empty_tensor = partial(torch.empty, device=device)
def gen_conv2d_parameters(chan_in, chan_out, kernel_size, groups):
weight_tensor = empty_tensor(chan_out, chan_in // groups, kernel_size, kernel_size)
bias_tensor = empty_tensor(chat_out)
bound = math.sqrt(groups / (chan_in * kernel_size * kernel_size))
weight_parameter = Parameter(uniform_(weight_tensor, -bound, bound))
bias_parameter = Parameter(uniform_(bias_tensor, -bound, bound))
# ... store parameters ...
return weight_parameter, bias_parameter
weight1, bias1 = gen_conv2d_weight(1, 16, 3, 1)
weight2, bias2 = gen_conv2d_weight(16, 16, 3, 1)
weight3, bias3 = gen_conv2d_weight(16, 10, 3, 1)
def model(x):
x = x.view(-1, 1, 28, 28)
x = relu(conv2d(x, weight1, bias1, stride=2, padding=1)
x = relu(conv2d(x, weight2, bias2, stride=2, padding=1)
x = relu(conv2d(x, weight3, bias3, stride=2, padding=1)
x = adaptive_avg_pool2d(x, 1)
x = x.view(-1, 10)
def train_one_batch(input_batch):
output_batch = model(input_batch)
loss = cross_entropy(output_batch, expected_batch)
loss.backward()
optimizer.step()
optimizer.zero_grad()