I’m training a small CNN (code below) to predict sparse amplitude maps from binary masks.
-
Input: 60×60 image with exactly 15 pixels set to 1, rest are 0.
-
Target: Same size, 0 everywhere except those 15 pixels, which have values in the range 0.6–1.0.
The CNN is trained on ~1800 images and tested on ~400. The goal is for it to predict the amplitude at the 15 known locations, given the mask as input.
Here’s an example output:
And the detailed predicted vs. target values:
Index (row, col) | Predicted | Target
(40, 72) | 0.9177 | 0.9143
(40, 90) | 0.9177 | 1.0000
(43, 52) | 0.9177 | 0.8967
(50, 32) | 0.9177 | 0.9205
(51, 70) | 0.9177 | 0.9601
(53, 45) | 0.9177 | 0.9379
(56, 88) | 0.9177 | 0.8906
(61, 63) | 0.9177 | 0.9280
(62, 50) | 0.9177 | 0.9154
(65, 29) | 0.9177 | 0.9014
(65, 91) | 0.9177 | 0.8941
(68, 76) | 0.9177 | 0.9043
(76, 80) | 0.9177 | 0.9206
(80, 31) | 0.9177 | 0.8872
(80, 61) | 0.9177 | 0.9019
As you can see, the network collapses to a constant output, despite the targets being quite different.
I have been able to play around with the CNN and get values that are not all the same:
Index (row, col) | Predicted | Target
(40, 72) | 0.9559 | 0.9143
(40, 90) | 0.9563 | 1.0000
(43, 52) | 0.9476 | 0.8967
(50, 32) | 0.9515 | 0.9205
(51, 70) | 0.9512 | 0.9601
(53, 45) | 0.9573 | 0.9379
(56, 88) | 0.9514 | 0.8906
(61, 63) | 0.9604 | 0.9280
(62, 50) | 0.9519 | 0.9154
(65, 29) | 0.9607 | 0.9014
(65, 91) | 0.9558 | 0.8941
(68, 76) | 0.9560 | 0.9043
(76, 80) | 0.9555 | 0.9206
(80, 31) | 0.9620 | 0.8872
(80, 61) | 0.9563 | 0.9019
I’ve tried many things:
- Scale the amplitudes to be from -5 to 5, -3 to 3, and -1 to 1 (linear and nonlinear behavior for them) then unscale when in the test() function
- Different optimizers Adam and AdamW
- Used different criteria: SmoothL1Loss() and MSELoss()
- A large for loop over epoch and lr
- Instead of doing a MSE for all pixels together, I instead did them individually
What’s interesting is that I trained the same architecture for phase prediction, where values range from -π to π, and it learns beautifully:
Index (row, col) | Predicted | Target
(40, 72) | -0.1235 | -0.1235
(40, 90) | 0.5146 | 0.5203
(43, 52) | -1.0479 | -1.0490
(50, 32) | -0.3166 | -0.3165
(51, 70) | -1.5540 | -1.5521
(53, 45) | 0.5990 | 0.6034
(56, 88) | -0.4752 | -0.4752
(61, 63) | -2.4576 | -2.4600
(62, 50) | 2.0495 | 2.0526
(65, 29) | -2.6678 | -2.6681
(65, 91) | -1.9935 | -1.9961
(68, 76) | -1.9096 | -1.9142
(76, 80) | -1.7976 | -1.8025
(80, 31) | -2.7799 | -2.7795
(80, 61) | 0.5338 | 0.5393
So this proves that the CNN can learn, I just can’t figure out how it can work with amplitudes. The only difference is, that the input phase values are the same values as the loss function. Here is what I mean:
When being trained (let’s just take 1 pixel value of -1.2 for the phase):
-1.2 → CNN → output gets compared to -1.2
Whereas the amplitude of 1 pixel is like this:
1.0 → CNN ->output gets compared to true value such as 0.9143
So maybe the phase has an “easier” life, nonetheless I am struggling with the CNN for the amplitude and I would really appreciate some insight if anyone can help!
import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split
class AmplitudeDataset(Dataset):
def __init__(self, input_dir_ampl, target_dir_ampl):
self.input_dir_ampl = input_dir_ampl
self.target_dir_ampl = target_dir_ampl
self.fileids = sorted([f.split('_')[1] for f in os.listdir(input_dir_ampl)])
def __len__(self):
return len(self.fileids)
def __getitem__(self, idx):
input_ampl = np.load(os.path.join(self.input_dir_ampl, f'tweezer_{self.fileids[idx]}_ampl_input.npy')).astype(np.float32)
target_ampl = np.load(os.path.join(self.target_dir_ampl, f'tweezer_{self.fileids[idx]}_wgs_ampl.npy')).astype(np.float32)
return torch.tensor(input_ampl).unsqueeze(0), torch.tensor(target_ampl).unsqueeze(0)
class SingleOutputFCN(nn.Module):
def __init__(self):
super(SingleOutputFCN, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(1, 16, kernel_size=3, padding='same'),
nn.LeakyReLU(0.1),
nn.Conv2d(16, 16, kernel_size=3, padding='same'),
nn.LeakyReLU(0.1),
nn.Conv2d(16, 1, kernel_size=3, padding='same'),
)
def forward(self, x):
return self.net(x)
def train(model, dataloader, device, epochs=4, lr=1e-3):
model.to(device)
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.MSELoss()
for epoch in range(epochs):
total_loss = 0
for inputs, targets in dataloader:
inputs = inputs.to(device)
targets = targets.to(device)
optimizer.zero_grad()
pred = model(inputs)
loss = criterion(pred, targets)
loss.backward()
optimizer.step()
total_loss += loss.item()
pred_std = pred.std().item()
print(f"[AMPL] Epoch {epoch+1}/{epochs} | Avg Loss: {total_loss / len(dataloader):.6f}")
def test(model, dataloader, device):
model.eval()
output_dir = './cnn_predictions/'
os.makedirs(output_dir, exist_ok=True)
with torch.no_grad():
for i, (inputs, _) in enumerate(dataloader):
inputs = inputs.to(device)
outputs = model(inputs).cpu()
input_np = inputs.cpu().numpy()
output_np = outputs.cpu().numpy()
for j in range(inputs.shape[0]):
idx = i * dataloader.batch_size + j
file_id = dataloader.dataset.dataset.fileids[dataloader.dataset.indices[idx]] if isinstance(dataloader.dataset, Subset) else dataloader.dataset.fileids[idx]
np.save(os.path.join(output_dir, f'tweezer_{file_id}_predicted_ampl.npy'), output_np[j, 0])
np.save(os.path.join(output_dir, f'tweezer_{file_id}_input_ampl.npy'), input_np[j, 0])
if __name__ == '__main__':
input_dir_ampl = './WGS_CNN_Cropped/ampl_input'
target_dir_ampl = './WGS_CNN_Cropped/wgs_amplitude'
dataset = AmplitudeDataset(input_dir_ampl, target_dir_ampl)
train_ids, test_ids = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
train_dataset = Subset(dataset, train_ids)
test_dataset = Subset(dataset, test_ids)
train_loader = DataLoader(train_dataset, batch_size=15, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=15, shuffle=False)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SingleOutputFCN()
print("\n=== Training Amplitude Model ===")
train(model, train_loader, device)
print("\n=== Testing Amplitude Model ===")
test(model, test_loader, device)