Memory usage with pytorch-cpu in Windows

import torch as T
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np


if T.cuda.is_available():
    T.cuda.set_device(2)

function = {'relu': lambda x, **kwargs: F.relu(x, True), 'linear': lambda x, **kwargs: x,
            'lrelu': lambda x, **kwargs: lrelu(x, **kwargs), 'tanh': lambda x, **kwargs: F.tanh(x)}

init = {'He_normal': nn.init.kaiming_normal_, 'He_uniform': nn.init.kaiming_uniform,
        'Xavier_normal': nn.init.xavier_normal_, 'Xavier_uniform': nn.init.xavier_uniform}


def lrelu(x, **kwargs):
    alpha = kwargs.get('alpha', 0.1)
    return F.leaky_relu(x, alpha, True)


class Layer(nn.Module):
    def __init__(self):
        super(Layer, self).__init__()
        self.trainable = []
        self.regularizable = []

    @property
    def output_shape(self):
        return

    def forward(self, input):
        raise NotImplementedError

    def reset(self):
        pass


class BNLayer(Layer):
    def __init__(self, input_shape, epsilon=1e-4, running_average_factor=1e-1, activation='relu', no_scale=False,
                 show=True, **kwargs):
        """
        :param input_shape:
        :param epsilon:
        :param running_average_factor:
        :param activation:
        :param no_scale:
        :param args:
        """
        super(BNLayer, self).__init__()

        self.input_shape = tuple(input_shape) if isinstance(input_shape, (tuple, list)) else input_shape
        self.epsilon = np.float32(epsilon)
        self.running_average_factor = running_average_factor
        self.activation = function[activation]
        self.no_scale = no_scale
        self.kwargs = kwargs

        self.bn = nn.BatchNorm1d(self.input_shape, epsilon, running_average_factor) if isinstance(input_shape, int) \
            else nn.BatchNorm2d(self.input_shape[0], epsilon, running_average_factor)
        params = list(self.bn.parameters())
        self.trainable = [params[1]] if self.no_scale else list(params)
        self.regularizable = [params[0]] if not self.no_scale else []

        self.W_values = self.bn.weight.data.numpy().copy()
        self.b_values = self.bn.bias.data.numpy().copy()

        if show:
            print(self.bn)

    def forward(self, input):
        input = self.activation(self.bn(input), **self.kwargs)
        return input

    @property
    def output_shape(self):
        return self.input_shape


