CNN predicts constant values for sparse amplitude regression — can't learn true pixel values

I’m training a small CNN (code below) to predict sparse amplitude maps from binary masks.

  • Input: 60×60 image with exactly 15 pixels set to 1, rest are 0.

  • Target: Same size, 0 everywhere except those 15 pixels, which have values in the range 0.6–1.0.

The CNN is trained on ~1800 images and tested on ~400. The goal is for it to predict the amplitude at the 15 known locations, given the mask as input.

Here’s an example output:

And the detailed predicted vs. target values:

Index (row, col) |  Predicted |     Target

        (40, 72) |     0.9177 |     0.9143
        (40, 90) |     0.9177 |     1.0000
        (43, 52) |     0.9177 |     0.8967
        (50, 32) |     0.9177 |     0.9205
        (51, 70) |     0.9177 |     0.9601
        (53, 45) |     0.9177 |     0.9379
        (56, 88) |     0.9177 |     0.8906
        (61, 63) |     0.9177 |     0.9280
        (62, 50) |     0.9177 |     0.9154
        (65, 29) |     0.9177 |     0.9014
        (65, 91) |     0.9177 |     0.8941
        (68, 76) |     0.9177 |     0.9043
        (76, 80) |     0.9177 |     0.9206
        (80, 31) |     0.9177 |     0.8872
        (80, 61) |     0.9177 |     0.9019

As you can see, the network collapses to a constant output, despite the targets being quite different.
I have been able to play around with the CNN and get values that are not all the same:

Index (row, col) | Predicted | Target

        (40, 72) |     0.9559 |     0.9143
        (40, 90) |     0.9563 |     1.0000
        (43, 52) |     0.9476 |     0.8967
        (50, 32) |     0.9515 |     0.9205
        (51, 70) |     0.9512 |     0.9601
        (53, 45) |     0.9573 |     0.9379
        (56, 88) |     0.9514 |     0.8906
        (61, 63) |     0.9604 |     0.9280
        (62, 50) |     0.9519 |     0.9154
        (65, 29) |     0.9607 |     0.9014
        (65, 91) |     0.9558 |     0.8941
        (68, 76) |     0.9560 |     0.9043
        (76, 80) |     0.9555 |     0.9206
        (80, 31) |     0.9620 |     0.8872
        (80, 61) |     0.9563 |     0.9019

I’ve tried many things:

  1. Scale the amplitudes to be from -5 to 5, -3 to 3, and -1 to 1 (linear and nonlinear behavior for them) then unscale when in the test() function
  2. Different optimizers Adam and AdamW
  3. Used different criteria: SmoothL1Loss() and MSELoss()
  4. A large for loop over epoch and lr
  5. Instead of doing a MSE for all pixels together, I instead did them individually

What’s interesting is that I trained the same architecture for phase prediction, where values range from -π to π, and it learns beautifully:

Index (row, col) |  Predicted |     Target

        (40, 72) |    -0.1235 |    -0.1235
        (40, 90) |     0.5146 |     0.5203
        (43, 52) |    -1.0479 |    -1.0490
        (50, 32) |    -0.3166 |    -0.3165
        (51, 70) |    -1.5540 |    -1.5521
        (53, 45) |     0.5990 |     0.6034
        (56, 88) |    -0.4752 |    -0.4752
        (61, 63) |    -2.4576 |    -2.4600
        (62, 50) |     2.0495 |     2.0526
        (65, 29) |    -2.6678 |    -2.6681
        (65, 91) |    -1.9935 |    -1.9961
        (68, 76) |    -1.9096 |    -1.9142
        (76, 80) |    -1.7976 |    -1.8025
        (80, 31) |    -2.7799 |    -2.7795
        (80, 61) |     0.5338 |     0.5393

So this proves that the CNN can learn, I just can’t figure out how it can work with amplitudes. The only difference is, that the input phase values are the same values as the loss function. Here is what I mean:

