Thank you @ptrblck and sorry for the ambiguity. I will try to make it clear/
This is your implimentaion of BatchNorm with some modifications(call to the saving function):
class BatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(BatchNorm2d, self).__init__(
num_features, eps, momentum, affine, track_running_stats)
#@torch.jit.script_method
def forward(self, input):
self._check_input_dim(input)
exponential_average_factor = 0.0
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None: # use cumulative moving average
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else: # use exponential moving average
exponential_average_factor = self.momentum
# calculate running estimates
if self.training:
save_statisctics(self.running_mean)
mean = input.mean([0, 2, 3])
# use biased var in train
var = input.var([0, 2, 3], unbiased=False)
n = input.numel() / input.size(1)
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
# update running_var with unbiased var
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
input = (input - mean[None, :, None, None]) / (torch.sqrt(var[None, :, None, None] + self.eps))
if self.affine:
input = input * self.weight[None, :, None, None] + self.bias[None, :, None, None]
return input
the saving function:
@torch.jit.ignore
def save_statisctics(x : torch.Tensor):
random_string = ''.join(random.choice(string.ascii_lowercase) for i in range(16))
current_directory = os.path.dirname(os.path.realpath(__file__))
final_directory = os.path.join(current_directory, r'batchStatistics')
if not os.path.exists(final_directory):
os.makedirs(final_directory)
save_x = x.numpy()
path = os.path.join(final_directory, '{:s}'.format(random_string))
np.save(path, save_x)
and this is the model I am using:
class CNNNet_CIFAR_MyBN(nn.Module):
def __init__(self, conv1_dim=100, conv2_dim=150, conv3_dim=250, conv4_dim=500):
super(CNNNet_CIFAR_MyBN, self).__init__()
self.conv4_dim = conv4_dim
self.conv1 = nn.Conv2d(3, conv1_dim, 5, stride=1, padding=2)
self.conv2 = nn.Conv2d(conv1_dim, conv2_dim, 3, stride=1, padding=2)
self.conv3 = nn.Conv2d(conv2_dim, conv3_dim, 3, stride=1, padding=2)
self.conv4 = nn.Conv2d(conv3_dim, conv4_dim, 3, stride=1, padding=2)
self.pool = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(conv4_dim * 3 * 3, 270) # 3x3 is precalculated and written, you need to do it if you want to change the # of filters
self.fc2 = nn.Linear(270, 150)
self.fc3 = nn.Linear(150, 10)
self.normalize1 = syft_nn.BatchNorm2d(conv1_dim)
self.normalize2 = syft_nn.BatchNorm2d(conv2_dim)
self.normalize3 = syft_nn.BatchNorm2d(conv3_dim)
self.normalize4 = syft_nn.BatchNorm2d(conv4_dim)
def forward(self, x):
x = self.pool(F.relu(self.normalize1((self.conv1(x))))) # first convolutional then batch normalization then relu then max pool
x = self.pool(F.relu(self.normalize2((self.conv2(x)))))
x = self.pool(F.relu(self.normalize3((self.conv3(x)))))
x = self.pool(F.relu(self.normalize4((self.conv4(x)))))
x = x.view(-1, self.conv4_dim * 3 * 3) # flattening the features
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
and here is how I call it:
model = CNNNet_CIFAR_MyBN()
traced_model = torch.jit.trace(model, torch.zeros([1, 3, 32, 32], dtype=torch.float))
By using Tracing like this, the custom batchNorm is not executed when the model is received by the workers. I tried to fix it by adding torch.jit.script_method
in the BatchNorm( the commented line). However, I get several errors:
torch.jit.frontend.UnsupportedNodeError: with statements aren’t supported
and even when I ignore the “with statement” I get this error:
TypeError: ‘ScriptMethodStub’ object is not callable