Splitting batch/tensor

I have an image set of 256 that gets augmented to a size of 512.

Then in a custom loss function.
Images, list = batch
Where images is of size (512,3,96,96) (batch size, layers, h, w)

When i perform convolution and normalization i get [512,128] for a size().

I would like to split the standard pictures from the augmented ones to perform a custom loss function and was wondering how i would go about doing this.

If it was better to split the tensor or try and split images in the list… i tried split() but when i did i split it into [512,64] and i don’t know if this is the right way of splitting them.

If I understand your use case correctly you are upsampling the original images of [512, 3, 96, 96] into [512, 3, 512, 512] and create an activation in the shape [512, ?, 512, 128] after the convolution and normalization layer?

This seems as if you are concatenating the augmented images into the larger spatial size instead of upsampling? Could you describe this augmentation in more detail, please?