quick question, how can i visualize aka plt.imshow() the fully connected layers, when i try to do so it says that it has invalid dimensions for images. any help is greatly appreciated
What object are you calling
plt.imshow() on? If it’s a Tensor, it shouldn’t be a problem. E.g.,
a = torch.Tensor(10, 1).uniform_(0, 1) plt.imshow(a)
should work. If it’s a variable, use the .data attribute. E.g.,
b = torch.autograd.Variable(a) plt.imshow(b.data)
The same should apply if you have a model from the module class, e.g.:
class MyNet(torch.nn.Module): def __init__(self, num_features, num_classes): super(MyNet, self).__init__() self.linear_1 = torch.nn.Linear(num_features, num_hidden_1) ...
# do training ... plt.imshow(model.linear_1.weight.data)
Maybe one thing to check would be calling
.size() on your object, maybe it’s missing an axis or has too many axes
its a 1*800 that is its size, does that help? ill try the methods that you have put forth, thank you!
Hm, imshow should work with 2D arrays, not sure why you are getting that error. It also works with 3D arrays if the last axis has 1, 3, or 4 elements (i.e., color channels).