import torch as T
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
if T.cuda.is_available():
T.cuda.set_device(2)
function = {'relu': lambda x, **kwargs: F.relu(x, True), 'linear': lambda x, **kwargs: x,
'lrelu': lambda x, **kwargs: lrelu(x, **kwargs), 'tanh': lambda x, **kwargs: F.tanh(x)}
init = {'He_normal': nn.init.kaiming_normal_, 'He_uniform': nn.init.kaiming_uniform,
'Xavier_normal': nn.init.xavier_normal_, 'Xavier_uniform': nn.init.xavier_uniform}
def lrelu(x, **kwargs):
alpha = kwargs.get('alpha', 0.1)
return F.leaky_relu(x, alpha, True)
class Layer(nn.Module):
def __init__(self):
super(Layer, self).__init__()
self.trainable = []
self.regularizable = []
@property
def output_shape(self):
return
def forward(self, input):
raise NotImplementedError
def reset(self):
pass
class BNLayer(Layer):
def __init__(self, input_shape, epsilon=1e-4, running_average_factor=1e-1, activation='relu', no_scale=False,
show=True, **kwargs):
"""
:param input_shape:
:param epsilon:
:param running_average_factor:
:param activation:
:param no_scale:
:param args:
"""
super(BNLayer, self).__init__()
self.input_shape = tuple(input_shape) if isinstance(input_shape, (tuple, list)) else input_shape
self.epsilon = np.float32(epsilon)
self.running_average_factor = running_average_factor
self.activation = function[activation]
self.no_scale = no_scale
self.kwargs = kwargs
self.bn = nn.BatchNorm1d(self.input_shape, epsilon, running_average_factor) if isinstance(input_shape, int) \
else nn.BatchNorm2d(self.input_shape[0], epsilon, running_average_factor)
params = list(self.bn.parameters())
self.trainable = [params[1]] if self.no_scale else list(params)
self.regularizable = [params[0]] if not self.no_scale else []
self.W_values = self.bn.weight.data.numpy().copy()
self.b_values = self.bn.bias.data.numpy().copy()
if show:
print(self.bn)
def forward(self, input):
input = self.activation(self.bn(input), **self.kwargs)
return input
@property
def output_shape(self):
return self.input_shape
class Conv2DLayer(Layer):
def __init__(self, input_shape, num_filters, filter_size, init_mode='Xavier_normal', no_bias=True,
border_mode='half', stride=(1, 1), dilation=(1, 1), activation='relu', groups=1, show=True, **kwargs):
"""
:param input_shape:
:param num_filters:
:param filter_size:
:param init_mode: Xavier_normal, Xavier_uniform, He_normal, He_uniform
:param no_bias:
:param border_mode:
:param stride:
:param dilation:
:param activation:
:param groups:
:param args:
"""
assert isinstance(input_shape, list) or isinstance(input_shape, tuple), \
'input_shape must be list or tuple. Received %s' % type(input_shape)
assert len(input_shape) == 3, \
'input_shape must have 3 elements. Received %d' % len(input_shape)
assert isinstance(num_filters, int) and isinstance(filter_size, (int, list, tuple))
assert isinstance(border_mode, (int, list, tuple, str)), 'border_mode should be either \'int\', ' \
'\'list\', \'tuple\' or \'str\', got {}'.format(type(border_mode))
assert isinstance(stride, (int, list, tuple))
super(Conv2DLayer, self).__init__()
self.input_shape = tuple(input_shape)
self.filter_shape = (num_filters, input_shape[0], filter_size[0], filter_size[1]) if isinstance(filter_size, (list, tuple)) \
else (num_filters, input_shape[0], filter_size, filter_size)
self.no_bias = no_bias
self.activation = function[activation]
self.init_mode = init_mode
self.border_mode = border_mode
self.stride = tuple(stride) if isinstance(stride, (tuple, list)) else (stride, stride)
self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
self.groups = groups
self.kwargs = kwargs
k1, k2 = self.filter_shape[2] + (self.filter_shape[2] - 1)*(self.dilation[0] - 1), \
self.filter_shape[3] + (self.filter_shape[3] - 1)*(self.dilation[1] - 1)
if isinstance(self.border_mode, str):
if self.border_mode == 'half':
self.padding = (k1 // 2, k2 // 2)
elif self.border_mode == 'valid':
self.padding = (0, 0)
elif self.border_mode == 'full':
self.padding = (k1 - 1, k2 - 1)
else:
raise NotImplementedError
elif isinstance(self.border_mode, (list, tuple)):
self.padding = tuple(self.border_mode)
elif isinstance(self.border_mode, int):
self.padding = (self.border_mode, self.border_mode)
else:
raise NotImplementedError
self.conv = nn.Conv2d(int(input_shape[0]), num_filters, filter_size, stride, self.padding, dilation, bias=not no_bias, groups=groups)
params = list(self.conv.parameters())
self.trainable = list(params)
self.regularizable = [params[0]]
init[init_mode](self.conv.weight)
self.W_values = self.conv.weight.data.numpy().copy()
if not self.no_bias:
self.b_values = np.zeros(self.filter_shape[0], dtype='float32')
self.conv.bias.data = T.tensor(self.b_values)
if show:
print(self)
def forward(self, input):
input = self.activation(self.conv(input), **self.kwargs)
return input
@property
def output_shape(self):
size = list(self.input_shape)
assert len(size) == 3, "Shape must consist of 3 elements only"
k1, k2 = self.filter_shape[2] + (self.filter_shape[2] - 1)*(self.dilation[0] - 1), \
self.filter_shape[3] + (self.filter_shape[3] - 1)*(self.dilation[1] - 1)
size[1] = (size[1] - k1 + 2*self.padding[0]) // self.stride[0] + 1
size[2] = (size[2] - k2 + 2*self.padding[1]) // self.stride[1] + 1
size[0] = self.filter_shape[0]
return tuple(size)
class ConvBNAct(Layer):
def __init__(self, input_shape, num_filters, filter_size, init_mode='Xavier_normal', no_bias=True,
border_mode='half', stride=1, activation='relu', dilation=1, epsilon=1e-4, running_average_factor=1e-1,
no_scale=False, groups=1, show=True, **kwargs):
super(ConvBNAct, self).__init__()
self.input_shape = input_shape
self.num_filters = num_filters
self.filter_size = filter_size
self.activation = activation
self.stride = stride
self.padding = border_mode
self.conv = Conv2DLayer(input_shape, num_filters, filter_size, init_mode, no_bias, border_mode, stride,
dilation, activation='linear', groups=groups, show=False, **kwargs)
self.bn = BNLayer(self.conv.output_shape, epsilon, running_average_factor, activation, no_scale, show=False, **kwargs)
self.trainable += self.conv.trainable + self.bn.trainable
self.regularizable += self.conv.regularizable + self.bn.regularizable
if show:
print(self)
def forward(self, input):
input = self.conv(input)
input = self.bn(input)
return input
@property
def output_shape(self):
return self.bn.output_shape
class StackingConv(Layer):
def __init__(self, input_shape, num_filters, filter_size, num_layers, batch_norm=False, init_mode='Xavier_normal',
no_bias=True, border_mode='half', stride=(1, 1), dilation=(1, 1), activation='relu', groups=1, **kwargs):
assert num_layers > 1, 'num_layers must be greater than 1, got %d' % num_layers
super(StackingConv, self).__init__()
self.input_shape = input_shape
self.num_filters = num_filters
self.filter_size = filter_size
self.stride = stride
self.activation = activation
self.num_layers = num_layers
self.batch_norm = batch_norm
self.block = nn.Sequential()
shape = tuple(input_shape)
conv_layer = ConvBNAct if batch_norm else Conv2DLayer
for num in range(num_layers - 1):
self.block.add_module('Stacking_conv%d' % (num+1), conv_layer(input_shape=shape, num_filters=num_filters,
filter_size=filter_size, init_mode=init_mode,
stride=1, border_mode=border_mode, dilation=dilation,
activation=activation, no_bias=no_bias, show=False, **kwargs))
shape = self.block[-1].output_shape
self.trainable += self.block[-1].trainable
self.regularizable += self.block[-1].regularizable
self.block.add_module('Stacking_conv%d' % num_layers, conv_layer(input_shape=self.block[-1].output_shape,
num_filters=num_filters, no_bias=no_bias,
filter_size=filter_size, init_mode=init_mode,
stride=stride, border_mode=border_mode,
dilation=dilation, activation=activation, show=False, **kwargs))
self.trainable += self.block[-1].trainable
self.regularizable += self.block[-1].regularizable
print(self)
def forward(self, input):
input = self.block(input)
return input
@property
def output_shape(self):
return self.block[-1].output_shape
class ResNetBasicBlock(Layer):
def __init__(self, input_shape, num_filters, filter_size=3, stride=1, activation='relu', groups=1, **kwargs):
super(ResNetBasicBlock, self).__init__()
self.input_shape = input_shape
self.num_filters = num_filters
self.filter_size = filter_size
self.stride = stride
self.activation = function[activation]
self.groups = groups
self.kwargs = kwargs
self.convbnact1 = ConvBNAct(input_shape, num_filters, filter_size, 'He_normal', stride=stride, activation=activation,
show=False, **kwargs)
self.convbnact2 = ConvBNAct(self.convbnact1.output_shape, num_filters, filter_size, 'He_normal', stride=1,
activation='linear', show=False, **kwargs)
self.trainable += self.convbnact1.trainable + self.convbnact2.trainable
self.regularizable += self.convbnact1.regularizable + self.convbnact2.regularizable
self.downsample = lambda x: x
if stride > 1 or input_shape[0] != num_filters:
self.downsample = ConvBNAct(input_shape, num_filters, 1, stride=stride, no_bias=True, activation='linear')
self.trainable += self.downsample.trainable
self.regularizable += self.downsample.regularizable
print(self)
def forward(self, input):
res = input
input = self.convbnact1(input)
input = self.convbnact2(input)
input += self.downsample(res)
return self.activation(input, **self.kwargs)
@property
def output_shape(self):
return self.convbnact2.output_shape
class MyNet(nn.Sequential):
def __init__(self, **kwargs):
super(MyNet, self).__init__()
self.input_tensor_shape = [3, 240, 320]
self.network = {'encoder': nn.Sequential(), 'decoder': nn.Sequential()}
kwargs = {'alpha': .2}
subnet = 'encoder'
self.network[subnet].add_module('first_conv', ConvBNAct(self.input_tensor_shape, 64, 7, 'He_normal',
stride=2, activation='lrelu', **kwargs))
block_size = list('ab')
for c in block_size:
self.network[subnet].add_module('ResBlock1_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
64, activation='lrelu', **kwargs))
block_size = list('ab')
for c in block_size:
if c == 'a':
self.network[subnet].add_module('ResBlock2_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
128, stride=2, activation='lrelu', **kwargs))
else:
self.network[subnet].add_module('ResBlock2_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
128, stride=1, activation='lrelu', **kwargs))
block_size = list('ab')
for c in block_size:
if c == 'a':
self.network[subnet].add_module('ResBlock3_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
256, stride=2, activation='lrelu', **kwargs))
else:
self.network[subnet].add_module('ResBlock3_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
256, stride=1, activation='lrelu', **kwargs))
block_size = list('ab')
for c in block_size:
if c == 'a':
self.network[subnet].add_module('ResBlock4_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
512, stride=2, activation='lrelu', **kwargs))
else:
self.network[subnet].add_module('ResBlock4_%s' % c,
ResNetBasicBlock(self.network[subnet][-1].output_shape,
512, stride=1, activation='lrelu', **kwargs))
subnet = 'decoder'
self.network[subnet].add_module('Resizing1',
UpsamplingLayer(self.network['encoder'][-1].output_shape, scale_factor=2,
mode='bilinear'))
self.network[subnet].add_module('Stackingconv1',
StackingConv(self.network[subnet][-1].output_shape, 256, 5,
3, False, 'He_normal', activation='lrelu', **kwargs))
#
self.network[subnet].add_module('Resizing2',
UpsamplingLayer(self.network[subnet][-1].output_shape, scale_factor=2,
mode='bilinear'))
self.network[subnet].add_module('Stackingconv2',
StackingConv(self.network[subnet][-1].output_shape, 128, 5,
5, False, 'He_normal', activation='lrelu', **kwargs))
self.network[subnet].add_module('Resizing3',
UpsamplingLayer(self.network[subnet][-1].output_shape, scale_factor=2,
mode='bilinear'))
self.network[subnet].add_module('Stackingconv3',
StackingConv(self.network[subnet][-1].output_shape, 128, 5,
7, False, 'He_normal', activation='lrelu', **kwargs))
self.network[subnet].add_module('Resizing4',
UpsamplingLayer(self.network[subnet][-1].output_shape, scale_factor=2,
mode='bilinear'))
self.network[subnet].add_module('Stackingconv4',
StackingConv(self.network[subnet][-1].output_shape, 64, 5,
9, False, 'He_normal', activation='lrelu', **kwargs))
self.network[subnet].add_module('Output',
Conv2DLayer(self.network[subnet][-1].output_shape, 3, 5, 'He_normal',
False, activation='tanh'))
self.trainable = []
for subnet in ('encoder', 'decoder'):
for layer in self.network[subnet]:
self.trainable += layer.trainable
self.optimizer = optim.Adam(self.trainable, 1e-3)
self.network['encoder'].cuda()
self.network['decoder'].cuda()
def forward(self, input):
output = self.network['encoder'](input)
output = self.network['decoder'](output)
return output
def learn(self, loss=None, pred=None, target=None, **kwargs):
cost = loss if loss is not None else self.compute_cost(pred, target, **kwargs)
cost.backward(retain_graph=True)
self.optimizer.step()
self.optimizer.zero_grad()
class UpsamplingLayer(Layer):
def __init__(self, input_shape, new_shape=None, scale_factor=2, mode='bilinear'):
"""
:param input_shape:
:param new_shape:
:param scale_factor:
"""
assert len(input_shape) == 3, 'input_shape must have 3 elements. Received %d.' % len(input_shape)
assert isinstance(scale_factor, (int, list, tuple)), 'scale_factor must be an int, a list or a tuple. ' \
'Received %s.' % type(scale_factor)
super(UpsamplingLayer, self).__init__()
self.input_shape = input_shape
self.new_shape = new_shape
self.scale_factor = scale_factor
self.upsample = nn.Upsample(new_shape, mode=mode, align_corners=False) if new_shape is not None \
else nn.Upsample(scale_factor=scale_factor, mode=mode, align_corners=False)
print(self.upsample)
@property
def output_shape(self):
return (self.input_shape[0], self.new_shape[0], self.new_shape[1]) if self.new_shape is not None \
else (self.input_shape[0], self.input_shape[1]*self.scale_factor, self.input_shape[2]*self.scale_factor) \
if isinstance(self.scale_factor, int) else (self.input_shape[0], self.input_shape[1]*self.scale_factor[0], self.input_shape[2]*self.scale_factor[1])
def forward(self, input):
input = self.upsample(input)
return input
def train(**kwargs):
for i in range(10):
T.cuda.empty_cache()
net = MyNet(**kwargs)
for seq in range(10):
X = T.rand(2, 3, 240, 320).cuda()
iteration = 0
while iteration < 10:
iteration += 1
if iteration % 100 == 0:
print('Iteration %d' % iteration)
net.train()
Y = net(X)
loss = T.mean((Y - X) ** 2)
net.learn(loss)
cost = loss.cpu().data.numpy()
net.eval()
if __name__ == '__main__':
train()
This toy code can reproduce the problem. It occupies 10GB on Windows but only >4GB on Linux. Also it can reproduce the problem i reported here. The only problem i cant produce with this code is this. Seems like it happens only when the training loop is not trivial like this.