TRAINING LOOP
i am fairly new to this field and was trying to build a gan to generate ecg signals any help would be much appreciated
global_step = 0
outputs_list = [] # List to store generator outputs
d_losses = [] # List to store discriminator losses
g_losses = [] # List to store generator losses
dtw_distances = [] # List to store DTW distances
for epoch in tqdm(range(num_epochs)):
# Iterate over the dataset
total_d_loss = 0.0
total_g_loss = 0.0
num_batches = 0
for i in range(0, len(normalized_record_tensors), batch_size):
# Get the current batch of real upper segments
real_signals = normalized_record_tensors[i:i+batch_size].to(device)
# Generate fake upper segments
noise = torch.randn(real_signals.shape[0], seq_length, 1).to(device)
generated_signals = generator(noise)
outputs_list.append(generated_signals.detach().cpu())
optimizer_D.zero_grad() # Reset gradients before computing the new loss
for _ in range(n_critic):
d_loss = train_discriminator(discriminator, generator, optimizer_D, criterion_d, real_signals, generated_signals)
gradient_penalty = compute_gradient_penalty(discriminator, real_signals, generated_signals)
d_loss += lambda_gp * gradient_penalty
total_d_loss += d_loss
g_loss = train_generator(discriminator, generator, optimizer_G, criterion_g, generated_signals)
total_g_loss += g_loss
num_batches += 1
# Calculate DTW distance
avg_dtw_distance = calculate_dtw(real_signals, generated_signals)
dtw_distances.append(avg_dtw_distance)
# Print discriminator outputs for real and fake samples
real_outputs = discriminator(real_signals)
fake_outputs = discriminator(generated_signals)
print(f"Epoch [{epoch+1}/{num_epochs}], Real Outputs: {real_outputs.mean().item()}, Fake Outputs: {fake_outputs.mean().item()}")
# Plot the generated segments of length 10 from index 0 to 10
generated_ecgs = outputs_list[epoch]
plt.figure(figsize=(10, 6))
for i in range(len(generated_ecgs)):
segment = generated_ecgs[i, :350].squeeze() # Extract the segment of length 10d
plt.plot(segment, label=f"Segment {i+1}")
plt.title(f"Generator Outputs - Epoch {epoch+1}")
plt.xlabel("Time")
plt.ylabel("Amplitude")
plt.legend()
plt.show()
# Calculate average losses and accuracies per epoch
avg_d_loss = total_d_loss / num_batches
avg_g_loss = total_g_loss / num_batches
# Store the loss values and accuracies
d_losses.append(avg_d_loss)
g_losses.append(avg_g_loss)
print(f"Epoch [{epoch+1}/{num_epochs}], D_loss: {avg_d_loss:.4f}, G_loss: {avg_g_loss:.4f}, DTW distance: {avg_dtw_distance:.4f}")
def train_generator(generator, discriminator, optimizer_G, criterion, real_signals):
# print(real_upper_segments.shape)
batch_size, seq_length, channels = real_signals.shape
noise = torch.randn(real_signals.shape[0], seq_length, 1).to(device)
generated_signals = generator(noise)
# Train the generator
optimizer_G.zero_grad()
# Pass the generated upper segments through the discriminator
fake_outputs = discriminator(generated_signals)
fake_outputs = torch.sigmoid(fake_outputs)
# Compute generator loss (maximize the probability of the generated upper segments being real)
g_loss = criterion_g(fake_outputs, torch.ones_like(fake_outputs))
# Backpropagation and optimization
g_loss.backward()
torch.nn.utils.clip_grad_norm_(generator.parameters(), max_grad_norm)
optimizer_G.step()
# Calculate generator accuracy
# predicted_labels = torch.round(fake_outputs).detach()
# accuracy = (predicted_labels == torch.ones_like(predicted_labels)).sum().item() / predicted_labels.numel()
return g_loss
def train_discriminator(discriminator, generator, optimizer_D, criterion, real_signals, fake_signals):
# Clear the gradients of the discriminator optimizer
optimizer_D.zero_grad()
# Forward pass through the discriminator with real and fake data
real_outputs = discriminator(real_signals)
fake_outputs = discriminator(fake_signals)
# Create target tensors with the same shape as the discriminator outputs
real_labels = torch.ones(real_outputs.size(), device=device)
fake_labels = torch.zeros(fake_outputs.size(), device=device)
# Compute loss for real and fake data
d_loss_real = criterion_d(real_outputs, real_labels) if real_outputs is not None else None
d_loss_fake = criterion_d(fake_outputs, fake_labels) if fake_outputs is not None else None
# Check if both losses are valid
if d_loss_real is None or d_loss_fake is None:
return None
# Compute total discriminator loss
d_loss = d_loss_real + d_loss_fake
# Perform backpropagation and update discriminator weights
d_loss.backward()
optimizer_D.step()
# Unfreeze the generator parameters after every n_critic iterations
if global_step % n_critic == 0:
for param in generator.parameters():
param.requires_grad = True
# Freeze the generator parameters again
for param in generator.parameters():
param.requires_grad = False
return d_loss