When being trained (let’s just take 1 pixel value of -1.2 for the phase):

-1.2 → CNN → output gets compared to -1.2

Whereas the amplitude of 1 pixel is like this:

1.0 → CNN ->output gets compared to true value such as 0.9143

So maybe the phase has an “easier” life, nonetheless I am struggling with the CNN for the amplitude and I would really appreciate some insight if anyone can help!

import os
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from sklearn.model_selection import train_test_split

class AmplitudeDataset(Dataset):
    def __init__(self, input_dir_ampl, target_dir_ampl):
        self.input_dir_ampl = input_dir_ampl
        self.target_dir_ampl = target_dir_ampl
        self.fileids = sorted([f.split('_')[1] for f in os.listdir(input_dir_ampl)])

    def __len__(self):
        return len(self.fileids)

    def __getitem__(self, idx):
        input_ampl = np.load(os.path.join(self.input_dir_ampl, f'tweezer_{self.fileids[idx]}_ampl_input.npy')).astype(np.float32)
        target_ampl = np.load(os.path.join(self.target_dir_ampl, f'tweezer_{self.fileids[idx]}_wgs_ampl.npy')).astype(np.float32)

        return torch.tensor(input_ampl).unsqueeze(0), torch.tensor(target_ampl).unsqueeze(0)

class SingleOutputFCN(nn.Module):
    def __init__(self):
        super(SingleOutputFCN, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3, padding='same'),
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 16, kernel_size=3, padding='same'),
            nn.LeakyReLU(0.1),
            nn.Conv2d(16, 1, kernel_size=3, padding='same'),
        )

    def forward(self, x):
        return self.net(x)

def train(model, dataloader, device, epochs=4, lr=1e-3):
    model.to(device)
    model.train()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.MSELoss()

    for epoch in range(epochs):
        total_loss = 0
        for inputs, targets in dataloader:
            inputs = inputs.to(device)
            targets = targets.to(device)

            optimizer.zero_grad()
            pred = model(inputs)
            loss = criterion(pred, targets)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            pred_std = pred.std().item()

        print(f"[AMPL] Epoch {epoch+1}/{epochs} | Avg Loss: {total_loss / len(dataloader):.6f}")

def test(model, dataloader, device):
    model.eval()
    output_dir = './cnn_predictions/'
    os.makedirs(output_dir, exist_ok=True)

    with torch.no_grad():
        for i, (inputs, _) in enumerate(dataloader):
            inputs = inputs.to(device)
            outputs = model(inputs).cpu()
            input_np = inputs.cpu().numpy()
            output_np = outputs.cpu().numpy()


            for j in range(inputs.shape[0]):
                idx = i * dataloader.batch_size + j
                file_id = dataloader.dataset.dataset.fileids[dataloader.dataset.indices[idx]] if isinstance(dataloader.dataset, Subset) else dataloader.dataset.fileids[idx]

                np.save(os.path.join(output_dir, f'tweezer_{file_id}_predicted_ampl.npy'), output_np[j, 0])
                np.save(os.path.join(output_dir, f'tweezer_{file_id}_input_ampl.npy'), input_np[j, 0])

if __name__ == '__main__':
    input_dir_ampl = './WGS_CNN_Cropped/ampl_input'
    target_dir_ampl = './WGS_CNN_Cropped/wgs_amplitude'

    dataset = AmplitudeDataset(input_dir_ampl, target_dir_ampl)
    train_ids, test_ids = train_test_split(range(len(dataset)), test_size=0.2, random_state=42)
    train_dataset = Subset(dataset, train_ids)
    test_dataset = Subset(dataset, test_ids)

    train_loader = DataLoader(train_dataset, batch_size=15, shuffle=True)
    test_loader = DataLoader(test_dataset, batch_size=15, shuffle=False)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model = SingleOutputFCN()
    print("\n=== Training Amplitude Model ===")
    train(model, train_loader, device)

    print("\n=== Testing Amplitude Model ===")
    test(model, test_loader, device)

