I am trying to write feed-forward and back propagation from a computational graph point of view in PyTorch as described in this CS231n page for the MultiLabel Soft Margin Loss function of PyTorch. The feed-forward part, I have figured out and it is giving the same results as the loss function of PyTorch. However, for the backprop portion, I am having trouble.
x = torch.tensor([[ 0.1778, 0.1203, -0.2264]], requires_grad=True) # 1 Row, 3 columns (N=1,C=3) (Batch_size=1, num_classes=3)
y = torch.tensor([[0., 1., 0.]], requires_grad=True)
#logsigmoid(x) = torch.log(1/(1+torch.exp(-x)))
#loss = torch.mean(-(target * logsigmoid(x) + (1 - target) * logsigmoid(-x)),dim=1)
# forward pass
w = -1
l = torch.log(1/(1+torch.exp(-x)))
g = w * x
a = y * l
k = g + l
h = k * y * w
b = h + k
### no change below this line
c = a + b #(7)
d = -1 * c #(8)
e = torch.sum(d,dim=1) #(9)
f = e / x.size(1) #(10)
print(f) # tensor([0.6690], grad_fn=<DivBackward0>), same as value from actual function
f.backward()
print(x.grad.data) # tensor([[ 0.1814, -0.1567, 0.1479]])
print(y.grad.data) # tensor([[-0.0593, -0.0401, 0.0755]])
# back propagation steps from scratch
dfde = 1 / x.size(1)
dfdd = torch.ones(x.size(1)) * dfde
dfdc = w * dfdd
dfda = 1 * dfdc
dfdb = 1 * dfdc
### no change above this line
dfdh = 1 * dfdb
dfdk = 1 * dfdb # ---
# | ----> df/dk is being repeated. This is a problem
dfdk = y * w * dfdh # ---
dfdy = k * w * dfdh
dfdg = 1 * dfdk
dfdl = 1 * dfdk # ---
# | ----> df/dl is being repeated. This is a problem.
dfdl = y * dfda # ---
dfdy = l * dfda #### df/dy = logsigmoid(x) * (-1) * torch.ones(x.size(1)) * 1/x.size(1) -----> This should be correct. But doesn't match the y.grad.data
dfdw = x * dfdg
dfdx = w * dfdg #---
# | -----> df/dx is begin repeated. This is a problem.
dfdx = (dl/dx) * dfdl #---
So, I have mentioned in the comments in the code, where I am currently facing problems. I just don’t know how to go about resolving them.
I just need to find df/dx and df/dy and check whether they are equal to x.grad.data and y.grad.data
The reason being that, I am interested to know what functions df/dx and df/dy actually are.
Can someone please help me out with this particular task?
And if possible, I would also like some general advice on how to overcome such beginner problems in the future.