GRU event prediction architecture

Hi

I’m new to working with timelines, but I have a problem to which I am not able to find any good resources.
So I would appreciate if anyone could give me some pointers.

So in my case I’m interested in predicting an event, for a user of a website.
Per user, i have a timeline of their usage of the website in the form of a dataframe.
Say i would want to predict, if a user leaves the website.

My goal now is, to look at a certain time frame (lets say the last 5 minutes, so that would be 5 rows of my dataframe, if i’m aggregating data per minute) and predict the probability of a user leaving the website.
A data sample would be (for one user)

Minute Feature_1 Feature_2 Feature_3 ... Target_0 Target_1
-4     0.4        0.23     0.64          1            0   
-3     0.24       0.23     0.64          1            0   
-2     0.34       0.1      0.64          1            0   
-1     0.56       0.2      0.64          1            0   
0      0.64       0.3      0.64          0            1   

Where the last row is the most recent observation of this user. Target describes, if the user has left the website at this point.

Now my goal would be, to give the network a set amount of timesteps (lets say 4) and then let it predict how likely it is, for a given user to stay on the website, or leave.

Currently I’m feeding my network a batch and then predict the target values for the last timestep in the series.
Using this code:

lookback = 4
batch_size = 8
layer_size = 64
learning_rate = 5e-4
dataset_name = 'Timeline_Dataset'

class MY_FIRST_GRU(nn.Module):
    def __init__(self):
        super(MY_FIRST_GRU, self).__init__()
        self.gru = nn.GRU(input_size=32, 
                          hidden_size=20, 
                          num_layers=2,
                          batch_first=True) 
        self.l_out = nn.Linear(in_features=20,
                               out_features=2)
 
    def forward(self, batch):
        x, x_length, _ = batch
        x_pack = pack_padded_sequence(x, x_length, batch_first=True).float()
        packed_x, hidden = self.gru(x_pack)
        output_padded ,input_sizes = pad_packed_sequence(packed_x, batch_first=True)
        output = self.l_out(output_padded)
        #return output
        return F.log_softmax(output, dim=1),
def train(train_loader):
    dl_model.train()
    total_loss = 0
    correct = 0
    for data_list in train_loader:
        optimizer.zero_grad()
        output = dl_model(data_list) # shape = [8, 4, 2]
        y = data_list[2]
        batch_ce_loss = 0.0
        loss = F.hinge_embedding_loss(output[0], y.long())
        loss.backward()
        total_loss += loss.item()
        with torch.no_grad():
            pred = output[0]
        correct += pred.eq(y.long()).sum().item()
        optimizer.step()
    return total_loss / len(train_dataset), correct / len(train_loader.dataset)


def test(loader):
    dl_model.eval()
    actuals = []
    probabilities = []
    correct = 0
    for data_list in test_loader:
        with torch.no_grad():
            output = dl_model(data_list)
            y = data_list[2]
            pred = output[0]
            actuals.extend(y.cpu().detach().numpy())
        correct += pred.eq(y.long()).sum().item()
    return correct / len(loader.dataset), actuals 
dataset = Timeline_Dataset('./data/', lookback)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset,[train_size, test_size])

train_loader = DataLoader(dataset=train_dataset,
                        batch_size=batch_size,
                        shuffle=True, # use custom collate function here
                        collate_fn=PadSequences(),
                        pin_memory=True,
                        num_workers=1)
test_loader = DataLoader(dataset=test_dataset,
                        batch_size=batch_size,
                        shuffle=True,
                        collate_fn=PadSequences(),
                        pin_memory=True,
                        num_workers=1)




epochs = 100
if torch.cuda.is_available():
   print('Using GPU '+ str(torch.cuda.current_device()) + ' as main GPU')
else:
   print('Using CPU')
t0 = time.time()
for epoch in range(1, epochs + 1):
     loss, train_acc = train(train_loader)
     test_acc, actuals = test(test_loader)

Is there a way, predict for every timestep and then calculate the loss based on all of the outputs?

1 Like

Is your Timeline_Dataset randomly selecting time moments on every iteration? If not, your network will only see and predict fifth minute exits.

