How does one do the data transforms or collate_fn in GPU?

I have to do some computations in GPU before passing my data to my models. Lets say for simplicity I compute the mean (1st moment) and std (2nd moment). How do I make sure this is done in GPU in during the dataset/dataloader processing?

(I am assuming that doing such computations on say, numpy, is worse than just doing it directly in Pytorch’s GPU code)

e.g. code I just made up:

def get_data():
    D_emb1 = 6
    D_emb2 = 7
    D_emb3 = 8
    a = np.random.rand(3, D_emb1, D_emb2, D_emb3)
    b = np.random.rand(5, D_emb1, D_emb2, D_emb3)
    c = np.random.rand(2, D_emb1, D_emb2, D_emb3)
    d = np.random.rand(1, D_emb1, D_emb2, D_emb3)
    return a, b, c, d

def collate_fn_gpu(batch):
    Collate function that maps data samples to Tensor then calls the ops on GPU.
    batch_size = len(batch)
    trailing_dims = batch[0].shape[1:]
    #new_batch = torch.zeros(batch_size,*trailing_dims)
    new_batch = torch.zeros(batch_size)
    for i in range(batch_size):
        new_batch[i] = torch.Tensor(batch[i]).to(device).std()
    return new_batch

def main():
    transform = None
    #transform, trainset, trainloader, testset, testloader = get_cifar10(collate_fn)
    ## generate fake data
    a,b,c,d = get_data(name)
    data = [a, b, c, d]
    ## create data set
    dataset = DataSet(data)
    ## create data laoder
    trainloader =, batch_size=3, shuffle=True, num_workers=0, collate_fn=collate_fn_gpu)
    for data in trainloader:

if __name__ == '__main__':
1 Like

btw, how do I make sure these ops are not trainable?