GRU Output adjustment for rare event detection

So i’m working on rare event detection, using a GRU network.
I feed multiple timelines, with varying length into my network, so i use padding and packing there.

class MY_FIRST_GRU(nn.Module):
    def __init__(self):
        super(MY_FIRST_GRU, self).__init__()
        self.gru = nn.GRU(input_size=32, 
                          hidden_size=20, 
                          num_layers=gru_layers,
                          batch_first=True)  # Note that "batch_first" is set to "True"
        self.l_out = nn.Linear(in_features=20*1,
                               out_features=2)
 
    def forward(self, batch):
        x, x_length, _ = batch
        x_pack = pack_padded_sequence(x, x_length, batch_first=True).float()
        packed_x, hidden = self.gru(x_pack)
        output_padded ,input_sizes = pad_packed_sequence(packed_x, batch_first=True)
        output = self.l_out(output_padded)
        return output



def train(train_loader):
    dl_model.train()
    total_loss = 0
    correct = 0
    for data_list in train_loader:
        # datalist:
        # 0 = data
        # 1 = data shape
        # 2 = label
        optimizer.zero_grad()
        output = dl_model(data_list)
        y = data_list[2]
        loss = loss_fn(output, y.float())
        loss.backward()
        total_loss += loss.item()
        #Accuracy
        output = (output>0.5).float()
        correct = (output == y).float().sum()
        optimizer.step()
    return total_loss / len(train_loader), correct / len(train_dataset)


def test(loader):
    dl_model.eval()
    actuals = []
    probabilities = []
    correct = 0
    f1_s = 0
    for data_list in test_loader:
        with torch.no_grad():
            output = dl_model(data_list)
            y = data_list[2]
            pred = (output>0.5).float()
        correct += (pred == y).float().sum()
        f1_s += f1_score_calculation(y.cpu(),output.sigmoid().cpu() > 0.5)
    return correct/len(train_loader)/len(train_dataset), f1_s/output.shape[0]

dl_model = MY_FIRST_GRU()
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
dl_model = dl_model.to(device)

optimizer = torch.optim.Adam(dl_model.parameters(), lr=5e-4, weight_decay=5e-4)
  
if __name__ == '__main__':
    dataset = Timeline_Dataset('./data/', lookback)

    train_size = int(0.7 * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = torch.utils.data.random_split(dataset,[train_size, test_size])

    train_loader = DataLoader(dataset=train_dataset,
                        batch_size=batch_size,
                        shuffle=True, # use custom collate function here
                        collate_fn=PadSequences())
    test_loader = DataLoader(dataset=test_dataset,
                        batch_size=batch_size,
                        shuffle=False,
                        collate_fn=PadSequences())

  class_weights = torch.tensor([len(test_dataset)/len(train_dataset), len(train_dataset)/len(test_dataset)], dtype=torch.float64)
    loss_fn  = nn.BCEWithLogitsLoss(pos_weight=class_weights)
    # loss_fn  = nn.BCEWithLogitsLoss(reduction='none', pos_weight=class_weights)

    for epoch in range(1, epochs + 1):
        loss, train_acc = train(train_loader)
        test_acc, f1_score = test(test_loader)

Lets take the example batch size of 8.
The network would output, for each data point in the network, a tensor that has the length of the longest series i have in my input (which is understandable, as they are all padded to the same length)
An example output would be:

lengths = tensor([25, 11,  4,  4,  2,  2,  2,  1])
output = tensor([[[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.0584, -0.2549],
         [ 0.0482, -0.2386],
         [ 0.0410, -0.2234],
         [ 0.0362, -0.2111],
         [ 0.0333, -0.2018],
         [ 0.0318, -0.1951],
         [ 0.0311, -0.1904],
         [ 0.0310, -0.1873],
         [ 0.0312, -0.1851],
         [ 0.0315, -0.1837],
         [ 0.0318, -0.1828],
         [ 0.0322, -0.1822],
         [ 0.0324, -0.1819],
         [ 0.0327, -0.1817],
         [ 0.0328, -0.1815],
         [ 0.0330, -0.1815],
         [ 0.0331, -0.1814],
         [ 0.0332, -0.1814],
         [ 0.0333, -0.1814],
         [ 0.0333, -0.1814],
         [ 0.0334, -0.1814],
         [ 0.0334, -0.1814],
         [ 0.0334, -0.1814]],

        [[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.0584, -0.2549],
         [ 0.0482, -0.2386],
         [ 0.0410, -0.2234],
         [ 0.0362, -0.2111],
         [ 0.0333, -0.2018],
         [ 0.0318, -0.1951],
         [ 0.0311, -0.1904],
         [ 0.0310, -0.1873],
         [ 0.0312, -0.1851],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]],

        [[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.0584, -0.2549],
         [ 0.0482, -0.2386],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]],

        [[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.0584, -0.2549],
         [ 0.0482, -0.2386],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]],

        [[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]],

        [[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]],

        [[ 0.0868, -0.2623],
         [ 0.0716, -0.2668],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]],

        [[ 0.0868, -0.2623],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164],
         [ 0.1003, -0.2164]]])

output[ : ,lengths-1]

With output being my output and length being the original lenghts of the data.
Now my problem is: As i’m predicting a pretty rare event that will only appear up to a maximum of 1 time per series and as low as 0 times, even using a weight for imbalanced class distribution it’s just not very accurat.
And it’s understandable: My precision can be pretty high, as i’m identifying all of the states, where the event does not happen and the few times it happenes and i don’t catch it, don’t make a lot of difference.

Now my question would be: can i modify my network, so that it only produces the binary output of: does the event happen in this time series, or does it not.
So instead of the output shape: [8,24,2], i would like to get [8,1,2].

You could use the last time step of your RNN and feed only this tensor to the linear layer or alternatively try to reduce the temporal dimension (e.g. via a mean() calculation etc.).

I’m not sure which approach would work better, so you might need to try out different approaches.

1 Like

When using only the last time step though, would’t there be the problem of the varying lengths of my data?
So i would need to use my original lengths, to get those positions respectively for each data in my batch.
And so far I didn’t figure out how to get this.

output[ : ,lengths-1]

was the first thing i tried, but it did not return the correct values. Am i just bad at indexing here?

Yeah, I think you are right. Would it be possible to reuse x_length to get the last valid state?

I think it should be possible.
In the example i provided earlier, lengths and x_length are the same value. And i’m trying to index it, but that doesn’t really work right now, as i seem to have made an error while indexing.
And another question i would have regarding this:
If i use this indexing, to get only my last time step (or mean over only the relevant time steps) would it still be possible, to do loss calculation and back propagation with the data?

Yes, that should still be possible, but I’m not an expert in NLP, so let’s also wait for others to chime in.