NaN embedding layer

Hi dear forum!

I’m dealing with intensive care data at the moment (see MIMIC-IV on physionet.org). There are many missing values in there and I’m trying some methods to deal with those NaNs directly by embedding them properly without imputing these missing values e.g. with a mean/median.

I created a NanEmbedding layer, see below. I takes in a batch of 1-dimensional feature vectors that can contain NaNs. Each feature is projected to an out_size-dimensional vector using its own linear layer. All feature embedding vectors are then summed up, whereas the vectors of features with a NaN are set to 0 (or ignored) during the summation. This allows the embedding to distuingish between a regular value, an input of 0, and a NaN:

class NanEmbed(torch.nn.Module):
    def __init__(self, in_size, out_size, use_conv=True):
        super().__init__()
        self.in_size = in_size
        self.out_size = out_size
        # create embedding weights
        self.emb_layers = torch.nn.ModuleList([torch.nn.Linear(1, out_size) for _ in range(in_size)])
        
    def forward(self, x):
        # embed each feature into a larger embedding vector of size out_size
        out = torch.stack([layer(x[:, i].unsqueeze(1)) for i, layer in enumerate(self.emb_layers)], dim=-1)
        # method 1 for setting NaNs to 0
        #with torch.no_grad():
        #    out[torch.isnan(out)] = 0
        # method 2 (current method)
        out = torch.nan_to_num(out)
        emb = out.mean(dim=-1)
        # method 3 (slow and also yields NaN grads)
        #mask = torch.isnan(x)
        #bs = x.shape[0]
        #emb = torch.stack([out[i][:, torch.where(mask[i])[0]].sum(dim=-1) / self.in_size
        #                    for i in range(bs)])
        return emb

Now, I have two big problems:

  1. Applying these individual linear layers in a for-loop is slow and ugly. Hence, I attempted to use a solution from How to apply different kernels to each example in a batch when using convolution? - #3 by postBG and apply it to 1-D. I replace my list of linear layers by: conv = torch.nn.Conv1d(in_size, in_size * out_size, 1, stride=1, padding=0, groups=in_size, bias=True). This projects my input of shape (batch_size, feature_num==in_size, 1) to (batch_size, in_size * out_size, 1). So it seems to work if I reshape the output. My question: Does this do the right thing, i.e. apply individual linear weights to each single feature?
  2. The bigger problem: this approach does not work, the gradients of the weights of a linear layer belonging to a feature are NaN as soon a single value of the corresponding feature is NaN in the batch. So the network does not learn at all. I am showing three methods I tried in the forward pass. I don’t understand at all why the gradients turn to NaN… I would be happy for any hints, as far as I see I might need to write a custom backward function, but I’ve never done that before so some help would be greatly appreciated.

Hope you’re having a nice day and I’m looking forward to any responses!
Anton

1 Like

Here’s a minimal example to get the NaNs in the gradients using the layer above with a batch size of 4 and 5 input features:

# creat inputs with NaNs
emb_inputs = torch.rand(4, 5)
mask = torch.rand(*emb_inputs.shape) > 0.5
emb_inputs[mask] = np.nan
print("NaNs in input: ", torch.isnan(emb_inputs).sum().item())
# create NaN embedding layer
embed = NanEmbed(5, 8)
# apply layer
embedded_input = embed(emb_inputs)
print("NaNs in embedding: ", torch.isnan(embedded_input).sum().item())
# backward
embedded_input.mean().backward()
# check for NaN in grads
for n, p in embed.named_parameters():
    print(n, "\t", p.grad.norm())

The output is:

NaNs in input:  12
NaNs in embedding:  0
emb_layers.0.weight 	 tensor(nan)
emb_layers.0.bias 	 tensor(0.0354)
emb_layers.1.weight 	 tensor(nan)
emb_layers.1.bias 	 tensor(0.0354)
emb_layers.2.weight 	 tensor(nan)
emb_layers.2.bias 	 tensor(0.0177)
emb_layers.3.weight 	 tensor(nan)
emb_layers.3.bias 	 tensor(0.0177)
emb_layers.4.weight 	 tensor(nan)
emb_layers.4.bias 	 tensor(0.0354)

Paging @ptrblck , do you have any idea how to fix this? It seems like I’m not detaching the NaN results properly…

@NotNANtoN Very interesting problem. I am not sure how you are going to solve this using embedding, but considering your current problem, following can be one of the solutions

  • You cannot insert NAN into the tensor that acts as an input into the model.If you do that you will have to add hooks etc to alter data during each gradient computation. This is a lengthy solution

  • We can leverage the fact that you know the positions in your input data that are NANS. We will use this mask itself as an extension of your embedding. This can be leveraged by the model for distinguishing between nan and non-nans

