Code works with one GPU but raises "gradient computation" error on DDP

I am running a mix of codes (mine, code from a medium article, chatGPT, etc) to train a SimCLR model. The code manages to run without any errors on a single GPU, but raises the following error when running on multiple GPUs.

“RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [2048]] is at version 3; expected version 2 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!”

Any help with the source of the error in my code? I couldn’t understand where the inplace operation error is coming from, especially since it works on a single GPU. Code in the file for running DDP is below. I ran the command torchrun /path_to_file.

class PairDataset(Dataset):
def init(self, image_folder_path, dataset,transform):
self.image_folder_path=f"{image_folder_path}"#/{dataset}
self.transform=transform
self.files_list = [x for x in os.listdir(self.image_folder_path) if ‘._’ not in x and ‘DS’ not in x
and ‘.jpg’ in x]

def __len__(self):
    return len(self.files_list)

def __getitem__(self, idx):
    file_name = self.files_list[idx]
    img=Image.open(f"{self.image_folder_path}/{file_name}")
    img1 = self.transform(img)
    img2 = self.transform(img)
    return img1, img2

def get_dataloaders(rank, world_size, batch_size, resize_crop, means, stds,image_folder_path, dataset):
simclr_transform = transforms.Compose([
transforms.RandomResizedCrop(resize_crop),
transforms.RandomHorizontalFlip(),
transforms.RandomApply([
transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)
], p=0.8),
transforms.RandomGrayscale(p=0.2),
transforms.GaussianBlur(kernel_size=9),
transforms.ToTensor(),
transforms.Normalize(means, stds)])
image_dataset=PairDataset(image_folder_path=image_folder_path, dataset=dataset,
transform=simclr_transform)
sampler = DistributedSampler(image_dataset, num_replicas=world_size,rank=rank)
dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size,
sampler = sampler)
return image_dataset, dataloader

class SimCLRResNet(nn.Module):
def init(self, base_model=models.resnet50, out_dim=128):
super(SimCLRResNet, self).init()
self.encoder = base_model(weights=None)
self.encoder.fc = nn.Identity() # Remove classification head

    self.projection_head = nn.Sequential(
        nn.Linear(2048, 512),
        nn.ReLU(),
        nn.Linear(512, out_dim)
    )

def forward(self, x):
    features = self.encoder(x)
    projections = self.projection_head(features)
    return projections

class SimCLR_Loss(nn.Module):
def init(self, batch_size, temperature):
super().init()
self.batch_size = batch_size
self.temperature = temperature

    self.mask = self.mask_correlated_samples(batch_size)
    self.criterion = nn.CrossEntropyLoss(reduction="sum")
    self.similarity_f = nn.CosineSimilarity(dim=2)

def mask_correlated_samples(self, batch_size):
    N = 2 * batch_size
    mask = torch.ones((N, N), dtype=bool)
    mask = mask.fill_diagonal_(0)
    
    for i in range(batch_size):
        mask[i, batch_size + i] = 0
        mask[batch_size + i, i] = 0
    return mask

def forward(self, z_i, z_j):

    N = 2 * self.batch_size

    z = torch.cat((z_i, z_j), dim=0)

    sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

    sim_i_j = torch.diag(sim, self.batch_size)
    sim_j_i = torch.diag(sim, -self.batch_size)
    
    # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
    positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
    negative_samples = sim[self.mask].reshape(N, -1)
    
    #SIMCLR
    labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long() #.float()
    
    logits = torch.cat((positive_samples, negative_samples), dim=1)
    loss = self.criterion(logits, labels)
    loss_avg = loss/N
    
    return loss_avg

def setup(rank, world_size):
# Initialize the process group
dist.init_process_group(“gloo”, rank=rank, world_size=world_size)

def main(rank, world_size, image_folder_path = IMAGE_FOLDER_PATH, temperature=0.5,
batch_size=64, resize_crop=224, means =[0.485, 0.456, 0.406], stds=[0.229, 0.224, 0.225], learning_rate = 0.0003, num_epochs=100):
setup(rank, world_size)
torch.cuda.empty_cache()
criterion = SimCLR_Loss(batch_size = batch_size, temperature = temperature)
criterion.to(rank)
image_dataest, dataloader= get_dataloaders(rank=rank, world_size=world_size, batch_size=batch_size, resize_crop=resize_crop, means=means, stds=stds, image_folder_path=image_folder_path, dataset=‘train’)
model = SimCLRResNet()
model = model.to(rank)
model=DDP(model, device_ids = [rank])
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
model.train()
total_loss = 0
for i, (img1,img2) in enumerate(dataloader):
if i%20==0:
print(‘.’,end=‘’)
img1=img1.to(rank)
img2=img2.to(rank)

        optimizer.zero_grad()
        z1=model(img1)
        z2=model(img2)

        loss = criterion(z1, z2)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader):.4f}")
torch.save(model.encoder.state_dict(), "simclr_resnet50.pth")
dist.destroy_process_group()

if name == “main”:
world_size = torch.cuda.device_count()
print(f"Using {world_size} GPUs for training.")
mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)

Sorry, code is hopefully more readable here.

