I have been trying to understand how backprop works in PyTorch. For that purpose, I implemented a simple 4 layer network to predict the labels of mnist dataset. Below is the code.
from torch import nn
import torch
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
# -- network parameters
self.fc1 = nn.Linear(784, 512, bias=False)
self.fc2 = nn.Linear(512, 264, bias=False)
self.fc3 = nn.Linear(264, 128, bias=False)
self.fc4 = nn.Linear(128, 1, bias=False)
def forward(self, y0):
# -- adjacency network (shared)
y1 = self.fc1(y0.squeeze(1).reshape(-1, 784))
y2 = self.fc2(y1)
y3 = self.fc3(y2)
return self.fc4(y3)
def load_data(batch):
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transforms.ToTensor())
TrainDataset = DataLoader(mnist_trainset, batch_size=batch, shuffle=False)
TestDataset = DataLoader(mnist_testset, batch_size=batch, shuffle=False)
return TrainDataset, TestDataset
# -- model
epochs = 10
model = MyModel()
optimizer = optim.SGD(model.parameters(), lr=1e-3)
# -- data
batch_size = 20
TrainDataset, TestDataset = load_data(batch_size)
# -- train
def trainepoch(epoch):
model.train()
train_loss = 0
for batch_idx, (data, target) in enumerate(TrainDataset):
if batch_idx < 10:
optimizer.zero_grad()
y4 = model(data)
e = y4 - target.unsqueeze(1)
loss = 0.5 * e.T @ e
loss.backward()
train_loss += loss.item()
optimizer.step()
print('Train Epoch: {}\tLoss: {:.6f}'.format(epoch, train_loss / 200))
for epoch in range(1, epochs + 1):
trainepoch(epoch)
Please be advised that I purposefully decided to predict the label as a value rather than a classification problem. Next, I tried to reimplement backprop machinery using numpy. Here is the implementation:
import numpy as np
import torch
from keras.datasets import mnist
n_train = 200
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = np.reshape(X_train[:n_train, :, :], (n_train, 784)).T
y_train = np.reshape(y_train[:n_train], (n_train, 1)).T
# -- model
def init_w(input_size, output_size):
w = 1. / np.sqrt(input_size)
return np.random.uniform(-w, w, (output_size, input_size))
def model(x, w1, w2, w3, w4):
y1 = np.matmul(w1, x)
y2 = np.matmul(w2, y1)
y3 = np.matmul(w3, y2)
y4 = np.matmul(w4, y3)
return x, y1, y2, y3, y4
batch_size = 10
batch_n = -(-len(X_train.T) // batch_size)
epochs = 10
eta = 1e-3
w1 = init_w(784, 512)
w2 = init_w(512, 264)
w3 = init_w(264, 128)
w4 = init_w(128, 1)
for epoch in range(1, epochs + 1):
train_loss = 0
for idx in range(batch_n):
# -- training data
y0 = X_train[:, idx * batch_size:(idx + 1) * batch_size] / 256
y_target = y_train[:, idx * batch_size:(idx + 1) * batch_size]
# -- predict
y0, y1, y2, y3, y4 = model(y0, w1, w2, w3, w4)
# -- compute error
e4 = y_target - y4
e3 = np.matmul(w4.T, e4)
e2 = np.matmul(w3.T, e3)
e1 = np.matmul(w2.T, e2)
# -- weight update
w1 = w1 - eta * np.matmul(e1, y0.T)
w2 = w2 - eta * np.matmul(e2, y1.T)
w3 = w3 - eta * np.matmul(e3, y2.T)
w4 = w4 - eta * np.matmul(e4, y3.T)
# -- compute loss
train_loss += 0.5 * np.matmul(e4, e4.T).item()
print(train_loss)
# -- log
print('Train Epoch: {}\tLoss: {:.6f}'.format(epoch, train_loss / n_train))
Unfortunately, my numpy reimplementation blows up after few iteration while the pytorch implementation starts to converge. I was wondering if pytorch is using anything different than simple chain rule in the SGD optimizer or if my numpy implementation is off.
Thanks.