CNN with Custom Convolutions, loss NAN

I created a custom convolution Product Units (PU, attached image) to be specific. I am concatenating both the PU and CNN but my loss nan. I went through my PU code several time and I still cant figure out what I’m wrong.

Blockquote
class ProductConv2d(nn.Module):
def init(self, in_channels, out_channels, kernel_size, stride=1):
super(ProductConv2d, self).init()
self.weights = nn.Parameter(torch.Tensor(out_channels, in_channels, kernel_size, kernel_size))
self.bias = nn.Parameter(torch.Tensor(out_channels))
self.stride = stride
nn.init.kaiming_uniform_(self.weights, a=math.sqrt(5))
nn.init.constant_(self.bias, -1) # Initialize bias to -1

def forward(self, x):
    unfolded = F.unfold(x, kernel_size=self.weights.shape[2:], stride=self.stride)
    unfolded = unfolded.view(x.size(0), -1, unfolded.size(-1))
    log_unfolded = torch.log(torch.clamp(torch.abs(unfolded), min=1e-10))  # Use clamp to avoid log(0)
    log_weights = torch.log(torch.clamp(torch.abs(self.weights), min=1e-10)).view(self.weights.size(0), -1)  # Use clamp to avoid log(0)
    log_product = torch.matmul(log_weights, log_unfolded)
    return torch.exp(log_product).view(x.size(0), self.weights.size(0), x.size(2) - self.weights.shape[2] + 1, x.size(3) - self.weights.shape[3] + 1)

Define the CombinedCNN class

class CombinedCNN(nn.Module):
def init(self):
super(CombinedCNN, self).init()
self.product_conv1 = ProductConv2d(1, 5, kernel_size=3) # 5 output channels, 1 input channel for MNIST
self.conv1 = nn.Conv2d(1, 5, 3, 1) # 5 output channels, 1 input channel for MNIST
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(10, 5, 3, 1) # 5 output channels, adjusted input channels to match concatenated feature maps
self.conv3 = nn.Conv2d(5, 5, 3, 1) # New conv3 layer with 10 output channels

    # Calculate the size of the input to the first fully connected layer
    self.fc_input_size = self._get_fc_input_size()

    self.fc1 = nn.Linear(self.fc_input_size, 128)  # Adjusted input size for MNIST
    self.fc2 = nn.Linear(128, 10)  # 10 classes for MNIST

    # Initialize hooks dictionary
    self.feature_maps = {}

    # Register hooks
    self.product_conv1.register_forward_hook(self.save_feature_maps('product_conv1'))
    self.conv1.register_forward_hook(self.save_feature_maps('conv1'))
    self.conv2.register_forward_hook(self.save_feature_maps('conv2'))
    # self.conv3.register_forward_hook(self.save_feature_maps('conv3'))

def _get_fc_input_size(self):
    # Create a dummy tensor with the same shape as the input
    dummy_input = torch.zeros(1, 1, 28, 28)
    # Pass the dummy tensor through the convolutional layers
    y = self.product_conv1(dummy_input)
    y = F.tanh(y)
    # y = self.pool(y)
    z = F.relu(self.conv1(dummy_input))
    z = self.pool(z)
    if y.size(2) != z.size(2) or y.size(3) != z.size(3):
        z = F.interpolate(z, size=(y.size(2), y.size(3)))
    combined_features = torch.cat((y, z), dim=1)
    combined_features = F.tanh(self.conv2(combined_features))
    # combined_features = self.pool(combined_features)
    # combined_features = F.relu(self.conv3(combined_features))
    combined_features = self.pool(combined_features)
    # Flatten the tensor and return the size
    return combined_features.view(1, -1).size(1)

def save_feature_maps(self, layer_name):
    def hook(module, input, output):
        self.feature_maps[layer_name] = output.detach()
    return hook


def forward(self, x):
    y = self.product_conv1(x)  # Product units layer
    y = F.tanh(y)
    # y = self.pool(y)


    z = F.relu(self.conv1(x))  # Standard CNN layer
    z = self.pool(z)


    if y.size(2) != z.size(2) or y.size(3) != z.size(3):
        z = F.interpolate(z, size=(y.size(2), y.size(3)))

    # Concatenate feature maps from conv1 and product_conv1
    combined_features = torch.cat((y, z), dim=1)
    combined_features = F.tanh(self.conv2(combined_features))
    combined_features = self.pool(combined_features)


    # combined_features = F.relu(self.conv3(combined_features))  # Apply conv3
    # combined_features = self.pool(combined_features)

    combined_features = combined_features.view(combined_features.size(0), -1)  # Flatten the tensor
    x = F.relu(self.fc1(combined_features))
    x = self.fc2(x)
    x= F.softmax(x, dim=1)
    return x

screenshot