Hi,
I’m trying to train image classification (ResNet152 model), I compiled and executed the below code but its throwing an error. any guidance will be apppreciated.
python version : 3.8.18
pytorch version : 2.1.1
here is the code,
import os, random, torch, torchvision, warnings
random.seed(32)
from tqdm.auto import tqdm
from typing import Dict, List, Tuple
from PIL import Image
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from typing import Tuple, Dict, List
from torchvision import models
from torchinfo import summary
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
warnings.filterwarnings("ignore")
class ImageFolderCustom(Dataset):
    def __init__(self, image_dir: str, class_file:str ,transform=None) -> None:
        with open(image_dir) as img_file:
            self.image_filename = ['/home/food-101/images/'+line.rstrip()+'.jpg' for line in img_file]
        self.transform = transform
        with open(class_file) as file:
            classes = [line.rstrip() for line in file]
        self.class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)}
    def load_image(self, index: int) -> Image.Image:
        image_path = self.image_filename[index]
        return Image.open(image_path).convert('RGB')
    
    def __len__(self) -> int:
        return len(self.image_filename)
    
    def __getitem__(self, index: int) -> Tuple[torch.Tensor, int]:
        img = self.load_image(index)
        class_name = self.image_filename[index].split('/')[-2]
        class_idx = self.class_to_idx[class_name]
        if self.transform:
            return self.transform(img), class_idx
        else:
            return transforms.ToTensor()(img), class_idx
# Augment train data
train_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])
test_transforms = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor()
])
train_dir = '/home/food-101/meta/train.txt'
test_dir = '/home/food-101/meta/test.txt'
class_file = '/home/food-101/meta/classes.txt'
train_data_custom = ImageFolderCustom(image_dir=train_dir, transform=train_transforms,class_file=class_file)
test_data_custom = ImageFolderCustom(image_dir=test_dir, class_file=class_file,transform=test_transforms)
train_dataloader = DataLoader(train_data_custom, batch_size=256,pin_memory=True, num_workers=10, shuffle=True)
test_dataloader = DataLoader(test_data_custom, batch_size=256, num_workers=10,pin_memory=True, shuffle=False)
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
res_model = models.resnet152(pretrained=True)
res_model.fc = nn.Sequential(nn.Linear(2048, 128),
               nn.ReLU(inplace=True),
               nn.Linear(128, 4))
