Skimming through the code it seems that tf.nn.moments can be replaced by torch.var, while other functions can be mapped to the torch. namespace without name changes.
Could you post your current approach and explain, where you are stuck at the moment?
import torch
import torch.nn as nn
import torch.nn.functional as F
class EvoNorm1ds0(nn.Module):
__constants__ = ['num_features', 'eps', 'nonlinearity']
def __init__(self, num_features, eps=1e-5, nonlinearity=True):
super(EvoNorm1ds0, self).__init__()
self.num_features = num_features
self.eps = eps
self.nonlinearity = nonlinearity
self.weight = nn.Parameter(torch.Tensor(1, num_features, 1))
self.bias = nn.Parameter(torch.Tensor(1, num_features, 1))
if self.nonlinearity:
self.v = nn.Parameter(torch.Tensor(1, num_features, 1))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.nonlinearity:
nn.init.ones_(self.v)
def group_std(self, x, groups=8):
N, C, H = x.shape
x = torch.reshape(x, (N, groups, C // groups, H))
std = torch.std(x, 3, keepdim=True)
return torch.reshape(std + self.eps, (N, C, 1))
def forward(self, x):
if self.nonlinearity:
num = x * F.sigmoid(self.v * x)
return num / self.group_std(x) * self.weight + self.bias
else:
return x * self.weight + self.bias
But I am interested in the batch version.
What I have trouble with is handling the training flag, and register_buffers to keep running mean and running std. The specific part I’m not sure about is how to update those values.