Can I split my batch in the forward method?

Hey, I am creating a DeepFake in Pytorch. My model has following architecture:
x := input-image
x -> encoder -> if: label_of_x == 0: -> decoder1, elif label_of_x == 1: -> decoder2

Its a conditional flow in which all the images pass through the same encoder but go through different decoders dependent on their class. Until now I used batchsize = 1 because I dont know if/how I can split the batch into two smaller batches (each containing all the input-images of the same class) and then merge them together for the backpropagation. Is this possible? And if so: How?
I really appreciate any help!

IMO, yes you can do it by passing in another flag in your dataset (csv file) and selecting images based on that make, make sure that in a batch you have similar no of images of all classes; And then in the forward you can select what you want to-do accordingly easily;

So do you think this is possible (pseudo-code):

...
def _split_classes(batch, targets):
    class_a, class_b = [], []
    for i, sample in enumerate(batch):

        if targets[i] == 1:
            class_a.append(sample)
        else:
            class_b.append(sample)

    return torch.Tensor(class_a), torch.Tensor(class_b)


def forward(x, targets):
    x = encoder(x)
    x_a, x_b = self._split_classes(x, targets)

    x_a = decoderA(x_a)
    x_b = decoderB(x_b)

    x = concat((x_a, x_b))
    return x

Yep similar to the pseudo code pretty much would be my init attempt; Note that during test_time you won’t have the targets with you;

Okay thank you, ill try that!

Wait, but how is backpropagation possible then?

Hey, Apologize for a delayed response. Why do you think backprop won’t happen or have issues with it?

We can always break up a bigger batch into smallers and still consider all those smallers ones as a single batch altogether. Similarly, when we have split out batch, now different I/ps are passing to different model code blocks, post forward pass, the grads will be combined and all of that will be send back together before the splitting point.

Also you night get little poor results but that depends on your BS as well,

Also you might want to look into collate_fn!