Issue (Model not getting trained) during Backpropagation in Adaptive Neural Fuzzy Inference System

import numpy as np
import torch
import skfuzzy as fuzz
from torch import nn
import torch.optim as optim
from sklearn.metrics import mean_squared_error as mse

def generate_random_params_antecedent():
return [np.random.rand() , np.random.rand(), np.random.rand()]

def generate_random_params_consequent():
while True:
a, b, c, d = np.random.randint(1, 10, 4)
if a <= b <= c <= d:
return [a, b, c, d]

class manfis(nn.Module):
def init(self, membership_function = “general_bell”):
super(manfis,self).init()
self.membership_function = membership_function

    if membership_function == "general_bell":
        self.premise_params = nn.ParameterList([
            nn.Parameter(torch.tensor(generate_random_params_antecedent(), dtype=torch.float)),
            nn.Parameter(torch.tensor(generate_random_params_antecedent(), dtype=torch.float)),
            nn.Parameter(torch.tensor(generate_random_params_antecedent(), dtype=torch.float)),
            nn.Parameter(torch.tensor(generate_random_params_antecedent(), dtype=torch.float)),
        ])

        self.consequent_params = nn.ParameterList([
            nn.Parameter(torch.tensor(generate_random_params_consequent(), dtype=torch.float)),
            nn.Parameter(torch.tensor(generate_random_params_consequent(), dtype=torch.float))
        ])

def _fuzzification(self, x):
    mu_A1 = fuzz.gbellmf(x.detach().numpy(), c = self.premise_params[0][1].item(), b = self.premise_params[0][0].item(), a = self.premise_params[0][2].item())
    mu_A2 = fuzz.gbellmf(x.detach().numpy(), c = self.premise_params[1][1].item(), b = self.premise_params[1][0].item(), a = self.premise_params[1][2].item())
    mu_B1 = fuzz.gbellmf(x.detach().numpy(), c = self.premise_params[2][1].item(), b = self.premise_params[2][0].item(), a = self.premise_params[2][2].item())
    mu_B2 = fuzz.gbellmf(x.detach().numpy(), c = self.premise_params[3][1].item(), b = self.premise_params[3][0].item(), a = self.premise_params[3][2].item())
    return (mu_A1, mu_A2, mu_B1, mu_B2)

def _inference(self, membership_degrees):
    rule_firing_strength_1 = membership_degrees[0] * membership_degrees[2] # mu_A1 * mu_B1
    rule_firing_strength_2 = membership_degrees[1] * membership_degrees[3] # mu_A2 * mu_B2
    weighted_rule_firing_strength_1 = rule_firing_strength_1 / (rule_firing_strength_1 + rule_firing_strength_2)
    weighted_rule_firing_strength_2 = rule_firing_strength_2 / (rule_firing_strength_1 + rule_firing_strength_2)
    return (weighted_rule_firing_strength_1, weighted_rule_firing_strength_2)

def _implication(self, rule_firing_strengths):
    output_1 = rule_firing_strengths[0] * self.consequent_params[0][0].item() # W1 * c1(from A1)
    output_2 = rule_firing_strengths[1] * self.consequent_params[1][0].item() # W2 * c2(from A2)
    return (output_1, output_2)

def _aggregation(self, consequent_activations):
    return consequent_activations[0] + consequent_activations[1] # w1c1 + w2c2

def _defuzzification(self, x, aggregated_output):
    centroid_1 = fuzz.defuzzify.centroid(x.detach().numpy(), fuzz.trapmf(x.detach().numpy(), [self.consequent_params[0][0].item(), self.consequent_params[0][1].item(), self.consequent_params[0][2].item(), self.consequent_params[0][3].item()]))
    area_1 = np.trapz(fuzz.trapmf(x.detach().numpy(), [self.consequent_params[0][0].item(), self.consequent_params[0][1].item(), self.consequent_params[0][2].item(), self.consequent_params[0][3].item()]), x.detach().numpy())
    centroid_2 = fuzz.defuzzify.centroid(x.detach().numpy(), fuzz.trapmf(x.detach().numpy(), [self.consequent_params[1][0].item(), self.consequent_params[1][1].item(), self.consequent_params[1][2].item(), self.consequent_params[1][3].item()]))
    area_2 = np.trapz(fuzz.trapmf(x.detach().numpy(), [self.consequent_params[1][0].item(), self.consequent_params[1][1].item(), self.consequent_params[1][2].item(), self.consequent_params[1][3].item()]), x.detach().numpy())
    centroid = ((centroid_1 * area_1) + (centroid_2 * area_2)) / (area_1 + area_2) # D
    output = torch.tensor(aggregated_output * centroid, dtype=torch.float) # O_4 * Defuzzified_result
    return output

