Training diffusion autoencoder with multiple GPUs

Hello, I have a diffusion model as defined by


class Diffusion_AE(torch.nn.Module):
    def __init__(self, embedding_dimension = 64, num_channels=1):
        super().__init__()
        self.unet = DiffusionModelUNet(
                    spatial_dims=2,
                    in_channels=num_channels,
                    out_channels=num_channels,
                    num_channels=(128, 256, 256),
                    attention_levels=(False, True, True),
                    num_res_blocks=1,
                    num_head_channels=64,
                    with_conditioning=True,
                    cross_attention_dim=1,
                )
        
        self.semantic_encoder = torchvision.models.resnet18()
        # self.semantic_encoder = torchvision.models.efficientnet_v2_s()
        
        self.semantic_encoder.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.semantic_encoder.fc = torch.nn.Linear(512, embedding_dimension)
        
        print(f'num_channels is {num_channels}')
        
        
    def forward(self, xt, x_cond, t):
        latent = self.semantic_encoder(x_cond)
        noise_pred = self.unet(x=xt, timesteps=t, context=latent.unsqueeze(2))
        return noise_pred, latent

I then train it via this:

while iteration < max_epochs:
        for batch in train_loader:
            iteration += 1
            model.train()
            optimizer.zero_grad(set_to_none=True)
            images = batch["image"].to(device)
            images = images[:,0:num_channels,:,:]
            noise = torch.randn_like(images).to(device)
            # Create timesteps
            timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()
            # Get model prediction
            # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1
            latent = model.semantic_encoder(images)
            noise_pred = inferer(
                inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)
            )
            loss = F.mse_loss(noise_pred.float(), noise.float())
    
            loss.backward()
            optimizer.step()
    
            iter_loss += loss.item()
            print(f"Iteration {iteration}/{max_epochs} - train Loss {loss.item():.4f}" + "\r")

My question is how can I train this model with multiple GPUs? Specifically, if I try the standard way of just using model = nn.DataParallel(model), I will get an error that the model does not contain attribute semantic_encoder (in the line latent = model.semantic_encoder(images)). The model returned by nn.DataParallel does not seem to inherit all the methods input to model, probably it’s for the more basic case. Does anyone know how to use multi GPU for my case?

All internal submodules will be wrapped into a .modules attribute and can be accessed by it. However, DDP expects to work on the forward method of your parent module and you should not call internal modules in your training loop as you would be responsible for the distributed training logic.
A proper approach is to move the submodule execution into the forward of the parent module and let DDP then parallelize the training.

Hello, thank you for your kind and timely response. I am quite new to pytorch, so I believe I understand what you are saying. In other words, I need to re-write my class Diffusion_AE() class so that the return from the forward pass is simply the typical output to be fed into the loss, in this case noise_pred. Specifically, it seems like these lines need to be in the forward pass:

timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()
            # Get model prediction
            # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1
            latent = model.semantic_encoder(images)
            noise_pred = inferer(
                inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)
            )

I am not totally sure how to do this, as it seems I need to pass multiple arguments to the model beyond the standard input (images). Specifically, I need to pass noise and timesteps. It seems typically forward does not take in additional arguments? So how would I get this info into forward, then? Thanks very much in advance

The forward method accepts any defined arguments so your approach should work.

Hello, you are correct I believe I misunderstood another post. I have modified my code as follows, and can now train on 1 GPU. The class is now


