I made this simple class as following:
from vgg import VGG16
import torch.nn as nn
class NewNet(nn.Module):
def __init__(self, n_classes,List,Input):
super(NewNet, self).__init__()
self.Input = Input
self.n_classes = n_classes
self.Base = VGG16()
self.Loc = nn.Conv2d(self.Input.size(1), len(List) * 4, 3, padding=1)
self.Conf = nn.Conv2d(self.Input.size(1), len(List) * (self.n_classes + 1), 3, padding=1)
self.ConfMap = nn.Conv2d(self.Input.size(1), len(List), 3, padding=1)
def forward(self):
x= self.Input
Out1 = self.Base(x)
Loc_Out1 = self.Loc(Out1)
return Loc_Out1
I am not sure why I get this error when I want to use it?
object is not iterable
Also I am newbie in terms of classes, so im not sure if I can do what i did in __init__
or not, if I can do it, how should i call my Class?