Distributed Training with Complex Wrapper Model (Unet and Conditional First Stage)

Hi there! I am just wondering if someone more experienced could take a look at this setup and tell me if the gradients will flow properly across both the unet and the Cond_First_Stage essentially attached to the model.

I am using the Accelerate library for distributed training with 10 nodes and 4 GPUS per node. I want both the unet model weights to be trained and the conditional first stage weights to be trained (effectively aligning the image embeddings with an unconventional data input type to the model)

Happy to share more of the code (i.e. the training script) to provide a better understanding of the workflow.

Thanks in advance!!!

Below is code:

import torch
from torch import nn as nn
from transformers import CLIPVisionModelWithProjection
import torch.nn.functional as F

class WrapperModel(nn.Module):
    def __init__(self, unet, vae, noise_scheduler, encoder, config, frozen_image_embedder,accelerator):
        super(WrapperModel, self).__init__()
        self.unet = unet
        self.vae = vae
        self.noise_scheduler = noise_scheduler
        self.encoder = encoder
        self.config = config
        self.frozen_image_embedder = frozen_image_embedder
        self.accelerator = accelerator
        self.weight_dtype = torch.float32
        # Initialize the conditional first stage within the UNet
        self.unet.cond_first_stage = First_Stage(encoder=self.encoder)

        # Ensure the VAE and encoder are not updated during training

        # Move each model to the appropriate device
        self.vae.to(self.accelerator.device, dtype=self.weight_dtype)

    def forward(self, batch,accelerator):
        image_embedding = self.vae.encode(batch['pixel_values'].to(self.weight_dtype)).latent_dist.sample()
        image_embedding = image_embedding * self.vae.config.scaling_factor
        # Generate initial noise matching the dimensions of the image embeddings
        noise = torch.randn_like(image_embedding)
        # Add a noise offset to introduce variability
        noise += self.config.noise_offset * torch.randn((image_embedding.shape[0], image_embedding.shape[1], 1, 1), device=image_embedding.device)
        # Introduce additional perturbation to the noise for robustness
        new_noise = noise + self.config.input_perturbation * torch.randn_like(noise)
        # Determine batch size for processing
        bsz = image_embedding.shape[0]
        # Randomly sample timesteps for each image in the batch for the diffusion process
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=image_embedding.device).long()
        # Add noise to the image embeddings at the sampled timesteps, simulating the forward diffusion process
        noisy_latents = self.noise_scheduler.add_noise(image_embedding, new_noise, timesteps)
        # Obtain EEG embeddings by processing EEG data through the conditional first stage of the UNet
        encoder_hidden_states, ___ = self.unet.cond_first_stage(batch['eeg'].to(accelerator.device))
        # Define the target for the training as the original noise
        target = noise
        # Get the model's prediction for the noisy latents and EEG embeddings
        model_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
        # Compute CLIP loss (assumes a method get_clip_loss exists within First_Stage)
        img_embd = self.frozen_image_embedder(batch['pixel_values'].to(accelerator.device))
        clip_loss = self.unet.cond_first_stage.get_clip(batch['eeg'], img_embd)
        return model_pred, clip_loss, timesteps,target


class Dim_Mapper(nn.Module):
    def __init__(self):
        self.conv1 = nn.Conv1d(128, 1, 1, stride=1)  
        # Fully connected layer to transform the feature vector
        self.fc1 = nn.Linear(1024, 768)  # Note: Change the first dim to the encoder last dim size.

    def forward(self, x):
        # Apply a convolution operation to the input

        x = self.conv1(x)

        # Remove unnecessary dimension after convolution
        x = x.squeeze(1)

        # Apply a linear transformation
        x = self.fc1(x)
        return x
# The point of having a seperate Dim_Mapper class is so that we can swap out using and not using clip alignmnet 
class First_Stage(nn.Module):
    def __init__(self, encoder):
        self.encoder = encoder
        self.encoder.requires_grad_(False)  # Freeze the encoder weights to prevent updates during training
        self.seq_len = encoder.num_patches  # Get the sequence length from the encoder model
        self.input_dim = 1024
        self.output_dim = 768
        # Dimensionality mapper for adjusting feature vector sizes
        self.mapper = Dim_Mapper()

        # Unet expects (batch, sequence_length, feature_dim)
        # Encoder outputs (batch, sequence_length, feature_dim) -> (batch, 128, 1024)

        self.conv_block = nn.Sequential(
            nn.Conv1d(in_channels=self.input_dim, out_channels=128, kernel_size=3, stride=2, padding=1),
            nn.Conv1d(in_channels=128, out_channels=256, kernel_size=3, stride=2, padding=1),
            nn.Conv1d(in_channels=256, out_channels=self.output_dim, kernel_size=3, stride=2, padding=1)

        conv_seq_len = self.seq_len // (2**3)

        self.fc = nn.Linear(in_features=conv_seq_len*self.output_dim , out_features=self.output_dim)
    def forward(self, x):
        # Encode the input using the encoder model
        x = self.encoder.forward(x)
        latent = x  # Store the encoder output for potential use

        # Rearrange input to (batch, feature_dim, sequence_length) for Conv1d
        x = x.transpose(1, 2)

        x = F.relu(self.conv_block(x))
        x = torch.flatten(x, start_dim=1)
        x = self.fc(x).unsqueeze(1)

        return x, latent
    def get_clip(self, x, image_embeddings):
        # Map the input dimensions to align with the image embeddings
        x = self.encoder.forward(x)
        x = self.mapper(x)

        # Calculate the CLIP loss by comparing the cosine similarity between the mapped input and image embeddings
        loss = 1 - torch.nn.functional.cosine_similarity(x, image_embeddings, dim=-1).mean()
        return loss