Data-dependent network - problems in dataloader and train loop

for documentation and others seeing the post

I solved my problem by what I mentioned in the previous reply.

although, in my question, rnn padding is no the best to integrate in the collate_fn, I used a dynamic padding strategy and also return a mask tensor in collate_fn

def collate_fn(batch): # something like
    DNA_tf_x, targets= zip(*batch)
    DNA_x, tf_x = zip(*DNA_tf_x)
    max_size_dim1 = max(tf.size(1) for tf in tf_x)
    tf_padded_x = [torch.cat((tf, torch.zeros(tf.size(0), max_size_dim1 - tf.size(1))), dim=1) 
                  if tf.size(1) < max_size_dim1 else tf for tf in tf_x]
    mask_tensor = [torch.tensor([False] * tf.size(1) + [True] * (max_size_dim1 - tf.size(1))) for tf in tf_x]
    return (torch.stack(DNA_x), torch.stack(tf_padded_x).long()), torch.stack([ torch.tensor(i) for i in targets]), torch.stack(mask_tensor)

and in my network, I stack every conv out and using mask to fill the unwanted position, so that I can get real average output

class MyNetwork(nn.Module): # something like
    def __init__():
        ...
    def forward():
        conv_out_list = []
        for tf_1, mask_1 in zip(tf_x.unbind(dim=-2), mask.unbind(dim=-1)):
                conv_out_list.append(custom_conv1d) 
                # where tf_1 will be transform to a kernel, 
                # and some tf_1 are generate by padding rather than valid tf
                # and if enable MaskedBatchNorm1d, it would be like this
                conv_out_list.append(mask_bn(custom_conv1d), mask_1)

        # stack all conv_out by last dim
        conv_out = torch.stack(conv_out_list, dim=-1)

        mask_true = torch.sum(mask, dim=-1)

        # to avoid 0 in the mask_true
        mask_true = mask_true.masked_fill(mask_true==0, 1)
        mask = mask.unsqueeze(1).unsqueeze(2) 

        # fill the padding by mask, so in every data in bs, if the tf_1 is padding, the output should be 0
        # like [bs, :, :, 1]
        conv_out = conv_out.masked_fill(mask,0)

        # take average of the conv_out
        conv_out = torch.sum(conv_out, dim=-1)
        conv_out = torch.div(conv_out, mask_true.view(-1,1,1))  


        return conv_out

and the batch norm also need mask, in my question, I need to block batch-level data after conv_out generate, so I modified a masked batch norm layer posted 2y ago

class MaskedBatchNorm1d(nn.BatchNorm1d):
    def __init__(self, num_features, eps=1e-5, momentum=0.1,
                 affine=True, track_running_stats=True):
        super(MaskedBatchNorm1d, self).__init__(
            num_features,
            eps,
            momentum,
            affine,
            track_running_stats
        )

    def forward(self, inp, mask):
        self._check_input_dim(inp)
        exponential_average_factor = 0.0
        n = mask.sum()
        if n.item() != 0:
            mask = mask / n
            mask = mask.unsqueeze(1).unsqueeze(1).expand(inp.shape)            
            process_inp = inp * mask
        else:
            process_inp = inp

        if self.training and self.track_running_stats:
            if self.num_batches_tracked is not None:
                self.num_batches_tracked += 1
                if self.momentum is None:  
                    exponential_average_factor = 1.0 / float(self.num_batches_tracked)
                else:  
                    exponential_average_factor = self.momentum

        if not self.track_running_stats:  # Should raise an exception if n==1
            mean = (process_inp).sum([0, 2])
            var = ((process_inp ** 2).sum([0, 2]) - mean ** 2) * n / (n - 1)
        elif self.training and n > 1:
            mean = (process_inp).sum([0, 2])
            var = (process_inp ** 2).sum([0, 2]) - mean ** 2
            with torch.no_grad():
                self.running_mean = exponential_average_factor * mean\
                    + (1 - exponential_average_factor) * self.running_mean
                self.running_var = exponential_average_factor * var * n / (n - 1)\
                    + (1 - exponential_average_factor) * self.running_var
        else:
            mean = self.running_mean
            var = self.running_var

        inp = (inp - mean[None, :, None]) / (torch.sqrt(var[None, :, None] + self.eps))
        if self.affine:
            inp = inp * self.weight[None, :, None] + self.bias[None, :, None]

        return inp

I do not check the code explicitly, if anyone would check and if there is anything goes wrong, plz inform me!
thanks