Pixel RNN and CNN what is the difference?

import torch

import torch.nn as nn

import torch.nn.functional as F

class MaskedConv2d(nn.Conv2d):

"""

2D Convolution with autoregressive masking.



Type 'A' mask:

    - Excludes current pixel

Type 'B' mask:

    - Includes current pixel

"""



*def* \__init_\_(

    *self*,

    *mask_type*: str,

    *in_channels*: int,

    *out_channels*: int,

    *kernel_size*: tuple,

    *stride*: int = 1,

    *padding*: int = 0,

    *bias*: bool = True,

) -> None:

    super().\__init_\_(

        in_channels,

        out_channels,

        kernel_size,

        stride,

        padding,

        *bias*=bias,

    )



    assert mask_type in {"A", "B"}, "mask_type must be 'A' or 'B'"

    *self*.mask_type = mask_type



    *# TODO:*

    *# 1. Create a binary mask tensor of same shape as self.weight*

    *self*.mask = torch.ones_like(*self*.weight)

    *# 2. Zero out future pixels (autoregressive constraint)*

    center = *self*.kernel_size\[0\] // 2

    *self*.mask\[:, :, center + 1, :\] = 0

    *self*.mask\[:, :, center, center + 1 :\] = 0

    *# 3. For type 'A', also mask the center pixel*

    if *self*.mask_type == "A":

        *self*.mask\[:, :, center, center\] = 0



*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:

    *# TODO:*

    *# Multiply weights by mask BEFORE convolution*

    masked_weights = *self*.weight \* *self*.mask

    return F.conv2d(

        x,

        masked_weights,

        *self*.bias,

        *self*.stride,

        *self*.padding,

    )

class PixelCNNBlock(nn.Module):

"""

A single PixelCNN block:

    MaskedConv2d -> ReLU

"""



*def* \__init_\_(

    *self*,

    *mask_type*: str,

    *in_channels*: int,

    *out_channels*: int,

    *kernel_size*: int,

    *padding*: int,

) -> None:

    super().\__init_\_()



    *# TODO:*

    *# Initialize masked convolution*

    *self*.conv = MaskedConv2d(

        mask_type, in_channels, out_channels, kernel_size, *padding*=padding

    )

    *self*.relu = nn.ReLU()



*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:

    *# TODO:*

    *# Apply convolution and nonlinearity*

    x = *self*.conv(x)

    x = *self*.relu(x)



    return x

class PixelCNN(nn.Module):

"""

Basic PixelCNN model for image generation.

"""



*def* \__init_\_(

    *self*,

    *in_channels*: int,

    *hidden_channels*: int,

    *n_layers*: int,

    *kernel_size*: int,

    *n_classes*: int = 256,

) -> None:

    super().\__init_\_()



    padding = kernel_size // 2

    layers = \[\]

    *# TODO:*

    *# 1. First layer must use Mask-A*

    layers.append(

        PixelCNNBlock("A", in_channels, hidden_channels, kernel_size, padding)

    )

    *# 2. Remaining layers must use Mask-B*

    for n in range(n_layers - 1):

        layers.append(

            PixelCNNBlock(

                "B", hidden_channels, hidden_channels, kernel_size, padding

            )

        )

    *# 3. Final 1x1 convolution to predict pixel distribution*

    *self*.net = nn.Sequential(\*layers)

    *self*.output_conv = nn.Conv2d(hidden_channels, n_classes, *kernel_size*=1)



*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:

    """

    Args:

        x: Tensor of shape \[B, C, H, W\] with values in \[0, 255\]



    Returns:

        logits: \[B, n_classes, H, W\]

    """

    *# TODO:*

    *# 1. Normalize input*

    x = x / 255

    *# 2. Pass through network*

    x = *self*.net(x)

    x = *self*.output_conv(x)

    *# 3. Return logits*



    return x

————————-
import torch

from torch import nn

from torch.utils.hooks import RemovableHandle

class PixelRNN(nn.Module):

"""

Skeleton for PixelRNN.



References:

    van den Oord et al., 2016 – Pixel Recurrent Neural Networks

    https://arxiv.org/abs/1601.06759

"""



*def* \__init_\_(

    *self*,

    *in_channels*: int,

    *hidden_channels*: int,

    *n_layers*: int,

    *n_classes*: int = 256,

):

    super().\__init_\_()

    *self*.in_channels = in_channels

    *self*.hidden_channels = hidden_channels

    *self*.n_layers = n_layers

    *self*.n_classes = n_classes



    *# TODO: Initialize vertical LSTM layers*

    *self*.vertical_layers = nn.ModuleList()

    for l in range(n_layers):

        input_size = in_channels if l == 0 else hidden_channels

        *self*.vertical_layers.append(

            nn.LSTM(input_size, hidden_channels, n_layers, *batch_first*=True)

        )



    *# TODO: Initialize horizontal LSTM layers*

    *self*.horizontal_layers = nn.ModuleList()

    for l in range(n_layers):

        input_size = in_channels if l == 0 else hidden_channels

        *self*.horizontal_layers.append(

            nn.LSTM(input_size, hidden_channels, n_layers, *batch_first*=True)

        )



    *# TODO: Output convolution to predict pixel distribution*

    *self*.output_conv = nn.Conv2d(hidden_channels, n_classes, *kernel_size*=1)



*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:

    """

    Args:

        x: Input tensor \[B, C, H, W\] with pixel values in \[0, 255\]



    Returns:

        logits: \[B, n_classes, H, W\]

    """

    *# TODO:*

    B, C, H, W = x.shape

    *# 1. Normalize input to \[0, 1\] if needed*

    x = x / 255

    out = x

    for v_lstm, h_lstm in zip(*self*.vertical_layers, *self*.horizontal_layers):



        *# 2. Pass through vertical LSTMs*

        out = out.permute(0, 3, 2, 1).reshape(B \* W, H, -1)

        out_v, \_ = v_lstm(out)

        out_v = out_v.reshape(B, W, H, -1).permute(0, 3, 2, 1)



        *# 3. Pass through horizontal LSTMs (masked dependencies)*

        out_v = out_v.permute(0, 2, 3, 1).reshape(B \* H, W, -1)

        out_h, \_ = h_lstm(out_v)

        out_h = out_h.reshape(B, H, W, -1).permute(0, 3, 1, 2)



    *# 4. Compute output logits via 1x1 conv*

    out = *self*.output_conv(x)

    return out



*# Optional: Pixel-by-pixel sampling function*

@torch.no_grad()

*def* sample(

    *self*, *batch_size*: int, *height*: int, *width*: int, *device*: torch.device

) -> torch.Tensor:

    """

    TODO:

    - Initialize zeros tensor

    - Loop over rows and columns

    - Compute logits for current pixel

    - Sample from softmax

    - Fill tensor sequentially

    """



    x = torch.zeros(batch_size, *self*.in_channels, height, width, *device*=device)



    for i in range(height):

        for j in range(width):

            logits = *self*.forward(x)

            probs = nn.functional.softmax(logits\[:, :, i, j\], *dim*=1)

            sampled = torch.multinomial(probs, 1).squeeze(1)

            x\[:, :, i, j\] = sampled.float()



    return x