Training a ViT Model for Multiple Classification

Hello, good day! I’m new to ViT model, but I have a project that I plan to use ViT on. So I basically have different images for each parameter (soil pH, moisture, and classification), I would like to ask if there’s any way to have them in one model when the 3 parameters have their own classes (acidic, neutral, alkaline for pH/ low, adequate, high or dry, wet for moisture/ peat, clay, silt for classification/soil type) ? Since I haven’t been able to find a dataset that labels their images with all of my parameters.

(While this is mostly prompted) – Should work as a skeleton?

Most parts have comments

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTModel, ViTConfig
from PIL import Image
import os
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)

# Define the device (GPU if available, else CPU)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Define the multi-task ViT model
class MultiTaskViT(nn.Module):
    def __init__(self, num_ph_classes, num_moisture_classes, num_soil_type_classes):
        super().__init__()
        # Load pre-trained ViT model
        self.vit = ViTModel.from_pretrained('google/vit-base-patch16-224')
        # Define separate classifiers for each task
        self.ph_classifier = nn.Linear(self.vit.config.hidden_size, num_ph_classes)
        self.moisture_classifier = nn.Linear(self.vit.config.hidden_size, num_moisture_classes)
        self.soil_type_classifier = nn.Linear(self.vit.config.hidden_size, num_soil_type_classes)

    def forward(self, pixel_values):
        # Pass input through ViT
        outputs = self.vit(pixel_values=pixel_values)
        pooled_output = outputs.pooler_output
        # Get logits for each task
        ph_logits = self.ph_classifier(pooled_output)
        moisture_logits = self.moisture_classifier(pooled_output)
        soil_type_logits = self.soil_type_classifier(pooled_output)
        return ph_logits, moisture_logits, soil_type_logits

# Custom dataset for soil images and multiple labels
class SoilDataset(Dataset):
    def __init__(self, image_dir, ph_labels, moisture_labels, soil_type_labels, transform=None):
        self.image_paths = [os.path.join(image_dir, img) for img in os.listdir(image_dir) if img.endswith(('.png', '.jpg', '.jpeg'))]
        self.ph_labels = ph_labels
        self.moisture_labels = moisture_labels
        self.soil_type_labels = soil_type_labels
        self.transform = transform

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load and preprocess image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        if self.transform:
            image = self.transform(image)
        # Return image and all labels
        return image, self.ph_labels[idx], self.moisture_labels[idx], self.soil_type_labels[idx]

# Training function
def train_model(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    for images, ph_labels, moisture_labels, soil_type_labels in tqdm(train_loader, desc="Training"):
        # Move data to device
        images = images.to(device)
        ph_labels = ph_labels.to(device)
        moisture_labels = moisture_labels.to(device)
        soil_type_labels = soil_type_labels.to(device)

        # Forward pass
        optimizer.zero_grad()
        ph_logits, moisture_logits, soil_type_logits = model(images)

        # Compute loss for each task
        ph_loss = criterion(ph_logits, ph_labels)
        moisture_loss = criterion(moisture_logits, moisture_labels)
        soil_type_loss = criterion(soil_type_logits, soil_type_labels)

        # Combine losses
        loss = ph_loss + moisture_loss + soil_type_loss
        total_loss += loss.item()

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

    return total_loss / len(train_loader)

# Evaluation function
def evaluate_model(model, val_loader, criterion, device):
    model.eval()
    total_loss = 0
    correct_predictions = {'ph': 0, 'moisture': 0, 'soil_type': 0}
    total_samples = 0

    with torch.no_grad():
        for images, ph_labels, moisture_labels, soil_type_labels in tqdm(val_loader, desc="Evaluating"):
            images = images.to(device)
            ph_labels = ph_labels.to(device)
            moisture_labels = moisture_labels.to(device)
            soil_type_labels = soil_type_labels.to(device)

            ph_logits, moisture_logits, soil_type_logits = model(images)

            # Compute loss
            ph_loss = criterion(ph_logits, ph_labels)
            moisture_loss = criterion(moisture_logits, moisture_labels)
            soil_type_loss = criterion(soil_type_logits, soil_type_labels)
            loss = ph_loss + moisture_loss + soil_type_loss
            total_loss += loss.item()

            # Compute accuracy
            _, ph_preds = torch.max(ph_logits, 1)
            _, moisture_preds = torch.max(moisture_logits, 1)
            _, soil_type_preds = torch.max(soil_type_logits, 1)

            correct_predictions['ph'] += (ph_preds == ph_labels).sum().item()
            correct_predictions['moisture'] += (moisture_preds == moisture_labels).sum().item()
            correct_predictions['soil_type'] += (soil_type_preds == soil_type_labels).sum().item()
            total_samples += ph_labels.size(0)

    avg_loss = total_loss / len(val_loader)
    accuracies = {k: v / total_samples for k, v in correct_predictions.items()}
    return avg_loss, accuracies

# Main execution
if __name__ == "__main__":
    # Hyperparameters
    num_epochs = 10
    batch_size = 32
    learning_rate = 1e-4

    # Define transforms for data augmentation and normalization
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])

    # Create dataset and data loaders
    # Note: You need to provide your own data and labels here
    train_dataset = SoilDataset('path/to/train/images', ph_labels, moisture_labels, soil_type_labels, transform=transform)
    val_dataset = SoilDataset('path/to/val/images', ph_labels, moisture_labels, soil_type_labels, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size)

    # Initialize model, optimizer, and loss function
    model = MultiTaskViT(num_ph_classes=3, num_moisture_classes=2, num_soil_type_classes=3).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Training loop
    for epoch in range(num_epochs):
        train_loss = train_model(model, train_loader, optimizer, criterion, device)
        val_loss, val_accuracies = evaluate_model(model, val_loader, criterion, device)
        
        print(f"Epoch {epoch+1}/{num_epochs}")
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val Accuracies: pH: {val_accuracies['ph']:.4f}, "
              f"Moisture: {val_accuracies['moisture']:.4f}, "
              f"Soil Type: {val_accuracies['soil_type']:.4f}")
        print("-" * 50)

    # Save the trained model
    torch.save(model.state_dict(), 'multi_task_vit_soil_model.pth')

    print("Training completed and model saved!")

Doesn’t the _init method takes the ph_labels,moisture_labels, and soil_type_labels, suggesting that there is a single dataset of soil images, which has three corresponding labels? and not three separate datasets

Sorry I am a bit confused then. Isn’t your objective just to have three separate classifiers?

Good day! That was initially the plan, but there is no image dataset that is labeled with pH classification, moisture level classification, soil organic matter classification, nitrogen classification, and soil type classification. So, I had to have different datasets. Technically speaking, I know that there should be different models since I have different datasets.

Would it be possible through a multiple model architecture?