for documentation and others seeing the post
I solved my problem by what I mentioned in the previous reply.
although, in my question, rnn padding is no the best to integrate in the collate_fn, I used a dynamic padding strategy and also return a mask tensor in collate_fn
def collate_fn(batch): # something like
DNA_tf_x, targets= zip(*batch)
DNA_x, tf_x = zip(*DNA_tf_x)
max_size_dim1 = max(tf.size(1) for tf in tf_x)
tf_padded_x = [torch.cat((tf, torch.zeros(tf.size(0), max_size_dim1 - tf.size(1))), dim=1)
if tf.size(1) < max_size_dim1 else tf for tf in tf_x]
mask_tensor = [torch.tensor([False] * tf.size(1) + [True] * (max_size_dim1 - tf.size(1))) for tf in tf_x]
return (torch.stack(DNA_x), torch.stack(tf_padded_x).long()), torch.stack([ torch.tensor(i) for i in targets]), torch.stack(mask_tensor)
and in my network, I stack every conv out and using mask to fill the unwanted position, so that I can get real average output
class MyNetwork(nn.Module): # something like
def __init__():
...
def forward():
conv_out_list = []
for tf_1, mask_1 in zip(tf_x.unbind(dim=-2), mask.unbind(dim=-1)):
conv_out_list.append(custom_conv1d)
# where tf_1 will be transform to a kernel,
# and some tf_1 are generate by padding rather than valid tf
# and if enable MaskedBatchNorm1d, it would be like this
conv_out_list.append(mask_bn(custom_conv1d), mask_1)
# stack all conv_out by last dim
conv_out = torch.stack(conv_out_list, dim=-1)
mask_true = torch.sum(mask, dim=-1)
# to avoid 0 in the mask_true
mask_true = mask_true.masked_fill(mask_true==0, 1)
mask = mask.unsqueeze(1).unsqueeze(2)
# fill the padding by mask, so in every data in bs, if the tf_1 is padding, the output should be 0
# like [bs, :, :, 1]
conv_out = conv_out.masked_fill(mask,0)
# take average of the conv_out
conv_out = torch.sum(conv_out, dim=-1)
conv_out = torch.div(conv_out, mask_true.view(-1,1,1))
return conv_out
and the batch norm also need mask, in my question, I need to block batch-level data after conv_out generate, so I modified a masked batch norm layer posted 2y ago
class MaskedBatchNorm1d(nn.BatchNorm1d):
def __init__(self, num_features, eps=1e-5, momentum=0.1,
affine=True, track_running_stats=True):
super(MaskedBatchNorm1d, self).__init__(
num_features,
eps,
momentum,
affine,
track_running_stats
)
def forward(self, inp, mask):
self._check_input_dim(inp)
exponential_average_factor = 0.0
n = mask.sum()
if n.item() != 0:
mask = mask / n
mask = mask.unsqueeze(1).unsqueeze(1).expand(inp.shape)
process_inp = inp * mask
else:
process_inp = inp
if self.training and self.track_running_stats:
if self.num_batches_tracked is not None:
self.num_batches_tracked += 1
if self.momentum is None:
exponential_average_factor = 1.0 / float(self.num_batches_tracked)
else:
exponential_average_factor = self.momentum
if not self.track_running_stats: # Should raise an exception if n==1
mean = (process_inp).sum([0, 2])
var = ((process_inp ** 2).sum([0, 2]) - mean ** 2) * n / (n - 1)
elif self.training and n > 1:
mean = (process_inp).sum([0, 2])
var = (process_inp ** 2).sum([0, 2]) - mean ** 2
with torch.no_grad():
self.running_mean = exponential_average_factor * mean\
+ (1 - exponential_average_factor) * self.running_mean
self.running_var = exponential_average_factor * var * n / (n - 1)\
+ (1 - exponential_average_factor) * self.running_var
else:
mean = self.running_mean
var = self.running_var
inp = (inp - mean[None, :, None]) / (torch.sqrt(var[None, :, None] + self.eps))
if self.affine:
inp = inp * self.weight[None, :, None] + self.bias[None, :, None]
return inp
I do not check the code explicitly, if anyone would check and if there is anything goes wrong, plz inform me!
thanks