I am trying to implement Spatial BatchNorm in PyTorch. The model contains a number of layers of the shape Conv - [BatchNorm?] - ReLU - [MaxPool?], where the BN and Pooling layers are used only if the corresponding flags are set. The model runs properly when the BatchNorm layer is not used, but when the BN flag is set, there is a RunTime error in the 2nd layer.
RuntimeError: set_sizes_and_strides is not allowed on a Tensor created from .data or .detach().
If your intent is to change the metadata of a Tensor (such as sizes / strides / storage / storage_offset)
without autograd tracking the change, remove the .data / .detach() call and wrap the change in a `with torch.no_grad():` block.
For example, change:
x.data.set_(y)
to:
with torch.no_grad():
x.set_(y)
Here are the implementations of the model and the different layers used.
class FastConv(object):
@staticmethod
def forward(x, w, b, conv_param):
N, C, H, W = x.shape
F, _, HH, WW = w.shape
stride, pad = conv_param['stride'], conv_param['pad']
layer = torch.nn.Conv2d(C, F, (HH, WW), stride=stride, padding=pad)
layer.weight = torch.nn.Parameter(w)
layer.bias = torch.nn.Parameter(b)
tx = x.detach()
tx.requires_grad = True
out = layer(tx)
cache = (x, w, b, conv_param, tx, out, layer)
return out, cache
class BatchNorm(object):
@staticmethod
def forward(x, gamma, beta, bn_param):
"""
Forward pass for batch normalization.
Input:
- x: Data of shape (N, D)
- gamma: Scale parameter of shape (D,)
- beta: Shift paremeter of shape (D,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance.
- running_mean: Array of shape (D,) giving running mean of features
- running_var Array of shape (D,) giving running variance of features
Returns a tuple of:
- out: of shape (N, D)
- cache: A tuple of values needed in the backward pass
"""
mode = bn_param['mode']
eps = bn_param.get('eps', 1e-5)
momentum = bn_param.get('momentum', 0.9)
N, D = x.shape
running_mean = bn_param.get('running_mean', torch.zeros(D, dtype=x.dtype, device=x.device))
running_var = bn_param.get('running_var', torch.zeros(D, dtype=x.dtype, device=x.device))
out, cache = None, None
if mode == 'train':
mean = torch.mean(x, dim = 0)
var = torch.var(x, dim = 0, unbiased = False)
# var = (1./N) * torch.sum((x-mean)**2, 0)
running_mean = momentum*running_mean + (1-momentum)*mean
running_var = momentum*running_var + (1-momentum)*var
x_hat = (x-mean)/torch.sqrt(var + eps)
out = x_hat * gamma + beta
cache = {'x':x, 'mean':mean, 'var':var, 'gamma':gamma, \
'beta':beta, 'eps':eps, 'x_hat':x_hat, 'mode':mode, \
'running_mean':running_mean, 'running_var': running_var}
elif mode == 'test':
x_hat = (x-running_mean)/torch.sqrt(running_var + eps)
out = x_hat * gamma + beta
cache = {'x':x, 'gamma':gamma, \
'beta':beta, 'eps':eps, 'x_hat':x_hat, 'mode':mode, \
'running_mean':running_mean, 'running_var': running_var}
else:
raise ValueError('Invalid forward batchnorm mode "%s"' % mode)
# Store the updated running means back into bn_param
bn_param['running_mean'] = running_mean.detach()
bn_param['running_var'] = running_var.detach()
return out, cache
The Spatial BN class makes use of the normal BN created for the Linear layers by reshaping the tensor into a 2-d (NHW, C) tensor. This is what is causing an issue when I call FastConv in the 2nd layer.
class SpatialBatchNorm(object):
@staticmethod
def forward(x, gamma, beta, bn_param):
"""
Computes the forward pass for spatial batch normalization.
Inputs:
- x: Input data of shape (N, C, H, W)
- gamma: Scale parameter, of shape (C,)
- beta: Shift parameter, of shape (C,)
- bn_param: Dictionary with the following keys:
- mode: 'train' or 'test'; required
- eps: Constant for numeric stability
- momentum: Constant for running mean / variance. momentum=0 means that
old information is discarded completely at every time step, while
momentum=1 means that new information is never incorporated. The
default of momentum=0.9 should work well in most situations.
- running_mean: Array of shape (C,) giving running mean of features
- running_var Array of shape (C,) giving running variance of features
Returns a tuple of:
- out: Output data, of shape (N, C, H, W)
- cache: Values needed for the backward pass
"""
out, cache = None, None
# Replace "pass" statement with your code
N, C, H, W = x.shape
# x = x.clone()
x_temp = x.permute(0, 2, 3, 1).reshape(N*H*W, C)
out, cache = BatchNorm.forward(x_temp, gamma, beta, bn_param)
out = out.reshape(N, H, W, C).permute(0, 3, 1, 2)
return out, cache
Loop for creating model - The first iteration runs properly, but when the self.batchnorm flag is set to True, the 2nd iteration fails at the FastConv forward function call with the RunTime error shown above. This suggests that since the loop runs perfectly without the BN layer, the issue is caused by the BN layer.
X_copy = X
caches = []
for i in range(1, self.num_layers):
cache = []
X_copy, conv_cache = FastConv.forward(X_copy,
self.params['W'+str(i)],
self.params['b'+str(i)],
conv_param
)
# print("Completed Conv of " + str(i))
cache.append(conv_cache)
if self.batchnorm:
X_copy, bn_cache = SpatialBatchNorm.forward(X_copy,
self.params['gamma'+str(i)],
self.params['beta'+str(i)],
self.bn_params[i-1]
)
# print("Completed BN of " + str(i))
cache.append(bn_cache)
X_copy, relu_cache = ReLU.forward(X_copy)
cache.append(relu_cache)
if (i-1) in self.max_pools:
X_copy, max_pool_cache = FastMaxPool.forward(X_copy, pool_param)
cache.append(max_pool_cache)
caches.append(cache)
# print("Completed Entire Process of Layer " + str(i))
scores, linear_cache = Linear.forward(X_copy,
self.params['W'+str(self.num_layers)],
self.params['b'+str(self.num_layers)]
)
caches.append(linear_cache)