Dataparallel in customized helper module

Hi,

I am wondering if torch.nn.dataparallel applies to a helper module that doesn’t have forward and parameters need to backward. I attached the code example here (it is an image reconstruction module using optical flow). I tried it and it didn’t work. reporting

Can this module be written in some ways that work for nn.data parallel

Class Reconstructor(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.resampler = Resample2d
        mean = [0.5,0.5,0.5] 
        mean = torch.nn.Parameter(torch.FloatTensor(mean), requires_grad=False)
        std = [0.5,0.5,0.5]
        std = torch.nn.Parameter(torch.FloatTensor(std), requires_grad=False)
        self.normalizer = Normalizer(mean=mean, std=std)

    def reconstruct_images(self, image_batch: torch.Tensor, flows: Union[tuple, list]):
        # SSIM DOES NOT WORK WITH Z-SCORED IMAGES
        # requires images in the range [0,1]. So we have to denormalize for it to work!
        image_batch = self.normalizer.denormalize(image_batch)
        if image_batch.ndim == 4:
            N, C, H, W = image_batch.shape
            num_images = int(C / 3) - 1
            t0 = image_batch[:, :num_images * 3, ...].contiguous().view(N * num_images, 3, H, W)
            t1 = image_batch[:, 3:, ...].contiguous().view(N * num_images, 3, H, W)
        elif image_batch.ndim == 5:
            N, C, T, H, W = image_batch.shape
            num_images = T - 1
            t0 = image_batch[:, :, :num_images, ...]
            t0 = t0.transpose(1, 2).reshape(N * num_images, C, H, W)
            t1 = image_batch[:, :, 1:, ...]
            t1 = t1.transpose(1, 2).reshape(N * num_images, C, H, W)
        else:
            raise ValueError('unexpected batch shape: {}'.format(image_batch))

        reconstructed = []
        t1s = []
        t0s = []
        flows_reshaped = []
        for flow in flows:
            # upsampled_flow = F.interpolate(flow, (h,w), mode='bilinear', align_corners=False)
            if flow.ndim == 4:
                n, c, h, w = flow.size()
                flow = flow.view(N * num_images, 2, h, w)
            else:
                n, c, t, h, w = flow.shape
                flow = flow.transpose(1, 2).reshape(n * t, c, h, w)

            downsampled_t1 = F.interpolate(t1, (h, w), mode='bilinear', align_corners=False)
            downsampled_t0 = F.interpolate(t0, (h, w), mode='bilinear', align_corners=False)
            t0s.append(downsampled_t0)
            t1s.append(downsampled_t1)
            reconstructed.append(self.resampler(downsampled_t1, flow))
            del (downsampled_t1, downsampled_t0)
            flows_reshaped.append(flow)

        return tuple(t0s), tuple(reconstructed), tuple(flows_reshaped)

    def __call__(self, image_batch: torch.Tensor, flows: Union[tuple, list]):
        return self.reconstruct_images(image_batch, flows)

Hi, can you please be more specific about what you refer to as the “helper module”?
Can you also share your current code example of using torch.nn.dataparallel with it? We see the trace, but it would be nice to also look at the code example, thanks.

Thanks for the response. I got it. the mean and std need to be in torch.nn.parameter mode