Mini batches with DataLoader and a 3D input

I have been struggling to manage and create batches for a 3D tensor. I have used it before as a way to create batches for 1D tensor. However, in my current research, I need to create batches out of a tensor with shape (1024,1024,2).

I created custom data to use as my input for the DataLoader method in pytorch. I created the following for the 1D array:

class CustomDataset(Dataset):
def __init__(self, x_tensor, y_tensor):
    self.xdomain = x_tensor
    self.ydomain = y_tensor
    
def __getitem__(self, index):
    return (self.xdomain[index], self.ydomain[index])

def __len__(self):
    return len(self.xdomain)

It works pretty well, however, I realized that this doesn’t work for tensors x_tensor and y_tensor of shape (1024,1024,2) and (1024,1024,1) respectively. I understand that I have to change the __ getitem __ and __ len __ function in a way so it can divide the tensors into batches.

I tried many things, but one I know it could work is that I could flatten these tensors into shapes (1024 x1024,2) and (1024x1024,1). However, I would have to not only change my NN definition but must of my code.

So I want to keep it as is and try to understand how to create these functions if possible. What I understand of these functions are:
__ len __ so that len(dataset) returns the size of the dataset.
__ getitem __ to support the indexing such that dataset[i] can be used to get ith sample.

With this knowledge, I created this class, that finds the indexes of the first 2 dimensions(to find the ith sample). However, this created the input of the NN to be (1024x1024,2) and output (1024x1024,1). And I want it to be (1024,1024,2) and (1024,1024,1).

If someone with a better understanding of Data Loader and mini-batches could explain what am I missing, that could be amazing. An first of all is this possible?

Thanks for reading this, sorry if this question is too basic. I hope is clear.

If you don’t understand my question please ask. It will help me to get better.

I don’t completely understand the issue and the posted shapes and I get the expected output as [batch_size, 1024, 2] and [batch_size, 1024, 1] as seen here:

class CustomDataset(Dataset):
    def __init__(self, x_tensor, y_tensor):
        self.xdomain = x_tensor
        self.ydomain = y_tensor
        
    def __getitem__(self, index):
        return (self.xdomain[index], self.ydomain[index])
    
    def __len__(self):
        return len(self.xdomain)
    
    
x_tensor = torch.randn(1024, 1024, 2)
y_tensor = torch.randn(1024, 1024, 1)
dataset = CustomDataset(x_tensor, y_tensor)
loader = DataLoader(dataset, batch_size=2)

for x, y in loader:
    print(x.shape, y.shape)
> torch.Size([2, 1024, 2]) torch.Size([2, 1024, 1])
  torch.Size([2, 1024, 2]) torch.Size([2, 1024, 1])
  torch.Size([2, 1024, 2]) torch.Size([2, 1024, 1])
  [...]