MLP mixer - Saving the training model

I’m trying to train the MLP mixer on a custom dataset based on this repository.

The code I have so far is shown below. How can I save the training model to further use it on test images?

import torch
import numpy as np
from torch import nn
from einops.layers.torch import Rearrange
import glob
import cv2
from torch.utils.data import Dataset, DataLoader

class customDataset(Dataset):
    def __init__(self):
        self.imags_path = '/path_to_dataset/'
        file_list = glob.glob(self.imags_path + '*')
        self.data = []
        for class_path in file_list:
            class_name = class_path.split('/')[-1]
            for img_path in glob.glob(class_path + '/*.jpg'):
                self.data.append([img_path,class_name])
        self.class_map = {'dogs':0, 'cats':1}
        self.img_dim = (416,416)

    def __len__(self):
        return len(self.data)

    def __getitem__(self,idx):
        img_path,class_name = self.data[idx]
        img = cv2.imread(img_path)
        img = cv2.resize(img,self.img_dim)
        class_id = self.class_map[class_name]
        img_tensor = torch.from_numpy(img)
        img_tensor = img_tensor.permute(2, 0, 1)
        class_id = torch.tensor([class_id])
        return img_tensor, class_id

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class MixerBlock(nn.Module):

    def __init__(self, dim, num_patch, token_dim, channel_dim, dropout = 0.):
        super().__init__()

        self.token_mix = nn.Sequential(
            nn.LayerNorm(dim),
            Rearrange('b n d -> b d n'),
            FeedForward(num_patch, token_dim, dropout),
            Rearrange('b d n -> b n d')
        )

        self.channel_mix = nn.Sequential(
            nn.LayerNorm(dim),
            FeedForward(dim, channel_dim, dropout),
        )

    def forward(self, x):

        x = x + self.token_mix(x)

        x = x + self.channel_mix(x)

        return x


class MLPMixer(nn.Module):

    def __init__(self, in_channels, dim, num_classes, patch_size, image_size, depth, token_dim, channel_dim):
        super().__init__()

        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        self.num_patch =  (image_size // patch_size) ** 2
        self.to_patch_embedding = nn.Sequential(
            nn.Conv2d(in_channels, dim, patch_size, patch_size),
            Rearrange('b c h w -> b (h w) c'),
        )

        self.mixer_blocks = nn.ModuleList([])

        for _ in range(depth):
            self.mixer_blocks.append(MixerBlock(dim, self.num_patch, token_dim, channel_dim))

        self.layer_norm = nn.LayerNorm(dim)

        self.mlp_head = nn.Sequential(
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):

        x = self.to_patch_embedding(x)

        for mixer_block in self.mixer_blocks:
            x = mixer_block(x)

        x = self.layer_norm(x)

        x = x.mean(dim=1)

        return self.mlp_head(x)


if __name__ == '__main__':

    dataset = customDataset()
    train_loader = DataLoader(dataset,batch_size=1,shuffle=True)

    mixer_model = MLPMixer(in_channels=3, 
        image_size=416, 
        patch_size=16, 
        num_classes=2,
        dim=512, 
        depth=8, 
        token_dim=256, 
        channel_dim=2048)

    for i, data in enumerate(train_loader,0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)
        inputs, labels = inputs.float(), labels.float()
        outputs = mixer_model(inputs)

Thanks.

You could save the model’s state_dict as described in the Saving and Loading Models Tutorial.

1 Like

Thanks so much for your kind reply. So, should I insert the save statement after the last for-loop.

Yes, you should store it after the training is done.
Currently you are using a single epoch, so depending on your use case, you might want to train longer.

1 Like

Awesome! Thanks so much for your confirmation.