Hi guys , i’ m newbie to pytorch. Train loss of my LeNet network was not updated. please help me
Thanks in advance
import torch
import torchvision
import torchvision.transforms as transforms
from itertools import islice
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from skorch import NeuralNetClassifier
torch.manual_seed(0)
transform=torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
])
trainset = torchvision.datasets.FashionMNIST("./data",download=True,train=True,transform=transform)
# X_examples,y_examples = zip(*islice(iter(trainset),5))
def plot_example(X, y, n=5):
"""Plot the images in X and their labels in rows of `n` elements."""
fig = plt.figure()
rows = len(X) // n + 1
for i, (img, y) in enumerate(zip(X, y)):
ax = fig.add_subplot(rows, n, i + 1)
ax.imshow(img.reshape(28, 28))
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(y)
plt.tight_layout()
return fig
# plot_example(torch.stack(X_examples), y_examples, n=5).show()
callbacks = []
USE_TENSORBOARD = 1
if USE_TENSORBOARD:
from torch.utils.tensorboard import SummaryWriter
from skorch.callbacks import TensorBoard
writer = SummaryWriter()
callbacks.append(TensorBoard(writer))
class MyLeNet(nn.Module):
def __init__(self):
super(MyLeNet,self).__init__()
self.conv1 = nn.Conv2d(1,6,5,padding=2)
self.avg_pool1 = nn.AvgPool2d(2,stride=2)
self.conv2 = nn.Conv2d(6,16,5)
self.avg_pool2 = nn.AvgPool2d(2,stride=2)
self.fc1 = nn.Linear(400,200)
self.fc2 = nn.Linear(200,84)
self.fc3 = nn.Linear(84,10)
self.sigmoid = nn.Sigmoid()
def forward(self,X):
image = X
X = self.sigmoid( self.conv1(X))
X = self.avg_pool1(X)
X = self.sigmoid(self.conv2(X))
X = self.avg_pool2(X)
X = X.view((image.shape[0],-1))
X = self.sigmoid(self.fc1(X))
X = self.sigmoid(self.fc2(X))
X = self.fc3(X)
X = F.softmax(X, dim=-1)
return X
net = NeuralNetClassifier(
MyLeNet,
max_epochs=10,
lr=0.01,
device="cuda",
callbacks=callbacks,
)
y_train = np.array([y for x,y in iter(trainset)])
print("y_train.size",y_train.shape)
net.fit(trainset,y=y_train)