Katz's backoff implementation

I’ve been staring at this wikipedia article on Katz’s backoff model for quite some time. I’m interested in trying to implement it into my pytorch model as a loss function. I have no sample code for the loss unfortunately. To make things easier, I could set d = 0 and start with a trigram model.

I was wondering if anyone out there had implemented something like this? I was thinking for a given sentence I would calculate the katz probability for the trigrams in the sentence then take the negative log of their sum as the loss. Has anyone done something like this or something similar?

Here’s a simplification of my thought process:

x = torch.tensor(input_text) #here the input text is a training example.
new_text = model(x) #new_text is generated text by the model
loss = katz_loss(new_text) #the katz backoff probability is calculated for each trigram in new_text. The final loss will be the negative log of their sum.
loss.backward()

Thanks in advance for your help!