GPU does not accelerate deep neural model

I am currently looking at this implementation of Joint Entity and Relational Extraction on GitHub vedantc6/mtl-dts, see here. I am able to execute the model that they have built, but training on 10 epochs takes about 30 hours - regardless of whether I specify using CUDA or not (and use CPU instead). This behaviour was observed by another user.

Further diving into their codes, I was able to see that the original authors did move the codes onto the GPU device. When monitoring my GPU using nvidia-smi, I did observe activity but there is no speeding up of the entire process.

I am rather unfamiliar with optimising PyTorch implementations with GPU and building models from scratch. Any advice on how I can achieve faster computation will be appreciated as I am looking to utilise this as a baseline model. Thank you so much.

The model structure code can be seen here and it has been well documented. This architecture has one shared RNN between the NER RNN and RE RNN, forming the main structure of the model. Codes with the parameter device are attempts to shift the training onto GPU.

import torch
import torch.nn as nn
from torch.nn.utils.rnn import pack_padded_sequence
from utils import load_glove_embeddings, load_elmo_embeddings, load_onehot_embeddings, get_boundaries
from crf import CRFLoss
import math
import itertools
from collections import defaultdict, Counter
import numpy as np

class MTLArchitecture(nn.Module):
    """
    The main class where all successive architectures are initialised and a forward pass is done through each
    of them.
    """

    def __init__(self, num_word_types, shared_layer_size, num_char_types,
                 char_dim, hidden_dim, dropout, re_dropout, num_layers_shared, num_layers_ner, 
                 num_layers_re, num_tag_types, num_rel_types, init, label_embeddings_size, re_ff1_size,
                 re_lambda, e1_activation_type, r1_activation_type, recurrent_unit="gru", device='cuda'):
        """
        Initialise.

        :param num_word_types: vocabulary size, to be used as input
        :param shared_layer_size: final output size of the shared layers, to be as inputs to task-specific layers
        :param num_char_types: vocabulary of characters, to be used for CharRNN
        :param char_dim: character dimension
        :param hidden_dim: hidden dimensions of biRNN
        :param dropout: dropout values for nodes in biRNN
        :param num_layers_shared: number of shared biRNN layers
        :param num_layers_ner: number of NER biRNN layers
        :param num_layers_re: number of RE biRNN layers
        :param num_tag_types: unique tags of the model, will be used by NER specific layers
        :param num_rel_types: unique relations of the model, will be used by RE specific layers
        :param init: uniform initialization used in NER CRF
        :param label_embeddings_size: label embedding size to be used in NER and RE
        :param activation_type: the type of activation function to use
        :param recurrent_unit: the type of recurrent unit to use for biRNN - GRU or LSTM
        """

        super(MTLArchitecture, self).__init__()

        self.RELossLambda = re_lambda
        self.shared_layers = SharedRNN(num_word_types, shared_layer_size, num_char_types,
                                       char_dim, hidden_dim, dropout, num_layers_shared,
                                       recurrent_unit, device)

        self.ner_layers = NERSpecificRNN(shared_layer_size, num_tag_types, hidden_dim, dropout,
                                         num_layers_ner, init, label_embeddings_size,
                                         e1_activation_type, recurrent_unit)

        self.re_layers = RESpecificRNN(shared_layer_size, num_rel_types, hidden_dim, dropout, re_dropout,
                                       num_layers_re, label_embeddings_size, re_ff1_size,
                                       r1_activation_type, recurrent_unit, device)

        self.loss = nn.CrossEntropyLoss()

    def score(self, X, Y, C, C_lengths, rstartseqs, rendseqs, rseqs, sents):
        """
        Evaluation through all the shared, NER and RE RNNs.

        :param X: encoded sentences
        :param Y: encoded tags
        :param C: encoded characters
        :param C_lengths: lengths of characters in the words
        :param rstartseqs: the start indices of the relations for RE
        :param rendseqs: the end indices of relations for RE
        :param rseqs:
        :param sents: raw non-encoded sentences
        :return:
        """

        shared_representations = self.shared_layers(C, C_lengths, sents)
        ner_preds, ner_tag_embeddings = self.ner_layers.scorer(shared_representations, Y)
        re_scores = self.re_layers.scorer(shared_representations, ner_tag_embeddings, rstartseqs, rendseqs, rseqs)
        return ner_preds, re_scores

    def forward(self, X, Y, C, C_lengths, rstartseqs, rendseqs, rseqs, sents):
        """
        Do a single forwawrd pass on the entire architecture - through all the shared, NER and RE RNNs.

        :param X: encoded sentences
        :param Y: encoded tags
        :param C: encoded characters
        :param C_lengths: lengths of characters in the words
        :param rstartseqs: the start indices of the relations for RE
        :param rendseqs: the end indices of relations for RE
        :param rseqs:
        :param sents: raw non-encoded sentences
        :return:
        """

        shared_representations = self.shared_layers(C, C_lengths, sents)
        ner_score, ner_tag_embeddings = self.ner_layers(shared_representations, Y)
        re_score = self.re_layers(shared_representations, ner_tag_embeddings, rstartseqs, rendseqs, rseqs)
        return ner_score, re_score

    def do_epoch(self, epoch_num, train_batches, clip, optim, logger=None, check_interval=200):
        """
        Run the forward pass in multiple epochs across training batches.

        :param epoch_num: number of epochs
        :param train_batches: the training data batches
        :param optim: the optimiser used for minimising the loss
        :param check_interval: save the results once after this many intervals
        :return:
        """

        self.train()
        print("\nTraining...")

        output = {}
        for batch_num, (X, Y, C, C_lengths, rstartseqs, rendseqs, rseqs, sents) in enumerate(train_batches):
            # print(batch_num)
            optim.zero_grad()
            NER_forward_result, RE_forward_result = self.forward(X, Y, C, C_lengths, rstartseqs, rendseqs, rseqs, sents)
            loss_NER, loss_RE = NER_forward_result["loss"], RE_forward_result["loss"]
            final_loss = loss_NER + self.RELossLambda * loss_RE
            final_loss.backward()

            nn.utils.clip_grad_norm_(self.parameters(), clip)
            optim.step()

            output["loss"] = NER_forward_result['loss'] + RE_forward_result['loss'] if not 'loss' in output else \
                                output["loss"] + (NER_forward_result['loss'] + RE_forward_result['loss'])
            output["ner_loss"] = NER_forward_result['loss'] if not 'ner_loss' in output else \
                                output["ner_loss"] + NER_forward_result['loss']
            output["re_loss"] = RE_forward_result['loss'] if not 're_loss' in output else \
                                output["re_loss"] + RE_forward_result['loss']


            if logger and (batch_num + 1) % check_interval == 0:
                logger.log('Epoch {:3d} | Batch {:5d}/{:5d} | '
                           'Average Loss {:8.4f} | '
                           'Average NER Loss {:8.4f} | '
                           'Average RE Loss {:8.4f} \n'.format(epoch_num, batch_num + 1, len(train_batches), 
                            output['loss'] / (batch_num + 1), output['ner_loss'] / (batch_num + 1), 
                            output['re_loss'] / (batch_num + 1)))

            if math.isnan(output['loss']):
                print('Stopping training since objective is NaN')
                break
        for key in output:
            output[key] /= (batch_num + 1)

        return output

    def evaluate(self, eval_batches, logger=None, tag2y=None, rel2y=None):
        self.eval()
        print("Evaluating...")
        if 'O' in tag2y:
            y2tag = [None for tag in tag2y]
            for tag in tag2y:
                y2tag[tag2y[tag]] = tag
            tp = Counter()
            fp = Counter()
            fn = Counter()

        num_preds = 0
        num_correct = 0
        num_rel_total = 0
        re_tp = 0
        re_fp = 0
        re_fn = 0
        output = dict()
        gold_entities = {}
        for (X, Y, C, C_lengths, rstartseqs, rendseqs, rseqs, sents) in eval_batches:
            try:
                B, T = Y.size()
                ner_preds, re_scores = self.score(X, Y, C, C_lengths, rstartseqs, rendseqs, rseqs, sents)  # B x T x L

                num_preds += B * T
                num_correct += (ner_preds == Y).sum().item()

                if 'O' in tag2y:
                    for i in range(B):
                        gold_bio_labels = [y2tag[Y[i, j].item()]
                                        for j in range(T)]
                        pred_bio_labels = [y2tag[ner_preds[i, j].item()]
                                        for j in range(T)]
                        gold_boundaries = set(get_boundaries(gold_bio_labels))
                        pred_boundaries = set(get_boundaries(pred_bio_labels))
                        for (s, t, entity) in gold_boundaries:
                            gold_entities[entity] = True
                            if (s, t, entity) in pred_boundaries:
                                tp[entity] += 1
                                tp['<all>'] += 1
                            else:
                                fn[entity] += 1
                                fn['<all>'] += 1
                        for (s, t, entity) in pred_boundaries:
                            if not (s, t, entity) in gold_boundaries:
                                fp[entity] += 1
                                fp['<all>'] += 1

                    for ner_actual, ner_pred, rstartseq, rendseq, rseq, re_score in \
                            zip(Y, ner_preds, rstartseqs, rendseqs, rseqs, re_scores):

                        rstart_list = rstartseq.tolist()
                        rend_list = rendseq.tolist()
                        rseq_list = rseq.tolist()
                        num_rel_total += len(rseq_list)
                        gold_bio_labels = [y2tag[ner_actual[j].item()] for j in range(T)]
                        pred_bio_labels = [y2tag[ner_pred[j].item()] for j in range(T)]
                        gold_boundaries = set(get_boundaries(gold_bio_labels))
                        pred_boundaries = set(get_boundaries(pred_bio_labels))
                        for rel_start, rel_end, rel_ind, re_sc in zip(rstart_list, rend_list, \
                                                                        rseq_list, re_score):

                            score = np.asarray(re_sc.tolist())
                            max_score = np.max(score)
                            arg_max = np.argmax(score)

                            first_entity_success, second_entity_success = False, False
                            for (s, t, entity) in gold_boundaries:
                                if t == rel_start and (s, t, entity) in pred_boundaries:
                                    first_entity_success = True
                                if t == rel_end and (s, t, entity) in pred_boundaries:
                                    second_entity_success = True
                            ner_successful = first_entity_success and second_entity_success
                            print(ner_successful, rel_ind, re_sc.tolist()[rel_ind])

                            if max_score >= 0.9:
                                if arg_max == rel_ind and ner_successful is True:
                                    re_tp += 1
                                else:
                                    re_fp += 1
                            else:
                                re_fn += 1
                
            except Exception as e:
                logger.log('-' * 89)
                logger.log('X {}, Y: {}, C: {}, C_len: {} \n Error: {}'.format(X, Y, C, C_lengths, e))
                continue

        output["ner_acc"] = 100 * num_correct / num_preds
        output["re_precision"] = 100 * re_tp / (re_tp + re_fp + 1e-16)
        output["re_recall"] = 100 * re_tp / (re_tp + re_fn + 1e-16)
        output["re_f1"] = (2*output["re_recall"]*output["re_precision"])/(output["re_recall"] + output["re_precision"] + 1e-16)

        if 'O' in tag2y:
            for e in list(gold_entities) + ['<all>']:
                p_denom = tp[e] + fp[e]
                r_denom = tp[e] + fn[e]
                p_e = 100 * tp[e] / p_denom if p_denom > 0 else 0
                r_e = 100 * tp[e] / r_denom if r_denom > 0 else 0
                f1_denom = p_e + r_e
                f1_e = 2 * p_e * r_e / f1_denom if f1_denom > 0 else 0
                output['ner_p_%s' % e] = p_e
                output['ner_r_%s' % e] = r_e
                output['ner_f1_%s' % e] = f1_e

        logger.log("NER: P {}, R {}, F1 {} | RE: P {}, R {}, F1: {}".format(output['ner_p_<all>'],
                    output['ner_r_<all>'], output['ner_f1_<all>'], output['re_precision'],
                    output['re_recall'], output['re_f1']))
        return output

