I am trying to implement GradCAM for Efficientnet-B5. Could you please help debug the following error?

ENet = EfficientNet.from_pretrained(‘efficientnet-b5’).to(device)
in_features = ENet._fc.in_features
ENet._fc = nn.Linear(in_features= in_features, out_features=2)
ENet = ENet.to(device)

class ENET(nn.Module):
def init(self):
super(ENET, self).init()

    self.ENet = ENet

    #disect the network to access its last convolutional layer
    self.head0 = nn.Sequential(*list(self.ENet.children())[:4])
    #self.head = nn.Sequential(*list(self.ENet.children())[2:4])
    
    self.BN =  nn.BatchNorm2d(2048, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)

    self.tail = nn.Sequential(*list(self.ENet.children())[-4:])

    # placeholder for the gradients
    self.gradients = None

# hook for the gradients of the activations
def activations_hook(self, grad):
    self.gradients = grad
    
def forward(self, *x):
    x = self.head0(x)

    #x = self.head(x)

    #register the hook
    h = x.register_hook(self.activations_hook)
    
    x = self.BN(x)

    x = x.view((1, -1))

    x = self.tail(x)

    return x

# method for the gradient extraction
def get_activations_gradient(self):
    return self.gradients

# method for the activation exctraction
def get_activations(self, x):
    return self.head0(x)

model = ENET()

set the evaluation mode

model.eval()

get the image from the dataloader

img, _ = next(iter(train_loader))
img = img.to(device)
print(img.shape)

get the most likely prediction of the model

pred = model(img)

OUTPUT:
torch.Size([1, 3, 299, 299])

AttributeError Traceback (most recent call last)
in ()
9 print(img.shape)
10 # get the most likely prediction of the model
—> 11 pred = model(img)

8 frames
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
–> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

in forward(self, *x)
26
27 def forward(self, *x):
—> 28 x = self.head0(x)
29
30 #x = self.head(x)

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
–> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py in forward(self, input)
115 def forward(self, input):
116 for module in self:
–> 117 input = module(input)
118 return input
119

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
–> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/efficientnet_pytorch/utils.py in forward(self, x)
268
269 def forward(self, x):
–> 270 x = self.static_padding(x)
271 x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.dilation, self.groups)
272 return x

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
725 result = self._slow_forward(*input, **kwargs)
726 else:
–> 727 result = self.forward(*input, **kwargs)
728 for hook in itertools.chain(
729 _global_forward_hooks.values(),

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/padding.py in forward(self, input)
19
20 def forward(self, input: Tensor) -> Tensor:
—> 21 return F.pad(input, self.padding, ‘constant’, self.value)
22
23 def extra_repr(self) -> str:

/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py in _pad(input, pad, mode, value)
3549 _pad, (input,), input, pad, mode=mode, value=value)
3550 assert len(pad) % 2 == 0, ‘Padding length must be divisible by 2’
-> 3551 assert len(pad) // 2 <= input.dim(), ‘Padding length too large’
3552 if mode == ‘constant’:
3553 return _VF.constant_pad_nd(input, pad, value)

AttributeError: ‘tuple’ object has no attribute ‘dim’