Code:
config.py
config = {
"model_name": "resnet18", # change to "mobilenet" or "resnet18" or "resnet50"
# Standard configuration
"image_width": 1224,
"image_height": 1024,
"batch_size": 16,
"train_dir": "./data/train",
"val_dir": "./data/val",
"test_dir": "./data/test",
# "test_dir": "./data",
"model_save_dir_path": "./artifacts/",
# training parameters
"num_epochs": 150,
"learning_rate": 0.0001,
"reduce_lr_on_plateau": {
"factor": 0.5,
"patience": 2,
"min_lr": 1e-6,
"threshold": 1e-2 # lower (e.g. 1e-5) for stable learning
},
"early_stop_patience": 15,
"freeze_epochs": 3 # number of epochs to keep backbone frozen
}
Orignal export_to_onnx.py:
"""Export a quadrant-patching model to ONNX format."""
import argparse
import os
import sys
import torch
from torch import nn
from torchvision import models
from torchvision.models import ResNet18_Weights
import torch.nn as nn
def load_model(image_width=64, image_height=64):
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
# Replace 3-channel conv1 with 1-channel
# model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Replace the classifier for binary output
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 1)
)
return model
# export_model_to_onnx_quadrants.py (only the wrapper shown)
class ONNXWrapper(nn.Module):
"""
NHWC uint8 full image(s) -> TL/TR/BL/BR tiles -> resize to (target_h, target_w)
-> normalize (ImageNet) -> model -> per-tile probs [4B,2] & preds [4B,1].
"""
def __init__(self, model, target_h: int, target_w: int):
super().__init__()
self.model = model
self.register_buffer("mean_imagenet", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
self.register_buffer("std_imagenet", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
self.target_h = target_h
self.target_w = target_w
def forward(self, image):
"""
Splits input image into quadrants, preprocesses, and runs through the model.
Args:
image (torch.Tensor): Input image tensor in NHWC uint8 format.
Returns:
tuple: Preprocessed tiles and predictions.
"""
x = image
tiles = []
tile_h, tile_w = 1024, 1224 # hardcoded for quadrant tiling
for i in range(2):
for j in range(2):
start_h = i * tile_h
end_h = (i + 1) * tile_h
start_w = j * tile_w
end_w = (j + 1) * tile_w
tile = x[:, start_h:end_h, start_w:end_w, :] # [B, tile_h, tile_w, 3]
tiles.append(tile)
tiles = torch.cat(tiles, dim=0) # (4, tile_h, tile_w, 3)
tiles = tiles.permute(0, 3, 1, 2).float() # (4, 3, tile_h, tile_w)
tiles = tiles / 255.0
tiles = (tiles - self.mean_imagenet) / self.std_imagenet
logits = self.model(tiles) # [4, 1] or [4, 2]
if logits.ndim == 1:
logits = logits.unsqueeze(1)
if logits.ndim == 2 and logits.shape[1] == 1:
conf = torch.sigmoid(logits)
preds = (conf > 0.5).to(logits.dtype)
else:
conf = torch.sigmoid(logits)
preds = torch.argmax(conf, dim=1)
return logits, preds # export logits to inspect raw scores
def export_to_onnx(model_name, weights_path, output_path, patch_w, patch_h,
dummy_full_w=2448, dummy_full_h=2048):
"""
Exports the model to ONNX format using quadrant tiling.
"""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model()
model.load_state_dict(torch.load(weights_path, map_location=device))
model.eval()
wrapper = ONNXWrapper(model, target_h=patch_h, target_w=patch_w)
dummy = torch.zeros(1, dummy_full_h, dummy_full_w, 3, dtype=torch.uint8)
torch.onnx.export(
wrapper,
dummy,
output_path,
input_names=["image"],
output_names=["tiles", "predictions"],
)
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Export quadrant-patching model to ONNX")
ap.add_argument("--model", required=True, help="Model name (e.g. resnet18)")
ap.add_argument("--weights", required=True, help="Path to .pth")
ap.add_argument("--output", required=True, help="Path to .onnx")
ap.add_argument("--width", type=int, default=1224, help="Trained PATCH width")
ap.add_argument("--height", type=int, default=1024, help="Trained PATCH height")
# You can leave dummy size at 2448x2048; dynamic axes make runtime flexible
ap.add_argument("--dummy_full_width", type=int, default=2448)
ap.add_argument("--dummy_full_height", type=int, default=2048)
args = ap.parse_args()
export_to_onnx(
args.model, args.weights, args.output,
patch_w=args.width, patch_h=args.height,
dummy_full_w=args.dummy_full_width, dummy_full_h=args.dummy_full_height
)
test_onnx.py:
# test_onnx_quadrants.py
"""
ONNX test script for RGB image classifier with 4-quadrant patching inside the ONNX graph.
Assumes the ONNX model:
- Accepts NHWC uint8 full images of arbitrary H×W
- Splits into TL/TR/BL/BR, resizes each tile to trained patch size
- Outputs:
class_probabilities: [4B, 2] (per-tile probs: [NoKnot, Knot])
class_predictions: [4B, 1] (per-tile binary after sigmoid>0.5)
Image-level decision = 1 if ANY tile's class-1 probability >= threshold, else 0.
"""
import argparse
import os
from collections import Counter
import numpy as np
from PIL import Image
import onnxruntime as ort
from torchvision.datasets import ImageFolder
from PIL import ImageDraw
from sklearn.metrics import classification_report, confusion_matrix, precision_score, recall_score, f1_score
PATCH_NAMES = ["TL", "TR", "BL", "BR"]
def create_session(onnx_path: str, providers=None) -> ort.InferenceSession:
if providers is None:
providers = ["CUDAExecutionProvider", "CPUExecutionProvider"]
try:
return ort.InferenceSession(onnx_path, providers=providers)
except Exception:
return ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
def load_image_nhwc_uint8(path: str, full_size: tuple[int, int]) -> np.ndarray:
"""
Open image, convert to RGB, resize to (W,H)=full_size, return NHWC uint8 with batch dim.
"""
W, H = full_size
img = Image.open(path)
arr = np.asarray(img, dtype=np.uint8) # [H, W, 3]
arr = np.expand_dims(arr, axis=0) # [1, H, W, 3]
return arr
def infer_image(sess: ort.InferenceSession, x_nhwc_uint8: np.ndarray, threshold: float):
"""
Run one image through ONNX. Returns (overall_pred, pos_scores, fired_idxs).
pos_scores are per-tile class-1 probabilities (length 4).
"""
print(x_nhwc_uint8.shape)
in_name = sess.get_inputs()[0].name
outputs = sess.run(None, {in_name: x_nhwc_uint8})
# Map names -> arrays
out_names = [o.name for o in sess.get_outputs()]
name2 = {o.name: arr for o, arr in zip(sess.get_outputs(), outputs)}
if "class_probabilities" in out_names:
probs = name2["class_probabilities"] # [4B, 2]; B=1 -> [4, 2]
else:
# Fallback: assume first output are probs/logits with shape [4,2].
probs = outputs[0]
if not (probs.ndim == 2 and probs.shape[1] == 2):
raise RuntimeError(f"Unexpected output shape {probs.shape}; expected [4,2].")
# Per-tile positive (class-1) probabilities
pos_scores = probs[:, 1] # shape [4]
fired_mask = pos_scores >= threshold
overall_pred = int(np.any(fired_mask))
fired_idxs = np.where(fired_mask)[0].tolist()
return overall_pred, pos_scores, fired_idxs
def main():
ap = argparse.ArgumentParser(description="Test ONNX model with quadrant tiling inside the graph")
ap.add_argument("--onnx", required=True, help="Path to model.onnx")
ap.add_argument("--root", required=True, help="Path to test ImageFolder root")
ap.add_argument("--full_width", type=int, default=2448, help="Full image width fed to ONNX")
ap.add_argument("--full_height", type=int, default=2048, help="Full image height fed to ONNX")
ap.add_argument("--threshold", type=float, default=0.5, help="Tile-level positive threshold")
args = ap.parse_args()
# Build dataset (labels & paths from folder names)
ds = ImageFolder(args.root)
classes = ds.classes
samples = ds.samples # list of (path, label)
print(f"Discovered {len(samples)} image(s) across classes: {classes}")
if len(samples) == 0:
print("No images found; exiting.")
return
# Create session
sess = create_session(args.onnx)
in_name = sess.get_inputs()[0].name
print(f"ONNX input: {in_name} | providers: {sess.get_providers()}")
y_true, y_pred = [], []
wrong_details = [] # tuples: (path, true, pred, fired_names, pos_scores)
# Evaluate
for path, label in samples:
x = load_image_nhwc_uint8(path, (args.full_width, args.full_height))
try:
pred, pos_scores, fired_idxs = infer_image(sess, x, args.threshold)
except Exception as e:
print(f"[ERROR] {path}: {e}")
continue
y_true.append(int(label))
y_pred.append(int(pred))
if pred != int(label):
fired_names = [PATCH_NAMES[i] for i in fired_idxs]
wrong_details.append((path, int(label), int(pred), fired_names, pos_scores.tolist()))
# break
# Reporting
print("\n🧪 Test Results")
print(classification_report(y_true, y_pred, target_names=classes, zero_division=0))
print("🧮 Confusion Matrix")
print(confusion_matrix(y_true, y_pred, labels=[0, 1]))
print("\nClass balance (y_true):", Counter(y_true))
print("Class balance (y_pred):", Counter(y_pred))
if wrong_details:
print("\nMisclassified images (tile probs shown as [TL, TR, BL, BR]):")
for (path, y, p, fired, pos_scores) in wrong_details:
# reorder pos_scores to TL,TR,BL,BR if needed (we assume wrapper order is TL,TR,BL,BR)
tile_str = ", ".join(f"{name}={pos_scores[i]:.3f}" for i, name in enumerate(PATCH_NAMES))
fired_str = ", ".join(fired) if fired else "none"
print(f" - {path} | true={y} pred={p} | fired: [{fired_str}] | {tile_str}")
else:
print("\nNo misclassified images 🎉")
# Return summary metrics if you want to call this as a function elsewhere
acc = float(np.mean(np.array(y_true) == np.array(y_pred))) if y_true else float("nan")
prec = precision_score(y_true, y_pred, zero_division=0)
rec = recall_score(y_true, y_pred, zero_division=0)
f1 = f1_score(y_true, y_pred, zero_division=0)
print(f"\nSummary: acc={acc:.4f} precision={prec:.4f} recall={rec:.4f} f1={f1:.4f}")
if __name__ == "__main__":
main()
test_pth.py
"""
Test script for RGB image classifier with 4-quadrant patching per image.
Each image is split into: top-left (TL), top-right (TR), bottom-left (BL), bottom-right (BR).
Image-level prediction = 1 if any quadrant predicts 1, else 0.
Args:
--model_path: Path to the trained model .pth file
Returns:
Prints classification report and confusion matrix.
"""
import argparse
import numpy as np
from collections import Counter
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
import torchvision.transforms as transforms
from sklearn.metrics import precision_score, recall_score, f1_score, classification_report, confusion_matrix
from torchvision.transforms import InterpolationMode
from config import config
from torchvision import models
from torchvision.models import ResNet18_Weights
import torch.nn as nn
def load_model(image_width=64, image_height=64):
model = models.resnet18(weights=ResNet18_Weights.DEFAULT)
# Replace 3-channel conv1 with 1-channel
# model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
# Replace the classifier for binary output
model.fc = nn.Sequential(
nn.Linear(model.fc.in_features, 1)
)
return model
PATCH_NAMES = ["TL", "TR", "BL", "BR"]
def get_test_transforms(width, height):
"""Transforms applied per PATCH (not on the whole image)."""
return transforms.Compose([
transforms.Lambda(lambda img: img.convert("RGB")),
transforms.Resize((height, width), interpolation=InterpolationMode.BILINEAR, antialias=True), # resize each patch to model input
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
class QuadrantPatchDataset(Dataset):
"""Wraps ImageFolder; for each image returns 4 patches (TL, TR, BL, BR)."""
def __init__(self, root, transform=None):
self.base = ImageFolder(root)
self.transform = transform
# Keep class names and image paths accessible
self.classes = self.base.classes
self.class_to_idx = self.base.class_to_idx
self.samples = self.base.samples
self.image_paths = [p for p, _ in self.base.samples]
def __len__(self):
return len(self.base)
def _quadrants(self, img):
# img is a PIL Image
W, H = img.size
xm, ym = W // 2, H // 2
# Boxes: (left, upper, right, lower)
print(f"Image size: W={W}, H={H}, xm={xm}, ym={ym}")
boxes = [
(0, 0, xm, ym), # TL
(xm, 0, W, ym), # TR
(0, ym, xm, H), # BL
(xm, ym, W, H), # BR
]
return [img.crop(b) for b in boxes]
def __getitem__(self, idx):
path, label = self.base.samples[idx]
img = self.base.loader(path) # PIL
quads = self._quadrants(img)
if self.transform is not None:
quads = [self.transform(q) for q in quads] # list of Tensors [C,H,W]
# Stack to [4, C, H, W]
patches = torch.stack(quads, dim=0)
return patches, label, path
def test_model(config, model_path: str):
"""Test RGB image classifier with quadrant patching and print metrics."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Dataset & loader
transform = get_test_transforms(config["image_width"], config["image_height"])
test_set = QuadrantPatchDataset(config["test_dir"], transform=transform)
test_loader = DataLoader(test_set, batch_size=1, shuffle=False)
# Load model
model = load_model()
state = torch.load(model_path, map_location=device)
model.load_state_dict(state)
model.eval().to(device)
# Evaluation
y_true, y_pred = [], []
wrong_images = []
wrong_details = [] # store which patches fired
with torch.no_grad():
for patches, label, path in test_loader:
# patches: [1, 4, C, H, W] -> [4, C, H, W]
patches = patches.squeeze(0).to(device)
# Model forward on the 4 patches as a batch of size 4
logits = model(patches) # expect [4, 1] or [4] for binary BCEWithLogitsLoss setup
if logits.ndim == 2 and logits.shape[1] == 1:
logits = logits.squeeze(1) # [4]
# Convert logits to probabilities for thresholding
probs = torch.sigmoid(logits) # [4]
preds_patch = (probs > 0.5).int().cpu().numpy() # [4] in {0,1}
# Aggregate to image-level prediction: 1 if any patch == 1
img_pred = int(np.any(preds_patch == 1))
y_pred.append(img_pred)
y_true.append(int(label))
if img_pred != int(label):
# Which patches fired?
fired = [PATCH_NAMES[i] for i, v in enumerate(preds_patch.tolist()) if v == 1]
wrong_images.append(path)
wrong_details.append((path, fired))
# Reporting
print("🧪 Test Results")
print(classification_report(y_true, y_pred, target_names=test_set.classes, zero_division=0))
print("🧮 Confusion Matrix")
print(confusion_matrix(y_true, y_pred, labels=[0, 1]))
print("\nClass balance (y_true):", Counter(y_true))
print("Class balance (y_pred):", Counter(y_pred))
if wrong_images:
print("\nMisclassified images:")
for (path, fired) in wrong_details:
fired_str = ", ".join(fired) if fired else "none"
print(f" - {path} | patches predicted 1: [{fired_str}]")
else:
print("\nNo misclassified images 🎉")
return {
"accuracy": float(np.mean(np.array(y_true) == np.array(y_pred))) if y_true else float("nan"),
"precision": precision_score(y_true, y_pred, zero_division=0),
"recall": recall_score(y_true, y_pred, zero_division=0),
"f1": f1_score(y_true, y_pred, zero_division=0)
}
if __name__ == "__main__":
ap = argparse.ArgumentParser(description="Test RGB Classifier Model (Quadrant Patching)")
ap.add_argument("--model_path", required=True, help="Path to the trained model .pth file")
args = ap.parse_args()
test_model(config, model_path=args.model_path)