class SharedRNN(nn.Module):
    """
    A shared Bidirectional GRU layer that takes in the initial words, converts them to respective embeddings
    and passes them to the respective NER and RE specific layers.
    """

    ELMODim = 1024
    GloveDim = 300
    OneHotDim = 7

    def __init__(self, num_word_types, shared_layer_size, num_char_types, \
                 char_dim, hidden_dim, dropout, num_layers, recurrent_unit="gru", \
                 device="cpu"):
        """
        :param num_word_types:
        :param shared_layer_size:
        :param num_char_types:
        :param char_dim:
        :param hidden_dim:
        :param dropout:
        :param num_layers:
        :param recurrent_unit:
        """

        super(SharedRNN, self).__init__()
        self.CharDim = char_dim
        self.Pad_ind = 0
        self.device = device
        word_dim = self.ELMODim + self.GloveDim + self.CharDim + self.OneHotDim

        # Initialise char-embedding BiRNN
        self.cemb = nn.Embedding(num_char_types, self.CharDim, padding_idx=self.Pad_ind)
        self.charRNN = CharRNN(self.cemb, 1, recurrent_unit)
        self.dropout = nn.Dropout(p=dropout)

        if recurrent_unit == "gru":
            self.wordRNN = nn.GRU(word_dim, shared_layer_size, num_layers, bidirectional=True)
        else:
            self.wordRNN = nn.LSTM(word_dim, shared_layer_size, num_layers, bidirectional=True)

    def forward(self, char_encoded, C_lengths, raw_sentences):
        """
        Pass the input sentences through the GRU layers.
        :param X: batch of sentences
        :return:
        """

        batch_size = len(raw_sentences)
        elmo_embeddings = load_elmo_embeddings(raw_sentences).to(self.device)
        glove_embeddings = load_glove_embeddings(raw_sentences).to(self.device)
        char_embeddings = self.charRNN(char_encoded, C_lengths).to(self.device)
        one_hot_embeddings = load_onehot_embeddings(raw_sentences).to(self.device)
        num_words, char_dim = char_embeddings.size()
        char_embeddings = char_embeddings.view(batch_size, num_words // batch_size, char_dim)
        final_embeddings = torch.cat([elmo_embeddings, glove_embeddings, char_embeddings, one_hot_embeddings], dim=2)

        # Dropout pre BiRNN
        final_embeddings = self.dropout(final_embeddings)

        # Get the shared layer representations.
        shared_output, _ = self.wordRNN(final_embeddings)
        return shared_output


class NERSpecificRNN(nn.Module):
    """
    NER specific bidirectional GRU layers that take in the shared representations from the shared layers and calculates
    the respective NER scores.
    """

    def __init__(self, shared_layer_size, num_tag_types, hidden_dim, dropout, num_layers, \
                 init, label_embeddings_size, activation_type="relu", recurrent_unit="gru"):
        """        print(batched[0])
        s
        Initialise.

        :param shared_layer_size: final output size of the shared layers, to be as inputs to task-specific layers
        :param num_tag_types: unique tags of the model, will be used by NER specific layers
        :param hidden_dim: the NER biRNN hidden layer dimension
        :param dropout: dropout values for nodes in biRNN
        :param num_layers: number of layers in this biRNN
        :label_embeddings_size: label embedding size
        :param activation_type: the type of activation function to use
        :param recurrent_unit: the type of recurrent unit to use for biRNN - GRU or LSTM
        """

        super(NERSpecificRNN, self).__init__()

        self.Pad_ind = 0
        self.tag_embeddings = nn.Embedding(num_tag_types, label_embeddings_size, padding_idx=self.Pad_ind)
        nn.init.xavier_uniform_(self.tag_embeddings.weight.data)

        if recurrent_unit == "gru":
            self.birnn = nn.GRU(2 * shared_layer_size, shared_layer_size, num_layers, bidirectional=True)
        else:
            self.birnn = nn.LSTM(2 * shared_layer_size, shared_layer_size, num_layers, bidirectional=True)

        self.dropout = nn.Dropout(p=dropout)
        self.FFNNe1 = nn.Linear(2 * shared_layer_size, hidden_dim)
        if activation_type == "relu":
            self.activation = nn.ReLU()
        elif activation_type == "tanh":
            self.activation = nn.Tanh()
        elif activation_type == "gelu":
            self.activation = nn.GELU()

        self.FFNNe2 = nn.Linear(hidden_dim, num_tag_types)
        self.loss = CRFLoss(num_tag_types, init)

    def forward(self, shared_representations, Y):
        """
        Do a forward pass by taking input from the shared layers and generating the NER scores for the input
        sentences.

        :param shared_representations: tensor of shared representations from the shared layer.
        :param Y: the label NER tags for the input sentences
        :return: NER scores
        """
        # Dropout before biRNN
        shared_representations = self.dropout(shared_representations)
        ner_representation, _ = self.birnn(shared_representations)
        scores = self.FFNNe2(self.activation(self.FFNNe1(ner_representation)))
        loss = self.loss(scores, Y)
        tag_embeddings = self.tag_embeddings(Y)
        return {'loss': loss}, tag_embeddings

    def scorer(self, shared_representations, Y):
        """
        Score the representation at evaluation time

        :param shared_representations: tensor of shared representations from the shared layer.
        :param Y: the label NER tags for the input sentences
        :return: NER scores
        """
        ner_representation, _ = self.birnn(shared_representations)
        scores = self.FFNNe2(self.activation(self.FFNNe1(ner_representation)))
        _, preds = self.loss.decode(scores)  # B x T
        tag_embeddings = self.tag_embeddings(preds)
        # print("Actual Predictions: ", Y)
        # print("NER Predictions: ", preds)
        return preds, tag_embeddings


class RESpecificRNN(nn.Module):
    """
    RE specific bidirectional GRU layers that take in the shared representations from the shared layers and calculates
    the respective RE scores.
    """

    def __init__(self, shared_layer_size, num_rel_types, hidden_dim, dropout, re_dropout, num_layers, \
                    label_embeddings_size, re_ff1_size, activation_type="relu", \
                    recurrent_unit="gru", device="cpu"):
        """
        Initialise.

        :param shared_layer_size:
        :param num_rel_types:
        :param hidden_dim:
        :param dropout:
        :param num_layers:
        :param label_embeddings_size:
        :param activation_type:
        :param recurrent_unit:
        """

        super(RESpecificRNN, self).__init__()

        self.device = device

        # Add check for 0 task-specific layers, it'll be used while hyperparameter tuning
        if recurrent_unit == "gru":
            self.birnn = nn.GRU(2 * shared_layer_size, shared_layer_size, num_layers, bidirectional=True)
        else:
            self.birnn = nn.LSTM(2 * shared_layer_size, shared_layer_size, num_layers, bidirectional=True)

        final_re_entity_embedding_size = 2 * shared_layer_size + label_embeddings_size

        self.dropout = nn.Dropout(p=dropout)
        self.re_dropout = nn.Dropout(p=re_dropout)

        self.FFNNr1 = nn.Linear(final_re_entity_embedding_size, re_ff1_size)

        if activation_type == "relu":
            self.activation = nn.ReLU()
        elif activation_type == "tanh":
            self.activation = nn.Tanh()
        elif activation_type == "gelu":
            self.activation = nn.GELU()

        self.FFNNr2 = nn.Linear((2 * re_ff1_size) + 1 + num_rel_types, num_rel_types)

        # Initialise matrix for DistMult score calculation
        self.M = []
        for _ in range(num_rel_types):
            self.M.append(torch.diag(torch.rand(size=(final_re_entity_embedding_size,))))
        self.M = torch.stack(self.M).to(self.device)
        self.M.requires_grad = True

        self.loss = nn.BCELoss()

    def _trim_embeddings(self, embeddings, rstartseqs, rendseqs, rseqs):
        """
        :param embeddings:
        :param rstartseqs:
        :param rendseqs:
        :param rseqs:
        :return:
        """

        B, T, E = embeddings.shape
        batches = []
        for i in range(B):  # Each sentence
            rstart_list = rstartseqs[i].tolist()
            rend_list = rendseqs[i].tolist()
            rseq_list = rseqs[i].tolist()

            end_tokens_of_first_entities = embeddings[i][rstartseqs[i]]
            end_tokens_of_second_entities = embeddings[i][rendseqs[i]]
            all_end_tokens_embeddings = torch.cat([end_tokens_of_first_entities, end_tokens_of_second_entities], dim=0)
            all_end_tokens_indices = rstart_list + rend_list

            # All possible entity pairs
            all_possible_entity_pairs_embeddings = list(itertools.combinations(all_end_tokens_embeddings, 2))
            all_possible_entity_pairs_indices = list(itertools.combinations(all_end_tokens_indices, 2))

            # True RE labels
            true_RE_labels = defaultdict(list)
            for (start, end, relation) in zip(rstart_list, rend_list, rseq_list):
                true_RE_labels[relation].append((start, end))

            batches.append((all_possible_entity_pairs_embeddings, all_possible_entity_pairs_indices, true_RE_labels))

        return batches

    def _RE_scoring_layers(self, first_entity_embedding, second_entity_embedding):
        """

        :param first_entity_embedding:
        :param second_entity_embedding:
        :return:
        """

        # Calculate DistMult score
        first_entity_embedding = first_entity_embedding.unsqueeze(0).to(self.device)  # (1 x p)
        second_entity_embedding = second_entity_embedding.unsqueeze(0).to(self.device)  # (1 x p)
        distmult_scores = torch.matmul(first_entity_embedding, torch.matmul(self.M, second_entity_embedding.T))
        distmult_scores = distmult_scores.squeeze(2)
        distmult_scores = distmult_scores.T  # 1 x num_rel_types

        # Hidden representations of entities
        first_entity_hidden_repr = self.activation(self.FFNNr1(first_entity_embedding))
        second_entity_hidden_repr = self.activation(self.FFNNr1(second_entity_embedding))

        # Cosine distance
        cosine_distance = torch.cosine_similarity(first_entity_hidden_repr, second_entity_hidden_repr)
        cosine_distance = cosine_distance.unsqueeze(1)  # 1 x 1

        # Concatenate everything
        final_embedding = torch.cat([first_entity_hidden_repr, second_entity_hidden_repr,
                                     cosine_distance, distmult_scores], dim=1).to(self.device)

        # RE Scores (Sij)
        RE_scores_for_entity_pair = self.FFNNr2(final_embedding)
        sigmoid_RE_scores = torch.sigmoid(RE_scores_for_entity_pair)
        return sigmoid_RE_scores

    def _calculate_RE_scores(self, batches):
        """
        :param batches:
        :return:
        """

        batch_loss = 0
        for (entity_pairs_embeddings, entity_pairs_indices, true_RE_labels) in batches:
            for i, (first_entity_embedding, second_entity_embedding) in enumerate(entity_pairs_embeddings):
                predicted_RE_scores_for_entity_pair = self._RE_scoring_layers(first_entity_embedding,
                                                                              second_entity_embedding).to(self.device)

                # Create ground truth RE labels for current entity pair
                first_entity_end_index = entity_pairs_indices[i][0]
                second_entity_end_index = entity_pairs_indices[i][1]
                target_RE_Labels_for_entity_pair = torch.zeros(predicted_RE_scores_for_entity_pair.shape).to(self.device)
                for i in range(predicted_RE_scores_for_entity_pair.shape[1]):
                    # RE_score_for_current_relation = torch.stack([1 - predicted_RE_scores_for_entity_pair[:, i],
                    #                                              predicted_RE_scores_for_entity_pair[:, i]])
                    # RE_score_for_current_relation = RE_score_for_current_relation.T

                    # If the particular relation exists between the current entity pair then y = 1 else 0
                    if (first_entity_end_index, second_entity_end_index) in true_RE_labels[i]:
                        target_RE_Labels_for_entity_pair[:, i] = 1

                # print(predicted_RE_scores_for_entity_pair, target_RE_Labels_for_entity_pair)
                batch_loss += self.loss(predicted_RE_scores_for_entity_pair, target_RE_Labels_for_entity_pair)
        return batch_loss

    def forward(self, shared_representations, ner_tag_embeddings, rstartseqs, rendseqs, rseqs):
        """
        :param shared_representations:
        :param Y:
        :param tag_embeddings:
        :param rstartseqs:
        :param rendseqs
        :param rseqs
        :return:
        """

        # Pre biRNN dropout
        shared_representations = self.dropout(shared_representations)
        re_representation, _ = self.birnn(shared_representations)
        re_representation = torch.cat([re_representation, ner_tag_embeddings], dim=2)

        # Pre RE Scoring Dropout
        re_representation = self.re_dropout(re_representation)

        batches = self._trim_embeddings(re_representation, rstartseqs, rendseqs, rseqs)
        loss = self._calculate_RE_scores(batches)
        return {'loss': loss}

    def scorer(self, shared_representations, ner_tag_embeddings, rstartseqs, rendseqs, rseqs):
        """
        :param shared_representations:
        :param Y:
        :param tag_embeddings:
        :param rstartseqs:
        :param rendseqs
        :param rseqs
        :return:
        """

        re_representation, _ = self.birnn(shared_representations)
        re_representation = torch.cat([re_representation, ner_tag_embeddings], dim=2)

        # Pre RE Scoring Dropout. In model.eval dropout won't work
        re_representation = self.dropout(re_representation)

        batched = []
        batches = self._trim_embeddings(re_representation, rstartseqs, rendseqs, rseqs)
        for i, (entity_pairs_embeddings, entity_pairs_indices, true_RE_labels) in enumerate(batches):
            rstart_list = rstartseqs[i].tolist()
            rend_list = rendseqs[i].tolist()
            filter_r = set(zip(rstart_list, rend_list))
            # print(filter_r)
            pairs = []
            visited = set()
            for j, (first_entity_embedding, second_entity_embedding) in enumerate(entity_pairs_embeddings):
                predicted_RE_scores_for_entity_pair = self._RE_scoring_layers(first_entity_embedding,
                                                                              second_entity_embedding)

                first_entity_end_index = entity_pairs_indices[j][0]
                second_entity_end_index = entity_pairs_indices[j][1]
                if (first_entity_end_index, second_entity_end_index) in filter_r and \
                    (first_entity_end_index, second_entity_end_index) not in visited:
                    # print(first_entity_end_index, second_entity_end_index)
                    pairs.append(predicted_RE_scores_for_entity_pair.squeeze(0))
                visited.add((first_entity_end_index, second_entity_end_index))
            batched.append(pairs)
        return batched


class CharRNN(nn.Module):
    """
    Trains character level embeddings via Bidirectional LSTM.
    """

    def __init__(self, cemb, num_layers=1, recurrent_unit="gru"):
        """
        Initialise.

        :param cemb: nn.Embedding for the characters
        :param num_layers: number of layers in this biRNN
        :param recurrent_unit: the type of recurrent unit to use for biRNN - GRU or LSTM
        """

        super(CharRNN, self).__init__()
        self.cemb = cemb
        self.num_layers = num_layers
        if recurrent_unit == "gru":
            self.birnn = nn.GRU(cemb.embedding_dim, cemb.embedding_dim, num_layers, bidirectional=True)
        else:
            self.birnn = nn.LSTM(cemb.embedding_dim, cemb.embedding_dim, num_layers, bidirectional=True)

    def forward(self, padded_chars, char_lengths):
        """
        Do a forward pass to learn the character embeddings.

        :param padded_chars: the padded character encodings
        :param char_lengths: lengths of the words
        :return: learned character embeddings in the form of biRNN hidden vector
        """
        B = len(char_lengths)

        packed = pack_padded_sequence(self.cemb(padded_chars), char_lengths,
                                      batch_first=True, enforce_sorted=False)
        _, (final_h, _) = self.birnn(packed)
        return final_h

I would generally recommend to take a look at the performance guide and check if you could avoid common performance degrading issues.
To get a better picture what might be the bottleneck of your code you might want to profile the code with the PyTorch profiler or e.g. Nsight Systems.
As a quick smoke test, check the .device attribute of any parameter of the model as well as the input and make sure if’s on the GPU. If so, then the GPU will be used.

Thank you so much for your advice, it is extremely helpful. I suspect that the original author has created their own custom dataset class without the use of torch.utils.data class, which is likely the bottleneck issue. Cheers.

The data loading is indeed often the bottleneck. As a quick test you could remove the entire data loading pipeline, create random tensors on the GPU once, and train the model with it to check the speed and GPU utilization.