Hi, I’m trying to train a custom projection module to align the image and text embeddings onto a common embedding space. There is a MLP module for learning positional embeddings of the coordinates of the image patches. The dataset consists of large dimension images, so are split into patches, processed along with the corresponding coordinates to learn image-level embeddings.
At the end of the first iteration, immediately after optimizer.step(), the model parameters become Nan.
I have attached the code snippet below. Your support is highly appreciated, Thanks!
class CustomVLM(nn.Module):
def __init__(self, coord_mlp, projection_module, device="cuda"):
super(CustomVLM, self).__init__()
self.device = device
self.coord_mlp = coord_mlp
self.projection_module = projection_module
checkpoint = torch.load("mm_projector.bin")
checkpoint = {k.replace("mm_projector.", ""): v for k, v in checkpoint.items()}
self.projection_module.load_state_dict(checkpoint)
self.projection_module = self.projection_module.to(self.device).half()
def forward(self, last_minus1_k_features, last_k_features, patch_coords_pt, text_embeddings):
if torch.isnan(last_minus1_k_features).any() or torch.isnan(last_k_features).any() or torch.isnan(patch_coords_pt).any() or torch.isnan(text_embeddings).any():
raise ValueError("Input tensors contain NaN")
patch_coords_pt = patch_coords_pt.half().to(self.device)
if torch.isnan(patch_coords_pt).any():
raise ValueError("patch_coords_pt contains NaN")
text_embeddings = text_embeddings.squeeze(0)
# Generate coords embeddings
coord_embed = self.coord_mlp(patch_coords_pt).unsqueeze(1)
if torch.isnan(coord_embed).any():
raise ValueError("coord_embed contains NaN")
# Add patch and coords embeddings --> stack to get the image-level embeddings
image_level_embeddings_kMinus1 = last_minus1_k_features + coord_embed[:-1, :, :]
image_level_embeddings_kth = (last_k_features + coord_embed[-1, :, :])
image_level_embeddings_kMinus1_concat = torch.cat([t for t in image_level_embeddings_kMinus1], dim=0)
combined_image_embeddings = torch.cat([image_level_embeddings_kMinus1_concat, image_level_embeddings_kth], dim=0).unsqueeze(0) #torch.Size([1, 729, 1152])
if torch.isnan(combined_image_embeddings).any():
raise ValueError("combined_image_embeddings contains NaN")
# Pass the resulting embeddings through the projection module
# projected_embeddings = self.projection_module(combined_image_embeddings.half()).squeeze(0)
projected_embeddings = self.projection_module(combined_image_embeddings).squeeze(0)
# Ensure no NaNs in projected embeddings
if torch.isnan(projected_embeddings).any():
raise ValueError("projected_embeddings contains NaN")
if torch.isinf(projected_embeddings).any():
raise ValueError("projected_embeddings contains Inf")
loss = contrastive_loss(projected_embeddings, text_embeddings)
return loss
class MLPBlock(nn.Module):
def __init__(self, input_dim: int, embedding_dim: int) -> None:
super().__init__()
self.lin1 = nn.Linear(input_dim, embedding_dim * 2) # embedding_dim = 2 --> (row, col), mlp_dim = 1152*2 = 2304
self.lin2 = nn.Linear(embedding_dim * 2, embedding_dim) # embedding_dim = 1152 --> (output embedding vector)
self.act = nn.GELU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.lin2(self.act(self.lin1(x)))
class ProjectionModule(nn.Module):
def __init__(self, mm_hidden_size, hidden_size):
super(ProjectionModule, self).__init__()
# Directly set up the sequential model
self.model = nn.Sequential(
nn.Linear(mm_hidden_size, hidden_size),
nn.GELU(),
nn.Linear(hidden_size, hidden_size),
)
def forward(self, x):
return self.model(x)
def cross_entropy(preds, targets, reduction='none'):
log_softmax = nn.LogSoftmax(dim=-1)
loss = (-targets * log_softmax(preds)).sum(1)
if reduction == "none":
return loss
elif reduction == "mean":
return loss.mean()
def contrastive_loss(image_embeddings, text_embeddings, temperature=1.0):
image_embeddings = nn.functional.normalize(image_embeddings, dim=-1)
text_embeddings = nn.functional.normalize(text_embeddings, dim=-1)
print("image_embeddings:", image_embeddings)
print("text_embeddings:", text_embeddings)
logits = (text_embeddings @ image_embeddings.T) / temperature
images_similarity = image_embeddings @ image_embeddings.T
texts_similarity = text_embeddings @ text_embeddings.T
targets = F.softmax((images_similarity + texts_similarity) / 2 * temperature, dim=-1)
texts_loss = cross_entropy(logits, targets, reduction='none')
images_loss = cross_entropy(logits.T, targets.T, reduction='none')
loss = (images_loss + texts_loss) / 2.0 # shape: (batch_size)
return loss.mean()
coord_mlp = MLPBlock(2, 1152)
coord_mlp = coord_mlp.to('cuda').half()
projection_module = ProjectionModule(1152, 4096)
vlm = CustomVLM(coord_mlp, projection_module)
for param in vlm.parameters():
param.requires_grad = True
vlm.train()
optimizer = optim.AdamW(list(vlm.parameters()), lr=1e-5)
for epoch in range(num_epochs):
epoch_loss = 0.0 # Initialize epoch_loss to accumulate loss for the entire epoch
# Open text file for current epoch
with open(f"training_logs/epoch_{epoch+1}_loss.txt", "a") as loss_file:
for batchID, batch in tqdm(enumerate(data_loader)):
last_minus1_k_features, last_k_features, patch_coords_pt, text_embeddings = batch
last_minus1_k_features = last_minus1_k_features.squeeze(0)
last_k_features = last_k_features.squeeze(0)
patch_coords_pt = patch_coords_pt.squeeze(0)
loss = vlm(last_minus1_k_features, last_k_features, patch_coords_pt, text_embeddings)
# Backward pass and update weights
optimizer.zero_grad(set_to_none=True)
loss.backward()
# Clip gradients using L2 norm (default)
max_norm = 1.0 # Adjust this value as needed
nn_utils.clip_grad_norm_(vlm.parameters(), max_norm, norm_type=2.0)
optimizer.step()
# scheduler.step()
epoch_loss += loss.item() # Accumulate loss for the current batch
# Write loss for current batch to file
loss_file.write(f"Epoch {epoch+1}, Batch {batchID+1}, Loss: {loss.item():.4f}\n")
epoch_loss /= len(data_loader) # Calculate average loss for the epoch
# Print the epoch and average loss
print(f"\nEpoch: {epoch + 1}/{num_epochs}, Average Loss: {epoch_loss:.4f}")