trying to convert keras code to pytorch

Something like this might work for your case: