Implementation SFA in Pytorch

Could someone help me to implement this code in pytorch where the inputs are (batch_size, M, 192) and two classes

image