Denoising Autoencoder Testing mode for Multiclass Classification

I am training an autoencoder for a multiclass classification problem where I transmit 16 equiprobable messages and send them through a denoising autoencoder to receive them. I am trying to implement the result (modification of Fig. 3b) in this paper, to be specific: Please refer to Fig. 2 in https://arxiv.org/pdf/1702.00832.pdf for the model.

Here is my autoencoder class:

class FullyConnectedAutoencoder(nn.Module):
    def __init__(self, k, n_channel, EbN0_dB):
        self.k = k
        self.n_channel = n_channel
        self.EbN0_dB = EbN0_dB
        
        super(FullyConnectedAutoencoder, self).__init__()
        self.transmitter = nn.Sequential(
            nn.Linear(in_features=2 ** k, out_features=2 ** k, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=2 ** k, out_features=n_channel, bias=True) )
        self.receiver = nn.Sequential(
            nn.Linear(in_features=n_channel, out_features=2 ** k, bias=True),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=2 ** k, out_features=2 ** k, bias=True),)

    def forward(self, x):

        x = self.transmitter(x)
        # Normalization
        n = (x.norm(dim=-1)[:,None].view(-1,1).expand_as(x))
        x = sqrt(7)*(x / n)
        training_SNR = 10 ** (self.EbN0_dB / 10)  # Train at 3 dB
        R = k / n_channel
        noise = torch.randn(x.size()) / ((2*R*training_SNR) ** 0.5)
        x += noise

        x = self.receiver(x)
        return x

and my training loop is as follows:

# TRAINING
for epoch in range(epochs):
    for step, (x, y) in enumerate(trainloader):  # gives batch data, normalize x when iterate train_loader
            
        # Forward pass
        output = net(x)  # output
        y = (y.long()).view(-1)
        loss = loss_func(output, y)  # cross entropy loss

        # Backward and optimize
        optimizer.zero_grad()  # clear gradients for this training step
        loss.backward()  # backpropagation, compute gradients
        optimizer.step()  # apply gradients

        if step % 100 == 0:
            train_output = net(train_data)
            pred_labels = torch.max(train_output, 1)[1].data.squeeze()
            accuracy = sum(pred_labels == train_labels) / float(train_labels.size(0))
            print('Epoch: ', epoch, '| train loss: %.4f' % loss.item(), '| train accuracy: %.4f' % accuracy)

However, I want to test my approach across different signal-to-noise ratios. I am having some problems doing that. Here are the two approaches I am trying

Approach 1: Declare a new object every time I test the autoencoder

for p in range(len(EbNo_test)):
    with torch.no_grad():
        for test_data, test_labels in testloader:  
            
            net = FullyConnectedAutoencoder(k, n_channel, EbNo_test[p])
            decoded_signal = net(test_data)
            
            # encoded_signal = net.transmitter(test_data) 
            # noisy_signal = encoded_signal + test_noise
            # decoded_signal =  net.receiver(noisy_signal)
            
            
            pred_labels = torch.max(decoded_signal, 1)[1].data.squeeze()
            test_BLER[p] = sum(pred_labels != test_labels) / float(test_labels.size(0))
            
    print('Eb/N0:',EbNo_test[p].numpy(), '| test BLER: %.4f' % test_BLER[p])

Approach 2: This is more intutive. Use the transmitter and receiver part separately and add noise after I transmit the signal.

for p in range(len(EbNo_test)):
    EcNo_test_sqrt[p] = 1/(2*R*(10**(EbNo_test[p]/20)))
    test_noise = EcNo_test_sqrt[p] * torch.randn(batch_size, n_channel)
    with torch.no_grad():
        for test_data, test_labels in testloader:  
            
            encoded_signal = net.transmitter(test_data) 
            noisy_signal = encoded_signal + test_noise
            decoded_signal =  net.receiver(noisy_signal)
            
            pred_labels = torch.max(decoded_signal, 1)[1].data.squeeze()
            test_BLER[p] = sum(pred_labels != test_labels) / float(test_labels.size(0))
            
    print('Eb/N0:',EbNo_test[p].numpy(), '| test BLER: %.4f' % test_BLER[p])
   

Strangely, I am getting wrong answers - meaning the errors are 90% where as they should be following a trend like this.

image

Am I doing something wrong? Any help is much appreciated.

Your first approach would reinitialize your model, so I assume this won’t work, as you are not testing your trained model.

The second approach looks alright. Could you check the data type of (pred_labels != test_labels) and make sure the sum does not overflow?
This shouldn’t be the case in the current stable release, but might be a problem, if you are using an old PyTorch version.

1 Like