[Help!] Simple Generate / Predict / Inference Function for Attention Transformer

Following the code completely here, I am attempting to write a function that will generate tokens from a custom dataset. So far my training and validation losses have been abysmal and I’ve been having no luck with my approach of greedy decoding the model’s forward output which I believe is because my src_mask is not functioning properly. I don’t believe it’s growing in size because the arg shapes for the model forward declaration are

    Args: src: Tensor, shape [seq_len, batch_size], src_mask: Tensor, shape [seq_len, seq_len]

which

    Returns: output Tensor of shape [seq_len, batch_size, ntoken]

and if I’m recurrently making a src_mask of y_input.size(0), the mask will always be size 1. If I have the mask grow with y_input.size(1), I get an error because the shape OUGHT to be [1,1]. Transposing the y_input at the pred declaration allows for the mask to grow, but I still get the same output as what the below code gets me, a single repeating token. Can someone please redirect me and get me on track for what I need to get this basic transformer functioning? I’ve been scouring the web for days, but maybe I’m using the wrong keywords?

Here is my code below:

def generate(model: nn.Module, src_text:str):
    src=BeatleSet.encode(src_text.lower())
    SOS=BeatleSet.textTokDict['<sos>'] ; EOS=BeatleSet.textTokDict['<eos>']
    print(src)
    model.eval(); entry=[SOS]+src
    y_input=torch.tensor([entry], dtype=torch.long, device=device)
    num_tokens=len(BeatleSet)
    for i in range(50):
        src_mask=generate_square_subsequent_mask(y_input.size(0)).to(device)
        if i%49==0: print(y_input.shape,y_input,'\n',src_mask.shape,src_mask)
        pred=model(y_input, src_mask)
        if i%49==0: print(pred)
        next_item = pred.topk(1)[1].view(-1)[-1].item()
        next_item = torch.tensor([[next_item]], device=device)
        if i%49==0: print(next_item,next_item.shape)
        y_input=torch.cat((y_input, next_item), dim=1)   
        if next_item.view(-1).item() == EOS:
            break
    return " ".join(BeatleSet.decode(y_input.view(-1).tolist()))
    
print(generate(model, "Oh yeah I"))

Thank you.

Okay, gang, I’m getting real impatient at this point. In the meantime I’ll show you how to use a custom dataset because the tutorial works with iters and something about that don’t jive with me. Maybe you can point out the flaws in my program, or convince me to use iters

So what I’ve gone and done is write a class that scrapes the lyrics from a list of songs, and splits the text according to whether you want char, word, or line tokens. It returns a tuple of train/test/val data and it’s key components look like this:

import spotipy
import random
import regex as re
import operator
import requests
from bs4 import BeautifulSoup as bs
from functools import reduce
import unicodedata as ucd
import unidecode as udc
from torch.utils.data import Dataset, random_split
from spotipy.oauth2 import SpotifyClientCredentials
from concurrent.futures import ThreadPoolExecutor