class Conv2DLayer(Layer):
    def __init__(self, input_shape, num_filters, filter_size, init_mode='Xavier_normal', no_bias=True,
                 border_mode='half', stride=(1, 1), dilation=(1, 1), activation='relu', groups=1, show=True, **kwargs):
        """
        :param input_shape:
        :param num_filters:
        :param filter_size:
        :param init_mode: Xavier_normal, Xavier_uniform, He_normal, He_uniform
        :param no_bias:
        :param border_mode:
        :param stride:
        :param dilation:
        :param activation:
        :param groups:
        :param args:
        """
        assert isinstance(input_shape, list) or isinstance(input_shape, tuple), \
            'input_shape must be list or tuple. Received %s' % type(input_shape)
        assert len(input_shape) == 3, \
            'input_shape must have 3 elements. Received %d' % len(input_shape)
        assert isinstance(num_filters, int) and isinstance(filter_size, (int, list, tuple))
        assert isinstance(border_mode, (int, list, tuple, str)), 'border_mode should be either \'int\', ' \
                                                                 '\'list\', \'tuple\' or \'str\', got {}'.format(type(border_mode))
        assert isinstance(stride, (int, list, tuple))
        super(Conv2DLayer, self).__init__()
        self.input_shape = tuple(input_shape)
        self.filter_shape = (num_filters, input_shape[0], filter_size[0], filter_size[1]) if isinstance(filter_size, (list, tuple)) \
            else (num_filters, input_shape[0], filter_size, filter_size)
        self.no_bias = no_bias
        self.activation = function[activation]
        self.init_mode = init_mode
        self.border_mode = border_mode
        self.stride = tuple(stride) if isinstance(stride, (tuple, list)) else (stride, stride)
        self.dilation = (dilation, dilation) if isinstance(dilation, int) else dilation
        self.groups = groups
        self.kwargs = kwargs

        k1, k2 = self.filter_shape[2] + (self.filter_shape[2] - 1)*(self.dilation[0] - 1), \
                 self.filter_shape[3] + (self.filter_shape[3] - 1)*(self.dilation[1] - 1)
        if isinstance(self.border_mode, str):
            if self.border_mode == 'half':
                self.padding = (k1 // 2, k2 // 2)
            elif self.border_mode == 'valid':
                self.padding = (0, 0)
            elif self.border_mode == 'full':
                self.padding = (k1 - 1, k2 - 1)
            else:
                raise NotImplementedError
        elif isinstance(self.border_mode, (list, tuple)):
            self.padding = tuple(self.border_mode)
        elif isinstance(self.border_mode, int):
            self.padding = (self.border_mode, self.border_mode)
        else:
            raise NotImplementedError

        self.conv = nn.Conv2d(int(input_shape[0]), num_filters, filter_size, stride, self.padding, dilation, bias=not no_bias, groups=groups)
        params = list(self.conv.parameters())
        self.trainable = list(params)
        self.regularizable = [params[0]]

        init[init_mode](self.conv.weight)
        self.W_values = self.conv.weight.data.numpy().copy()
        if not self.no_bias:
            self.b_values = np.zeros(self.filter_shape[0], dtype='float32')
            self.conv.bias.data = T.tensor(self.b_values)

        if show:
            print(self)

    def forward(self, input):
        input = self.activation(self.conv(input), **self.kwargs)
        return input

    @property
    def output_shape(self):
        size = list(self.input_shape)
        assert len(size) == 3, "Shape must consist of 3 elements only"

        k1, k2 = self.filter_shape[2] + (self.filter_shape[2] - 1)*(self.dilation[0] - 1), \
                 self.filter_shape[3] + (self.filter_shape[3] - 1)*(self.dilation[1] - 1)

        size[1] = (size[1] - k1 + 2*self.padding[0]) // self.stride[0] + 1
        size[2] = (size[2] - k2 + 2*self.padding[1]) // self.stride[1] + 1

        size[0] = self.filter_shape[0]
        return tuple(size)


class ConvBNAct(Layer):
    def __init__(self, input_shape, num_filters, filter_size, init_mode='Xavier_normal', no_bias=True,
                 border_mode='half', stride=1, activation='relu', dilation=1, epsilon=1e-4, running_average_factor=1e-1,
                 no_scale=False, groups=1, show=True, **kwargs):
        super(ConvBNAct, self).__init__()
        self.input_shape = input_shape
        self.num_filters = num_filters
        self.filter_size = filter_size
        self.activation = activation
        self.stride = stride
        self.padding = border_mode
        self.conv = Conv2DLayer(input_shape, num_filters, filter_size, init_mode, no_bias, border_mode, stride,
                                dilation, activation='linear', groups=groups, show=False, **kwargs)
        self.bn = BNLayer(self.conv.output_shape, epsilon, running_average_factor, activation, no_scale, show=False, **kwargs)
        self.trainable += self.conv.trainable + self.bn.trainable
        self.regularizable += self.conv.regularizable + self.bn.regularizable

        if show:
            print(self)

    def forward(self, input):
        input = self.conv(input)
        input = self.bn(input)
        return input

    @property
    def output_shape(self):
        return self.bn.output_shape


class StackingConv(Layer):
    def __init__(self, input_shape, num_filters, filter_size, num_layers, batch_norm=False, init_mode='Xavier_normal',
                 no_bias=True, border_mode='half', stride=(1, 1), dilation=(1, 1), activation='relu', groups=1, **kwargs):
        assert num_layers > 1, 'num_layers must be greater than 1, got %d' % num_layers
        super(StackingConv, self).__init__()
        self.input_shape = input_shape
        self.num_filters = num_filters
        self.filter_size = filter_size
        self.stride = stride
        self.activation = activation
        self.num_layers = num_layers
        self.batch_norm = batch_norm

        self.block = nn.Sequential()
        shape = tuple(input_shape)
        conv_layer = ConvBNAct if batch_norm else Conv2DLayer
        for num in range(num_layers - 1):
            self.block.add_module('Stacking_conv%d' % (num+1), conv_layer(input_shape=shape, num_filters=num_filters,
                                                                          filter_size=filter_size, init_mode=init_mode,
                                                                          stride=1, border_mode=border_mode, dilation=dilation,
                                                                          activation=activation, no_bias=no_bias, show=False, **kwargs))
            shape = self.block[-1].output_shape
            self.trainable += self.block[-1].trainable
            self.regularizable += self.block[-1].regularizable
        self.block.add_module('Stacking_conv%d' % num_layers, conv_layer(input_shape=self.block[-1].output_shape,
                                                                         num_filters=num_filters, no_bias=no_bias,
                                                                         filter_size=filter_size, init_mode=init_mode,
                                                                         stride=stride, border_mode=border_mode,
                                                                         dilation=dilation, activation=activation, show=False, **kwargs))
        self.trainable += self.block[-1].trainable
        self.regularizable += self.block[-1].regularizable

        print(self)

    def forward(self, input):
        input = self.block(input)
        return input

    @property
    def output_shape(self):
        return self.block[-1].output_shape


class ResNetBasicBlock(Layer):
    def __init__(self, input_shape, num_filters, filter_size=3, stride=1, activation='relu', groups=1, **kwargs):
        super(ResNetBasicBlock, self).__init__()
        self.input_shape = input_shape
        self.num_filters = num_filters
        self.filter_size = filter_size
        self.stride = stride
        self.activation = function[activation]
        self.groups = groups
        self.kwargs = kwargs
        self.convbnact1 = ConvBNAct(input_shape, num_filters, filter_size, 'He_normal', stride=stride, activation=activation,
                                    show=False, **kwargs)
        self.convbnact2 = ConvBNAct(self.convbnact1.output_shape, num_filters, filter_size, 'He_normal', stride=1,
                                    activation='linear', show=False, **kwargs)
        self.trainable += self.convbnact1.trainable + self.convbnact2.trainable
        self.regularizable += self.convbnact1.regularizable + self.convbnact2.regularizable

        self.downsample = lambda x: x
        if stride > 1 or input_shape[0] != num_filters:
            self.downsample = ConvBNAct(input_shape, num_filters, 1, stride=stride, no_bias=True, activation='linear')
            self.trainable += self.downsample.trainable
            self.regularizable += self.downsample.regularizable

        print(self)

    def forward(self, input):
        res = input
        input = self.convbnact1(input)
        input = self.convbnact2(input)
        input += self.downsample(res)
        return self.activation(input, **self.kwargs)

    @property
    def output_shape(self):
        return self.convbnact2.output_shape


class MyNet(nn.Sequential):
    def __init__(self, **kwargs):
        super(MyNet, self).__init__()

        self.input_tensor_shape = [3, 240, 320]
        self.network = {'encoder': nn.Sequential(), 'decoder': nn.Sequential()}

        kwargs = {'alpha': .2}
        subnet = 'encoder'
        self.network[subnet].add_module('first_conv', ConvBNAct(self.input_tensor_shape, 64, 7, 'He_normal',
                                                                       stride=2, activation='lrelu', **kwargs))

        block_size = list('ab')
        for c in block_size:
            self.network[subnet].add_module('ResBlock1_%s' % c,
                                            ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                    64, activation='lrelu', **kwargs))

        block_size = list('ab')
        for c in block_size:
            if c == 'a':
                self.network[subnet].add_module('ResBlock2_%s' % c,
                                                ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                        128, stride=2, activation='lrelu', **kwargs))
            else:
                self.network[subnet].add_module('ResBlock2_%s' % c,
                                                ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                        128, stride=1, activation='lrelu', **kwargs))

        block_size = list('ab')
        for c in block_size:
            if c == 'a':
                self.network[subnet].add_module('ResBlock3_%s' % c,
                                                ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                        256, stride=2, activation='lrelu', **kwargs))
            else:
                self.network[subnet].add_module('ResBlock3_%s' % c,
                                                ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                        256, stride=1, activation='lrelu', **kwargs))

        block_size = list('ab')
        for c in block_size:
            if c == 'a':
                self.network[subnet].add_module('ResBlock4_%s' % c,
                                                ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                        512, stride=2, activation='lrelu', **kwargs))
            else:
                self.network[subnet].add_module('ResBlock4_%s' % c,
                                                ResNetBasicBlock(self.network[subnet][-1].output_shape,
                                                                        512, stride=1, activation='lrelu', **kwargs))

        subnet = 'decoder'
        self.network[subnet].add_module('Resizing1',
                                        UpsamplingLayer(self.network['encoder'][-1].output_shape, scale_factor=2,
                                                               mode='bilinear'))
        self.network[subnet].add_module('Stackingconv1',
                                        StackingConv(self.network[subnet][-1].output_shape, 256, 5,
                                                            3, False, 'He_normal', activation='lrelu', **kwargs))
        #
        self.network[subnet].add_module('Resizing2',
                                        UpsamplingLayer(self.network[subnet][-1].output_shape, scale_factor=2,
                                                               mode='bilinear'))
        self.network[subnet].add_module('Stackingconv2',
                                        StackingConv(self.network[subnet][-1].output_shape, 128, 5,
                                                            5, False, 'He_normal', activation='lrelu', **kwargs))

        self.network[subnet].add_module('Resizing3',
                                        UpsamplingLayer(self.network[subnet][-1].output_shape, scale_factor=2,
                                                               mode='bilinear'))
        self.network[subnet].add_module('Stackingconv3',
                                        StackingConv(self.network[subnet][-1].output_shape, 128, 5,
                                                            7, False, 'He_normal', activation='lrelu', **kwargs))

        self.network[subnet].add_module('Resizing4',
                                        UpsamplingLayer(self.network[subnet][-1].output_shape, scale_factor=2,
                                                               mode='bilinear'))
        self.network[subnet].add_module('Stackingconv4',
                                        StackingConv(self.network[subnet][-1].output_shape, 64, 5,
                                                            9, False, 'He_normal', activation='lrelu', **kwargs))

        self.network[subnet].add_module('Output',
                                        Conv2DLayer(self.network[subnet][-1].output_shape, 3, 5, 'He_normal',
                                                           False, activation='tanh'))
        self.trainable = []
        for subnet in ('encoder', 'decoder'):
            for layer in self.network[subnet]:
                self.trainable += layer.trainable
        self.optimizer = optim.Adam(self.trainable, 1e-3)
        self.network['encoder'].cuda()
        self.network['decoder'].cuda()

    def forward(self, input):
        output = self.network['encoder'](input)
        output = self.network['decoder'](output)
        return output

    def learn(self, loss=None, pred=None, target=None, **kwargs):
        cost = loss if loss is not None else self.compute_cost(pred, target, **kwargs)
        cost.backward(retain_graph=True)
        self.optimizer.step()
        self.optimizer.zero_grad()


