Hello, I have a diffusion model as defined by
class Diffusion_AE(torch.nn.Module):
def __init__(self, embedding_dimension = 64, num_channels=1):
super().__init__()
self.unet = DiffusionModelUNet(
spatial_dims=2,
in_channels=num_channels,
out_channels=num_channels,
num_channels=(128, 256, 256),
attention_levels=(False, True, True),
num_res_blocks=1,
num_head_channels=64,
with_conditioning=True,
cross_attention_dim=1,
)
self.semantic_encoder = torchvision.models.resnet18()
# self.semantic_encoder = torchvision.models.efficientnet_v2_s()
self.semantic_encoder.conv1 = torch.nn.Conv2d(num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.semantic_encoder.fc = torch.nn.Linear(512, embedding_dimension)
print(f'num_channels is {num_channels}')
def forward(self, xt, x_cond, t):
latent = self.semantic_encoder(x_cond)
noise_pred = self.unet(x=xt, timesteps=t, context=latent.unsqueeze(2))
return noise_pred, latent
I then train it via this:
while iteration < max_epochs:
for batch in train_loader:
iteration += 1
model.train()
optimizer.zero_grad(set_to_none=True)
images = batch["image"].to(device)
images = images[:,0:num_channels,:,:]
noise = torch.randn_like(images).to(device)
# Create timesteps
timesteps = torch.randint(0, inferer.scheduler.num_train_timesteps, (batch_size,)).to(device).long()
# Get model prediction
# cross attention expects shape [batch size, sequence length, channels], we are use channels = latent dimension and sequence length = 1
latent = model.semantic_encoder(images)
noise_pred = inferer(
inputs=images, diffusion_model=model.unet, noise=noise, timesteps=timesteps, condition=latent.unsqueeze(2)
)
loss = F.mse_loss(noise_pred.float(), noise.float())
loss.backward()
optimizer.step()
iter_loss += loss.item()
print(f"Iteration {iteration}/{max_epochs} - train Loss {loss.item():.4f}" + "\r")
My question is how can I train this model with multiple GPUs? Specifically, if I try the standard way of just using model = nn.DataParallel(model), I will get an error that the model does not contain attribute semantic_encoder (in the line latent = model.semantic_encoder(images)). The model returned by nn.DataParallel does not seem to inherit all the methods input to model, probably it’s for the more basic case. Does anyone know how to use multi GPU for my case?