Triplet loss become zero after first batch

i am using siamese network with triplet loss using resnet50,
when i train the network after second iteration loss becoming ‘0’.


res_model = resnet50(pretrained=True)
# Unfreeze model weights
for param in res_model.parameters():
    param.requires_grad = True

class SiameseNetwork(nn.Module):
    def __init__(self):
        super(SiameseNetwork, self).__init__()
        self.res = nn.Sequential(OrderedDict([
            ('res', res_model)
        self.fc_layers = nn.Sequential(OrderedDict([
            ('fc1', nn.Linear(1000, 1000, bias=False)),
            ('fbn', nn.BatchNorm1d(1000)),
            ('frelu', nn.ReLU()),
            ('fc2', nn.Linear(1000, 512)),
    def forward(self, inputs):
        x = self.res(inputs)
        x = x.view(x.shape[0], 1000)
        x = self.fc_layers(x)
        return x

class siaNet(nn.Module):
    def __init__(self):
        super(siaNet, self).__init__()
        self.anchor = nn.Sequential(OrderedDict([
            ('anchor', SiameseNetwork())
        self.positive = nn.Sequential(OrderedDict([
            ('positive', SiameseNetwork())
        self.negative = nn.Sequential(OrderedDict([
            ('negative', SiameseNetwork())

    def forward(self,input1,input2,input3):
        output1 = self.anchor(input1)
        output2 = self.positive(input2)
        output3 = self.negative(input3)
        return output1,output2,output3

model = siaNet()

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin
    def calc_euclidean(self, x1, x2):
        return (x1 - x2).pow(2).sum(1)
    def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor:
        distance_positive = self.calc_euclidean(anchor, positive)
        distance_negative = self.calc_euclidean(anchor, negative)
        losses = torch.relu(distance_positive - distance_negative + self.margin)

        return losses.mean()

anc, pos, neg = train_x  
batch_size = 8
for epoch in range(25):
    for i in range(0, len(anc), batch_size):
        anchor_img = torch.from_numpy(anc[i:i+2]).to(device).permute(0,3,1,2).float()
        positive_img = torch.from_numpy(pos[i:i+2]).to(device).permute(0,3,1,2).float()
        negative_img = torch.from_numpy(neg[i:i+2]).to(device).permute(0,3,1,2).float()
        anchor_out,positive_out,negative_out  = model(anchor_img,positive_img,negative_img)
        loss = criterion(anchor_out, positive_out, negative_out)
        print("Epoch number {}\n Current loss {}\n".format(epoch,loss.item()))