Could someone help me to implement this code in pytorch where the inputs are (batch_size, M, 192) and two classes