def forward(self, x):
    membership_degrees = self._fuzzification(x)
    rule_firing_strengths = self._inference(membership_degrees)
    consequent_activations = self._implication(rule_firing_strengths)
    aggregated_output = self._aggregation(consequent_activations)
    final_output = self._defuzzification(x, aggregated_output)
    return final_output

class modeltrain():
def init(self, xMapping, yMapping):
self.xMapping=xMapping
self.yMapping=yMapping

def _fit(self, epochs):
    model = manfis()
    criteria = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)

    for e in range(epochs):
        optimizer.zero_grad()
        out = model(self.xMapping)
        loss = criteria(out, self.yMapping)
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            pred = model(self.xMapping)
            pred = pred.cpu().detach().numpy()
            real = self.yMapping.cpu().numpy().reshape(-1, 1)
            test_rmse = mse(pred, real) ** 0.5

        print('[Epoch: {}/{}] [Train Loss: {}] [Test RMSE: {}]'.format(
            e+1, epochs, str(loss.item())[:6], str(test_rmse)[:6]))
    return model

I have debugged the code multiple times, but I cannot figure out why the model is not updating with new parameters. It would be really helpful if someone could assist me. Is there any error in the backpropagation code? @ptrblck_de

Your code is a bit hard to read as it’s not properly formatted (you can post code snippets by wrapping them into three backticks ```).

However, it seems you are detaching the computation graph several times:

...
centroid_1 = fuzz.defuzzify.centroid(x.detach().numpy(), fuzz.trapmf(x.detach().numpy(), [self.consequent_params[0][0].item(), self.consequent_params[0][1].item(), self.consequent_params[0][2].item(), self.consequent_params[0][3].item()]))
...
  • Calling x.detach() will obviously detach the tensor x from the computation graph and parameters previously used to compute x will not get any valid gradients anymore.

  • Besides that using a 3rd party library, such as numpy, will also not be tracked. If you need to use numpy, you would need to create a custom autograd.Function as described here.

  • Calling item() will return a plain Python scalar and will also detach the tensor from the computation graph.

Further down:

output = torch.tensor(aggregated_output * centroid, dtype=torch.float)
  • you are creating a new tensor, which will also detach the wrapped tensor as a new leaf-tensor will be created.

To resolve this issue, I defined the methods inside the class and ensured that all of them receive and give output as tensors. Also, according to your suggestion, I removed item() method to track the gradients. But now, after the 1st iteration, I am getting Function ‘PowBackward1’ returned nan values in its 1th output. and sometimes Function ‘MulBackward0’ returned nan values in its 0th output. I am trying to perform a regression task, and my trainable parameters are premise and consequent parameters.

Here is the code

import torch
from torch import nn
import torch.optim as optim
from sklearn.metrics import mean_squared_error as mse

def generate_random_params_antecedent():
  bi = np.random.randint(1, 10, 1)[0]
  ci = np.random.randint(1, 10, 1)[0]
  di = np.random.randint(1, 10, 1)[0]
  return [bi, ci, di]

def generate_random_params_consequent():
  while True:
    a, b, c, d = np.random.randint(1, 10, 4)
    if a <= b <= c <= d:
      return [a, b, c, d]

class MamdaniANFIS(nn.Module):

  def __init__(self, membership_function = "general_bell"):
        super(MamdaniANFIS,self).__init__()
        self.membership_function = membership_function

        if membership_function == "general_bell":
                  antecedent_params = [generate_random_params_antecedent() for _ in range(4)]
                  self.premise_params = nn.ParameterList([nn.Parameter(torch.tensor(params, dtype=torch.float)) for params in antecedent_params])

                  consequent_params = [generate_random_params_consequent() for _ in range(2)]
                  self.consequent_params = nn.ParameterList([nn.Parameter(torch.tensor(params, dtype=torch.float)) for params in consequent_params])

  def bell_shaped_function(self, x, ci, di, bi):
    epsilon = 1e-8
    return 1 / (1 + ((x - ci + epsilon) / (di + epsilon))**(2 * bi))


  def centroid(self, x, mfx):
        if x.numel() == 1:
            return x[0] * mfx[0] / torch.max(mfx[0], torch.tensor(torch.finfo(torch.float32).eps, device=x.device, dtype=x.dtype))

        x = x.float()
        mfx = torch.abs(mfx)
        sum_moment_area = torch.tensor(0.0, device=x.device, dtype=x.dtype)
        sum_area = torch.tensor(0.0, device=x.device, dtype=x.dtype)

        for i in range(1, x.numel()):
            x1 = x[i - 1]
            x2 = x[i]
            y1 = mfx[i - 1]
            y2 = mfx[i]

            if not (torch.allclose(torch.tensor([y1, y2], device=x.device, dtype=x.dtype), torch.tensor([0.0], device=x.device, dtype=x.dtype)) or torch.equal(x1, x2)):
                if torch.equal(y1, y2):
                    moment = 0.5 * (x1 + x2)
                    area = (x2 - x1) * y1
                elif torch.equal(y1, torch.tensor(0.0, device=x.device, dtype=x.dtype)) and not torch.equal(y2, torch.tensor(0.0, device=x.device, dtype=x.dtype)):  # triangle, height y2
                    moment = 2.0 / 3.0 * (x2 - x1) + x1
                    area = 0.5 * (x2 - x1) * y2
                elif torch.equal(y2, torch.tensor(0.0, device=x.device, dtype=x.dtype)) and not torch.equal(y1, torch.tensor(0.0, device=x.device, dtype=x.dtype)):  # triangle, height y1
                    moment = 1.0 / 3.0 * (x2 - x1) + x1
                    area = 0.5 * (x2 - x1) * y1
                else:
                    moment = (2.0 / 3.0 * (x2 - x1) * (y2 + 0.5 * y1)) / (y1 + y2) + x1
                    area = 0.5 * (x2 - x1) * (y1 + y2)

                sum_moment_area += moment * area
                sum_area += area

        return sum_moment_area / torch.max(sum_area, torch.tensor(torch.finfo(torch.float32).eps, device=x.device, dtype=x.dtype))


  def trapmf(self, x, abcd):
    """
    Trapezoidal membership function generator.

    Parameters
    ----------
    x : 1d tensor
        Independent variable.
    abcd : 1d tensor, length 4
        Four-element vector.  Ensure a <= b <= c <= d.

    Returns
    -------
    y : 1d tensor
        Trapezoidal membership function.
    """
    assert len(abcd) == 4, 'abcd parameter must have exactly four elements.'
    a, b, c, d = abcd
    assert a <= b and b <= c and c <= d, 'abcd requires the four elements a <= b <= c <= d.'
    y = torch.ones_like(x)

    idx = torch.nonzero(x <= b, as_tuple=True)[0]
    y[idx] = trimf(x[idx], torch.tensor([a, b, b]))

    idx = torch.nonzero(x >= c, as_tuple=True)[0]
    y[idx] = trimf(x[idx], torch.tensor([c, c, d]))

    idx = torch.nonzero(x < a, as_tuple=True)[0]
    y[idx] = torch.zeros_like(y[idx])

    idx = torch.nonzero(x > d, as_tuple=True)[0]
    y[idx] = torch.zeros_like(y[idx])

    return y


  def trimf(x, abc):
      """
      Triangular membership function generator.

      Parameters
      ----------
      x : 1d tensor
          Independent variable.
      abc : 1d tensor, length 3
          Three-element vector controlling the shape of the triangular function.
          Requires a <= b <= c.

      Returns
      -------
      y : 1d tensor
          Triangular membership function.
      """
      assert len(abc) == 3, 'abc parameter must have exactly three elements.'
      a, b, c = abc

      assert a <= b and b <= c, 'abc requires the three elements a <= b <= c.'
      y = torch.zeros_like(x)

      # Left side
      if a != b:
          idx = torch.nonzero(torch.logical_and(a < x, x < b), as_tuple=True)[0]
          y[idx] = (x[idx] - a) / (b - a)

      # Right side
      if b != c:
          idx = torch.nonzero(torch.logical_and(b < x, x < c), as_tuple=True)[0]
          y[idx] = (c - x[idx]) / (c - b)

      idx = torch.nonzero(x == b, as_tuple=True)[0]
      y[idx] = torch.ones_like(y[idx])

      return y

  def fuzzification(self, x):
        mu_A1 = self.bell_shaped_function(x, ci = self.premise_params[0][1], bi = self.premise_params[0][0], di = self.premise_params[0][2])
        mu_A2 = self.bell_shaped_function(x, ci = self.premise_params[1][1], bi = self.premise_params[1][0], di = self.premise_params[1][2])
        mu_B1 = self.bell_shaped_function(x, ci = self.premise_params[2][1], bi = self.premise_params[2][0], di = self.premise_params[2][2])
        mu_B2 = self.bell_shaped_function(x, ci = self.premise_params[3][1], bi = self.premise_params[3][0], di = self.premise_params[3][2])
        print("mua1", mu_A1)
        print("mua2", mu_A2)
        print("mub1", mu_B1)
        print("mub2", mu_B2)
        return (mu_A1, mu_A2, mu_B1, mu_B2)


  def inference(self, membership_degrees):
        rule_firing_strength_1 = membership_degrees[0] * membership_degrees[2] # mu_A1 * mu_B1
        rule_firing_strength_2 = membership_degrees[1] * membership_degrees[3] # mu_A2 * mu_B2
        weighted_rule_firing_strength_1 = rule_firing_strength_1 / (rule_firing_strength_1 + rule_firing_strength_2)
        weighted_rule_firing_strength_2 = rule_firing_strength_2 / (rule_firing_strength_1 + rule_firing_strength_2)
        print("wrls 1", weighted_rule_firing_strength_1)
        print("wrls 2", weighted_rule_firing_strength_2)
        return (weighted_rule_firing_strength_1, weighted_rule_firing_strength_2)

  def implication(self, rule_firing_strengths):
        output_1 = rule_firing_strengths[0] * self.premise_params[0][1] # W1 * c1(from A1)
        output_2 = rule_firing_strengths[1] * self.premise_params[1][1] # W2 * c2(from A2)
        print("output_1", output_1)
        print("output_2", output_2)
        return (output_1, output_2)

  def aggregation(self, consequent_activations):
        print("agg output ", consequent_activations[0] + consequent_activations[1])
        return consequent_activations[0] + consequent_activations[1] # w1c1 + w2c2

  def defuzzification(self, x, aggregated_output):
        centroid_1 = self.centroid(x, self.trapmf(x, self.consequent_params[0]))
        print("centroid_1 ", centroid_1)
        area_1 = torch.trapz(x,  self.trapmf(x, self.consequent_params[0]))
        print("area_1", area_1)
        centroid_2 = self.centroid(x, self.trapmf(x, self.consequent_params[1]))
        print("centroid_2", centroid_2)
        area_2 = torch.trapz(x,  self.trapmf(x, self.consequent_params[1]))
        print("area_2", area_2)
        centroid = ((centroid_1 * area_1) + (centroid_2 * area_2)) / (area_1 + area_2) # D
        output = aggregated_output * centroid # O_4 * Defuzzified_result
        print("output", output)
        return output

  def forward(self, x):
        membership_degrees = self.fuzzification(x)
        rule_firing_strengths = self.inference(membership_degrees)
        consequent_activations = self.implication(rule_firing_strengths)
        aggregated_output = self.aggregation(consequent_activations)
        final_output = self.defuzzification(x, aggregated_output)
        return final_output

class ModelTrainer():
    def __init__(self, x_mapping, y_mapping):
        self.x_mapping = x_mapping.clone().detach()
        self.y_mapping = y_mapping.clone().detach().requires_grad_(True)

    def _fit(self, epochs):
        m = MamdaniANFIS()
        crit = nn.MSELoss()
        para = list(m.parameters())
        optimizer = optim.Adam(para, lr=0.01)
        torch.autograd.set_detect_anomaly(True)

        for e in range(epochs):
            optimizer.zero_grad()
            out = m.forward(self.x_mapping)
            loss = crit(out, self.y_mapping)
            loss.backward()
            # m.float()
            optimizer.step()

            with torch.no_grad():
                test_rmse = torch.sqrt(crit(out, self.y_mapping))

            print('[Epoch: {}/{}] [Train Loss: {}] [Test RMSE: {}]'.format(
                    e+1, epochs, str(loss.item())[:6], str(test_rmse.item())[:6]))

        return m

x = torch.Tensor([1, 2, 3, 4])
y = torch.tensor(list(range(5, 9)), dtype=torch.float32)
model=ModelTrainer(x,y)
model._fit(20)```

Your code is not executable and fails with:

NameError: name 'trimf' is not defined

Using self.trimf fails with:

TypeError: MamdaniANFIS.trimf() takes 2 positional arguments but 3 were given

The only pow usage I’m seeing is in:

 return 1 / (1 + ((x - ci + epsilon) / (di + epsilon))**(2 * bi))

so you might want to double check if this operation is creating the invalid gradients.

I changed the bell_function code .
Further, I added the self parameter inside the function trimf function. The code now works fine. Thank you very much for your help.

    epsilon = 1e-8
    numerator = torch.abs(x - ci) + epsilon
    denominator = torch.abs(di) + epsilon
    return 1 / (1 + (numerator / denominator)**(2 * bi))```
1 Like