Implementing k-sparse autoencoder on FastText embedding, the output is strange

Hi, I’m implementing k-Sparse Autoencoders (A. Makhzani et al., 2013). With the implementation I’m trying to sparse-code my pre-trained word embeddings. Below is the algorithm explained in the paper. In my understanding, it’s basically overcomplete autoencoder which has constraint of only selecting top-k activations from the hidden layer.

What I’m struggling is that the most of the resulting sparse embeddings (top-k activations for each embedding rows) are all 0s, even though the loss is properly decreasing. Of course the goal of sparse coding is preserving only the information (activation) that matters, but the problem is there’s no information left!

I’m attaching the code segments for the advice. Here’s the full code, and the data. The data is .hdf5 file which has fields ['words', 'vectors']

# EmbeddingDataset
This consists of vectors (np array, ‘f4’) and words (np array, str, not used in training)

class EmbeddingDataset(data.Dataset):
    def __init__(self, vectors, words):
        self.vectors = vectors
        self.words = words
    
    def __len__(self):
        return len(self.words)
    
    def __getitem__(self, index):
        vectors = self.vectors[index]
        words = self.words[index]
        return vectors, words

# KSparseAutoencoder

  • Hidden size = 1000, k = e.g. 15 (making 300-dimensional embedding into sparse 1000-dimensional ones, which has at most k non-zero dimensions)
    • starting with k=15 might produce many dead-activations even before finding gradients, so k gradually decreases from 100 to 15 (the authors’ suggestion)
  • In the encode stage, I use cap-ReLU (min = 0, max = 1) to make the resulting embedding non-negative, regularized.
  • I select top-k activations in the forward method, making others zero. Then I proceed to decode stage only with the surviving values. Please check the code for any errors.
class KSparseAutencoder(nn.Module):
    def __init__(self, D_in_out, H, k, total_epoch, use_activation=True):
        """
        In the constructor we instantiate two nn.Linear modules and assign them as
        member variables.
        """
        super().__init__()
        self.encoder = nn.Linear(D_in_out, H)
        self.decoder = nn.Linear(H, D_in_out)
        self.H = H
        self.k = k
        self.use_activation = use_activation
        k_grace_period = total_epoch // 2
        k_start = max(k*2, 100)
        logging.info(f"k = {k_start}~{k}")
        self.k_list = [math.floor(k) for k in np.linspace(k_start, k, num=k_grace_period)] + [k for _ in range(total_epoch-k_grace_period)]
        
    def forward(self, x, epoch, is_final=False):
        """
        In the forward function we accept a Tensor of input data and we must return
        a Tensor of output data. We can use Modules defined in the constructor as
        well as arbitrary operators on Tensors.
        """
        # Encode
        encoded = self.encoder(x)
        if self.use_activation:
            encoded = encoded.clamp(min=0, max=1) # cap-ReLU

        # Select k
        for e in encoded:
            if is_final:
                k = int(self.k * 1.5)
            else:
                k = self._get_k(epoch)
            _, indices_to_erase = torch.topk(e, (self.H - k), largest=False)
            e[indices_to_erase] = 0.0

        # Decode
        decoded = self.decoder(encoded)
        return encoded, decoded
    
    def _get_k(self, epoch):
        return math.ceil(self.k_list[epoch])

Training

UPDATE: I suspect calculating the loss like this loss = criterion(vectors, decoded) is causing the problem. Is it a correct way to evaluation a batch of embeddings?

with torch.cuda.device(params['gpu_idx']):
        model = KSparseAutencoder(D_in_out, H, k, num_epochs, use_activation).to(device)
        learning_rate = 0.1
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        criterion = nn.MSELoss()
        for epoch in range(num_epochs):
            # Training
            epoch_loss = 0.0
            max_sparsity = 0.0
            for vectors, _ in data_generator:
                # Get batch
                vectors = Variable(vectors, requires_grad=False).type(dtype)

                # denoising
                if denoising and noise_level > 0.0:
                    noise = get_noise_features(vectors.shape[0], vectors.shape[1], noise_level)
                    noise = Variable(torch.from_numpy(noise.astype('f4')), requires_grad=False).type(dtype)
                    vectors += noise
                
                # Move data to GPU
                vectors = vectors.to(device)

                # Forward
                encoded, decoded = model(vectors, epoch)
                max_sparsity = max(max_sparsity, check_sparsity(encoded))

                # Loss
                loss = criterion(vectors, decoded)
                epoch_loss += criterion(vectors, decoded)

                # Zero gradients, perform a backward pass, and update the weights.
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            avg_epoch_loss = epoch_loss / len(words)
            N = 20
            print(vectors[0][:N])
            print(decoded[0][:N])
            print(f"After epoch {epoch}, Avg. Loss = {avg_epoch_loss:.6f}, Sparsity = {max_sparsity}")

If you manully place some values of a tensor to zeros, it will not compute the gradients properly.

I implent a k sparse filter, you can test it out.

import torch


class KSparse(torch.nn.Module):
    def __init__(self, k):

        super(KSparse, self).__init__()
        self.k = k

    def forward(self, x):

        _, indices = torch.topk(x, self.k)
        mask = torch.zeros(x.size()).cuda()
        mask.scatter_(2, indices, 1)
        return torch.mul(x, mask)


test = torch.tensor([[[1., 2., 4.], [2., 3., 1.]]]).cuda()
k_sparse = KSparse(1)
print(k_sparse(test))
1 Like

much appreciated! Thanks!