Pytorch torch nn equivalent of tensorflow (keras) dense layers?

I have had adequate understanding of creating nn in tensorflow but I have tried to port it to pytorch equivalent.
My tflow examples has following layers:
input->flatten->dense(300 nodes)->dense(100 nodes) but I can not get the dense layer definition in pytorch.nn. The web search seem to show or equate the nn.linear to dense but I am not sure.

Here are all layers in pytorch nn:
https://pytorch.org/docs/stable/nn.html

Now I concede the tensorflow example that I have uses really a keras so here are the keras/tensorflow definitions of its layers:

The example I cited seems to have two layers of dense having size 128, 10 respectively.
But in the torch example, user defines linear wiht 128, 10 with one linear.

Now I visited more elaborate definition:
pytorch nn linear layer inputs:

CLASS torch.nn. Linear (in_features , out_features , bias=True , device=None , dtype=None )

keras dense:

tf.keras.layers.Dense(    units, activation=None, use_bias=True,    kernel_initializer='glorot_uniform',    bias_initializer='zeros', kernel_regularizer=None,    bias_regularizer=None, activity_regularizer=None, kernel_constraint=None,    bias_constraint=None, **kwargs)

Here in keras.dense layer: units is defined as: “Positive integer, dimensionality of the output space.”
Here the keyword is output: does it mean this parameter is equivalent to out_features in pytorch nn? Then what should be “in_features*”?

Yes, that’s correct.

It depends on your use case and the number of features of your input.
E.g. if the input to your model has a shape of [batch_size, 10], then the first nn.Linear layer would use in_features=10. If you stack linear layers, then the out_features of the previous layer are the in_features of the next one.

After a daylong struggle,I finally made somewhat equivalent code in pytorch:
Now code runs fine however result predicted result (array of 10 in this case) is compared to target (label), it is a horribly mismatch. The tensorflow code predicted perfectly the classes. Obviously I am doing something wrong. Here is my much improved torch equivalent of fashion mnist code along its output:

import torch
import torch.nn as nn
import helper
import sys
import time
import re
import numpy as np
import matplotlib as plt
DEBUG=0
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

CONFIG_EPOCHS=2
CONFIG_BATCH_SIZE=64

for i in sys.argv:
    print("Processing ", i)
    try:
        if re.search("epochs", i):
            CONFIG_EPOCHS=int(i.split('=')[1])

        if re.search("batch_size", i):
            CONFIG_BATCH_SIZE=int(i.split('=')[1])

    except Exception as msg:
        print(msg)
        print("No argument provided, default values will be used.")

print("epochs: ", CONFIG_EPOCHS)
print("batch_size: ", CONFIG_BATCH_SIZE)
labels_map = {0 : 'T-Shirt', 1 : 'Trouser', 2 : 'Pullover', 3 : 'Dress', 4 : 'Coat', 5 : 'Sandal', 6 : 'Shirt',
              7 : 'Sneaker', 8 : 'Bag', 9 : 'Ankle Boot'};

trainset = datasets.FashionMNIST('~/.pytorch/F_MNIST_data/', download=True, train=True, transform=ToTensor())

training_data = datasets.FashionMNIST(
    root="data",
    train=True,
    download=True,
    transform=ToTensor()

)

trainloader = torch.utils.data.DataLoader(training_data, batch_size=CONFIG_BATCH_SIZE, shuffle = True)
testloader = torch.utils.data.DataLoader(test_data, batch_size=32, shuffle = True)

print("training_data/test_data: ", type(training_data), len(training_data), type(test_data), len(test_data))
print("type: ", type(training_data[0]))

print("trainloader: ", type(trainloader))
print("testloader:  ", type(testloader))

f1=nn.Flatten()
l1=nn.Linear(28*28, 300)
r1=nn.ReLU()
l2=nn.Linear(300, 100)
r2=nn.ReLU()
l3=nn.Linear(100, 30)
s3=nn.Softmax()
model = nn.Sequential(\
    f1,
    l1,
    r1,
    l2,
    r2,
    l3,
    s3, \
)

print("Model: ", model)

#print("model: layer0: ", model[0], model[0].weight)
print("l1 info: ", l1, l1.weight.shape)
print("l2 info: ", l2, l2.weight.shape)
print("l3 info: ", l3, l3.weight.shape)

#criterion = torch.nn.MSELoss()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

