Plotting hidden unit activations

Hi,

Im working on a network for a project that distinguishes two intertwined rectuangular spirals with which the user can select between a 1 hidden layer or 2 hidden layer model. I’ve got both models working, but am having some trouble creating a plot of the hidden unit activations at each layer which is contained in the graph_hidden() function. This is my code so far :

# rect.py

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

class Network(torch.nn.Module):
    def __init__(self, layer, hid):
        super(Network, self).__init__()
        self.layer = layer
        
        # Hidden Layer 1
        self.hid1 = nn.Sequential(
            nn.Linear(2, hid),
            nn.Tanh()
        )
        # Hidden Layer 2
        self.hid2 = nn.Sequential(
            nn.Linear(hid, hid),
            nn.Tanh()
        )
        # Output Layer
        self.out = nn.Sequential(
            nn.Linear(hid, 1),
            nn.Sigmoid()
        )

    def forward(self, input):
        if (self.layer == 1):
            output = self.hid1(input)
            output = self.out(output)
            return output
        elif (self.layer == 2):
            output = self.hid1(input)
            output = self.hid2(output)
            output = self.out(output)
            return output

def graph_hidden(net, layer, node): 
    xrange = torch.arange(start=-8,end=8.1,step=0.01,dtype=torch.float32)
    yrange = torch.arange(start=-8,end=8.1,step=0.01,dtype=torch.float32)
    xcoord = xrange.repeat(yrange.size()[0])
    ycoord = torch.repeat_interleave(yrange, xrange.size()[0], dim=0)
    grid = torch.cat((xcoord.unsqueeze(1),ycoord.unsqueeze(1)),1)

    with torch.no_grad():   # suppress updating of gradients
        net.eval()          # toggle batch norm, dropout
        if (layer == 1):
            output = net.hid1(grid)
        elif (layer == 2):
            output = net.hid2(grid)
        net.train()         # toggle batch norm, dropout back again

        pred = (output >= 0.5).float()

        # plot function computed by model
        plt.clf()
        plt.pcolormesh(xrange,yrange,pred.cpu().view(yrange.size()[0],xrange.size()[0]),
                       cmap='Wistia', shading='auto')
        

Really not sure what im doing here so any pointers are greatly appreciated :slight_smile: