nn.Sequential(
nn.Linear(in_features, out_features)
UnFlatten(), # reshape the output of previous layers so that its input shape to the next layers is proper
nn.ConvTranspose2d()
nn.ReLU(),
nn.ConvTranspose2d()
nn.ReLU(),
nn.ConvTranspose2d()
nn.ReLU(),
nn.ConvTranspose2d()
nn.Sigmoid(),
)
in the regular cases; it takes data X.shape = (batch_size, in_features); and output a batch of images,
but right now; I have an input looks like this X_shape = ( batch_size, another_features, in_features) and I want to output a batch of batch of images
In the for loop way, the implementation goes like this
result = []
for i in range(another_features):
temp.append(NN(X[:,i,:])) # feed each slice of another_features
result = torch.stack(result) # leads to 5D tensor has shape (another_features, batch_sizes, out_channels, out_features, out_features)
# then reshape into (batch_sizes , another_features, out_channels, out_features, out_features)
I’m wondering if there is an efficient way because technically, the process isn’t sequential and can be parallelized.
I don’t undestand at all.
Which architecture are you using?
You probably can reshape everything in the batch dimension and then undo for the 3D processing.
ie
input = torch.rand(1,5,3,224,224) which is a 3D tensor of shape (B,T,C,H,W)
so the your reshape
input=input.view(-1,3,224,224)
#you apply 2d processing
output = conv(input)
# undo the reshaping
input3d = output.view(1,5,C',H',W')
output3d = conv3d(input3d)
class Conv2DExt(nn.Module):
def __init__(self,*args,**kwargs):
super().__init__(self)
self.conv2d = nn.ConvTranspose2d(*args,**kwargs)
def forward(self, input, output_size=None): # type: (Tensor, Optional[List[int]]) -> Tensor
if input.ndimensions() == 5:
B, F, C, H, W = input.shape
input = input.view(-1, C, H, W)
return self.conv2d(input, output_size=output_size)
nn.Sequential(
nn.Linear(in_features, out_features)
UnFlatten(), # reshape the output of previous layers so that its input shape to the next layers is proper
nn.Conv2DExt()
nn.ReLU(),
nn.ConvTranspose2d()
nn.ReLU(),
nn.ConvTranspose2d()
nn.ReLU(),
nn.ConvTranspose2d()
nn.Sigmoid(),
)
# You have to reshape in the forward after this module.