So it turned out that I got NaN values when I used the typical BatchNorm1d. The time-series data has different sequence length. I followed this discussion and wrote a custom class to implement BatchNorm but I still get NaN values
class MaskedNorm(nn.Module):
def __init__(self, num_features, mask_on=True):
"""y is the input tensor of shape [batch_size, time_length,n_channels]
mask is of shape [batch_size, 1, time_length]
"""
# The process:
# 1. Merge the batch and time axes using reshape
# 2. Create a dummy time axis at the end with size 1.
# 3. Select the valid time steps using the mask
# 4. Apply BatchNorm1d to the valid time steps
# 5. Scatter the resulting values to the corresponding positions
# 6. Unmerge the batch and time axes
super().__init__()
self.norm = nn.BatchNorm1d(num_features=num_features)
self.num_features = num_features
self.mask_on = mask_on
#
def forward(self, y, mask=None):
#
self.sequence_length = y.shape[1]
if self.training and self.mask_on:
if mask is None:
seq_len = [torch.max((y[i,:, 0]!=0).nonzero()).item()+1 for i in range(y.shape[0])]
m = torch.zeros([y.shape[0],y.shape[1]+1], dtype=torch.bool).to(y.device)
m[(torch.arange(y.shape[0]), seq_len)] = 1
m = m.cumsum(dim=1)[:, :-1]
mask = (1-m)
reshaped = y.reshape([-1, self.num_features, 1])
reshaped_mask = mask.reshape([-1, 1, 1]) > 0
selected = torch.masked_select(reshaped, reshaped_mask).reshape([-1, self.num_features, 1])
batch_normed = self.norm(selected)
scattered = reshaped.masked_scatter(reshaped_mask, batch_normed)
return scattered.reshape([ -1, self.sequence_length, self.num_features])
else:
reshaped = y.reshape([-1, self.num_features, 1])
batched_normed = self.norm(reshaped)
return batched_normed.reshape([ -1, self.sequence_length, self.num_features])
The input data before BatchNorm doesn’t have any NaN values. I am wondering why I still get NaN values?