Hi Thunlok!

As I understand it, a given input sample is a 60x60 image that is mostly zero pixels and
15 pixels that are 1. Furthermore, the location of the 1 pixels differs from one input sample
to the next. So conceptually the problem is to predict the amplitude (and phase) for the
pixels that are 1 in the input mask based solely on the locations of those 1 pixels.

This is very confusing to me as I don’t see how your model could possibly predict the
phases.

If I understand your code correctly, your model pumps the input image through three
convolutions with kernel size 3. After the first 3x3 convolution, a given pixel that has
value 1 will only affect a 3x3 square of surrounding pixels. After two such convolutions,
the affected square is 5x5, and after the third, it is 7x7.

Saying the same thing in a slightly different way, the prediction for a given output pixel
can only depend on input pixels that are within three steps of the output-pixel location
in both the x and y directions.

For the specific example you posted, I see pairs of input pixels that are no closer to one
another (in both the x and y directions) than 8.

That is to say, the predicted output at the location of a given input pixel depends only on a
7x7 field of 0 pixels with a single 1 pixel in the middle. (This is because no other 1 pixel in
the input image is close enough to fall within that 7x7 field.) No information about the pattern
and relative locations of the 1 pixels in the input ever makes it to the prediction for a specific
output pixel.

There’s something I must be missing here, because the above analysis tells me that the
predictions for inputs that look like the sample you posted – 1 pixels for which the separation
in either the x or y direction is at least 8 – should just be a translationally-invariant constant
that is about equal to the average of the target values. (For the amplitudes you posted, the
constant prediction is 0.9177 and the average target is 0.9182)

Can you clarify what is going on here? Are your input images truly binary – all the pixels are
either exactly 0 or 1 (with 15 1 pixels in an input image)?

Best.

K. Frank

Hi Krank,

Firstly, many thanks for taking the time to reply. I really do appreciate it.

I think your interpretation is spot on given the code structure I originally shared.

Let me clarify the intent and what I’m actually trying to achieve:

Yes, you’re correct that each input is a 60×60 binary mask with exactly 15 pixels set to 1 (and the rest 0), and those 1-pixel positions vary per sample. The target is a corresponding 60×60 array where the same 15 pixel locations have values between 0.6 and 1.0, representing amplitudes, and all other pixels are 0. So, in principle, the model’s task is to predict which value to assign at each active pixel location, based solely on the global arrangement of those 15 pixels.

You’re also absolutely right that in my initial CNN architecture the receptive field was too small to capture anything beyond the immediate 7×7 neighborhood. So essentially, the model was only learning to recognize the presence of a 1 pixel and defaulting to predicting the mean amplitude for all of them which is why I saw the constant collapse around ~0.9177.

However, the goal isn’t to predict amplitude purely from the local pattern around a 1 pixel it’s to infer amplitude from the global configuration of all 15 active pixels. The assumption is that the amplitude at a given location depends on where the other 14 are placed, possibly due to interference or overlap effects (eg in a holography context).

Does that make more sense? I think my CNN is too crude to do what I want it to do…I think something like using U-Net could work but I don’t know enough about it to implement it at this point.

Did that help? If not, let me know and I can happily re-explain.

Sincerely,
thunlok

Also I want to mention, I think unet, generative adversarial network, or graph neural network would be applicable here…but I don’t know enough about them to implement them

Hi Thunlok!

Yes, this is what I understood the use case to be.

I’m still baffled as to why you didn’t get the same result when training on phases. I would have
expected (as with the amplitudes) the phase prediction to be a translationally-invariant constant.

Yes.

U-Net also occurred to me as a possible approach. You would want the receptive field of your
U-Net, of course, to be large enough to contain all of the relevant active pixels, perhaps simply
by having the receptive field contain the entire 60x60 input image.