i=0
for epoch in range(CONFIG_EPOCHS):
    print("")
    print('epoch/i: ', epoch, i, end=' ', flush=True)
    j=0
    for batch in trainloader:

        imgs, lbls = batch

        if j == 0:
            bypass_dots=int(len(training_data)/len(lbls)/80)
            print("batch: ", type(batch), ", ", len(batch))
            print("imgs: ", type(imgs), ", ", len(imgs), imgs.shape)
            print("lbls: ", type(lbls), ", ", len(lbls), lbls.shape)
            if DEBUG:
                print("bypass_dots quantity: ", bypass_dots)

        if DEBUG:
            print("batch: ", type(batch), ", ", len(batch))
            print("imgs: ", type(imgs), ", ", len(imgs), imgs.shape)
            print("lbls: ", type(lbls), ", ", len(lbls), lbls.shape)

        # Forward pass: Compute predicted y by passing x to the model
        y_pred = model(imgs)

        if DEBUG:
            print("y_pred: ", type(y_pred), y_pred.shape)
            print("lbls:   ", type(lbls), lbls.shape)

        # Compute and print loss
        loss = criterion(y_pred, lbls)

        if DEBUG:
            print('epoch/batch: ', epoch, i,' loss: ', loss.item())

        if j%bypass_dots == 0:
            print(".", end='', flush=True)

        # Zero gradients, perform a backward pass, and update the weights.
        optimizer.zero_grad()

        # perform a backward pass (backpropagation)
        loss.backward()

        # Update the parameters
        optimizer.step()
        j+=1
    i+=1

print("Testing...")

i=0

print(len(testloader), type(testloader))

i=0
#for batch in testloader:
#    imgs, lbls = batch
SLICE=10
for imgs, lbls in testloader:
    imgs1=imgs[:SLICE]
    lbls1=lbls[:SLICE]
    print("---", i, "---")
    print("imgs: ", imgs.shape)
    print("lbls: ", lbls.shape)
    print("imgs1: ", imgs1.shape)

    if i >= 0:
        break
    i+=1
y_pred=model(imgs1)
print("y_pred: ", y_pred.shape, type(y_pred))
print(y_pred)

_, pred_class = torch.max(y_pred, 1)
print("pred_class: ", pred_class)
print("lbls1: ", lbls1)

output:

root@nonroot-Standard-PC-i440FX-PIIX-1996:~/dev-learn/gpu/pytorch/port-effort-from-tflow-2nd# python3 p297.py epochs=5
Processing  p297.py
Processing  epochs=5
epochs:  5
batch_size:  64
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
training_data/test_data:  torchvision.datasets.mnist.FashionMNIST 60000 torchvision.datasets.mnist.FashionMNIST 10000
type:  <class 'tuple'>
trainloader:  torch.utils.data.dataloader.DataLoader
testloader:   torch.utils.data.dataloader.DataLoader
Model:  Sequential(
  (0): Flatten(start_dim=1, end_dim=-1)
  (1): Linear(in_features=784, out_features=300, bias=True)
  (2): ReLU()
  (3): Linear(in_features=300, out_features=100, bias=True)
  (4): ReLU()
  (5): Linear(in_features=100, out_features=30, bias=True)
  (6): Softmax(dim=None)
)
l1 info:  Linear(in_features=784, out_features=300, bias=True) torch.Size([300, 784])
l2 info:  Linear(in_features=300, out_features=100, bias=True) torch.Size([100, 300])
l3 info:  Linear(in_features=100, out_features=30, bias=True) torch.Size([30, 100])

epoch/i:  0 0 batch:  <class 'list'> ,  2
imgs:  <class 'torch.Tensor'> ,  64 torch.Size([64, 1, 28, 28])
lbls:  <class 'torch.Tensor'> ,  64 torch.Size([64])
/usr/local/lib/python3.6/dist-packages/torch/nn/modules/container.py:139: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  input = module(input)


