Building a loss function without neural nets

Hi, I’m applying pytorch to a character-based language model. I have to redefine the loss function according to the customized model. The consequence is the optimization becomes extremely slow, especially in the multiple embedded for loops here. I wonder if there’s any way to convert the NLL() and coemitNumerators() here to the built-in NLLloss function in pytorch? I look forward to hearing any thoughts, and love to share the full code if anyone is interested!

def coemitNumerators(word, ix_inventory, automata):
    states = [0] * len(automata)

    numerators = {}
    
    for pos, someSeg in enumerate(word):
        for seg in ix_inventory:
            numerators[pos, seg] = torch.tensor(1.) 
    # In each pos of a word, set the transition prob of a phoneme to 1.

    for somePos, someSeg in enumerate(word):
        for i, M in enumerate(automata):
            for seg in ix_inventory:
                numerators[somePos, seg] = numerators[somePos, seg] * M.tProb[(states[i], seg)]
            states[i] = M.D[(states[i], someSeg.item())]
    return numerators
def NLL(wlist, automata, inventory):
    nll = 0
    for w in wlist:
        numerators = coemitNumerators(w, inventory, automata)
        for realpos, realseg in enumerate(w):
            nll = nll - torch.log(numerators[realpos, realseg.item()] / sum(numerators[realpos, seg] for seg in inventory))
    return nll
def train_lm(dataset, dev, params, automata, ix_inventory, phone2ix, ix2phone, prev_perplexity = 1e10):

    variables = []
    for automaton in automata:
        variables += automaton.tProb.values()
    optimizer = torch.optim.Adam(variables, params['learning_rate'])
    num_examples, seq_len = dataset.size()
    batches = [
        (start, start + params['batch_size'])
        for start in range(0, num_examples, params['batch_size'])
    ]

    for epoch in range(params['epochs']):
        ep_loss = 0.
        start_time = time.time()
        random.shuffle(batches)
        for b_idx, (start, end) in enumerate(batches):
            loss = NLL(dataset[start:end], automata, ix_inventory)
            optimizer.zero_grad()
            
            loss.backward()
            optimizer.step()
            with torch.no_grad():
                for v in variables:
                    v.clamp_(0, 1)
            ep_loss += loss.detach()