class TheBeast(tuple): # produces a tuple consisting of train, test, and val data
    def __init__(self, artistURI:str, tokType:str='word'):
        self.tokType=tokType.lower(); self.textTokDict:dict; self.tokTextDict:dict
        self.testToks:list; self.trainToks:list; self.valToks:list
        self.artistURI=artistURI;
        __cid = 'bac7c5b352224d3ead9934fe4554ac1c'
        __secret = '7b226000998a410789eb7be62fda9a02'
        __ccm = SpotifyClientCredentials(client_id=__cid, client_secret=__secret)
        self.sp = spotipy.Spotify(client_credentials_manager = __ccm, requests_timeout=60)
        self.artistName=self.sp.artist(artistURI)['name']
        self.songList=self.__catalogueCollector() # IMPORTANT METHOD 1
        self.sesh=requests.Session(); executor=ThreadPoolExecutor(max_workers=100)
        # this stuff scrapes FAST thanks to threadpooling
        self.lyricList=list(filter(lambda s:s!="<!--INVALID URL-->", executor.map(self.__geniusScraper,self.songList)))  # IMPORTANT METHOD 2
        self.nSongs=len(self.lyricList) # 0 a value because you may get some invalid URLs
        # define some ratios to split data into songs, I bet you could think of a more elegant way... MAKE SURE THEY ADD UP
        tlen= 3*self.nSongs//5; vlen=(self.nSongs-tlen)//2; tstlen=self.nSongs-tlen-vlen    
        print(f"{self.nSongs} valid song lyrics scraped.")
        self.trainLyrics, self.testLyrics, self.valLyrics = random_split(self.lyricList,[tlen,tstlen,vlen])
        print(f"Split to {tlen} train songs, {tstlen} test songs, {vlen} val songs... ",end='')
        self.tokenTuple=self.__tokenify(); print("Tokenized.") # IMPORTANT METHOD 3
        return

    def __len__(self): return (len(self.tokTextDict)) # gives nTokens, you'll see SOS, EOS, and UNK tokens are included

    def __str__(self): # just to have something to print
        return f"Train text tokens: {len(self.trainToks)}, Test text tokens: {len(self.testToks)}, Validation text tokens: {len(self.valToks)}... Dictionary has {len(self)} tokens."

    def __getitem__(self, idx): # not sure if this is the best get implementation, but idx gets you the training, test or val data set
        return self.tokenTuple[idx] 

    #Moving on to the important methods!

    def __catalogueCollector(self) -> dict:
        #just collect the top ten songs to keep it easy
        songList=[track['name'] for track in self.sp.artist_top_tracks(self.artistURI, country='US')['tracks']]
        print(songList)
        # we don't want redundancies, so turn our list into a set and cut all that extra crap out, like " - REMASTERED" etc.
        songList = list(set([re.sub(r'(.+)\s\-\s.+',r'\1',song) for song in songList]))
        print(f'{len(songList)} UNIQUE songs acquired.')
        return songList

    def __normalfy(self, string: str)->str: # gotta keep it clean
        return udc.unidecode(ucd.normalize('NFKD', string))

    def __geniusify(self, string: str) -> str: # gotta put dashes between words... check that dolla sign
        return re.sub(r'(-)+',r'\1',re.sub(r'[^\w-]','',re.sub(r'[\s\/\\$\_]+','-',self.__normalfy(string))))
        
    def __geniusScraper(self, trackName: str) -> str:
        geniusName=f"{self.__geniusify(self.artistName)}-{self.__geniusify(trackName)}"
        url=f"https://genius.com/{geniusName}-lyrics"
        buffer=""
        page=self.sesh.get(url);
        if not page:
          buffer+=f"INVALID URL -- {geniusName}"
          return "<!--INVALID URL-->"
        else: buffer+=f"{url}"
        soup=bs(page.text, 'html.parser')
        for tag in soup.find_all('br'): tag.replaceWith('\n'+tag.text)
        divlist=soup.find_all("div", class_="Lyrics__Container-sc-1ynbvzw-6 jYfhrf") #The div class genius uses to hold their lyrics... May change periodically
        lyrics='<sos>'+"".join([p.get_text() for p in divlist])+'<eos>'
        print(buffer)
        return self.__normalfy(lyrics)

    def __splitUp(self, text:str)->list:
        Brackets = r"[\[\<].*?[\]\>]"; Words = r"\b[^\s\<\[]+\b"; Acronyms=r"(?:[A-Za-z]\.){2,}[a-zA-Z]?"; Lines = r"\S.*\n{1,}"
        if self.tokType=='word':    return re.findall(f"{Brackets}|{Acronyms}|{Words}|[\)\(,.!?\n]",text.lower())
        elif self.tokType=='line':    return re.findall(f"{Brackets}|{Lines}",text)
        elif self.tokType=='char':   return re.findall(f"{Brackets}|[\s\S]",text)
        else:
            print('INVALID TOKTYPE... SELECT "word" "char" OR "line"')
            return []
    
    def __tokenify(self)->tuple:
        trainText="".join([i for i in self.trainLyrics]); testText="".join([i for i in self.testLyrics]); valText="".join([i for i in self.valLyrics])
        self.trainToks=self.__splitUp(trainText); self.testToks=self.__splitUp(testText); self.valToks=self.__splitUp(valText)
        ###QUESTION TIME: Will a set be in the same order every time enumerate is performed upon it?
        ###IF NOT, why does this work?
        fullTok=self.trainToks+self.testToks+self.valToks; fullSet=(set(fullTok))
        self.textTokDict={**{'<unk>':0},**{tex:tok for tok, tex in enumerate(fullSet,1)}}
        self.tokTextDict={**{0:'<unk>'},**{tok:tex for tok, tex in enumerate(fullSet,1)}}
        trnTokens=[*map(self.textTokDict.get, self.trainToks)]; tstTokens=[*map(self.textTokDict.get, self.testToks)]; valTokens=[*map(self.textTokDict.get, self.valToks)]
        return trnTokens, tstTokens, valTokens
        
    def __call__(self):
        return self.tokenTuple

