Can't recreate backward prop with autograd.Function

Hello,
I will go fast in the context.
For a model I need a custom backward because in the forward I’m using the binary version of my weight and in the backward the real version.

The function needed is :

WW.t() / diag(WW.t())
Where W is a binary matrix of the size n*m
Following is the code of some test to try to recreate the backward and compare it to autograd.
test3 and test3b are no working and I can’t understant why.

Can someone help ?

import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Function
from itertools import combinations

# Test 1 : multiplication of a matrix and its transposed
class Test1(Function):
    @staticmethod
    def forward(ctx, weight, weightB):
        # ctx.save_for_backward(weight, weightB)

        transposed_matrix = weightB.t()
        product = torch.matmul(weightB, transposed_matrix)
        
        ctx.save_for_backward(weight, weightB)
        return product

    @staticmethod
    def backward(ctx, grad_output):
        weight, weightB = ctx.saved_tensors
        grad_weight = torch.matmul(grad_output,  2*weight)
        return grad_weight, None
    
def test1():
    W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]]) 
    W1 = W.detach()

    
    W.requires_grad = True
    W1.requires_grad = True

    transposed_matrix = W.t()
    product = torch.matmul(W, transposed_matrix)
    product=product.sum()

    product2 = Test1.apply(W1, W1.detach())
    product2 = product2.sum()

    torch.autograd.backward([product, product2])

    assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
    print("test1 -matmul- ok")

# Test 2 : divide a matrix by a vector ove the columns
class Test2(Function):
    @staticmethod
    def forward(ctx, weight, weightB):
        diag = torch.tensor([1,2,3])
        divid = weightB / diag
        ctx.save_for_backward(weight, weightB, diag)
        return divid

    @staticmethod
    def backward(ctx, grad_output):
        weight, weightB, diag = ctx.saved_tensors
        grad_weight =  grad_output / diag
        return grad_weight, None
    
def test2():
    W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]]) 
    W1 = W.detach()

    
    W.requires_grad = True
    W1.requires_grad = True

    diag = torch.tensor([1,2,3])
    divid = W / diag
    divid = divid.sum()

    divid2 = Test2.apply(W1, W1.detach())
    divid2 = divid2.sum()

    torch.autograd.backward([divid, divid2])


    assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
    print("test2 -divid- ok")

class Test2b(Function):  # version with diagonal from the input
    @staticmethod
    def forward(ctx, weight):
        diag = torch.diag(weight.detach())
        divid = weight / diag
        ctx.save_for_backward(weight, diag)
        return divid

    @staticmethod
    def backward(ctx, grad_output):
        weight, diag = ctx.saved_tensors
        grad_weight =  grad_output / diag
        return grad_weight, None

def test2b():
    W = torch.Tensor([[1,0,2,1],[3,0,0,1], [1,1,2,0], [0,0,4,1]])  
    W1 = W.detach()

    W.requires_grad = True
    W1.requires_grad = True

    diag = torch.diag(W.detach())
    divid = W / diag
    divid = divid.sum()

    divid2 = Test2b.apply(W1)
    divid2 = divid2.sum()

    torch.autograd.backward([divid, divid2])

    assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
    print("test2b -divid- ok")

# Test 3 : chain of test 1 and two
class Test3(Function):
    @staticmethod
    def forward(ctx, weight, weightB):
        transposed_matrix = weightB.t()
        product = torch.matmul(weightB, transposed_matrix)

        diag = torch.diag(product.detach())

        orthogonality = product / diag
        
        ctx.save_for_backward(weight, weightB, diag)
        return orthogonality

    @staticmethod
    def backward(ctx, grad_output):
        weight, weightB, diag = ctx.saved_tensors
        f1 =  grad_output / diag
        grad_weight = torch.matmul(f1,  2*weight)
        return grad_weight, None

    
def test3():
    W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]])  
    W1 = W.detach()
    
    W.requires_grad = True
    W1.requires_grad = True

    transposed_matrix = W.t()
    product = torch.matmul(W, transposed_matrix)

    diag = torch.diag(product.detach())

    orthogonality = product / diag
    orthogonality = orthogonality.sum()

    ortho2 = Test3.apply(W1, W1.detach())
    ortho2 = ortho2.sum()

    torch.autograd.backward([orthogonality, ortho2])

    assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
    print("test3 -chain- ok")

    
# Test 3 bis : chain rules but with two Functions 


def test3b():
    W = torch.Tensor([[1,0,1],[1,0,0], [1,1,1], [0,0,1]])  
    W1 = W.detach()
    
    W.requires_grad = True
    W1.requires_grad = True

    transposed_matrix = W.t()
    product = torch.matmul(W, transposed_matrix)

    diag = torch.diag(product.detach())

    orthogonality = product / diag
    orthogonality = orthogonality.sum()

    ortho2 = Test1.apply(W1, W1.detach())
    ortho2 = Test2b.apply(ortho2)
    ortho2 = ortho2.sum()

    torch.autograd.backward([orthogonality, ortho2])


    assert torch.prod(torch.round(W.grad, decimals=4) == torch.round(W1.grad, decimals=4))
    print("test3b -chain- ok")

if __name__ == '__main__':

    test1()
    test2()
    test2b()
    test3b()
    test3()