Pytorch equivalent of this tensorflow focal loss

This implementation of focal loss in tensorflow is giving good results with my dataset.

from keras import backend as K

def focal_loss(gamma=2., alpha=.25):
	def focal_loss_fixed(y_true, y_pred):
		pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
		pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
		return -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
	return focal_loss_fixed

I want to test in pytorch, so I created custom loss function in pytorch as below, but seems getting error.

def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
    pt_1 = torch.where(y_true == 1, y_pred, torch.ones_like(y_pred))
    pt_0 = torch.where(y_true == 0, y_pred, torch.zeros_like(y_pred))
    epsilon = 1e-7

    loss = -torch.mean(alpha * (1.0 - pt_1).pow(gamma) * torch.log(pt_1 + epsilon))
    loss -= torch.mean((1 - alpha) * pt_0.pow(gamma) * torch.log(1.0 - pt_0 + epsilon))

    return loss

clf = ts_learner(dls, loss_func=focal_loss, arch="InceptionTimePlus", metrics=[precision, recall], cbs=[save_best_model])

can somebody please let me know how to use this focal loss in above code that uses pytorch.