I could successfully bypass loss.backward() in a simple network for both MSE loss and CrossEntropy loss and I replaced it with a backward network. For future reference, I am sharing the code. Thanks albanD!
import torch.nn as nn
import torch.optim as optim
import torchvision
import numpy as np
import scipy.stats as ss
import scipy
import torch
from torchvision import datasets, transforms
import scipy.stats as ss
import matplotlib.pylab as plt
imagesetdir = './'
use_cuda = True
batch_size = 1024
kwargs = {'num_workers': 0, 'pin_memory': True, 'drop_last':True} if use_cuda else {}
train_loader = torch.utils.data.DataLoader(
datasets.MNIST(imagesetdir, train=True, download=True,
transform=transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
])),
batch_size=batch_size, shuffle=True, **kwargs)
test_loader = torch.utils.data.DataLoader(
datasets.MNIST(imagesetdir, train=False, transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])),
batch_size=batch_size, shuffle=False, **kwargs)
class Forward(nn.Module):
def __init__(self):
super(Forward, self).__init__()
self.fc_0 = nn.Linear(1024, 40, bias=False)
self.fc_1 = nn.Linear(40, 10, bias=False)
def forward(self, x):
x0 = self.fc_0(x)
x1 = self.fc_1(x0)
return x1, [x, x0, x1]
class Backward(nn.Module):
def __init__(self):
super(Backward, self).__init__()
self.fc_1 = nn.Linear(10, 40, bias=False)
self.fc_0 = nn.Linear(40, 1024, bias=False)
def forward(self, x):
x1 = self.fc_1(x)
x0 = self.fc_0(x1)
return x0, [x1, x]
def transpose_weights(state_dict):
state_dict_new = {}
for k, item in state_dict.items():
state_dict_new.update({k: item.t()})
return state_dict_new
def corr(t0, t1):
return ss.pearsonr(t0.view(-1).detach().cpu().numpy(),t1.view( -1).detach().cpu().numpy() )
# A simple hook class that returns the input and output of a layer during forward/backward pass
class Hook():
def __init__(self, module, backward=False):
if backward==False:
self.hook = module.register_forward_hook(self.hook_fn)
else:
self.hook = module.register_backward_hook(self.hook_fn)
def hook_fn(self, module, input, output):
self.input = input
self.output = output
def close(self):
self.hook.remove()
modelF = Forward().cuda() # main model
modelB = Backward().cuda() # backward network to compute gradients for modelF
modelC = Forward().cuda() # Forward Control model to compare to BP
modelC.load_state_dict(modelF.state_dict())
modelB.load_state_dict(transpose_weights(modelF.state_dict()) )
optimizerC = optim.Adam(modelC.parameters(), lr=0.0001)
optimizer = optim.Adam(modelF.parameters(), lr=0.0001)
criterion = nn.CrossEntropyLoss() #nn.MSELoss() #
# -------Implementing BP without using loss.backward() ----------
hookC = [Hook(layer[1], backward=True) for layer in list(modelC._modules.items())]
n_classes = 10
onehot = torch.zeros(train_loader.batch_size, n_classes).cuda()
Softmax = nn.Softmax(dim=1)
for epoch in range(20):
loss_running = 0
lossC_running = 0
for i, (inputs, target) in enumerate(train_loader):
inputs = inputs.view(train_loader.batch_size, -1).cuda()
target = target.cuda()
onehot.zero_()
onehot.scatter_(1, target.view(train_loader.batch_size,-1), 1)
# ------------- BP Control ------------------------------------------
outputsC, activationsC = modelC(inputs)
lossC = criterion(outputsC, target)
optimizerC.zero_grad()
lossC.backward()
optimizerC.step()
lossC_running += lossC.item()
ParamsC = [p for p in modelC.parameters()]
# -------------Implementing BP bypassing loss.backward()-------------
modelB.load_state_dict(transpose_weights(modelF.state_dict()) )
outputs, activationsF = modelF(inputs)
probs = Softmax(outputs.detach())
# the gradient of CrossEntropy is pi-yi (pi: softmax of the output, yi:onehot label)
grad_input = onehot - probs # for CrossEntropyLoss
#grad_input = (2/n_classes) * (onehot-outputs) # for MSEloss
recons, activationsB = modelB(grad_input)
loss = criterion(outputs, target)
optimizer.zero_grad()
ParamsF = [p for p in modelF.parameters() if p.requires_grad]
# copy the backward gradients into the parameter grads.
for ip, pF in enumerate(ParamsF):
# parameters of the control model
pC = ParamsC[ip]
hC = hookC[::-1][ip].output[0]
aF = activationsF[ip] # forward activations
aB = activationsB[ip] # backward activations
pF.grad = -torch.matmul(aB.t().detach(), aF.clone().detach())
# pF.grad should be close to pC.grad
optimizer.step()
loss_running += loss.item()
print('Epoch %d: Loss= %.3f, Loss_Control= %.3f'%(epoch, loss_running/(i+1), lossC_running/(i+1)))