Hi experts and community folks, here I am trying to do some hack using list.children()
import fastai
from torchsummary import summary
from fastai.vision.models.unet import *
from fastai.callbacks import *
from fastai.utils.mem import *
arch = torchvision.models.resnet34(pretrained=False)
li = list(model.children()) #list of the children
print( li[0][0][0] ) -> gives Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
So I do is li[0][0][0] = nn.Conv2d(1, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
So here my input will be one channel which I want as the input to the resnet.
summary(model.to(‘cuda’),input_size=(1, 320, 320)) -> Here I have this error
NotImplementedError Traceback (most recent call last)
in
----> 1 summary(model.to(‘cuda’),input_size=(1, 320, 320))
~/anaconda3/envs/harsimar/lib/python3.6/site-packages/torchsummary/torchsummary.py in summary(model, input_size, batch_size, device)
70 # make a forward pass
71 # print(x.shape)
—> 72 model(*x)
73
74 # remove these hooks
~/anaconda3/envs/harsimar/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
–> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~/anaconda3/envs/harsimar/lib/python3.6/site-packages/torch/nn/modules/container.py in forward(self, input)
98 def forward(self, input):
99 for module in self:
–> 100 input = module(input)
101 return input
102
~/anaconda3/envs/harsimar/lib/python3.6/site-packages/torch/nn/modules/module.py in call(self, *input, **kwargs)
530 result = self._slow_forward(*input, **kwargs)
531 else:
–> 532 result = self.forward(*input, **kwargs)
533 for hook in self._forward_hooks.values():
534 hook_result = hook(self, input, result)
~/anaconda3/envs/harsimar/lib/python3.6/site-packages/torch/nn/modules/module.py in forward(self, *input)
94 registered hooks while the latter silently ignores them.
95 “”"
—> 96 raise NotImplementedError
97
98 def register_buffer(self, name, tensor):
NotImplementedError:
Any views on this?