class PairDataset(Dataset):
    def __init__(self, image_folder_path, dataset,transform):
        self.image_folder_path=f"{image_folder_path}"#/{dataset}
        self.transform=transform
        self.files_list = [x for x in os.listdir(self.image_folder_path) if '._' not in x and 'DS' not in x
                          and '.jpg' in x]

    def __len__(self):
        return len(self.files_list)

    def __getitem__(self, idx):
        file_name = self.files_list[idx]
        img=Image.open(f"{self.image_folder_path}/{file_name}")
        img1 = self.transform(img)
        img2 = self.transform(img)
        return img1, img2

def get_dataloaders(rank, world_size, batch_size, resize_crop, means, stds,image_folder_path, dataset):
    simclr_transform = transforms.Compose([
        transforms.RandomResizedCrop(resize_crop),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
            transforms.ColorJitter(brightness=0.8, contrast=0.8, saturation=0.8, hue=0.2)
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.GaussianBlur(kernel_size=9),
        transforms.ToTensor(),
        transforms.Normalize(means, stds)])
    image_dataset=PairDataset(image_folder_path=image_folder_path, dataset=dataset,
                              transform=simclr_transform)
    sampler = DistributedSampler(image_dataset, num_replicas=world_size,rank=rank)
    dataloader = torch.utils.data.DataLoader(image_dataset, batch_size=batch_size,
                                              sampler = sampler)
    return image_dataset, dataloader

class SimCLRResNet(nn.Module):
    def __init__(self, base_model=models.resnet50, out_dim=128):
        super(SimCLRResNet, self).__init__()
        self.encoder = base_model(weights=None)
        self.encoder.fc = nn.Identity()  # Remove classification head

        self.projection_head = nn.Sequential(
            nn.Linear(2048, 512),
            nn.ReLU(),
            nn.Linear(512, out_dim)
        )

    def forward(self, x):
        features = self.encoder(x)
        projections = self.projection_head(features)
        return projections

class SimCLR_Loss(nn.Module):
    def __init__(self, batch_size, temperature):
        super().__init__()
        self.batch_size = batch_size
        self.temperature = temperature

        self.mask = self.mask_correlated_samples(batch_size)
        self.criterion = nn.CrossEntropyLoss(reduction="sum")
        self.similarity_f = nn.CosineSimilarity(dim=2)

    def mask_correlated_samples(self, batch_size):
        N = 2 * batch_size
        mask = torch.ones((N, N), dtype=bool)
        mask = mask.fill_diagonal_(0)
        
        for i in range(batch_size):
            mask[i, batch_size + i] = 0
            mask[batch_size + i, i] = 0
        return mask

    def forward(self, z_i, z_j):

        N = 2 * self.batch_size

        z = torch.cat((z_i, z_j), dim=0)

        sim = self.similarity_f(z.unsqueeze(1), z.unsqueeze(0)) / self.temperature

        sim_i_j = torch.diag(sim, self.batch_size)
        sim_j_i = torch.diag(sim, -self.batch_size)
        
        # We have 2N samples, but with Distributed training every GPU gets N examples too, resulting in: 2xNxN
        positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
        negative_samples = sim[self.mask].reshape(N, -1)
        
        #SIMCLR
        labels = torch.from_numpy(np.array([0]*N)).reshape(-1).to(positive_samples.device).long() #.float()
        
        logits = torch.cat((positive_samples, negative_samples), dim=1)
        loss = self.criterion(logits, labels)
        loss_avg = loss/N
        
        return loss_avg

def setup(rank, world_size):
    # Initialize the process group
    dist.init_process_group("gloo", rank=rank, world_size=world_size)

def main(rank, world_size, image_folder_path = '/home/jupyter/AP_Data/VT_R4_pieces', temperature=0.5,
        batch_size=64,resize_crop=224, means =[0.485, 0.456, 0.406], stds=[0.229, 0.224, 0.225],
        learning_rate = 0.0003, num_epochs=100):
    setup(rank, world_size)
    torch.cuda.empty_cache()
    criterion = SimCLR_Loss(batch_size = batch_size, temperature = temperature)
    criterion.to(rank)
    image_dataest, dataloader= get_dataloaders(rank=rank, world_size=world_size, batch_size=batch_size, resize_crop=resize_crop, means=means, stds=stds, image_folder_path=image_folder_path, dataset='train')
    model = SimCLRResNet()
    model = model.to(rank)
    model=DDP(model, device_ids = [rank])
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)
    for epoch in range(num_epochs):
        model.train()
        total_loss = 0
        for i, (img1,img2) in enumerate(dataloader):
            if i%20==0:
                print('.',end='')
            img1=img1.to(rank)
            img2=img2.to(rank)
            
            optimizer.zero_grad()
            z1=model(img1)
            z2=model(img2)

            loss = criterion(z1, z2)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss / len(dataloader):.4f}")
    torch.save(model.encoder.state_dict(), "simclr_resnet50.pth")
    dist.destroy_process_group()

if __name__ == "__main__":
    world_size = torch.cuda.device_count()
    print(f"Using {world_size} GPUs for training.")
    mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)