I am a beginner in RL so was trying to export a stable baselines3 model(model was of CarRacing-v0 from gym module and CnnPolicy) during which i am facing this issue
And this is the code
import numpy as np
import torch as th
from torch import nn as nn
import torch.nn.functional as F
from torch import tensor
from stable_baselines3.common.vec_env import VecTransposeImage
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.features_extractor = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=32, kernel_size=8, stride=4),
nn.ReLU(),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=4, stride=2),
nn.ReLU(),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1),
nn.ReLU(),
nn.Flatten(start_dim=1, end_dim=-1),
nn.Linear(in_features=4096, out_features=512, bias=True),
nn.ReLU()
)
self.action_net = nn.Sequential(
nn.Linear(in_features=512, out_features=3, bias=True),
nn.ReLU()
)
def forward(self, x):
x = self.features_extractor(x)
x = self.action_net(x)
x = x.argmax()
return x
def getMove(obs):
model = Net()
model = model.float()
model.load_state_dict(state_dict)
model = model.to('cpu')
model = model.eval()
obs = obs.copy()
obs = VecTransposeImage.transpose_image(obs)
obs = th.as_tensor(obs).to('cpu')
obs = obs.float() / 255
obs = obs.float()
action = model(obs)
return action
How can i fix it?