ChatSet=TheBeast("spotify:artist:1aQ7P3HtKOQFW16ebjiks1")
print(ChatSet,"\nThe first 15 tokens of the train set:",ChatSet[0][:15])
print("Decoded:",[*map(ChatSet.tokTextDict.get,ChatSet[0][:15])])

I’m using The Chats for an example

>>>python proof.py
['Smoko', 'Bus Money', 'Pub Feed', 'Drunk n Disorderly', 'Struck By Lightning', 'Identity Theft', 'The Clap', 'Do What I Want', 'Mum Stole My Darts', 'AC/DC CD']
10 UNIQUE songs acquired.
https://genius.com/The-Chats-The-Clap-lyrics
https://genius.com/The-Chats-AC-DC-CD-lyrics
https://genius.com/The-Chats-Struck-By-Lightning-lyrics
https://genius.com/The-Chats-Bus-Money-lyrics
https://genius.com/The-Chats-Mum-Stole-My-Darts-lyrics
https://genius.com/The-Chats-Identity-Theft-lyrics
https://genius.com/The-Chats-Do-What-I-Want-lyrics
https://genius.com/The-Chats-Drunk-n-Disorderly-lyrics
https://genius.com/The-Chats-Pub-Feed-lyrics
https://genius.com/The-Chats-Smoko-lyrics
10 valid song lyrics scraped.
Split to 6 train songs, 2 test songs, 2 val songs... Tokenized.
Train text tokens: 1485, Test text tokens: 496, Validation text tokens: 394... Dictionary has 449 tokens.
The first 15 tokens of the train set: [294, 414, 394, 186, 219, 404, 327, 28, 29, 394, 357, 313, 419, 197, 286]
Decoded: ['<sos>', '[verse 1]', '\n', 'last', 'week', ',', 'pulled', 'a', 'root', '\n', 'in', 'the', 'back', 'of', 'my']

Boom, dataset. You want it batched so it can be transformed? 'Ere.

from torch import nn, Tensor, LongTensor
from typing import Tuple
import torch

class BatchMeData(Tuple[Tensor,Tensor,Tensor]):
    def __init__(self, tokTuple:tuple, batchSize:int=8):
        self.tokTuple=tokTuple; self.batchSize=batchSize
        self.dataTuple = self.__tensorfy()
        self.batchedTuple=self.__batchify()
        return
    
    def __str__(self):
        return f"Okay, here's the first 10 items of the first batch in each set:\nTrain: {self.batchedTuple[0][:10,0]}\nTest: {self.batchedTuple[1][:10,0]}\nValidation: {self.batchedTuple[2][:10,0]}"
        
    def __getitem__(self,idx):
        return self.batchedTuple[idx]
    
    def __tensorfy(self)->Tuple[Tensor,Tensor,Tensor]:
        dt=tuple([])
        for toks in self.tokTuple:
            while len(toks)%self.batchSize!=0: toks.append(0)
            data=[torch.tensor(item) for item in toks]
            dt=(*dt,LongTensor(data))
        return dt
        
    def __batchify(self)->Tuple[Tensor,Tensor,Tensor]:
        bt=tuple([])
        for data in self.dataTuple:
            seqLen = data.size(0)//self.batchSize
            assert len(data)==len(data[:seqLen*self.batchSize])
            bdata=data.view(self.batchSize,seqLen).t().contiguous()
            bt=(*bt,bdata)
        return bt
        
    def __call__(self)->Tuple[Tensor,Tensor,Tensor]:
        return self.batchedTuple