To “predict for every timestep” you normally don’t use “lookback” (except as limit), feed longer sequences, and rely on RNN’s implicit discarding of older data as “lookback”.

In other words, in your implementation, you should already have 4 timestep predictions, if you stop thinking about 5 minute chunks.

1 Like

I’m only randomly selecting parts of timeline, if it’s above the size of the set lookback interval.
So with a lookback of 4, every timeline that is shorter then 4 timesteps, should not be affected. But all the timeslines that are longer, should be sliced at some random place.

Here is the dataset code, for reference. I’m loading from two different files, that are both from the same time period.
The file terminator_data contains all users, who have terminated their session on the website.
And the other file contains data from users who have not done so.
With a ratio of 1 termination timeline for every 43 non termination timelines.

class Timeline_Dataset(Dataset):

    def __init__(self, root, lookback, transform=None):

        self.root = root
        self.lookback = lookback
        self.transform = transform
        self.data = []
        self.__preprocess_to_list__()

        
    def __len__(self):
        return len(self.data)


    def __getitem__(self, idx):
        sample = self.data[idx]
        
        if self.lookback > len(sample):
            sample = [sample[0][-self.lookback:], sample[1][-self.lookback:]]
        else:
            rnd = randrange(0, len(sample), 1)
            sample = [sample[0][:rnd][-self.lookback:], sample[1][-self.lookback:]]


        if self.transform:
            sample = self.transform(sample)
            
        return sample
    
            
    def __preprocess_to_list__(self):
        files = ['terminator_data','user_data']
        for part in files:
            csv_file = self.root + part + '_processed.csv'
            print('Loading ' + part + ' dataset...')
            df = pd.read_csv(csv_file, sep=',',
                    index_col = 0,
                    engine='python',
                    error_bad_lines=False)
            user_list = split_data_by_customer(df)
            
            for x in user_list:
                label = x[['Target_0.0', 'Target_1.0']]
                x.drop(['USER_ID','Target_0.0', 'Target_1.0'], axis=1, inplace = True)
            
                x = torch.tensor(x[x.columns].values)
                y = torch.tensor(label[label.columns].values)
                
                self.data.append([x,y])
            print('Preprocessing complete.')
            print('Loaded ' + str(self.__len__()) + ' datapoints.')


class PadSequences:
    def __call__(self, batch):
        # Each element in "batch" is a tuple (data, label).
        # Sort the batch in the descending order
        sorted_batch = sorted(batch, key=lambda x: x[0].shape[0], reverse=True)
        # Get each sequence and pad it
        sequences = [x[0] for x in sorted_batch]
        sequences_padded = torch.nn.utils.rnn.pad_sequence(sequences, batch_first=True)
        # Also need to store the length of each sequence
        # This is later needed in order to unpad the sequences
        lengths = torch.LongTensor([len(x) for x in sequences])
        # Don't forget to grab the labels of the *sorted* batch
        labels = [x[1] for x in sorted_batch]
        labels_padded = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True)
        return sequences_padded, lengths, labels_padded

Ok, as I said, try just removing “lookback” variable or setting it to high value to limit max sequence length. Note that batches will then contain sequences of variable length, combined with padding, so loss mask should be added, something like:

loss = (F.hinge_embedding_loss(output[0], y.long(), reduction='none') * lossmask).sum() / lossmask.sum()

Actually, you need such a mask even with lookback=4, if shorter sequences exist.
I use this to create lossmask, but it is for time-major rnns

def create_seq_mask(seq_lengths : Tensor):
	max_seq_length = int(seq_lengths.max())
	n = len(seq_lengths)
	r = torch.zeros(max_seq_length, n, dtype=torch.float32)
	for i in range(n):
		r[:seq_lengths[i], i] = 1.0
	return r

Also I think your __getitem__ is buggy, as sample = [x,y], len(sample)=2, and you intended to use len(sample[0]).

Good luck.

1 Like

I adapted my code in the ways you suggested and now it works.
Thank you again.
I modified the masking function, based on another thread so that it now looks like this:

def create_seq_mask(seq_lengths : torch.tensor, device):
    length = int(seq_lengths.max())
    res = torch.stack([torch.arange(length, device=device) < x for x in seq_lengths])
    return res.unsqueeze(2)