Plotting individual hidden nodes using Pyplot

Hi, I want to create a function graph_hidden(net, layer, node), that takes a neural network, a layer (either 1 or 2) and a node (number from 0 to num. nodes) and plot the activation of the node.

I have this code so far but it keeps encountering errors, any advice would help greatly!

xrange = torch.arange(start=-8,end=8.1,step=0.01,dtype=torch.float32)
yrange = torch.arange(start=-8,end=8.1,step=0.01,dtype=torch.float32)
xcoord = xrange.repeat(yrange.size()[0])
ycoord = torch.repeat_interleave(yrange, xrange.size()[0], dim=0)
grid =,ycoord.unsqueeze(1)),1)

with torch.no_grad(): # suppress updating of gradients
net.eval()        # toggle batch norm, dropout    
if (layer == 1):
    output = net.hid1[node](grid)
elif (layer == 2):
    output = net.hid2[node](grid)
net.train() # toggle batch norm, dropout back again

pred = (output >= 0.5).float()

# plot function computed by model
               cmap='Wistia', shading='auto')