res_model = res_model.to(device)
optimizer = optim.Adam(res_model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
scheduler = StepLR(optimizer, step_size=1, gamma=0.7)
def train_step(model: torch.nn.Module, 
               dataloader: torch.utils.data.DataLoader, 
               loss_fn: torch.nn.Module, 
               optimizer: torch.optim.Optimizer,
               device: torch.device) -> Tuple[float, float]:
    model.train()
    train_loss, train_acc = 0, 0
    for batch, (X, y) in enumerate(dataloader):
        optimizer.zero_grad(set_to_none=True)
        X, y = X.to(device,non_blocking=True), y.to(device,non_blocking=True)
        y_pred = model(X)
        loss = loss_fn(y_pred, y)
        train_loss += loss.item() 
        loss.backward()
        optimizer.step()
        y_pred_class = torch.argmax(torch.softmax(y_pred, dim=1), dim=1)
        train_acc += (y_pred_class == y).sum().item()/len(y_pred)
    train_loss = train_loss / len(dataloader)
    train_acc = train_acc / len(dataloader)
    return train_loss, train_acc
def test_step(model: torch.nn.Module, 
              dataloader: torch.utils.data.DataLoader, 
              loss_fn: torch.nn.Module,
              device: torch.device) -> Tuple[float, float]:
    model.eval() 
    test_loss, test_acc = 0, 0
    with torch.inference_mode():
        for batch, (X, y) in enumerate(dataloader):
            X, y = X.to(device,non_blocking=True), y.to(device,non_blocking=True)
            test_pred_logits = model(X)
            loss = loss_fn(test_pred_logits, y)
            test_loss += loss.item()
            test_pred_labels = test_pred_logits.argmax(dim=1)
            test_acc += ((test_pred_labels == y).sum().item()/len(test_pred_labels))
    test_loss = test_loss / len(dataloader)
    test_acc = test_acc / len(dataloader)
    return test_loss, test_acc
def train(model: torch.nn.Module, 
          train_dataloader: torch.utils.data.DataLoader, 
          test_dataloader: torch.utils.data.DataLoader, 
          optimizer: torch.optim.Optimizer,
          loss_fn: torch.nn.Module,
          epochs: int,
          device: torch.device) -> Dict[str, List]:
    results = {"train_loss": [],
               "train_acc": [],
               "test_loss": [],
               "test_acc": []
    }
    model.to(device)
    for epoch in tqdm(range(epochs)):
        train_loss, train_acc = train_step(model=model,
                                          dataloader=train_dataloader,
                                          loss_fn=loss_fn,
                                          optimizer=optimizer,
                                          device=device)
        test_loss, test_acc = test_step(model=model,
          dataloader=test_dataloader,
          loss_fn=loss_fn,
          device=device)
        print(
          f"Epoch: {epoch+1} | "
          f"train_loss: {train_loss:.4f} | "
          f"train_acc: {train_acc:.4f} | "
          f"test_loss: {test_loss:.4f} | "
          f"test_acc: {test_acc:.4f}"
        )
        results["train_loss"].append(train_loss)
        results["train_acc"].append(train_acc)
        results["test_loss"].append(test_loss)
        results["test_acc"].append(test_acc)
    return results
train(model=res_model,train_dataloader= train_dataloader, test_dataloader= test_dataloader, optimizer= optimizer,loss_fn= loss_fn,epochs= 10, device=device)
error message:
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[9], line 1
----> 1 out_result = train(model=compiled_model,train_dataloader= train_dataloader, test_dataloader= test_dataloader, optimizer= optimizer,loss_fn= loss_fn,epochs= 10, device=device)
Cell In[8], line 102, in train(model, train_dataloader, test_dataloader, optimizer, loss_fn, epochs, device)
    100 # Loop through training and testing steps for a number of epochs
    101 for epoch in tqdm(range(epochs)):
--> 102     train_loss, train_acc = train_step(model=model,
    103                                       dataloader=train_dataloader,
    104                                       loss_fn=loss_fn,
    105                                       optimizer=optimizer,
    106                                       device=device)
    107     test_loss, test_acc = test_step(model=model,
    108       dataloader=test_dataloader,
    109       loss_fn=loss_fn,
    110       device=device)
    112     # Print out what's happening
Cell In[8], line 35, in train_step(model, dataloader, loss_fn, optimizer, device)
     28 train_loss += loss.item() 
     30 # 3. Optimizer zero grad
     31 #optimizer.zero_grad(set_to_none=True)
     32 #optimizer.zero_grad()
     33 
     34 # 4. Loss backward
---> 35 loss.backward()
     37 # 5. Optimizer step
     38 optimizer.step()
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_tensor.py:492, in Tensor.backward(self, gradient, retain_graph, create_graph, inputs)
    482 if has_torch_function_unary(self):
    483     return handle_torch_function(
    484         Tensor.backward,
    485         (self,),
   (...)
    490         inputs=inputs,
    491     )
--> 492 torch.autograd.backward(
    493     self, gradient, retain_graph, create_graph, inputs=inputs
    494 )
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/autograd/__init__.py:251, in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
    246     retain_graph = create_graph
    248 # The reason we repeat the same comment below is that
    249 # some Python versions print out the first line of a multi-line function
    250 # calls in the traceback and some print out the last line
--> 251 Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
    252     tensors,
    253     grad_tensors_,
    254     retain_graph,
    255     create_graph,
    256     inputs,
    257     allow_unreachable=True,
    258     accumulate_grad=True,
    259 )
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/autograd/function.py:288, in BackwardCFunction.apply(self, *args)
    282     raise RuntimeError(
    283         "Implementing both 'backward' and 'vjp' for a custom "
    284         "Function is not allowed. You should only implement one "
    285         "of them."
    286     )
    287 user_fn = vjp_fn if vjp_fn is not Function.vjp else backward_fn
