I haven’t been using PyTorch in a long time and thought I should get back into it as it might be super useful for an upcoming project.
So, was just doing some experimenting to see how the nn.Module API works, but unfortunately, I somehow get weird results (compared to using the autograd funcs manually). I have uploaded the code as a Jupyter notebook to GitHub (https://github.com/rasbt/bugreport/blob/master/pytorch/grad_q/grad_q.ipynb) so that I don’t have to copy all the lengthy code below.
So, I implemented a simple logistic regression classifier using nn.Module:
class LogisticRegression3(torch.nn.Module): def __init__(self, num_features): super(LogisticRegression3, self).__init__() self.linear = torch.nn.Linear(num_features, 1) # initialize weights to zeros here: self.linear.weight.data.zero_() self.linear.bias.data.zero_() def forward(self, x): logits = self.linear(x) probas = F.sigmoid(logits) return probas model = LogisticRegression3(num_features=2) cost_fn = torch.nn.BCELoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.1) num_epochs = 10 X_train_var = Variable(torch.Tensor(X_train), requires_grad=False) y_train_var = Variable(torch.Tensor(y_train), requires_grad=False).view(-1, 1) for epoch in range(num_epochs): out = model(X_train_var) cost = cost_fn(out, y_train_var) optimizer.zero_grad() cost.backward() optimizer.step() print('\nModel parameters:') print(' Weights: %s' % model.linear.weight.data) print(' Bias: %s' % model.linear.bias.data) Model parameters: Weights: 0.3552 0.3401 [torch.FloatTensor of size 1x2] Bias: 1.00000e-02 * -3.8595 [torch.FloatTensor of size 1]
When I run this, everything seems to work fine. However, when I manually compute the gradients, or when I use the autograd funcs directly, I get different weights and biases compared to using the nn.Module API (results from manual gradient computations & autograd gradients are consistent though)
I was just wondering if the nn.Module use above is correct, and if there’s some “magic” in the background that would make the results differ from e.g., doing things manually.