Swap channel dimension with batch size

I’m developing an AI based on RL to play Connect 4. Neural network model did not work very well so I want to try CNN network. To do so I have CNN class and QTrainer class as follow:

class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = torch.nn.Sequential()
        self.model.add_module(
            "conv_1",
            torch.nn.Conv2d(in_channels=1, out_channels=16, kernel_size=2, stride=1),
        )
        self.model.add_module("relu_1", torch.nn.ReLU())
        self.model.add_module("max_pool", torch.nn.MaxPool2d(2))
        self.model.add_module(
            "conv_2",
            torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=2, stride=1),
        )
        self.model.add_module("relu_2", torch.nn.ReLU())
        self.model.add_module("flatten", torch.nn.Flatten())

        self.model.add_module("linear", torch.nn.Linear(in_features=32, out_features=7))

    def forward(self, x):
        if len(x.shape) == 3:
            x = x.unsqueeze(0)
        x = self.model(x)
        return x
class QTrainer2:
    def __init__(self, model, lr, gamma):
        self.lr = lr
        self.gamma = gamma
        self.model = model
        self.optimizer = optim.Adam(model.parameters(), lr=self.lr)
        self.criterion = nn.MSELoss()

    def train_step(self, state, action, reward, next_state, done):
        state = torch.tensor(np.array(state),  dtype=torch.float).unsqueeze(0)
        next_state = torch.tensor(np.array(next_state), dtype=torch.float).unsqueeze(0)
        action = torch.tensor(action, dtype=torch.float)
        reward = torch.tensor(reward, dtype=torch.float)
        # (n, x)
        if type(done) != tuple:
            # (1, x)
            state = torch.unsqueeze(state, 0)
            next_state = torch.unsqueeze(next_state, 0)
            action = torch.unsqueeze(action, 0)
            reward = torch.unsqueeze(reward, 0)
            done = (done, )

        # 1: predicted Q values with current state
        print(state.shape) # <--- print i'm talking about
        print(state)
        pred = self.model(state)
        target = pred.clone()
        for idx in range(len(done)):
            Q_new = reward[idx]
            if not done[idx]:
                Q_new = reward[idx] + self.gamma * torch.max(self.model(next_state[idx]))
            target[idx][torch.argmax(action).item()] = Q_new

        self.optimizer.zero_grad()
        loss = self.criterion(target, pred)
        loss.backward()

        self.optimizer.step()
        # 2: r + y * next_predicted Q value

QTrainer.train_step is called during training short memory and long memory. In short memory state is a numpy array 6x7 that gets transformed into torch and then unsqueezed (in this scenario gets unsqueezed twice - second time in if type(done) != tuple:), so the final shape of state during training short memory is torch.Size([1, 1, 6, 7]) and I pass it to a model pred = self.model(state). Algorithm works well.

In long memory my state is a tuple of matrices 6x7 as before, but number of matrices varry (depends how many moves long was a game). Therefore state that I pass to pred = self.model(state) is shape of torch.Size([1, 4, 6, 7]) and that second dimension (4) varries as number of moves in a game. Because of that when I’m trying to pass that state to pred = self.model(state) I get an error:

RuntimeError: Given groups=1, weight of size [16, 1, 2, 2], expected input[1, 4, 6, 7] to have 1 channels, but got 4 channels instead

shouldn’t state.shape for a long memory be for example [4, 1, 6, 7] instead of [1, 4, 6, 7]? I mean that 4 isn’t a BATCH dimension? If so is it possible to “swap” batch size with channel size that is: [1, 4, 6, 7][4, 1, 6, 7]? Or maybe I’m packing my states incorrectly?

def train_long_memory(self):
    if len(self.memory) > BATCH_SIZE:
        mini_sample = random.sample(self.memory, BATCH_SIZE)  # list of tuples
    else:
        mini_sample = self.memory
    states, actions, rewards, next_states, dones = zip(*mini_sample)
    self.trainer.train_step(states, actions, rewards, next_states, dones)