class NanEmbed(torch.nn.Module):
    def __init__(self, in_size, out_size, use_conv=True):
        super().__init__()
        self.in_size = in_size
        self.out_size = out_size
        # create embedding weights
        self.emb_layers = torch.nn.ModuleList([torch.nn.Linear(1, out_size) for _ in range(in_size)])
        
    def forward(self, x, nan_indices):
        # embed each feature into a larger embedding vector of size out_size
        cur_mask = torch.zeros(x.size())
        cur_mask[nan_indices] = 1
        out = torch.stack([layer(x[:, i].unsqueeze(1)) for i, layer in enumerate(self.emb_layers)], dim=-1)
        emb = out.mean(dim=-1)
        emb = torch.cat((emb, x), 1)
        return emb
    

    
emb_inputs = torch.rand(4, 5)
mask = torch.rand(*emb_inputs.shape) > 0.5
emb_inputs[mask] = np.nan
print("NaNs in input: ", torch.isnan(emb_inputs).sum().item())
# create NaN embedding layer
embed = NanEmbed(5, 8)

# apply layer
emb_inputs = torch.nan_to_num(emb_inputs)
embedded_input = embed(emb_inputs, mask)
print("NaNs in embedding: ", torch.isnan(embedded_input).sum().item())
# backward
embedded_input.mean().backward()
# check for NaN in grads
for n, p in embed.named_parameters():
    print(n, "\t", p.grad.norm())
NaNs in input:  10
[torch.Size([4, 8]), torch.Size([4, 5])]
torch.Size([4, 13])
NaNs in embedding:  0
emb_layers.0.weight 	 tensor(0.0239)
emb_layers.0.bias 	 tensor(0.0435)
emb_layers.1.weight 	 tensor(0.0018)
emb_layers.1.bias 	 tensor(0.0435)
emb_layers.2.weight 	 tensor(0.0187)
emb_layers.2.bias 	 tensor(0.0435)
emb_layers.3.weight 	 tensor(0.0036)
emb_layers.3.bias 	 tensor(0.0435)
emb_layers.4.weight 	 tensor(0.0066)
emb_layers.4.bias 	 tensor(0.0435)

Hi, thanks for the response!

In your modified forward() you define the cur_mask depending on the nan_indices but then it is not used anywhere.
But I think I know what you mean, I tried it out here:

class NanEmbedOld(torch.nn.Module):
    def __init__(self, in_size, out_size, use_conv=True):
        super().__init__()
        self.in_size = in_size
        self.out_size = out_size
        # create embedding weights
        self.emb_layers = torch.nn.ModuleList([torch.nn.Linear(1, out_size) for _ in range(in_size)])
        
    def forward(self, x):
        # create mask to later fill with zeros
        mask = torch.isnan(x)
        x = torch.nan_to_num(x)
        # embed each feature into a larger embedding vector of size out_size
        out = torch.stack([layer(x[:, i].unsqueeze(1)) for i, layer in enumerate(self.emb_layers)], dim=-1)
        # shape [batch size, out_size, in_size]
        # fill embedding with 0 where we had a NaN before
        repeated_mask = mask.unsqueeze(1).repeat(1, self.out_size, 1)
        out[repeated_mask] = 0
        # average the embedding
        emb = out.mean(dim=-1) 
        return emb

This actually works and, most importantly, it can differentiate between a 0 and a NaN in it’s input :slight_smile:

Thanks! Now I just need to try if it works with the convolution too

@NotNANtoN ohh yes my bad

The forward function should be this

emb = torch.cat((emb, cur_mask), 1)
return emb

So now I’m trying to make it faster by using a grouped convolution to avoid iterating over all layers with a for-loop:

class NanEmbedFast(torch.nn.Module):
    def __init__(self, in_size, out_size, use_conv=True):
        super().__init__()
        self.in_size = in_size
        self.out_size = out_size
        # create embedding weights
        #self.emb_layers = torch.nn.ModuleList([torch.nn.Linear(1, out_size) for _ in range(in_size)])
        self.emb_layers = torch.nn.Conv1d(in_size, in_size * out_size, 1, 
                                          stride=1, padding=0, groups=in_size, bias=True)
        
    def forward(self, x):
        # create mask to later fill with zeros
        mask = torch.isnan(x)
        x = torch.nan_to_num(x)
        # embed each feature into a larger embedding vector of size out_size
        out = self.emb_layers(x.unsqueeze(-1)).reshape(x.shape[0], self.out_size, x.shape[1])
        # shape [batch size, out_size, in_size]
        # fill embedding with 0 where we had a NaN before
        repeated_mask = mask.unsqueeze(1).repeat(1, self.out_size, 1)
        out[repeated_mask] = 0
        # average the embedding
        emb = out.mean(dim=-1) 
        return emb

