Hi, I’m very new in torch and trying to figure out things.
I created a model (resnet 18) and tested it in RLlib as follows:
import torch
from torch import nn
class resblock(nn.Module):
def __init__(self, in_channels, out_channels, downsample):
super().__init__()
if downsample:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=2),
nn.BatchNorm2d(out_channels)
)
else:
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.shortcut = nn.Sequential()
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.bn1 = nn.BatchNorm2d(out_channels)
self.bn2 = nn.BatchNorm2d(out_channels)
def forward(self, input):
shortcut = self.shortcut(input)
input = nn.ReLU()(self.bn1(self.conv1(input)))
input = nn.ReLU()(self.bn2(self.conv2(input)))
input = input + shortcut
return nn.ReLU()(input)
class ResNet18( TorchModelV2,nn.Module):
def __init__(self, obs_space, action_space, num_outputs, model_config,
name,resblock,in_channels):
# def __init__(self, in_channels, resblock, outputs=11):
TorchModelV2.__init__(self, obs_space, action_space, num_outputs,
model_config, name)
nn.Module.__init__(self)
print(input)
# ( h, in_channels) = obs_space.shape
self.in_channels = in_channels
self.layer0 = nn.Sequential(
nn.Conv2d(in_channels ,64 , kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.layer0x0 = nn.Sequential(
nn.Conv2d(1 ,64 , kernel_size=7, stride=2, padding=3),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU()
)
self.layer1 = nn.Sequential(
resblock(64, 64, downsample=False),
resblock(64, 64, downsample=False)
)
self.layer2 = nn.Sequential(
resblock(64, 128, downsample=True),
resblock(128, 128, downsample=False)
)
self.layer3 = nn.Sequential(
resblock(128, 256, downsample=True),
resblock(256, 256, downsample=False)
)
self.layer4 = nn.Sequential(
resblock(256, 512, downsample=True),
resblock(512, 512, downsample=False)
)
self.gap = torch.nn.AdaptiveAvgPool2d(1)
self.fc = torch.nn.Linear(512, num_outputs)
self.fcv = torch.nn.Linear(512, 1)
def forward(
self,
input ,
state,
seq_lens):
print((input["obs"]))
X = torch.unsqueeze(input["obs"].float(), 0)
input = X.permute(1,0,2,3)
# print('input',input)
input = self.layer0(input)
# print(type(1))
# # print(input["obs"].float().size(dim=2))
# input = torch.unsqueeze(input["obs"].float(), 0)
# if input["obs"].size(dim=0) == 32:
# input = self.layer0(input)
# else:
# input = self.layer0x0(input)
input = self.layer1(input)
input = self.layer2(input)
input = self.layer3(input)
input = self.layer4(input)
input = self.gap(input)
# print(input.size())
input_last = torch.flatten(input)
# print(input_last.size())
output = self.fc(input_last)
self.value_f = self.fcv(input_last)
output = torch.unsqueeze(output,0)
# print(output)
return output,state
def value_function(self):
assert self.value_f is not None, "must call forward() first"
return self.value_f
my input is (800,800) and has 1 channel (greyscaled) but in backpropagation time my input in forward method shows (20,800,800) instead of (1,800,800). 20 is my minibatch sgd size but I don’t know why it should use forward. is there any way to handle this? sorry if my question is dumb I’m very new.