......................................................................................
epoch/i:  1 1 batch:  <class 'list'> ,  2
imgs:  <class 'torch.Tensor'> ,  64 torch.Size([64, 1, 28, 28])
lbls:  <class 'torch.Tensor'> ,  64 torch.Size([64])
......................................................................................
epoch/i:  2 2 batch:  <class 'list'> ,  2
imgs:  <class 'torch.Tensor'> ,  64 torch.Size([64, 1, 28, 28])
lbls:  <class 'torch.Tensor'> ,  64 torch.Size([64])
......................................................................................
epoch/i:  3 3 batch:  <class 'list'> ,  2
imgs:  <class 'torch.Tensor'> ,  64 torch.Size([64, 1, 28, 28])
lbls:  <class 'torch.Tensor'> ,  64 torch.Size([64])
......................................................................................
epoch/i:  4 4 batch:  <class 'list'> ,  2
imgs:  <class 'torch.Tensor'> ,  64 torch.Size([64, 1, 28, 28])
lbls:  <class 'torch.Tensor'> ,  64 torch.Size([64])
......................................................................................Testing...
313 torch.utils.data.dataloader.DataLoader
--- 0 ---
imgs:  torch.Size([32, 1, 28, 28])
lbls:  torch.Size([32])
imgs1:  torch.Size([10, 1, 28, 28])
y_pred:  torch.Size([10, 30]) <class 'torch.Tensor'>
tensor([[2.4371e-03, 4.3862e-05, 2.1819e-03, 3.1150e-04, 3.9027e-03, 2.6907e-04,
...
         2.6642e-03, 2.1615e-02, 1.3037e-02, 5.6110e-02, 6.3985e-04, 3.8399e-04,
         1.1296e-03, 5.2669e-04, 1.1069e-03, 4.4081e-04, 3.0513e-04, 6.3576e-04,
         7.3444e-04, 5.2634e-04, 9.4558e-04, 7.3277e-04, 7.7510e-04, 5.5766e-04,
         5.8993e-04, 4.8865e-04, 5.8177e-04, 5.2099e-04, 8.9330e-04, 9.2164e-04]],
       grad_fn=<SoftmaxBackward>)
pred_class:  tensor([9, 9, 0, 9, 9, 4, 9, 1, 1, 4])
lbls1:  tensor([9, 5, 3, 5, 7, 4, 5, 1, 1, 6])

I think, first you should start seeing whether training loss goes down / training accuracy goes up to know if network is learning.
Given that the network is learning, next, calculate the accuracy on test set (instead of seeing individual test prediction).
That should be the steps to do, IMHO.

Thanks for input. I could not figure out the code for accuracy calc. So i start another route and modified setosa code to fashion mnist. This code is completely different than above but appears to work much better. Challenge was to convert: [4] data to [28, 28].
Now I am getting about 0.64 as accuracy, which is pretty low. I tried several different epochs and batches but best I could get was always lower than 0.7.

root@nonroot-Standard-PC-i440FX-PIIX-1996:~/dev-learn/gpu/pytorch/port-effort-from-tflow-2nd# cat fashion-mnist-example-4.py
import torch
import torch.nn as nn
import helper
import sys
import time
import re
import numpy as np
import matplotlib as plt

from numpy import vstack
from numpy import argmax
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import accuracy_score
from torch import Tensor
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.utils.data import random_split
from torch.nn import Linear
from torch.nn import ReLU
from torch.nn import Softmax
from torch.nn import Module
from torch.optim import SGD
from torch.nn import CrossEntropyLoss
from torch.nn.init import kaiming_uniform_
from torch.nn.init import xavier_uniform_

DEBUG=0
from torchvision import datasets, transforms
from torchvision.transforms import ToTensor

CONFIG_EPOCHS=10
CONFIG_BATCH_SIZE=64

for i in sys.argv:
    print("Processing ", i)
    try:
        if re.search("epochs", i):
            CONFIG_EPOCHS=int(i.split('=')[1])

        if re.search("batch_size", i):
            CONFIG_BATCH_SIZE=int(i.split('=')[1])

    except Exception as msg:
        print(msg)
        print("No argument provided, default values will be used.")

print("epochs: ", CONFIG_EPOCHS)
print("batch_size: ", CONFIG_BATCH_SIZE)
labels_map = {0 : 'T-Shirt', 1 : 'Trouser', 2 : 'Pullover', 3 : 'Dress', 4 : 'Coat', 5 : 'Sandal', 6 : 'Shirt',
              7 : 'Sneaker', 8 : 'Bag', 9 : 'Ankle Boot'};

def prepare_data():
    train = datasets.FashionMNIST(
        root="data",
        train=True,
        download=True,
        transform=ToTensor()
    )

    test = datasets.FashionMNIST(
        root="data",
        train=False,
        download=True,
        transform=ToTensor()
    )

    print("train/test: ", type(train), len(train), type(test), len(test))
    train_dl = torch.utils.data.DataLoader(train, batch_size=CONFIG_BATCH_SIZE, shuffle=True)
    test_dl = torch.utils.data.DataLoader(test, batch_size=CONFIG_BATCH_SIZE, shuffle=False)
    print("train_dl/test_dl: ", type(train_dl), len(train_dl), type(test_dl), len(test_dl))
    return train_dl, test_dl


