Here you go:
import numpy as np
import hashlib
import torch, time
import random, math
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
use_cuda = torch.cuda.is_available()
Tensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = nn.BatchNorm3d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm3d(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
def gather_weights(m):
"""
Gathers the normed weights of each module
of a neural network class
values are stored in the global attribute weights
"""
for m in m.modules():
if isinstance(m, nn.Conv2d):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
elif isinstance(m, nn.Conv3d):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
elif isinstance(m, nn.BatchNorm2d):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
elif isinstance(m, nn.BatchNorm3d):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
elif isinstance(m, nn.Linear):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
elif isinstance(m, nn.LSTM):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
elif isinstance(m, nn.LSTMCell):
normed_weights.append(m.weight.norm(2).unsqueeze(0))
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=180, name=None):
"""
Name helps in controlling the hash value of this class
"""
self.name = name
self.inplanes = 64
self.num_classes = num_classes
super(ResNet, self).__init__()
self.conv1 = nn.Conv3d(36, 64, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = nn.BatchNorm3d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool3d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
self.loss_value_term = [] # (z - v)^2
self.loss_param_term = [] # pi^T log(p)
self.loss_log_prob_term = [] # c||\theta||^2
# this for logit probs head for angle probabilities
self.probhead = self._make_layer(block, num_classes, layers[4], stride=1)
for m in self.modules():
if isinstance(m, nn.Conv3d):
n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
m.weight.data.normal_(0, math.sqrt(2. / n))
elif isinstance(m, nn.BatchNorm3d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv3d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=stride, bias=False),
nn.BatchNorm3d(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# define probability distribution over state-action pairs
px = self.probhead(x)
px = px.view(px.size(0), -1)
s1, s2 = px.size()
linear_layer = nn.Linear(s1*s2, self.num_classes)
linear_layer = linear_layer.cuda() if use_cuda else linear_layer
probs = linear_layer(px)
probs = probs.cuda() if use_cuda else probs
probs = F.softmax(probs, dim=1)
valuehead = nn.Sequential(
nn.Linear(s1 * s2, 512),
nn.ReLU(inplace=True),
nn.Linear(512, 256),
nn.ReLU(inplace=True),
nn.Linear(256, 1),
)
valuehead = valuehead.cuda() if use_cuda else valuehead
value = F.tanh(valuehead(px))
del valuehead, linear_layer
return probs, value
def __hash__(self):
return int(hashlib.md5(self.name.encode('utf-8')).hexdigest(),16)
def __eq__(self,other):
if hash(self)==hash(other):
return True
return False
resnet = ResNet(BasicBlock, [3, 4, 6, 3, 1], num_classes=36, name='player1')
resnet = resnet.cuda() if use_cuda else resnet
running_state = Variable(torch.randn([1, 36, 122, 64, 64]))
# perform inference
probs, value = resnet(running_state)
normed_weights = []
resnet.apply(gather_weights)
print()
scaled_weights = torch.cat(([x for x in normed_weights]), 1)
scaled_weights = scaled_weights.mean(1, keepdim=True)
print('scaled_weights ', scaled_weights.size())
losses,log_prob_term,param_term,value_term = [[]]*4
for i in range(50):
value_term.append(value)
param_term.append(scaled_weights)
log_prob_term.append(probs.dot(probs.t().log()))
for val, log_prob, para in zip(value_term[::-1], log_prob_term[::-1], param_term[::-1]):
losses.append(val - log_prob + para)
losses = torch.cat(([x for x in losses]), 0)
print('losses: ', losses.size())
losses = (losses - losses.mean()) / (losses.std() + float(np.finfo(np.float32).eps))
loss = losses.sum(dim=0, keepdim=True)/losses.mean(dim=0, keepdim=True)
print('loss: ', loss.size())
loss.backward()
Gives an output like so:
losses: torch.Size([300, 1])
loss: torch.Size([1, 1])
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-9-05723febad27> in <module>()
13 loss = losses.sum(dim=0, keepdim=True)/losses.mean(dim=0, keepdim=True)
14 print('loss: ', loss.size())
---> 15 loss.backward()
~/anaconda3/envs/py35/lib/python3.5/site-packages/torch/autograd/variable.py in backward(self, gradient, retain_graph, create_graph, retain_variables)
165 Variable.
166 """
--> 167 torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
168
169 def register_hook(self, hook):
~/anaconda3/envs/py35/lib/python3.5/site-packages/torch/autograd/__init__.py in backward(variables, grad_variables, retain_graph, create_graph, retain_variables)
97
98 Variable._execution_engine.run_backward(
---> 99 variables, grad_variables, retain_graph)
100
101
RuntimeError: output and gradOutput shapes do not match: output [1 x 36], gradOutput [36 x 36] at /opt/conda/conda-bld/pytorch_1512383260527/work/torch/lib/THNN/generic/SoftMax.c:76