Tying weight of attention matrix


Is there a way to tie the weights of certain parts of a (square) attention matrix?
For example if I divide the matrix in 4 blocks (2 diagonal and 2 off-diagonal) and want to tie the weights of the two off diagonal blocks how could I do that?


say full attention matrix size is N and block size is block_N

attention_W  = torch.zeros(N,N)
attention_W_diag1 = torch.randn(block_N,block_N).requires_grad_(True)
attention_W_diag2 = torch.randn(block_N,block_N).requires_grad_(True)
attention_W_off_diag = torch.randn(block_N,block_N).requires_grad_(True)
attention_W[:block_N,:block_N] = attention_W_diag1
attention_W[block_N:,block_N:] = attention_W_diag2
attention_W[block_N:,:block_N] = attention_W_off_diag
attention_W[:block_N,block_N:] = attention_W_off_diag

Thank you for explaining with the example! I think this is what I was looking for