class TNet(nn.Module):
def __init__(self,number = 3):
super(TNet,self).__init__()
self.number = number
self.conv1 = torch.nn.Conv1d(self.number,64,1)
self.conv2 = torch.nn.Conv1d(64,128,1)
self.conv3 = torch.nn.Conv1d(128,1024,1)
self.fc1 = nn.Linear(1024,512)
self.fc2 = nn.Linear(512,256)
if self.number == 3:
self.fc3 = nn.Linear(256,9)
elif self.number == 64:
self.fc3 = nn.Linear(256,64*64)
self.relu = nn.ReLU()
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.bn4 = nn.BatchNorm1d(512,affine=False)
self.bn5 = nn.BatchNorm1d(256,affine=False)
def forward(self, x):
batchsize = x.size()[0]
x = self.conv1(x)
x = F.relu(self.bn1(x))
x = self.conv2(x)
x = F.relu(self.bn2(x))
x = self.conv3(x)
x = F.relu(self.bn3(x))
x = torch.max(x,2,keepdim=True)[0]
x = x.view(-1,1024)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
if self.number == 3:
iden = Variable(torch.from_numpy(np.array([1,0,0,0,1,0,0,0,1]).astype(np.float32))).view(1,9).repeat(batchsize,1)
elif self.number == 64:
arr = generate_dim()
iden = Variable(torch.from_numpy(np.array(arr).astype(np.float32))).view(1,64*64).repeat(batchsize,1)
if x.is_cuda:
iden = iden.cuda()
x = x + iden
x = x.view(-1, self.number, self.number)
return x
class Mnet(nn.Module):
def __init(self):
super(Mnet,self).__init__()
self.tnet1 = TNet()
self.tnet2= TNet(64)
self.conv1 = torch.nn.Conv1d(3,64,1)
self.conv2 = torch.nn.Conv1d(64,128,1)
self.conv3 = torch.nn.Conv1d(128,1024,1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
def forward(self,x):
batchsize = x.size()[0]
number = x.size()[2]
tranx1 = self.tnet1(x)
x = x.transpose(2,1)
x = torch.bmm(x,tranx1)
x = x.transpose(2,1)
x = self.conv1(x)
x = F.relu(self.bn1(x))
tranx2 = self.tnet2(x)
x = x.tranpose(2,1)
x = torch.bmm(x,tranx2)
x = x.transpose(2,1)
feat = x
x = self.conv2(x)
x = F.relu(self.bn2(x))
x = self.conv3(x)
x = F.relu(self.bn3(x))
x = torch.max(x,2,keepdim=True)[0]
x = x.view(-1,1024)
x = x.view(-1,1024,1).repeat(1,1,number)
return torch.cat([x,feat],1), tranx2
the error said that Mnet has no tnet1