Biaffine module in pytorch

Hi everyone,

I have been reading articles on parse Biaffine for several days, but I just can’t understand how the biaffine module should be implemented. My parser uses an affine MLP and I just can’t find how to change it to Biaffine.

The structure of my parser is as follows:

Parser(
  (dropout): Dropout(p=0.6, inplace=False)
  (word_embedding): Embedding(381, 100, padding_idx=0)
  (tag_embedding): Embedding(17, 40, padding_idx=0)
  (bilstm): LSTM(908, 600, num_layers=3, batch_first=True, dropout=0.3, bidirectional=True)
  (bilstm_to_hidden1): Linear(in_features=1200, out_features=500, bias=True)
  (hidden1_to_hidden2): Linear(in_features=500, out_features=150, bias=True)
  (hidden2_to_pos): Linear(in_features=150, out_features=45, bias=True)
  (hidden2_to_dep): Linear(in_features=150, out_features=41, bias=True)
)

where Embedding = nn.Embedding and Linear = nn.Linear.
Now, what I need to change to implement BiAffine layer?
I think that this two layers:

self.bilstm_to_hidden1 = nn.Linear(...)
self.hidden1_to_hidden2 = nn.Linear(...)

are simply MLPs in which there is only one Linear layer.
So I think the right thing to do is replace

self.hidden2_to_pos = nn.Linear(...)
self.hidden2_to_dep = nn.Linear(...)

with something a BiAffine module. What I found as a ““standard”” implementation in pytorch is the following:

class BiAffine(nn.Module):
    """Biaffine attention layer."""
    def __init__(self, input_dim, output_dim):
        super(BiAffine, self).__init__()
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.U = nn.Parameter(torch.FloatTensor(output_dim, input_dim, input_dim))
        nn.init.xavier_uniform(self.U)

    def forward(self, Rh, Rd):
        Rh = Rh.unsqueeze(1)
        Rd = Rd.unsqueeze(1)
        S = Rh @ self.U @ Rd.transpose(-1, -2)
        return S.squeeze(1)

but it doesn’t seem to work.

What am I doing wrong?
Do you have any ideas / suggestions?
Do you know if there is any code / module that I can use as a baseline to implement this module?

I no longer know where to turn.
Thanks a lot to everyone.