class Unet(nn.Module):
def __init__(self):
super().__init__()
fs = 16
self.up1 = UnetBlock(448,272,fs)
self.up2 = UnetBlock(fs,112,fs)
self.up3 = UnetBlock(fs,32,fs)
self.up4 = nn.ConvTranspose2d(fs, fs, 2, stride=2)
self.up5 = nn.ConvTranspose2d(fs, fs, 2, stride=2)
self.logit = nn.Sequential(nn.Conv2d(fs,fs,kernel_size=3,padding=1),nn.ReLU(inplace=True),
nn.Conv2d(fs,2,kernel_size=1,padding=0))
self.drop = nn.Dropout2d(0.1)
def forward(self,e0,e1,e2,e3,e4,img):
#pdb.set_trace()
img_sz = img.size(2)
d1 = self.drop(self.up1(e4, e3))
d2 = self.drop(self.up2(d1, e2))
d3 = self.drop(self.up3(d2, e1))
d4 = self.drop(self.up4(d4))
d5 = self.drop(self.up5(d5))
out = self.logit(d5)
return out
class UnetBlock(nn.Module):
def __init__(self, up_in, x_in, n_out):
super().__init__()
up_out = x_out = n_out//2
self.x_conv = nn.Conv2d(x_in, x_out, 1)
self.tr_conv = nn.ConvTranspose2d(up_in, up_out, 3, stride=1)
self.bn = nn.BatchNorm2d(n_out)
def forward(self, up_p, x_p):
up_p = self.tr_conv(up_p)
x_p = self.x_conv(x_p)
cat_p = torch.cat([up_p,x_p], dim=1)
return (self.bn(F.relu(cat_p)))
RuntimeError Traceback (most recent call last)
<ipython-input-73-83d934609393> in <module>
----> 1 learn.fit_one_cycle(3, lr, callbacks=[AccumulateStep(learn,n_acc)])
/opt/conda/lib/python3.6/site-packages/fastai/train.py in fit_one_cycle(learn, cyc_len, max_lr, moms, div_factor, pct_start, final_div, wd, callbacks, tot_epochs, start_epoch)
20 callbacks.append(OneCycleScheduler(learn, max_lr, moms=moms, div_factor=div_factor, pct_start=pct_start,
21 final_div=final_div, tot_epochs=tot_epochs, start_epoch=start_epoch))
---> 22 learn.fit(cyc_len, max_lr, wd=wd, callbacks=callbacks)
23
24 def lr_find(learn:Learner, start_lr:Floats=1e-7, end_lr:Floats=10, num_it:int=100, stop_div:bool=True, wd:float=None):
/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(self, epochs, lr, wd, callbacks)
198 callbacks = [cb(self) for cb in self.callback_fns + listify(defaults.extra_callback_fns)] + listify(callbacks)
199 if defaults.extra_callbacks is not None: callbacks += defaults.extra_callbacks
--> 200 fit(epochs, self, metrics=self.metrics, callbacks=self.callbacks+callbacks)
201
202 def create_opt(self, lr:Floats, wd:Floats=0.)->None:
/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in fit(epochs, learn, callbacks, metrics)
99 for xb,yb in progress_bar(learn.data.train_dl, parent=pbar):
100 xb, yb = cb_handler.on_batch_begin(xb, yb)
--> 101 loss = loss_batch(learn.model, xb, yb, learn.loss_func, learn.opt, cb_handler)
102 if cb_handler.on_batch_end(loss): break
103
/opt/conda/lib/python3.6/site-packages/fastai/basic_train.py in loss_batch(model, xb, yb, loss_func, opt, cb_handler)
24 if not is_listy(xb): xb = [xb]
25 if not is_listy(yb): yb = [yb]
---> 26 out = model(*xb)
27 out = cb_handler.on_loss_begin(out)
28
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
<ipython-input-70-7476c07fb28b> in forward(self, img)
61 def forward(self,img):
62 e0,e1,e2,e3,e4 = self.rn(img)
---> 63 out = self.unet(e0,e1,e2,e3,e4,img)
64 return out
65
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
<ipython-input-70-7476c07fb28b> in forward(self, e0, e1, e2, e3, e4, img)
44 #pdb.set_trace()
45 img_sz = img.size(2)
---> 46 d1 = self.drop(self.up1(e4, e3))
47 d2 = self.drop(self.up2(d1, e2))
48 d3 = self.drop(self.up3(d2, e1))
/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py in __call__(self, *input, **kwargs)
491 result = self._slow_forward(*input, **kwargs)
492 else:
--> 493 result = self.forward(*input, **kwargs)
494 for hook in self._forward_hooks.values():
495 hook_result = hook(self, input, result)
<ipython-input-70-7476c07fb28b> in forward(self, up_p, x_p)
81 x_p = self.x_conv(x_p)
82 temp_conv.append(x_p.size())
---> 83 cat_p = torch.cat([up_p,x_p], dim=1)
84 temp_cat.append(cat_p.size())
85 return self.act(self.bn(F.relu(cat_p)))
RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 8 and 10 in dimension 2 at /opt/conda/conda-bld/pytorch_1556653099582/work/aten/src/THC/generic/THCTensorMath.cu:71
I am trying to create a Unet architeture with EfficinetNetB4 as a backbone. The main error is coming from the Unet block in the nn.ConvTranspose2d. The nn.Conv2d is giving the expected size which is (8,8,8,8), but Conv2DTranspose is giving a size of (8,8,10,10). So while using nn.concat the dimension error is occuring. Can someone please help me in specifying the required setting of dimensions?