class UpsamplingLayer(Layer):
    def __init__(self, input_shape, new_shape=None, scale_factor=2, mode='bilinear'):
        """
        :param input_shape:
        :param new_shape:
        :param scale_factor:
        """
        assert len(input_shape) == 3, 'input_shape must have 3 elements. Received %d.' % len(input_shape)
        assert isinstance(scale_factor, (int, list, tuple)), 'scale_factor must be an int, a list or a tuple. ' \
                                                             'Received %s.' % type(scale_factor)
        super(UpsamplingLayer, self).__init__()
        self.input_shape = input_shape
        self.new_shape = new_shape
        self.scale_factor = scale_factor
        self.upsample = nn.Upsample(new_shape, mode=mode, align_corners=False) if new_shape is not None \
            else nn.Upsample(scale_factor=scale_factor, mode=mode, align_corners=False)
        print(self.upsample)

    @property
    def output_shape(self):
        return (self.input_shape[0], self.new_shape[0], self.new_shape[1]) if self.new_shape is not None \
            else (self.input_shape[0], self.input_shape[1]*self.scale_factor, self.input_shape[2]*self.scale_factor) \
            if isinstance(self.scale_factor, int) else (self.input_shape[0], self.input_shape[1]*self.scale_factor[0], self.input_shape[2]*self.scale_factor[1])

    def forward(self, input):
        input = self.upsample(input)
        return input


def train(**kwargs):
    for i in range(10):
        T.cuda.empty_cache()
        net = MyNet(**kwargs)
        for seq in range(10):
            X = T.rand(2, 3, 240, 320).cuda()
            iteration = 0
            while iteration < 10:
                iteration += 1
                if iteration % 100 == 0:
                    print('Iteration %d' % iteration)

                net.train()
                Y = net(X)

                loss = T.mean((Y - X) ** 2)
                net.learn(loss)
                cost = loss.cpu().data.numpy()

            net.eval()


if __name__ == '__main__':
    train()

This toy code can reproduce the problem. It occupies 10GB on Windows but only >4GB on Linux. Also it can reproduce the problem i reported here. The only problem i cant produce with this code is this. Seems like it happens only when the training loop is not trivial like this.