Hello,
I will go fast in the context.
For a model I need a custom backward because in the forward I’m using the binary version of my weight and in the backward the real version.
The function needed is :
WW.t() / diag(WW.t())
Where W is a binary matrix of the size n*m
Following is the code of some test to try to recreate the backward and compare it to autograd.
test3 and test3b are no working and I can’t understant why.
Can someone help ?
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Function
from itertools import combinations
# Test 1 : multiplication of a matrix and its transposed
class Test1(Function):
@staticmethod
def forward(ctx, weight, weightB):
# ctx.save_for_backward(weight, weightB)
transposed_matrix = weightB.t()
product = torch.matmul(weightB, transposed_matrix)
ctx.save_for_backward(weight, weightB)
return product
@staticmethod
def backward(ctx, grad_output):
weight, weightB = ctx.saved_tensors
grad_weight = torch.matmul(grad_output, 2*weight)
return grad_weight, None
def test1():
W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]])
W1 = W.detach()
W.requires_grad = True
W1.requires_grad = True
transposed_matrix = W.t()
product = torch.matmul(W, transposed_matrix)
product=product.sum()
product2 = Test1.apply(W1, W1.detach())
product2 = product2.sum()
torch.autograd.backward([product, product2])
assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
print("test1 -matmul- ok")
# Test 2 : divide a matrix by a vector ove the columns
class Test2(Function):
@staticmethod
def forward(ctx, weight, weightB):
diag = torch.tensor([1,2,3])
divid = weightB / diag
ctx.save_for_backward(weight, weightB, diag)
return divid
@staticmethod
def backward(ctx, grad_output):
weight, weightB, diag = ctx.saved_tensors
grad_weight = grad_output / diag
return grad_weight, None
def test2():
W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]])
W1 = W.detach()
W.requires_grad = True
W1.requires_grad = True
diag = torch.tensor([1,2,3])
divid = W / diag
divid = divid.sum()
divid2 = Test2.apply(W1, W1.detach())
divid2 = divid2.sum()
torch.autograd.backward([divid, divid2])
assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
print("test2 -divid- ok")
class Test2b(Function): # version with diagonal from the input
@staticmethod
def forward(ctx, weight):
diag = torch.diag(weight.detach())
divid = weight / diag
ctx.save_for_backward(weight, diag)
return divid
@staticmethod
def backward(ctx, grad_output):
weight, diag = ctx.saved_tensors
grad_weight = grad_output / diag
return grad_weight, None
def test2b():
W = torch.Tensor([[1,0,2,1],[3,0,0,1], [1,1,2,0], [0,0,4,1]])
W1 = W.detach()
W.requires_grad = True
W1.requires_grad = True
diag = torch.diag(W.detach())
divid = W / diag
divid = divid.sum()
divid2 = Test2b.apply(W1)
divid2 = divid2.sum()
torch.autograd.backward([divid, divid2])
assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
print("test2b -divid- ok")
# Test 3 : chain of test 1 and two
class Test3(Function):
@staticmethod
def forward(ctx, weight, weightB):
transposed_matrix = weightB.t()
product = torch.matmul(weightB, transposed_matrix)
diag = torch.diag(product.detach())
orthogonality = product / diag
ctx.save_for_backward(weight, weightB, diag)
return orthogonality
@staticmethod
def backward(ctx, grad_output):
weight, weightB, diag = ctx.saved_tensors
f1 = grad_output / diag
grad_weight = torch.matmul(f1, 2*weight)
return grad_weight, None
def test3():
W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]])
W1 = W.detach()
W.requires_grad = True
W1.requires_grad = True
transposed_matrix = W.t()
product = torch.matmul(W, transposed_matrix)
diag = torch.diag(product.detach())
orthogonality = product / diag
orthogonality = orthogonality.sum()
ortho2 = Test3.apply(W1, W1.detach())
ortho2 = ortho2.sum()
torch.autograd.backward([orthogonality, ortho2])
assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
print("test3 -chain- ok")
# Test 3 bis : chain rules but with two Functions
def test3b():
W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]])
W1 = W.detach()
W.requires_grad = True
W1.requires_grad = True
transposed_matrix = W.t()
product = torch.matmul(W, transposed_matrix)
diag = torch.diag(product.detach())
orthogonality = product / diag
orthogonality = orthogonality.sum()
ortho2 = Test1.apply(W1, W1.detach())
ortho2 = Test2b.apply(ortho2)
ortho2 = ortho2.sum()
torch.autograd.backward([orthogonality, ortho2])
assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
print("test3b -chain- ok")
if __name__ == '__main__':
test1()
test2()
test2b()
test3b()
test3()