EfficientNet encoder in Unet: Size mismatch while Upsampling

class Unet(nn.Module):
    def __init__(self):
        super().__init__()
        fs = 16
        self.up1 = UnetBlock(448,272,fs) 
        self.up2 = UnetBlock(fs,112,fs)  
        self.up3 = UnetBlock(fs,32,fs)
        self.up4 = nn.ConvTranspose2d(fs, fs, 2, stride=2)
        self.up5 = nn.ConvTranspose2d(fs, fs, 2, stride=2)
        self.logit = nn.Sequential(nn.Conv2d(fs,fs,kernel_size=3,padding=1),nn.ReLU(inplace=True),
                          nn.Conv2d(fs,2,kernel_size=1,padding=0))
        self.drop = nn.Dropout2d(0.1)
             
    def forward(self,e0,e1,e2,e3,e4,img):
        #pdb.set_trace()
        img_sz = img.size(2)     
        d1 = self.drop(self.up1(e4, e3))
        d2 = self.drop(self.up2(d1, e2))
        d3 = self.drop(self.up3(d2, e1))
        d4 = self.drop(self.up4(d4))
        d5 = self.drop(self.up5(d5))
        out = self.logit(d5)
        return out

class UnetBlock(nn.Module):
    def __init__(self, up_in, x_in, n_out):
        super().__init__()
        up_out = x_out = n_out//2
        self.x_conv  = nn.Conv2d(x_in,  x_out,  1)
        self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 3, stride=1)
        self.bn = nn.BatchNorm2d(n_out)
        
    def forward(self, up_p, x_p):
        up_p = self.tr_conv(up_p)
        x_p = self.x_conv(x_p)
        cat_p = torch.cat([up_p,x_p], dim=1)
        return (self.bn(F.relu(cat_p)))  
RuntimeError                              Traceback (most recent call last)
<ipython-input-73-83d934609393> in <module>
----> 1 learn.fit_one_cycle(3, lr, callbacks=[AccumulateStep(learn,n_acc)])

/opt/conda/lib/python3.6/site-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
     20     callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
     21                                        final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 22     learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
     23 
     24 def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None):

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
    198         callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
    199         if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
--> 200         fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
    201 
    202     def create_opt(self, lr:Floats, wd:Floats=0.)->None:

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
     99             for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
    100                 xb, yb = cb_handler.on_batch_begin(xb, yb)
--> 101                 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
    102                 if cb_handler.on_batch_end(loss): break
    103 

/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
     24     if not is_listy(xb): xb = [xb]
     25     if not is_listy(yb): yb = [yb]
---> 26     out = model(*xb)
     27     out = cb_handler.on_loss_begin(out)
     28 

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-70-7476c07fb28b> in forward(self, img)
     61     def forward(self,img):
     62         e0,e1,e2,e3,e4 = self.rn(img)
---> 63         out = self.unet(e0,e1,e2,e3,e4,img)
     64         return out
     65 

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-70-7476c07fb28b> in forward(self, e0, e1, e2, e3, e4, img)
     44         #pdb.set_trace()
     45         img_sz = img.size(2)
---> 46         d1 = self.drop(self.up1(e4, e3))
     47         d2 = self.drop(self.up2(d1, e2))
     48         d3 = self.drop(self.up3(d2, e1))

/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
    491             result = self._slow_forward(*input, **kwargs)
    492         else:
--> 493             result = self.forward(*input, **kwargs)
    494         for hook in self._forward_hooks.values():
    495             hook_result = hook(self, input, result)

<ipython-input-70-7476c07fb28b> in forward(self, up_p, x_p)
     81         x_p = self.x_conv(x_p)
     82         temp_conv.append(x_p.size())
---> 83         cat_p = torch.cat([up_p,x_p], dim=1)
     84         temp_cat.append(cat_p.size())
     85         return self.act(self.bn(F.relu(cat_p)))

RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 8 and 10 in dimension 2 at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/generic/THCTensorMath.cu:71

I am trying to create a Unet architeture with EfficinetNetB4 as a backbone. The main error is coming from the Unet block in the nn.ConvTranspose2d. The nn.Conv2d is giving the expected size which is (8,8,8,8), but Conv2DTranspose is giving a size of (8,8,10,10). So while using nn.concat the dimension error is occuring. Can someone please help me in specifying the required setting of dimensions?

1 Like

Iā€™m trying to do the same, did you have any success on this?