import torch
import torch.nn as nn
import torch.nn.functional as F
class MaskedConv2d(nn.Conv2d):
"""
2D Convolution with autoregressive masking.
Type 'A' mask:
- Excludes current pixel
Type 'B' mask:
- Includes current pixel
"""
*def* \__init_\_(
*self*,
*mask_type*: str,
*in_channels*: int,
*out_channels*: int,
*kernel_size*: tuple,
*stride*: int = 1,
*padding*: int = 0,
*bias*: bool = True,
) -> None:
super().\__init_\_(
in_channels,
out_channels,
kernel_size,
stride,
padding,
*bias*=bias,
)
assert mask_type in {"A", "B"}, "mask_type must be 'A' or 'B'"
*self*.mask_type = mask_type
*# TODO:*
*# 1. Create a binary mask tensor of same shape as self.weight*
*self*.mask = torch.ones_like(*self*.weight)
*# 2. Zero out future pixels (autoregressive constraint)*
center = *self*.kernel_size\[0\] // 2
*self*.mask\[:, :, center + 1, :\] = 0
*self*.mask\[:, :, center, center + 1 :\] = 0
*# 3. For type 'A', also mask the center pixel*
if *self*.mask_type == "A":
*self*.mask\[:, :, center, center\] = 0
*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:
*# TODO:*
*# Multiply weights by mask BEFORE convolution*
masked_weights = *self*.weight \* *self*.mask
return F.conv2d(
x,
masked_weights,
*self*.bias,
*self*.stride,
*self*.padding,
)
class PixelCNNBlock(nn.Module):
"""
A single PixelCNN block:
MaskedConv2d -> ReLU
"""
*def* \__init_\_(
*self*,
*mask_type*: str,
*in_channels*: int,
*out_channels*: int,
*kernel_size*: int,
*padding*: int,
) -> None:
super().\__init_\_()
*# TODO:*
*# Initialize masked convolution*
*self*.conv = MaskedConv2d(
mask_type, in_channels, out_channels, kernel_size, *padding*=padding
)
*self*.relu = nn.ReLU()
*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:
*# TODO:*
*# Apply convolution and nonlinearity*
x = *self*.conv(x)
x = *self*.relu(x)
return x
class PixelCNN(nn.Module):
"""
Basic PixelCNN model for image generation.
"""
*def* \__init_\_(
*self*,
*in_channels*: int,
*hidden_channels*: int,
*n_layers*: int,
*kernel_size*: int,
*n_classes*: int = 256,
) -> None:
super().\__init_\_()
padding = kernel_size // 2
layers = \[\]
*# TODO:*
*# 1. First layer must use Mask-A*
layers.append(
PixelCNNBlock("A", in_channels, hidden_channels, kernel_size, padding)
)
*# 2. Remaining layers must use Mask-B*
for n in range(n_layers - 1):
layers.append(
PixelCNNBlock(
"B", hidden_channels, hidden_channels, kernel_size, padding
)
)
*# 3. Final 1x1 convolution to predict pixel distribution*
*self*.net = nn.Sequential(\*layers)
*self*.output_conv = nn.Conv2d(hidden_channels, n_classes, *kernel_size*=1)
*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Tensor of shape \[B, C, H, W\] with values in \[0, 255\]
Returns:
logits: \[B, n_classes, H, W\]
"""
*# TODO:*
*# 1. Normalize input*
x = x / 255
*# 2. Pass through network*
x = *self*.net(x)
x = *self*.output_conv(x)
*# 3. Return logits*
return x
————————-
import torch
from torch import nn
from torch.utils.hooks import RemovableHandle
class PixelRNN(nn.Module):
"""
Skeleton for PixelRNN.
References:
van den Oord et al., 2016 – Pixel Recurrent Neural Networks
https://arxiv.org/abs/1601.06759
"""
*def* \__init_\_(
*self*,
*in_channels*: int,
*hidden_channels*: int,
*n_layers*: int,
*n_classes*: int = 256,
):
super().\__init_\_()
*self*.in_channels = in_channels
*self*.hidden_channels = hidden_channels
*self*.n_layers = n_layers
*self*.n_classes = n_classes
*# TODO: Initialize vertical LSTM layers*
*self*.vertical_layers = nn.ModuleList()
for l in range(n_layers):
input_size = in_channels if l == 0 else hidden_channels
*self*.vertical_layers.append(
nn.LSTM(input_size, hidden_channels, n_layers, *batch_first*=True)
)
*# TODO: Initialize horizontal LSTM layers*
*self*.horizontal_layers = nn.ModuleList()
for l in range(n_layers):
input_size = in_channels if l == 0 else hidden_channels
*self*.horizontal_layers.append(
nn.LSTM(input_size, hidden_channels, n_layers, *batch_first*=True)
)
*# TODO: Output convolution to predict pixel distribution*
*self*.output_conv = nn.Conv2d(hidden_channels, n_classes, *kernel_size*=1)
*def* forward(*self*, *x*: torch.Tensor) -> torch.Tensor:
"""
Args:
x: Input tensor \[B, C, H, W\] with pixel values in \[0, 255\]
Returns:
logits: \[B, n_classes, H, W\]
"""
*# TODO:*
B, C, H, W = x.shape
*# 1. Normalize input to \[0, 1\] if needed*
x = x / 255
out = x
for v_lstm, h_lstm in zip(*self*.vertical_layers, *self*.horizontal_layers):
*# 2. Pass through vertical LSTMs*
out = out.permute(0, 3, 2, 1).reshape(B \* W, H, -1)
out_v, \_ = v_lstm(out)
out_v = out_v.reshape(B, W, H, -1).permute(0, 3, 2, 1)
*# 3. Pass through horizontal LSTMs (masked dependencies)*
out_v = out_v.permute(0, 2, 3, 1).reshape(B \* H, W, -1)
out_h, \_ = h_lstm(out_v)
out_h = out_h.reshape(B, H, W, -1).permute(0, 3, 1, 2)
*# 4. Compute output logits via 1x1 conv*
out = *self*.output_conv(x)
return out
*# Optional: Pixel-by-pixel sampling function*
@torch.no_grad()
*def* sample(
*self*, *batch_size*: int, *height*: int, *width*: int, *device*: torch.device
) -> torch.Tensor:
"""
TODO:
- Initialize zeros tensor
- Loop over rows and columns
- Compute logits for current pixel
- Sample from softmax
- Fill tensor sequentially
"""
x = torch.zeros(batch_size, *self*.in_channels, height, width, *device*=device)
for i in range(height):
for j in range(width):
logits = *self*.forward(x)
probs = nn.functional.softmax(logits\[:, :, i, j\], *dim*=1)
sampled = torch.multinomial(probs, 1).squeeze(1)
x\[:, :, i, j\] = sampled.float()
return x