Hi,
I am wondering if torch.nn.dataparallel applies to a helper module that doesn’t have forward and parameters need to backward. I attached the code example here (it is an image reconstruction module using optical flow). I tried it and it didn’t work. reporting
Can this module be written in some ways that work for nn.data parallel
Class Reconstructor(torch.nn.Module):
def __init__(self):
super().__init__()
self.resampler = Resample2d
mean = [0.5,0.5,0.5]
mean = torch.nn.Parameter(torch.FloatTensor(mean), requires_grad=False)
std = [0.5,0.5,0.5]
std = torch.nn.Parameter(torch.FloatTensor(std), requires_grad=False)
self.normalizer = Normalizer(mean=mean, std=std)
def reconstruct_images(self, image_batch: torch.Tensor, flows: Union[tuple, list]):
# SSIM DOES NOT WORK WITH Z-SCORED IMAGES
# requires images in the range [0,1]. So we have to denormalize for it to work!
image_batch = self.normalizer.denormalize(image_batch)
if image_batch.ndim == 4:
N, C, H, W = image_batch.shape
num_images = int(C / 3) - 1
t0 = image_batch[:, :num_images * 3, ...].contiguous().view(N * num_images, 3, H, W)
t1 = image_batch[:, 3:, ...].contiguous().view(N * num_images, 3, H, W)
elif image_batch.ndim == 5:
N, C, T, H, W = image_batch.shape
num_images = T - 1
t0 = image_batch[:, :, :num_images, ...]
t0 = t0.transpose(1, 2).reshape(N * num_images, C, H, W)
t1 = image_batch[:, :, 1:, ...]
t1 = t1.transpose(1, 2).reshape(N * num_images, C, H, W)
else:
raise ValueError('unexpected batch shape: {}'.format(image_batch))
reconstructed = []
t1s = []
t0s = []
flows_reshaped = []
for flow in flows:
# upsampled_flow = F.interpolate(flow, (h,w), mode='bilinear', align_corners=False)
if flow.ndim == 4:
n, c, h, w = flow.size()
flow = flow.view(N * num_images, 2, h, w)
else:
n, c, t, h, w = flow.shape
flow = flow.transpose(1, 2).reshape(n * t, c, h, w)
downsampled_t1 = F.interpolate(t1, (h, w), mode='bilinear', align_corners=False)
downsampled_t0 = F.interpolate(t0, (h, w), mode='bilinear', align_corners=False)
t0s.append(downsampled_t0)
t1s.append(downsampled_t1)
reconstructed.append(self.resampler(downsampled_t1, flow))
del (downsampled_t1, downsampled_t0)
flows_reshaped.append(flow)
return tuple(t0s), tuple(reconstructed), tuple(flows_reshaped)
def __call__(self, image_batch: torch.Tensor, flows: Union[tuple, list]):
return self.reconstruct_images(image_batch, flows)