This runs and does not give any shape errors, as you see below. Then I wanted to test if it embeds corrrectly. I want it to have different weights for each single feature, therefore if feature A is completely set to one value in a batch (let’s say 1) and feature B is also set to 1, then the resulting embeddings should still be different due to the differing weights.

# creat inputs with NaNs
emb_inputs = torch.rand(4, 5)
mask = torch.rand(*emb_inputs.shape) > 0.5
emb_inputs[mask] = np.nan
print("NaNs in input: ", torch.isnan(emb_inputs).sum().item())
# create NaN embedding layer
fast_embed = NanEmbedFast(5, 8)
# apply layer
embedded_input = fast_embed(emb_inputs)
print(embedded_input)
print("NaNs in embedding: ", torch.isnan(embedded_input).sum().item())
# backward
embedded_input.mean().backward()
# check for NaN in grads
for n, p in fast_embed.named_parameters():
    print(n, "\t", p.grad.norm())

# create input batch with two identical feature values
test_input = emb_inputs.clone()
test_input[:, 0] = 1.0
test_input[:, 1] = 1.0
# make raw embedding prediction without summarization
out = fast_embed.emb_layers(torch.nan_to_num(test_input).unsqueeze(-1)).reshape(test_input.shape[0], fast_embed.out_size, test_input.shape[1])
print(out.shape)
print(out)

leads to

NaNs in input:  11
tensor([[-0.2859, -0.1354, -0.1115,  0.6052,  0.0163, -0.3774,  0.2210, -0.1135],
        [-0.1800, -0.1466,  0.1830,  0.1339, -0.0303,  0.0381, -0.0929, -0.1570],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [-0.2862, -0.1329, -0.1177,  0.5968,  0.1051, -0.3774,  0.1886, -0.1138]],
       grad_fn=<MeanBackward1>)
NaNs in embedding:  0
emb_layers.weight 	 tensor(0.0313)
emb_layers.bias 	 tensor(0.0810)
torch.Size([4, 8, 5])
tensor([[[-0.8431, -0.1716, -1.1135, -0.6864,  0.2439],
         [-1.4712,  0.5458,  0.0835,  0.4813,  0.0860],
         [-0.0350,  0.5864, -0.7791, -0.1491, -0.4795],
         [-0.1469,  1.2145,  0.6291, -1.0555,  0.7335],
         [-0.1232,  0.1440,  0.4199,  1.1142, -0.3590],
         [ 0.1903, -0.8523, -0.4901,  0.2158, -0.7350],
         [-0.4645,  0.3499,  0.9803,  0.5351,  0.2395],
         [-0.6844, -0.6714, -0.4080,  0.9382,  1.1966]],

        [[-0.8431, -0.1716, -1.1135, -0.6864,  0.2439],
         [-1.4712,  0.5458,  0.0835,  0.4813,  0.0860],
         [-0.0350,  0.5864, -0.7791, -0.1491, -0.4795],
         [-0.1469,  0.9791,  0.9197, -0.7255,  0.5619],
         [-0.1517, -0.1407,  0.1001,  0.7892, -0.3590],
         [ 0.1903, -0.8523, -0.4901,  0.2158, -0.7350],
         [-0.4645,  0.3499,  0.8218,  0.5664, -0.1773],
         [-0.7849, -0.5737, -0.0603,  0.8486,  0.8463]],

        [[-0.8431, -0.1716, -1.1135, -0.6864,  0.2439],
         [-1.4712,  0.5458,  0.0835,  0.4813,  0.0860],
         [-0.0350,  0.5864, -0.7791, -0.1491, -0.4795],
         [-0.1469,  0.9791,  0.9197, -0.7255,  0.5619],
         [-0.1517, -0.1407,  0.1001,  0.7892, -0.3590],
         [ 0.1903, -0.8523, -0.4901,  0.2158, -0.7350],
         [-0.4645,  0.3499,  0.8218,  0.5664, -0.1773],
         [-0.7849, -0.5737, -0.0603,  0.8486,  0.8463]],

        [[-0.8431, -0.1716, -1.1135, -0.6864,  0.2439],
         [-1.4712,  0.5458,  0.0835,  0.4813,  0.0860],
         [-0.0350,  0.5864, -0.7791, -0.1491, -0.4795],
         [-0.1469,  1.3797,  0.4252, -1.2869,  0.8538],
         [-0.1031,  0.3436,  0.6442,  1.3421, -0.3590],
         [ 0.1903, -0.8523, -0.4901,  0.2158, -0.7350],
         [-0.4645,  0.3499,  0.9357,  0.5439,  0.1221],
         [-0.7127, -0.6439, -0.3101,  0.9129,  1.0980]]],
       grad_fn=<ReshapeAliasBackward0>)

Unfortunately, you can see that all outputs for the batch are the same :-/
I mean, basically I want to create a pytorch Embedding layer that can take a continuous value as an input instead of an integer. Maybe I need to write that myself?