--> 288 return user_fn(self, *args)
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:3232, in aot_dispatch_autograd.<locals>.CompiledFunction.backward(ctx, *flat_args)
   3230     out = CompiledFunctionBackward.apply(*all_args)
   3231 else:
-> 3232     out = call_compiled_backward()
   3233 return out
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:3204, in aot_dispatch_autograd.<locals>.CompiledFunction.backward.<locals>.call_compiled_backward()
   3199     with tracing(saved_context), context(), track_graph_compiling(aot_config, "backward"):
   3200         CompiledFunction.compiled_bw = aot_config.bw_compiler(
   3201             bw_module, placeholder_list
   3202         )
-> 3204 out = call_func_with_args(
   3205     CompiledFunction.compiled_bw,
   3206     all_args,
   3207     steal_args=True,
   3208     disable_amp=disable_amp,
   3209 )
   3211 out = functionalized_rng_runtime_epilogue(CompiledFunction.metadata, out)
   3212 return tuple(out)
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py:1506, in call_func_with_args(f, args, steal_args, disable_amp)
   1504 with context():
   1505     if hasattr(f, "_boxed_call"):
-> 1506         out = normalize_as_list(f(args))
   1507     else:
   1508         # TODO: Please remove soon
   1509         # https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670
   1510         warnings.warn(
   1511             "Your compiler for AOTAutograd is returning a function that doesn't take boxed arguments. "
   1512             "Please wrap it with functorch.compile.make_boxed_func or handle the boxed arguments yourself. "
   1513             "See https://github.com/pytorch/pytorch/pull/83137#issuecomment-1211320670 for rationale."
   1514         )
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    326 dynamic_ctx.__enter__()
    327 try:
--> 328     return fn(*args, **kwargs)
    329 finally:
    330     set_eval_frame(prior)
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_dynamo/external_utils.py:17, in wrap_inline.<locals>.inner(*args, **kwargs)
     15 @functools.wraps(fn)
     16 def inner(*args, **kwargs):
---> 17     return fn(*args, **kwargs)
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_inductor/codecache.py:374, in CompiledFxGraph.__call__(self, inputs)
    373 def __call__(self, inputs) -> Any:
--> 374     return self.get_current_callable()(inputs)
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_inductor/compile_fx.py:628, in align_inputs_from_check_idxs.<locals>.run(new_inputs)
    626 def run(new_inputs):
    627     copy_misaligned_inputs(new_inputs, inputs_to_check)
--> 628     return model(new_inputs)
File ~/miniconda3/envs/pytorch2/lib/python3.8/site-packages/torch/_inductor/codecache.py:401, in _run_from_cache(compiled_graph, inputs)
    391     from .codecache import PyCodeCache
    393     compiled_graph.compiled_artifact = PyCodeCache.load_by_key_path(
    394         compiled_graph.cache_key,
    395         compiled_graph.artifact_path,
   (...)
    398         else (),
    399     ).call
--> 401 return compiled_graph.compiled_artifact(inputs)
File /tmp/torchinductor_iamalien/l2/cl27m74cyvyw6agpvll46v7gax6igmymkivu3ogafbumw3mgdrm3.py:5643, in call(args)
   5641 del primals_4
   5642 buf472 = buf471[0]
-> 5643 assert_size_stride(buf472, (s0, 64, 16, 16), (16384, 1, 1024, 64))
   5644 buf473 = buf471[1]
   5645 assert_size_stride(buf473, (64, 64, 1, 1), (64, 1, 64, 64))
AssertionError: expected size 64==64, stride 256==1 at dim=1
@ptrblck can you please guide me here?