class MLP(Module):
# prepare the dataset
    # define model elements
    def __init__(self):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten(1, 3)
        # input to first hidden layer
        self.hidden1 = Linear(784, 300)
        kaiming_uniform_(self.hidden1.weight, nonlinearity='relu')
        self.act1 = ReLU()
        # second hidden layer
        self.hidden2 = Linear(300, 100)
        kaiming_uniform_(self.hidden2.weight, nonlinearity='relu')
        self.act2 = ReLU()
        # third hidden layer and output
        self.hidden3 = Linear(100, 30)
#        xavier_uniform_(self.hidden3.weight)
#        self.act3 = Softmax(dim=1)
        self.act3 = Softmax()

    # forward propagate input
    def forward(self, X):

        if DEBUG:
            print("forward entered: X: ", X.size())
        # input to first hidden layer
        X = self.flatten(X)

        if DEBUG:
            print("forward: X (flatten): ", X.size())

        X = self.hidden1(X)

        if DEBUG:
            print("forward: X (hidden1): ", X.size())

        X = self.act1(X)

        if DEBUG:
            print("forward: X (act1/RELU): ", X.size())

        # second hidden layer
        X = self.hidden2(X)
        X = self.act2(X)
        # output layer
        X = self.hidden3(X)
        X = self.act3(X)

        if DEBUG:
            print("forward: X (returned): ", X.size())

        return X

# train the model
def train_model(train_dl, model):
    # define the optimization

    #criterion = torch.nn.CrossEntropyLoss()
    #optimizer = torch.optim.SGD(model.parameters(), lr = 0.01)

    criterion = CrossEntropyLoss()
    optimizer = SGD(model.parameters(), lr=0.01, momentum=0.9)
    # enumerate epochs
    for epoch in range(CONFIG_EPOCHS):
        print("\nepoch:", epoch, end="")
        # enumerate mini batches
        for i, (inputs, targets) in enumerate(train_dl):
            if i % 20 == 0:
                print(".", end="", flush=True)
            if DEBUG:
                print("inputs: ", inputs.size())
                print("targets: ", targets.size())
            # clear the gradients
            optimizer.zero_grad()
            # compute the model output
            yhat = model(inputs)
            if DEBUG:
                print("yhat: ", yhat.size())
                print("targets: ", targets.size())
            # calculate loss
            if DEBUG:
                print("yhat: ", type(yhat), yhat.shape)
                print("targets:   ", type(targets), targets.shape)

            loss = criterion(yhat, targets)
            # credit assignment
            loss.backward()
            # update model weights
            optimizer.step()

# evaluate the model
def evaluate_model(test_dl, model):
    predictions, actuals = list(), list()
    for i, (inputs, targets) in enumerate(test_dl):
        # evaluate the model on the test set
        yhat = model(inputs)
        # retrieve numpy array
        yhat = yhat.detach().numpy()
        actual = targets.numpy()
        # convert to class labels
        yhat = argmax(yhat, axis=1)
        # reshape for stacking
        actual = actual.reshape((len(actual), 1))
        yhat = yhat.reshape((len(yhat), 1))
        # store
        predictions.append(yhat)
        actuals.append(actual)
    predictions, actuals = vstack(predictions), vstack(actuals)
    # calculate accuracy
    acc = accuracy_score(actuals, predictions)
    return acc

# make a class prediction for one row of data
def predict(row, model):
    # convert row to data
    row = Tensor([row])
    # make prediction
    yhat = model(row)
    # retrieve numpy array
    yhat = yhat.detach().numpy()
    return yhat

# prepare the data
#path = 'https://raw.githubusercontent.com/jbrownlee/Datasets/master/iris.csv'
train_dl, test_dl = prepare_data()

print(len(train_dl.dataset), len(test_dl.dataset))
# define the network
model = MLP()
# train the model
print("train_dl: ", len(train_dl))
train_model(train_dl, model)
# evaluate the model
acc = evaluate_model(test_dl, model)
print('Accuracy: %.3f' % acc)
# make a single prediction

output:


