I am currently tryting to implement a UNet like architecture.
The tensors I am concatenating have the same shape and are both of the class torch.Tensor.
However, the operation
x = torch.cat((x, enc_out4), dim=1)
returns a tuple of len 1 with the concatenated tensor.
I am using torch 1.8.1.
Help would be very appreciated
Code of model:
class TestModel(nn.Module):
def __init__(self):
super().__init__()
self.num_features = 80
self.encoder1 = nn.Sequential(
nn.Conv3d(in_channels=2, out_channels=self.num_features, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(negative_slope=0.2)
)
self.encoder2 = nn.Sequential(
nn.Conv3d(in_channels=self.num_features, out_channels=self.num_features*2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=self.num_features*2),
nn.LeakyReLU(negative_slope=0.2)
)
self.encoder3 = nn.Sequential(
nn.Conv3d(in_channels=self.num_features*2, out_channels=self.num_features*4, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=self.num_features*4),
nn.LeakyReLU(negative_slope=0.2)
)
self.encoder4 = nn.Sequential(
nn.Conv3d(in_channels=self.num_features*4, out_channels=self.num_features*8, kernel_size=4, stride=1),
nn.BatchNorm3d(num_features=self.num_features*8),
nn.LeakyReLU(negative_slope=0.2)
)
self.bottleneck = nn.Sequential(
nn.Linear(in_features=640, out_features=640),
nn.ReLU(),
nn.Linear(in_features=640, out_features=640),
nn.ReLU()
)
self.decoder1 = nn.Sequential(
nn.ConvTranspose3d(in_channels=self.num_features*8*2, out_channels=self.num_features*4, kernel_size=4, stride=1),
nn.BatchNorm3d(num_features=self.num_features*4),
nn.ReLU()
)
self.decoder2 = nn.Sequential(
nn.ConvTranspose3d(in_channels=self.num_features*4*2, out_channels=self.num_features*2, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=self.num_features*2),
nn.ReLU()
)
self.decoder3 = nn.Sequential(
nn.ConvTranspose3d(in_channels=self.num_features*2*2, out_channels=self.num_features, kernel_size=4, stride=2, padding=1),
nn.BatchNorm3d(num_features=self.num_features),
nn.ReLU()
)
self.decoder4 = nn.Sequential(
nn.ConvTranspose3d(in_channels=self.num_features*2, out_channels=1, kernel_size=4, stride=2, padding=1)
)
def forward(self, x):
b = x.shape[0]
# Encode
enc_out1 = self.encoder1(x)
enc_out2 = self.encoder2(enc_out1)
enc_out3 = self.encoder3(enc_out2)
enc_out4 = self.encoder4(enc_out3)
x = enc_out4.view(b, -1)
x = self.bottleneck(x)
x = x.view(x.shape[0], x.shape[1], 1, 1, 1)
# Decode
#print(type(x), type(enc_out4), x.shape==enc_out4.shape)
x = torch.cat((x, enc_out4), dim=1),
#print(type(x), len(x), x[0].shape, type(x[0]))
x = x[0]
x = self.decoder1(x)
x = torch.cat((x, enc_out3), dim=1),
x = x[0]
x = self.decoder2(x)
x = torch.cat((x, enc_out2), dim=1),
x = x[0]
x = self.decoder3(x)
x = torch.cat((x, enc_out1), dim=1),
x = x[0]
x = self.decoder4(x)
x = torch.squeeze(x, dim=1)
x = torch.log(torch.add(x, 1.0))
return x