from functools import partial
from typing import Callable, List
import torch
import numpy as np
from torch import nn, Tensor
import torch.nn.functional as F
from torch.optim import SGD
ACTIVATIONS = {
'relu': F.relu,
'identity': nn.Identity,
}
class Network(nn.Module):
def __init__(self, shape: np.ndarray, has_bias: bool, init_fns: List[Callable], activation_fn_name: str):
super().__init__()
self._shape: np.ndarray = shape
self._has_bias: bool = has_bias
self._init_fns: List[Callable] = init_fns
self._activation_fn_name: str = activation_fn_name
self._activation_fn: Callable = ACTIVATIONS.get(activation_fn_name)
self._layers: List[nn.Module] = list()
self._init_layers()
def _init_layers(self):
for layer_idx in range(len(self._shape) - 1):
in_dim: int = self._shape[layer_idx]
out_dim: int = self._shape[layer_idx + 1]
layer: nn.Module = nn.Linear(in_features=in_dim, out_features=out_dim, bias=self._has_bias)
init_fn: Callable = self._init_fns[layer_idx]
layer.weight = nn.Parameter(init_fn(layer.weight))
self._layers.append(layer)
self.add_module('l_{}'.format(layer_idx), layer)
def forward(self, x: Tensor) -> Tensor:
for layer in self._layers[:-1]:
x: Tensor = self._activation_fn(layer(x))
out_layer: nn.Module = self._layers[-1]
return out_layer(x)
@property
def layers(self):
return self._layers
@layers.setter
def layers(self, value):
self._layers = value
class ReluNetwork(Network):
def __init__(self, shape: np.ndarray, init_fns: List[Callable],
has_bias: bool = False, activation_fn_name: str = 'relu'):
super().__init__(
shape=shape,
init_fns=init_fns,
has_bias=has_bias,
activation_fn_name=activation_fn_name
)
def single_forward_and_backward(init_fns):
shape = (3, 3, 2)
model = ReluNetwork(shape, init_fns)
x = torch.randn(shape[0])
y = torch.randn(shape[-1])
y_pred = model(x)
optimizer = SGD(model.parameters(), lr=.01)
optimizer.zero_grad()
criterion = nn.MSELoss()
loss = criterion(y_pred, y)
loss.backward()
print("=" * 20)
print('x', x)
print('y', y)
print('y_pred', y_pred)
print("initial weights:")
print("l_0:")
print(model.l_0.weight)
print("l_1:")
print(model.l_1.weight)
print('grads:')
print("l_0:")
print(model.l_0.weight.grad)
print("l_1:")
print(model.l_1.weight.grad)
optimizer.step()
print("final weights:")
print("l_0:")
print(model.l_0.weight)
print("l_1:")
print(model.l_1.weight)
if __name__ == '__main__':
INIT_FNS_WORK: List[Callable] = [
partial(nn.init.xavier_normal_, gain=1e-30),
partial(nn.init.uniform_, a=0., b=1.)
]
INIT_FNS_BREAK: List[Callable] = [
nn.init.zeros_,
partial(nn.init.uniform_, a=0., b=1.)
]
single_forward_and_backward(INIT_FNS_WORK)
single_forward_and_backward(INIT_FNS_BREAK)
Outputs:
====================
x tensor([1.0183, 1.7116, 1.7090])
y tensor([0.6103, 2.1445])
y_pred tensor([7.5593e-31, 1.5887e-30], grad_fn=<SqueezeBackward3>)
initial weights:
l_0:
Parameter containing:
tensor([[-1.8913e-31, -7.7141e-31, -3.0981e-33],
[ 7.3308e-31, 3.1623e-31, 1.0480e-30],
[ 7.5692e-31, -1.0767e-31, -5.9466e-31]], requires_grad=True)
l_1:
Parameter containing:
tensor([[0.3543, 0.2455, 0.0217],
[0.5538, 0.5160, 0.0304]], requires_grad=True)
grads:
l_0:
tensor([[ 0.0000, 0.0000, 0.0000],
[-1.2795, -2.1506, -2.1473],
[ 0.0000, 0.0000, 0.0000]])
l_1:
tensor([[-0.0000e+00, -1.8789e-30, -0.0000e+00],
[-0.0000e+00, -6.6025e-30, -0.0000e+00]])
final weights:
l_0:
Parameter containing:
tensor([[-1.8913e-31, -7.7141e-31, -3.0981e-33],
[ 1.2795e-02, 2.1506e-02, 2.1473e-02],
[ 7.5692e-31, -1.0767e-31, -5.9466e-31]], requires_grad=True)
l_1:
Parameter containing:
tensor([[0.3543, 0.2455, 0.0217],
[0.5538, 0.5160, 0.0304]], requires_grad=True)
====================
x tensor([-0.6772, 0.3680, 0.1908])
y tensor([-0.1825, -0.9870])
y_pred tensor([0., 0.], grad_fn=<SqueezeBackward3>)
initial weights:
l_0:
Parameter containing:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], requires_grad=True)
l_1:
Parameter containing:
tensor([[0.6648, 0.8862, 0.0178],
[0.9394, 0.6796, 0.6440]], requires_grad=True)
grads:
l_0:
tensor([[-0., 0., 0.],
[-0., 0., 0.],
[-0., 0., 0.]])
l_1:
tensor([[0., 0., 0.],
[0., 0., 0.]])
final weights:
l_0:
Parameter containing:
tensor([[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]], requires_grad=True)
l_1:
Parameter containing:
tensor([[0.6648, 0.8862, 0.0178],
[0.9394, 0.6796, 0.6440]], requires_grad=True)