Processing  fashion-mnist-example-4.py
Processing  epochs=5
Processing  batch_size=32
epochs:  5
batch_size:  32
/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:498: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  ../torch/csrc/utils/tensor_numpy.cpp:180.)
  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)
train/test:  torchvision.datasets.mnist.FashionMNIST 60000 torchvision.datasets.mnist.FashionMNIST 10000
train_dl/test_dl:  torch.utils.data.dataloader.DataLoader 1875 torch.utils.data.dataloader.DataLoader 313
60000 10000
train_dl:  1875

epoch: 0.fashion-mnist-example-4.py:121: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.
  X = self.act3(X)
.............................................................................................
epoch: 1..............................................................................................
epoch: 2..............................................................................................
epoch: 3..............................................................................................
epoch: 4..............................................................................................Accuracy: 0.634

Yes, I think it is good idea now to display loss along the way on each epoch. But I think there is still seriously wrong with even the last and best one as most decent trainig should result in >90% accuracy not like sub 70%.

nn.CrossEntropyLoss expects raw logits as the model output, so remove the nn.Softmax layer from your model.

@ptrblck, that was awesome. After adding some list test and splitting data further into train,valid and test, it got approx. 88% accuracy.
Running on list of single batch got all correct:


# make a single prediction
print("Making prediction...")

enum_test_dl = list(enumerate(test_dl))
enum_test_dl_sub=enum_test_dl[:1]
print("enum_test_dl_sub: ", len(enum_test_dl_sub))
yhat = model(enum_test_dl_sub[0][1][0])
yhat = yhat.detach().numpy()
actual = enum_test_dl_sub[0][1][1].numpy()
#print("actual1: ", actual)
# convert to class labels
yhat = argmax(yhat, axis=1)
# reshape for stacking
#actual = actual.reshape((len(actual), 1))
#yhat = yhat.reshape((len(yhat), 1))
print("yhat:   ", yhat[:10])
print("actual: ", actual[:10])


train/valid/test:  torch.utils.data.dataset.Subset 50000 torch.utils.data.dataset.Subset 10000 torchvision.datasets.mnist.FashionMNIST 10000
train_dl/test_dl:  torch.utils.data.dataloader.DataLoader 782 torch.utils.data.dataloader.DataLoader 157
50000 10000
train_dl:  782
epoch: 0 / 10........................................loss:  tensor(0.5518, grad_fn=<NllLossBackward>)
epoch: 1 / 10........................................loss:  tensor(0.3474, grad_fn=<NllLossBackward>)
epoch: 2 / 10........................................loss:  tensor(0.3830, grad_fn=<NllLossBackward>)
epoch: 3 / 10........................................loss:  tensor(0.3948, grad_fn=<NllLossBackward>)
epoch: 4 / 10........................................loss:  tensor(0.4192, grad_fn=<NllLossBackward>)
epoch: 5 / 10........................................loss:  tensor(0.1886, grad_fn=<NllLossBackward>)
epoch: 6 / 10........................................loss:  tensor(0.0419, grad_fn=<NllLossBackward>)
epoch: 7 / 10........................................loss:  tensor(0.1398, grad_fn=<NllLossBackward>)
epoch: 8 / 10........................................loss:  tensor(0.0961, grad_fn=<NllLossBackward>)
epoch: 9 / 10........................................loss:  tensor(0.3626, grad_fn=<NllLossBackward>)
Accuracy: 0.880
Making prediction...
enum_test_dl_sub:  1
yhat:    [9 2 1 1 6 1 4 6 5 7]
actual:  [9 2 1 1 6 1 4 6 5 7]

However, lingering questions remain, on tensorflow/keras part that I was “porting from”, softmax was there and it worked just fine. Just wondering why:

...
model=keras.models.Sequential()
model.add(keras.layers.Flatten(input_shape = [28, 28]))
model.add(keras.layers.Dense(300, activation="relu"))
model.add(keras.layers.Dense(100, activation="relu"))
model.add(keras.layers.Dense(30, activation="softmax"))
...

In PyTorch nn.CrossEntropyLoss expects raw logits, since internally F.log_softmax and F.nll_loss will be used. The log_softmax operation is used for a better numerical stability compared to splitting these operations. I don’t know, if TensorFlow/Keras applies log_softmax for the user automatically (without their knowledge), stabilizes the loss calculation in another way, or just applies the operations as they are.

If you don’t wish to calculate the input size explicitly, you can use a LazyLinear module (instead of a Linear) which will infer it:

fc_out = nn.LazyLinear(out_dim)