Help A3C PixelRL

Greetings,

I am trying to replicate A3C PixelRL network, and I am stuck in this error for some days. Anyone can help me?

My FCN:

class DilatedConvBlock(nn.Module):
        def __init__(self,
                        in_channels:int,
                        out_channels:int,
                        kernel_size:int=3,
                        stride:int=1,
                        padding:int=1):
                super().__init__()
                self.diconv = nn.Sequential(
                nn.Conv2d(in_channels=in_channels,
                        out_channels=out_channels,
                        kernel_size=kernel_size,
                        stride=stride,
                        padding=padding,
                        dilation=padding),
                nn.ReLU())

        def forward(self, x):
                return self.diconv(x)

class FCN(nn.Module):
        def __init__(self,
                        n_actions:int,
                        input_shape:int=1,
                        hidden_units:int=64,
                        output_shape:int=1):
                super().__init__()
                self.conv1 = nn.Sequential(
                                nn.Conv2d(in_channels=input_shape,
                                        out_channels=hidden_units,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1),
                                nn.ReLU()
                                )
                self.diconv2 = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=2)
                self.diconv3 = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=3)
                self.diconv4 = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=4)
                self.diconv5_pi = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=3)
                self.diconv6_pi = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=2)
                self.conv7_pi = nn.Sequential(
                                nn.Conv2d(in_channels=hidden_units,
                                        out_channels=n_actions,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1),
                                nn.Softmax(dim=1)
                                )
                self.diconv5_v = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=3)
                self.diconv6_v = DilatedConvBlock(in_channels=hidden_units,
                                                out_channels=hidden_units,
                                                kernel_size=3,
                                                stride=1,
                                                padding=2)
                self.conv7_v = nn.Conv2d(in_channels=hidden_units,
                                        out_channels=output_shape,
                                        kernel_size=3,
                                        stride=1,
                                        padding=1)

                self.conv1.apply(self.weight_init)
                self.diconv2.apply(self.weight_init)
                self.diconv3.apply(self.weight_init)
                self.diconv4.apply(self.weight_init)
                self.diconv5_pi.apply(self.weight_init)
                self.diconv5_v.apply(self.weight_init)
                self.diconv6_pi.apply(self.weight_init)
                self.diconv6_v.apply(self.weight_init)
                self.conv7_pi.apply(self.weight_init)
                self.conv7_v.apply(self.weight_init)

        def weight_init(self, m):
                classname = m.__class__.__name__
                if classname.find("Conv2d") != -1:
                        n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
                        m.weight.data.normal_(0, math.sqrt(2. / n))
                        if m.bias is not None:
                                m.bias.data.zero_()
                elif classname.find('Linear') != -1:
                        m.weight.data.normal_(0, 0.01)
                        m.bias.data = torch.ones(m.bias.data.size())

        def pi_and_v(self, x):
                h = self.conv1(x)
                h = self.diconv2(h)
                h = self.diconv3(h)
                h = self.diconv4(h)
                h_pi = self.diconv5_pi(h)
                h_pi = self.diconv6_pi(h_pi)
                p_out = self.conv7_pi(h_pi)
                h_v = self.diconv5_v(h)
                h_v = self.diconv6_v(h_v)
                v_out = self.conv7_v(h_v)

                return p_out, v_out

Agent:

def update(self, state_var, process_idx=0):
        assert self.t_start < self.t

        pre_train_weights = {}
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                pre_train_weights[name] = param.detach()

        if state_var is None:
            R = torch.zeros(size=(self.batch_size, 1, self.img_size[0], self.img_size[1]), device=self.device)
        else:
            print(f"[{process_idx}] State var update")
            _, vout = self.model.pi_and_v(state_var)
            R = vout.detach().to(self.device)

        pi_loss = 0
        v_loss = 0
        for i in reversed(range(self.t_start, self.t)):
            R = R * self.gamma
            past_reward = self.past_rewards[i]
            reward = np.zeros(R.shape)
            for b in range(self.batch_size):
                reward[b,0] += past_reward[b]
            R = R + torch.from_numpy(reward).to(self.device)
            del reward
            if self.use_average_reward:
                R = R - self.average_reward
            v = self.past_values[i]
            advantage = R - v.detach()
            if self.use_average_reward:
                self.average_reward += self.average_reward_tau * float(advantage.detach())

            # Accumulate gradients of policy
            log_prob = self.past_action_log_prob[i]
            entropy = self.past_action_entropy[i]

            # Log probability is increased proportinally to advantage
            pi_loss = pi_loss - (log_prob * advantage.detach())

            # Entropy is maximized
            pi_loss = pi_loss - (self.beta * entropy)

            # Accumulate gradients of value function
            v_loss = v_loss + ((v - R) ** 2 / 2)

        if self.pi_loss_coef != 1.0:
            pi_loss = pi_loss * self.pi_loss_coef

        if self.v_loss_coef != 1.0:
            v_loss = v_loss * self.v_loss_coef

        # Normalize the loss of sequences truncated by terminal states
        if self.keep_loss_scale_same and self.t - self.t_start < self.t_max:
            factor = self.t_max / (self.t - self.t_start)
            pi_loss = pi_loss * factor
            v_loss = v_loss * factor

        if self.normalize_grad_by_t_max:
            pi_loss = pi_loss/(self.t - self.t_start)
            v_loss = v_loss/(self.t - self.t_start)

        # if process_idx == 0:
        #   print(f"\npi_loss:\n{pi_loss}\n\nv_loss:\n{v_loss}")

        total_loss = torch.mean(pi_loss + v_loss)
        #self.loss_tracking.append(total_loss)
        print(f"[{process_idx}] Loss: {total_loss}")

        # Compute gradients using thread-specific model
        self.optimizer.zero_grad()
        total_loss.backward()

        ensure_shared_grads(self.model, self.shared_model)

        self.optimizer.step()
        post_train_weights = {}
        for name, param in self.model.named_parameters():
            if 'weight' in name:
                post_train_weights[name] = param.detach()

        # self.shared_model.load_state_dict(self.shared_model.state_dict())

        # print(f"Shared model params atualizado? {pre_train_weights != post_train_weights}")

        # if process_idx == 0:
        print(f'[{process_idx}] Update')

        self.clear_memory()

        self.t_start = self.t

    def act_and_train(self, state_var, reward, process_idx=0):
        # print(f"[{process_idx}] Act and train\nt: {self.t} | t_start: {self.t_start} | t_max: {self.t_max}")
        self.model.load_state_dict(self.shared_model.state_dict())

        self.past_rewards[self.t-1] = reward

        if self.t - self.t_start == self.t_max:
            self.update(state_var, process_idx=process_idx)

        self.past_states[self.t] = state_var

        pout, vout = self.model.pi_and_v(state_var)

        p_trans = pout.permute([0, 2, 3, 1])
        dist = Categorical(p_trans)
        action = dist.sample()

        self.past_action_log_prob[self.t] = dist.log_prob(action).unsqueeze(dim=1).to(self.device)
        self.past_action_entropy[self.t] = dist.entropy().unsqueeze(dim=1).to(self.device)
        self.past_values[self.t] = vout

        self.t += 1

        return action.detach().cpu().numpy()

def stop_episode_and_train(self, state_var, reward, done=False, process_idx=0):
        print(f'[{process_idx}] Stop and Train')

        self.past_rewards[self.t-1] =reward
        if done:
            self.update(None, process_idx=process_idx)
        else:
            self.update(state_var, process_idx=process_idx)