Understanding how forward works

Hi, I’m very new in torch and trying to figure out things.
I created a model (resnet 18) and tested it in RLlib as follows:

import torch
from torch import nn
class resblock(nn.Module):
    def __init__(self, in_channels, out_channels, downsample):
        super().__init__()
        if downsample:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
                nn.BatchNorm2d(out_channels)
            )
        else:
            self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
            self.shortcut = nn.Sequential()

        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, input):
        shortcut = self.shortcut(input)
        input = nn.ReLU()(self.bn1(self.conv1(input)))
        input = nn.ReLU()(self.bn2(self.conv2(input)))
        input = input + shortcut
        return nn.ReLU()(input)
class ResNet18( TorchModelV2,nn.Module):
    def __init__(self, obs_space, action_space, num_outputs, model_config,
                 name,resblock,in_channels):
    # def __init__(self, in_channels, resblock, outputs=11):
        TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
                      model_config, name)
        nn.Module.__init__(self)

        print(input)
        # ( h, in_channels) = obs_space.shape
        self.in_channels = in_channels
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels ,64 , kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer0x0 = nn.Sequential(
            nn.Conv2d(1 ,64 , kernel_size=7, stride=2, padding=3),
            nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU()
        )
        self.layer1 = nn.Sequential(
            resblock(64, 64, downsample=False),
            resblock(64, 64, downsample=False)
        )

        self.layer2 = nn.Sequential(
            resblock(64, 128, downsample=True),
            resblock(128, 128, downsample=False)
        )

        self.layer3 = nn.Sequential(
            resblock(128, 256, downsample=True),
            resblock(256, 256, downsample=False)
        )


        self.layer4 = nn.Sequential(
            resblock(256, 512, downsample=True),
            resblock(512, 512, downsample=False)
        )

        self.gap = torch.nn.AdaptiveAvgPool2d(1)
        self.fc = torch.nn.Linear(512, num_outputs)
        self.fcv = torch.nn.Linear(512, 1)
    def forward(
        self,
        input ,
        state,
        seq_lens):
        print((input["obs"]))
        X = torch.unsqueeze(input["obs"].float(), 0)
        input = X.permute(1,0,2,3)
        # print('input',input)
        input = self.layer0(input)
      
        
        # print(type(1))
        # # print(input["obs"].float().size(dim=2))
        # input = torch.unsqueeze(input["obs"].float(), 0)
        # if input["obs"].size(dim=0) == 32:
        #   input = self.layer0(input)
        # else:
        #   input = self.layer0x0(input)
        input = self.layer1(input)
        input = self.layer2(input)
        input = self.layer3(input)
        input = self.layer4(input)
        input = self.gap(input)
        # print(input.size())
        input_last = torch.flatten(input)
        # print(input_last.size())
        output = self.fc(input_last)
        self.value_f = self.fcv(input_last)
        output = torch.unsqueeze(output,0)
        # print(output)

        return output,state

    def value_function(self):
        assert self.value_f is not None, "must call forward() first"
        return self.value_f   

my input is (800,800) and has 1 channel (greyscaled) but in backpropagation time my input in forward method shows (20,800,800) instead of (1,800,800). 20 is my minibatch sgd size but I don’t know why it should use forward. is there any way to handle this? sorry if my question is dumb I’m very new.

2D Conv layer accepts a tensor of dimension B x C x H x W. (B = batch size, C = number of channels, H = height, W = width).
The dataloader (torch.utils.data.DataLoader) collects data samples of dimension C x H x W from the dataset (torch.utils.data.Dataset) and appends a batch dimension (B). In your case, check the dimension of the data instances returned from the dataset. It should be 1 x 800 x 800.

Or if you want a simple solution, do input = input.unsqueeze(dim=1) to add a singleton dimension appropriately and get (20,1,800,800) dimensional tensor.

1 Like