Another approach – which to my mind might be more promising – would be to use as input
to your model the coordinates of the active pixels. That is, a single input sample would be
a tensor of shape [15, 2] – the x and y coordinates of the 15 active pixels. The model could
then be a series of fully-connected layers.

This scheme would seemingly be easier for the model – the model doesn’t have to learn how
to locate the active pixels and doesn’t have to learn that the locations of the active pixels are
the things that matter.

Regardless of what approach you take, I would recommend training on both the amplitudes
and phases simultaneously – more shared information. Internal to the model, the learned
“features” that help predict the amplitudes are likely to be able to help predict the phases,
as well.

So in the scheme where you input the coordinates of the 15 active pixels (with shape [15, 2]),
the target and predicted output would be the amplitude and phase of each of the active pixels,
again with shape [15, 2].

Best.

K. Frank

Hi Kfrank,

Thanks for the help! Actually the phase part – the CNN is just learning the identity transformation as the input phase and the phase that are included in the loss function are the same. We can disregard this for now as this is actually not that useful and I would like to make an impactful and robust method for this.

To give you more insight, I would like to do something like this paper: https://arxiv.org/pdf/2401.06014

They use a U-net, but I want to go above their paper and try to use graph or generative neural networks. Does that make sense? It would be for the same use case as in the paper.

Sincerely,
Thunlok

Also, I didn’t mention, but the CNN should learn to have as input an arbitrary amount of activated pixels. So, I am just training on a training set of 1800 only with 15 pixels just to show the simplest case, but my goal is to expand this once it works and to give a random number of activated pixels (up to probably 500) and random position of those pixels. And then the goal would be to say “Hey CNN, take this amplitude input (and maybe also phase) of these X pixels at these positions and give me the best estimate of what the true result is”. Does that make sense?

Hi Thunlok!

I hate to rule anything out, but I don’t see anything in (what I think is) your use case that
would make a graph or generative model appropriate.

It would be helpful if you could give an overview of your actual real-world use case.

From the various things you’ve said, I’m now imagining that your input is the locations of
the active pixels together with some sort of measure of the amplitude and phase and you
want to predict some kind of cleaned-up, maybe noise-reduced amplitude and phase for
those pixels.

It still seems plausible to me that inputting the x and y coordinates of the active pixels
(rather than force the model to learn to locate them) could be a better way to go.

As I understand it, samples of your real data have varying numbers of active pixels.
Even so, you could still specify some maximum number of active pixels, say 500, and
pad your input with “dummy” pixels to get an input length of 500. You might also input
a flag field that indicates which pixels are active.

So your input might have shape [500, 5] with each pixel (active or dummy) having
the values x, y, amplitude, phase, and flag. (I would probably use flag = 1
for active pixels and flag = 0 for dummy pixels.)

It is true that with such a scheme the network would need to learn to pick the active
pixels out of the 500 total pixels, active and dummy, that were input. But this doesn’t
seem any harder than having the network learn to find the active pixels in a 60x60
image that contains 3600 pixels.

One last thought: If your problem is in some sense translationally invariant – that is, if
the prediction for a set of three or four nearby active pixels in the upper left of the image
would be the same regardless of where other active pixels are in the image, as it would
for the same pattern of three or four pixels in the lower right – then a (fully) convolutional
network would be a natural fit, as the convolutions have that translational invariance built
in. But if the prediction for active pixels in the upper right depends globally on the locations
of all of the active pixels – not just their nearby neighbors – then a convolutional network
wouldn’t necessary have any special benefit for your use case.

Best.

K. Frank

Hi Krank,

Thank you for the reply. I think you missed my additional comment where I mention the paper that I am trying to improve on. I am doing the same thing as them in terms of use case, but I want to have a better solution which would most likely need a more advanced network. Let me know what you think of the paper!

I made a new post that goes into better detail of the use case if you would like to take a gander!