How to implement 3-order hyper edge conv?

I need to do something like:

def message(self, x_i, x_j, x_k):
    input = torch.cat([x_i-x_j, x_i-x_k])

HypergraphConv could have been an option, but I can’t figure out a simple way to adapt that to this use case. Any help will be highly appreciated. Thanks.