How to slice output from bilstm and pass only selected vectors

I am doing citation function analysis using bilstm. and one of my inputs looks like
[“I”, “compared”, “agglomeration”, “to”, “a”, “top-down”, “method”, “that”, “CITSEG”, “call”, “partitioning”, “around”, “medoids”, “.”].
I need to pass each word to embedding layer followed by bilstm and from the output of bilstm I need to pass only “CITSEG” vector output to linear layer, basically I need to mask all other characters except CITSEG after performing BILSTM and pass only CITSEG to next layers.
could any body share idea how to proceed
Thank you

import torch

import torch.nn as nn

from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

class CitationClassifier(nn.Module):

def __init__(self, embedding_size, num_embeddings, num_layers, 

             hidden_dim, num_classes, dropout_p, 

             pretrained_embeddings=None, padding_idx=0):

  

    """

    Args:

        embedding_size (int): size of the embedding vectors

        num_embeddings (int): number of embedding vectors

        hidden_dim (int): the size of the hidden dimension

        num_classes (int): the number of classes in classification

        dropout_p (float): a dropout parameter 

        pretrained_embeddings (numpy.array): previously trained word embeddings

            default is None. If provided, 

        padding_idx (int): an index representing a null position

    """

    super(CitationClassifier, self).__init__()

    self.hidden_dim = hidden_dim

    if pretrained_embeddings is None:

        self.emb = nn.Embedding(embedding_dim=embedding_size,

                                num_embeddings=num_embeddings,

                                padding_idx=padding_idx)        

    else:

        pretrained_embeddings = torch.from_numpy(pretrained_embeddings).float()

        self.emb = nn.Embedding(embedding_dim=embedding_size,

                                num_embeddings=num_embeddings,

                                padding_idx=padding_idx,

                                _weight=pretrained_embeddings)

    

    self.lstm = nn.LSTM(embedding_size, hidden_dim,num_layers, bidirectional=True, batch_first=True)

    self.fc1 = nn.Linear(193*100, 13)

    self.relu = nn.ReLU()

    self.dropout = nn.Dropout(dropout_p)

    #self.fc2 = nn.Linear(64, num_classes)

    self.act=nn.Softmax()

    

def forward(self, x_in, apply_softmax=False):

    """The forward pass of the classifier

    

    Args:

        x_in (torch.Tensor): an input data tensor. 

            x_in.shape should be (batch, dataset._max_seq_length)

        apply_softmax (bool): a flag for the softmax activation

            should be false if used with the Cross Entropy losses

    Returns:

        the resulting tensor. tensor.shape should be (batch, num_classes)

    """

    batch_size = x_in.size(0)

    h_embedding = self.emb(x_in)

    print(h_embedding.shape)

    # text_lengths=len(x_in)

    # print(text_lengths)

    # packed_embedded = pack_padded_sequence(h_embedding, text_lengths, batch_first=True) 

    packed_output, (hidden, cell) = self.lstm(h_embedding)

    print("packed_output.shape")

    print(packed_output.shape)

    print("hidden.shape")

    print(hidden.shape)

    print("cell.shape")

    print(cell.shape)

    # avg_pool = torch.mean(packed_output, 1)

    # max_pool, _ = torch.max(packed_output, 1)

    # conc = torch.cat(( avg_pool, max_pool), 1)

    lstm_out = packed_output.contiguous().view(30, -1)

    # cat = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)

    print("lstm output.shope")

    print(lstm_out.shape)

    # rel = self.relu(lstm_out)

    # print("rel.shape")

    # print(rel.shape)

    out = self.dropout(lstm_out)

    print(out.shape)

    out = self.fc1(out)

    print(out.shape)

    pred1= self.act(out)

    print(pred1.shape)

    # dense1 = self.fc1(rel)

    # print("dense1.shape")

    # print(dense1.shape)

    # drop = self.dropout(dense1)

    # print("drop.shape")

    # print(drop.shape)

    # preds = self.fc2(drop)

    # print("preds.shapefc2")

    # print(preds.shape)

    # pred1= self.act(preds)

    # print("pred1 after last activation")

    # print(pred1.shape)

    return pred1