Custom layer output reinitialization, is it a solutiom?

I have this custom layer

class tiling_layer(nn.Module):
    def __init__(self, x_dim, tile_dim):
        super(tiling, self).__init__()
        self.b,d,self.h,self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d/(tile_dim**2)) 
        out_h = int(self.h*tile_dim) 
        out_w = int(self.w*tile_dim) 
        self.tiled_out = torch.zeros(self.b, self.out_d, int(self.h*self.tile_dim), int(self.w*self.tile_dim), dtype=torch.float)
    def forward(self, x):
        tile_clone = self.tiled_out.clone()# I had to detach to emit inplace operation errir #.detach()
        for ds in range(self.out_d):
            d_start = ds*(self.tile_dim**2)
            d_end = (ds+1)*(self.tile_dim**2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws]
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs*self.tile_dim
                    h_end = (1+hs)*self.tile_dim
                    w_start = ws*self.tile_dim
                    w_end = (1+ws)*self.tile_dim
                    tile_clone[:, ds, h_start:h_end, w_start:w_end] = out_tile
        self.tiled_out = tile_clone
        return self.tiled_out

Which always gave me an inplace operation error in the loss.backward() function. I first had to detach it temporarily to see if it solves the problem until I found out my layer wasn’t batch aware and failed with batches > 1.
So I find out this other temporary solution to reinitialize the output tensor in forward like this:

class tiling_layer(nn.Module):
    def __init__(self, x_dim, tile_dim):
        super(tiling, self).__init__()
        self.b,d,self.h,self.w = x_dim
        self.tile_dim = tile_dim
        self.out_d = int(d/(tile_dim**2))
    def forward(self, x):
        self.b = x.shape[0]
        tile_clone = torch.zeros(self.b, self.out_d, int(self.h*self.tile_dim), int(self.w*self.tile_dim), dtype=torch.float)
        for ds in range(self.out_d):
            d_start = ds*(self.tile_dim**2)
            d_end = (ds+1)*(self.tile_dim**2)
            for hs in range(self.h):
                for ws in range(self.w):
                    tile_select = x[:, d_start:d_end, hs, ws]
                    out_tile = tile_select.view(self.b, self.tile_dim, self.tile_dim)
                    h_start = hs*self.tile_dim
                    h_end = (1+hs)*self.tile_dim
                    w_start = ws*self.tile_dim
                    w_end = (1+ws)*self.tile_dim
                    tile_clone[:, ds, h_start:h_end, w_start:w_end] = out_tile
        self.tiled_out = tile_clone
        return self.tiled_out

Is this a correct writing of the layer considering that this is the last layer and outputs the prediction? Thanks