class Diffusion_AE(torch.nn.Module):
    def __init__(self, embedding_dimension = 64, num_channels=1):
        super().__init__()
        self.unet = DiffusionModelUNet(
                    spatial_dims=2,
                    in_channels=num_channels,
                    out_channels=num_channels,
                    num_channels=(128, 256, 256),
                    attention_levels=(False, True, True),
                    num_res_blocks=1,
                    num_head_channels=64,
                    with_conditioning=True,
                    cross_attention_dim=1,
                )
        
        self.semantic_encoder = torchvision.models.resnet18()
        # self.semantic_encoder = torchvision.models.efficientnet_v2_s()
        
        self.semantic_encoder.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.semantic_encoder.fc = torch.nn.Linear(512, embedding_dimension)
        
        print(f'num_channels is {num_channels}')
        
        
    def forward(self, x, noise, timesteps, inferer):
        
        latent = self.semantic_encoder(x)
        noise_pred = inferer(
            inputs=x, diffusion_model=self.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)
        )
        
        # latent = self.semantic_encoder(x_cond)
        # noise_pred = self.unet(x=xt, timesteps=t, context=latent.unsqueeze(2))
        return noise_pred

The training code is now


while iteration < max_epochs:
        for batch in train_loader:
            iteration += 1
            model.train()
            optimizer.zero_grad(set_to_none=True)
            images = batch["image"].to(device)
            images = images[:,0:num_channels,:,:]
            noise = torch.randn_like(images).to(device)
            # Create timesteps
            timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()
            # Get model prediction
            # cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1
            noise_pred = model(images, noise, timesteps, inferer)
            
            loss = F.mse_loss(noise_pred.float(), noise.float())
    
            loss.backward()
            optimizer.step()
    
            iter_loss += loss.item()
            print(f"Iteration {iteration}/{max_epochs} - train Loss {loss.item():.4f}" + "\r")

            if (iteration) % val_interval == 0:
                model.eval()
                val_iter_loss = 0
                for val_step, val_batch in enumerate(val_loader):
                    with torch.no_grad():
                        images = val_batch["image"].to(device)
                        images = images[:,0:num_channels,:,:]
                        timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()
                        noise = torch.randn_like(images).to(device)
                        noise_pred = model(images, noise, timesteps, inferer)
                        val_loss = F.mse_loss(noise_pred.float(), noise.float())
    
                    val_iter_loss += val_loss.item()
                iter_loss_list.append(iter_loss / val_interval)
                val_iter_loss_list.append(val_iter_loss / (val_step + 1))
                iterations.append(iteration)
                iter_loss = 0
                print(
                    f"Iteration {iteration} - Interval Loss {iter_loss_list[-1]:.4f}, Interval Loss Val {val_iter_loss_list[-1]:.4f}"
                )

I have now added this line after the model instantiaton:

 model = Diffusion_AE(embedding_dimension = embedding_dimension, num_channels=num_channels)
    
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        model = nn.DataParallel(model)
        
    model.to(device)

It works with a single GPU, however I now get an error about the scheduler not being on the correct device. Am I doing something incorrectly? Same error if I replace “cuda:0” with “cuda” in the above code snippet:


RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:2)

Any ideas with this? Thanks again for your help

I put a printout in the forward pass listing the device for x, noise and timesteps. It seems the first iteration works fine, however they are later scrambled up and not on the same device anymore:


noise device is cuda:0

timesteps device is cuda:0

x device is cuda:0

noise device is cuda:1

timesteps device is cuda:1

x device is cuda:1

noise device is cuda:2

timesteps device is cuda:2

x device is cuda:2

noise device is cuda:3

timesteps device is cuda:3

x device is cuda:3

Iteration 1/10000 - train Loss 0.9895

noise device is cuda:0

timesteps device is cuda:0

x device is cuda:0

noise device is cuda:1

timesteps device is cuda:1

x device is cuda:1

noise device is cuda:2

timesteps device is cuda:2

x device is cuda:2

noise device is cuda:3

timesteps device is cuda:3

x device is cuda:3

Iteration 2/10000 - train Loss 0.9922

noise device is cuda:1

timesteps device is cuda:1

x device is cuda:1noise device is cuda:3

timesteps device is cuda:3noise device is cuda:0

x device is cuda:3timesteps device is cuda:0noise device is cuda:2

x device is cuda:0timesteps device is cuda:2

x device is cuda:2

Traceback (most recent call last):