Help writing a differential custom loss function with logic

I’m trying to write a custom loss function with some specific logic and having a hard time coming up with a reformulation/approximation of that logic that is differentiable. I would love any help or suggestions that anyone may have on how I might proceed.

The loss function calculates how much money is gained or lost from betting on games. The logic I’ve implemented (but can’t figure out how to reformulate/approximate as differentiable) is: The model can predict three choices [no_bet, bet_on_home, bet_on_away]. Predicting no_bet adds zero to the loss. Incorrectly predicting to bet on home or away adds +1 to the loss. Correctly predicting to bet on home or away subtracts odds - 1 for that team from the loss.

For example, for a model output, target y (where 0 is a home win and 1 is an away win), and list of odds (column 0 is odds on home win and column 1 is odds on away win)

output: tensor([[ 0.2370,  0.2679,  0.3858],
        [ 0.0401,  0.5379, -0.1788],
        [-0.0537,  0.4639, -0.6986],
        [-0.0818,  0.6218, -0.5953],
        [ 0.1292,  0.1033, -0.1198],
        [-0.2797, -0.4663, -0.4223],
        [ 0.4070,  0.8124, -0.3414],
        [-0.5048,  1.0760,  0.3167]], device='cuda:0',
       grad_fn=<AddmmBackward0>)
y: tensor([1, 0, 1, 0, 0, 0, 0, 1], device='cuda:0')
odds: tensor([[2.3000, 1.6667],
        [2.3000, 1.6667],
        [1.9091, 1.9091],
        [1.1429, 6.0000],
        [2.2000, 1.7143],
        [1.8333, 2.0000],
        [1.0905, 8.0500],
        [5.0000, 1.2000]])

The following loss functions (which are NOT differentiable, from the argmax and other logic) are an example of the logic I described above.

def betting_loss(output, y, odds):
  # apply softmax
  whichbets = torch.argmax(torch.exp(F.log_softmax(output, dim = 1)), dim = 1)
  
  # keep only the games with bets places (not no_bet)
  betgames = whichbets != 0
  # also convert whichbets to [0,1] scale of y by subtracting 1 (2 - > 1, and 1 -> 0)
  whichbets = whichbets[betgames] - 1
  y = y[betgames]
  odds = odds[betgames]
  # games won
  wins = whichbets == y
  
  losses_loss = torch.sum(~wins)
  wins_loss = torch.sum(odds[wins, y[wins]] - 1)
  
  return losses_loss - wins_loss

The same thing, just in loop form for easier comprehension

def betting_loss_loop(output, y, odds):
  # apply softmax
  whichbets = torch.argmax(torch.exp(F.log_softmax(output, dim = 1)), dim = 1)
  # also convert whichbets to [0,1] scale of y by subtracting 1 (2 - > 1, and 1 -> 0)
  whichbets = whichbets - 1

  loss = 0
  for ii, whichbet in enumerate(whichbets):
    if whichbet != -1:
      if whichbet != y[ii]:
        loss += 1
      else:
        loss -= odds[ii, y[ii]] - 1

  return loss
print(betting_loss(output, y, odds))
print(betting_loss_loop(output, y, odds))

gives

tensor(-0.2000, device='cuda:0')
tensor(-0.2000)

Any help would be greatly appreciated. Thank you