ChatBatch=BatchMeData(ChatSet())
print(ChatBatch)
train_data, test_data, val_data = ChatBatch()

It batches your split data! You can plug this directly into the transformer model with the get_batch function they have. When printed we get the below:

>>>
Okay, here's the first 10 items of the first batch in each set:
Train: tensor([294, 414, 394, 186, 219, 404, 327,  28,  29, 394])
Test: tensor([294, 414, 394,  45,  85,  54, 224, 257, 321, 313])
Validation: tensor([294, 414, 394, 164, 411, 407, 440, 313, 128, 394])

Hope you found this informative cool, or you felt a need to reach out and critique my code. Either way, any help with the original post would be very appreciative. I’ve had to take a break from this problem for the time being. Thank you.

Okay gang, lemme try and break this down the best I can. I am struggling and it is not for a lack of trying.

I have a pickled model and dataset you can download here to test this code out on. It trained on all Kendrick Lamar songs for 5 epochs, bptt=35, batchsize=32… You can see the model summary if you print, or play around with the architecture and make your model. Below is my generation attempt.

pickledModel = pickle.load(open(f'{aToken}-model.pkl', 'rb'))
print(pickledModel) # model summary... Is this pickled right?
growingSet=pickle.load(open(f'{aToken}-set.pkl','rb'))

def generate(input_sentence:str)->str:
    pickledModel.eval() # Go into evaluation mode
    print(len(growingSet)) # Vocab size (nTokens)
    src=growingSet.encode(input_sentence.lower()) # converts your input to numerical tokens
    SOS=growingSet.textTokDict['<sos>']; EOS=growingSet.textTokDict['<eos>'] # grab the sos and eos tokens
    entry=[SOS]+src; seqLen=len(entry) # add sos token to front of entry, assigns sequence length
    y_input=torch.tensor([entry], device=n_net.device).t() # tensorfies the input, seqlen is along the 0 dim
    with torch.no_grad(): # I feel this is an appropriate time to use no grad but not sure why?
        for i in range(0, 100): # Enough range to hopefully get an eos token
            # mask is growing with the sequence length... Is this my misunderstanding?
            src_mask = n_net.generate_square_subsequent_mask(seqLen+i).to(n_net.device)
            output = pickledModel(y_input,src_mask) # output shape [seqlen+i, 1, nTokens]
            # So yes, output is the same size of y_input, with the added dimensionality of vocab (nTokens)
            print(output.shape) # proof
            # Okay another point of potential misunderstanding, I'll try and be concise:
            # is this output a tensor of probabilities in the 2 dim? does index corresponds to vocab?
            nextIndex=output[-1][0][:].argmax() # if so wouldn't the weights of the desired next item
            # be at the very bottom of the output?
            print(nextIndex) # cat in 0 dim, but we gotta add dimensionality to nextIndex
            # this is hacky, I know
            y_input=torch.cat((y_input, nextIndex.unsqueeze(dim=-1).unsqueeze(dim=-1)), dim=0) 
            if nextIndex.item() == EOS: #if output is eos token, we break
                break
    # returns a decoded string of text, starting with your input            
    return " ".join(growingSet.decode(y_input.view(-1).tolist())) 
  
print(generate("Play us something different")) # an attempt generates repeating text, no matter the length

And there I’m stuck folks. If anyone has any interest in NLP, attention transformers and seq2seq text generation or translation, this would be be the thread to post in. I think I’m gonna put some more attention into learning PHP and SQL cuz I need a break from pytorch, oi.

Oh you want some output? 'Ere:

>>>
...
<sos> play us something different
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't kill my vibe
 i don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't don't kill my vibe

Thank you, I’ll be over in the corner holding my breath.