[Caffe2] SquaredL2Distance as output layer of CNN network. Error: Exception when creating gradient for [Cast]

Hello all,

I am getting the following error when I use the SquaredL2Distance operator as the output layer of my CNN network (1st attempt):

dist = model.net.SquaredL2Distance([label, fc9_], 'dist')
predictions = dist.AveragedLoss([], ['predictions'])
carlos@carlos-ubuntu:~/Documents/git/Caffe2_scripts/caffe2_torcs_predictor$ python CNNTrainer_dpnet_dpnet.py 
GPU mode selected
sgd optimizer selected
WARNING: Logging before InitGoogleLogging() is written to STDERR
W1202 21:31:09.273607 27931 operator.cc:89] Operator Conv does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "data" input: "conv1__w" input: "conv1__b" output: "conv1_" name: "" type: "Conv" arg { name: "kernel" i: 11 } arg { name: "pad_l" i: 4 } arg { name: "pad_b" i: 4 } arg { name: "exhaustive_search" i: 0 } arg { name: "stride" i: 4 } arg { name: "pad_r" i: 3 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 5 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
W1202 21:31:09.273897 27931 operator.cc:89] Operator Conv does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "data" input: "conv1__w" input: "conv1__b" output: "conv1_" name: "" type: "Conv" arg { name: "kernel" i: 11 } arg { name: "pad_l" i: 4 } arg { name: "pad_b" i: 4 } arg { name: "exhaustive_search" i: 0 } arg { name: "stride" i: 4 } arg { name: "pad_r" i: 3 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 5 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
I1202 21:31:09.273910 27931 operator.cc:167] Engine CUDNN is not available for operator Conv.
W1202 21:31:09.275804 27931 operator.cc:89] Operator MaxPool does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "conv1_" output: "pool1_" name: "" type: "MaxPool" arg { name: "kernel" i: 3 } arg { name: "pad_l" i: 1 } arg { name: "pad_b" i: 1 } arg { name: "cudnn_exhaustive_search" i: 0 } arg { name: "stride" i: 2 } arg { name: "pad_r" i: 0 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 1 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
W1202 21:31:09.276072 27931 operator.cc:89] Operator MaxPool does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "conv1_" output: "pool1_" name: "" type: "MaxPool" arg { name: "kernel" i: 3 } arg { name: "pad_l" i: 1 } arg { name: "pad_b" i: 1 } arg { name: "cudnn_exhaustive_search" i: 0 } arg { name: "stride" i: 2 } arg { name: "pad_r" i: 0 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 1 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
I1202 21:31:09.276083 27931 operator.cc:167] Engine CUDNN is not available for operator MaxPool.
W1202 21:31:09.276993 27931 operator.cc:89] Operator MaxPool does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "conv5_" output: "pool5_" name: "" type: "MaxPool" arg { name: "kernel" i: 3 } arg { name: "pad_l" i: 1 } arg { name: "pad_b" i: 0 } arg { name: "cudnn_exhaustive_search" i: 0 } arg { name: "stride" i: 2 } arg { name: "pad_r" i: 1 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 1 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
W1202 21:31:09.277192 27931 operator.cc:89] Operator MaxPool does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "conv5_" output: "pool5_" name: "" type: "MaxPool" arg { name: "kernel" i: 3 } arg { name: "pad_l" i: 1 } arg { name: "pad_b" i: 0 } arg { name: "cudnn_exhaustive_search" i: 0 } arg { name: "stride" i: 2 } arg { name: "pad_r" i: 1 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 1 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN"
I1202 21:31:09.277199 27931 operator.cc:167] Engine CUDNN is not available for operator MaxPool.
W1202 21:31:13.799096 27931 operator.cc:89] Operator ConvGradient does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "data" input: "conv1__w" input: "conv1__grad" output: "conv1__w_grad" output: "conv1__b_grad" output: "data_grad" name: "" type: "ConvGradient" arg { name: "kernel" i: 11 } arg { name: "pad_l" i: 4 } arg { name: "pad_b" i: 4 } arg { name: "exhaustive_search" i: 0 } arg { name: "stride" i: 4 } arg { name: "pad_r" i: 3 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 5 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN" is_gradient_op: true
W1202 21:31:13.799773 27931 operator.cc:89] Operator ConvGradient does not support the requested feature. Msg: The current padding scheme leads to unequal padding on the left and right, which is not supported by cudnn.. Proto is: input: "data" input: "conv1__w" input: "conv1__grad" output: "conv1__w_grad" output: "conv1__b_grad" output: "data_grad" name: "" type: "ConvGradient" arg { name: "kernel" i: 11 } arg { name: "pad_l" i: 4 } arg { name: "pad_b" i: 4 } arg { name: "exhaustive_search" i: 0 } arg { name: "stride" i: 4 } arg { name: "pad_r" i: 3 } arg { name: "order" s: "NHWC" } arg { name: "pad_t" i: 5 } device_option { device_type: 1 cuda_gpu_id: 0 } engine: "CUDNN" is_gradient_op: true
I1202 21:31:13.799808 27931 operator.cc:167] Engine CUDNN is not available for operator ConvGradient.
== Starting Training for 100 epochs ==
WARNING:caffe2.python.workspace:Original python traceback for operator `28` in network `train_net` in exception above (most recent call last):
WARNING:caffe2.python.workspace:  File "CNNTrainer_dpnet_dpnet.py", line 24, in <module>
WARNING:caffe2.python.workspace:  File "/home/carlos/Documents/git/Caffe2_scripts/caffe2_torcs_predictor/CNNCreator_dpnet_dpnet.py", line 146, in train
WARNING:caffe2.python.workspace:  File "/home/carlos/Documents/git/Caffe2_scripts/caffe2_torcs_predictor/CNNCreator_dpnet_dpnet.py", line 86, in create_model
Traceback (most recent call last):
  File "CNNTrainer_dpnet_dpnet.py", line 24, in <module>
    stepsize=8000
  File "/home/carlos/Documents/git/Caffe2_scripts/caffe2_torcs_predictor/CNNCreator_dpnet_dpnet.py", line 159, in train
    workspace.RunNet(train_model.net)
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/workspace.py", line 217, in RunNet
    StringifyNetName(name), num_iter, allow_fail,
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/workspace.py", line 178, in CallWithExceptionIntercept
    return func(*args, **kwargs)
RuntimeError: [enforce fail at tensor.h:495] IsType<T>(). Tensor type mismatch, caller expects elements to be float while tensor contains double Error from operator: 
input: "label" input: "fc9_" output: "dist" name: "" type: "SquaredL2Distance" device_option { device_type: 1 cuda_gpu_id: 0 }
** while accessing input: label

Then I tried to fix the error converting label to float (even though it is already float32, see below) using the Cast operator (2nd attempt)

label_float = model.Cast(label, None, to=core.DataType.FLOAT)
dist = model.net.SquaredL2Distance([label_float, fc9_], 'dist')
predictions = dist.AveragedLoss([], ['predictions'])

However, I got the following error:

carlos@carlos-ubuntu:~/Documents/git/Caffe2_scripts/caffe2_torcs_predictor$ python CNNTrainer_dpnet_dpnet.py 
GPU mode selected
Traceback (most recent call last):
  File "CNNTrainer_dpnet_dpnet.py", line 24, in <module>
    stepsize=8000
  File "/home/carlos/Documents/git/Caffe2_scripts/caffe2_torcs_predictor/CNNCreator_dpnet_dpnet.py", line 153, in train
    self.add_training_operators(train_model, predictions, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum)
  File "/home/carlos/Documents/git/Caffe2_scripts/caffe2_torcs_predictor/CNNCreator_dpnet_dpnet.py", line 103, in add_training_operators
    model.AddGradientOperators([loss])
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/model_helper.py", line 335, in AddGradientOperators
    self.grad_map = self.net.AddGradientOperators(*args, **kwargs)
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/core.py", line 1840, in AddGradientOperators
    self._net.op[skip:], ys)
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/core.py", line 1107, in GetBackwardPass
    return ir.GetBackwardPass(ys)
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/core.py", line 982, in GetBackwardPass
    forward_op_idx, all_input_to_grad)
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/core.py", line 932, in _GenerateGradientsForForwardOp
    forward_op, g_output)
  File "/home/carlos/Documents/git/pytorch/build/caffe2/python/core.py", line 1080, in GetGradientForOp
    format(op.type, e, str(op))
Exception: Exception when creating gradient for [Cast]:[enforce fail at cast_op.cc:139] argsHelper.HasSingleArgumentOfType<string>("from_type") || argsHelper.HasSingleArgumentOfType<int>("from_type"). Argument 'from_type' of type int or string is required to get the gradient of CastOp .
Op: 
input: "label"
output: "train_net/Cast"
name: ""
type: "Cast"
arg {
  name: "to"
  i: 1
}
device_option {
  device_type: 1
  cuda_gpu_id: 0
}

Furthermore, I created the LMDB training dataset which stores the image data as uint8 and the label as a multivalue of float64:
key: 00000001
image_data: shape: (210, 280, 3) type: uint8
indicators: shape: (14,) type: float64

How can I fix the error???

To Reproduce

Steps to reproduce the behavior:

  1. Project consists in two files: a trainer and a creator. Run the trainer which calls the train function of the creator.

Trainer.py:

import logging
import CNNCreator_dpnet_dpnet

if __name__ == "__main__":
    logging.basicConfig(level=logging.DEBUG)
    logger = logging.getLogger()
    handler = logging.FileHandler("train.log", "w", encoding=None, delay="true")
    logger.addHandler(handler)

    dpnet_dpnet = CNNCreator_dpnet_dpnet.CNNCreator_dpnet_dpnet()
    dpnet_dpnet.train(
        num_epoch=100,
        batch_size=32,
        context='gpu',
        opt_type='sgd',
        base_learning_rate=0.01,
        policy='step',
        stepsize=8000
    )

CNNCreator_dpnet_dpnet.py:

from caffe2.python import workspace, core, model_helper, brew, optimizer
from caffe2.python.predictor import mobile_exporter
from caffe2.proto import caffe2_pb2
import numpy as np
import logging
import os
import sys
import lmdb
import leveldb

class CNNCreator_dpnet_dpnet:

    module = None
    _current_dir_ = os.path.join('./')
    _data_dir_    = os.path.join(_current_dir_, 'data', 'dpnet_dpnet')
    _model_dir_   = os.path.join(_current_dir_, 'model', 'dpnet_dpnet')

    INIT_NET    = os.path.join(_model_dir_, 'init_net.pb')
    PREDICT_NET = os.path.join(_model_dir_, 'predict_net.pb')

    def add_input(self, model, batch_size, db, db_type, device_opts):
        with core.DeviceScope(device_opts):
            # load the data
            data_uint8, label = brew.db_input(
                model,
                blobs_out=["data_uint8", "label"],
                batch_size=batch_size,
                db=db,
                db_type=db_type,
            )
            # cast the data to float
            data = model.Cast(data_uint8, "data", to=core.DataType.FLOAT)

            # scale data from [0,255] down to [0,1]
            data = model.Scale(data, data, scale=float(1./256))

            # don't need the gradient for the backward pass
            data = model.StopGradient(data, data)

            return data, label

    def create_model(self, model, data, label, device_opts):
    	with core.DeviceScope(device_opts):

    		data = data
    		# data, output shape: {[3,210,280]}
      		conv1_ = brew.conv(model, data, 'conv1_', dim_in=3, dim_out=96, kernel=11, stride=4, pad_t=5, pad_b=4, pad_l=4, pad_r=3) #legacy_pad=1)
    		# conv1_, output shape: {[96,53,70]}
    		relu1_ = brew.relu(model, conv1_, conv1_)
    		pool1_ = brew.max_pool(model, relu1_, 'pool1_', kernel=3, stride=2, pad_t=1, pad_b=1, pad_l=1, pad_r=0) #legacy_pad=1)
    		# pool1_, output shape: {[96,27,35]}
      		conv2_ = brew.conv(model, pool1_, 'conv2_', dim_in=96, dim_out=256, kernel=5, stride=4, pad_t=1, pad_b=1, pad_l=1, pad_r=1) #legacy_pad=1)
    		# conv2_, output shape: {[256,7,9]}
    		relu2_ = brew.relu(model, conv2_, conv2_)
    		pool2_ = brew.max_pool(model, relu2_, 'pool2_', kernel=3, stride=2, pad_t=1, pad_b=1, pad_l=1, pad_r=1) #legacy_pad=1)
    		# pool2_, output shape: {[256,4,5]}
      		conv3_ = brew.conv(model, pool2_, 'conv3_', dim_in=256, dim_out=384, kernel=3, stride=1, pad_t=1, pad_b=1, pad_l=1, pad_r=1) #legacy_pad=1)
    		# conv3_, output shape: {[384,4,5]}
    		relu3_ = brew.relu(model, conv3_, conv3_)
      		conv4_ = brew.conv(model, relu3_, 'conv4_', dim_in=384, dim_out=384, kernel=3, stride=1, pad_t=1, pad_b=1, pad_l=1, pad_r=1) #legacy_pad=1)
    		# conv4_, output shape: {[384,4,5]}
    		relu4_ = brew.relu(model, conv4_, conv4_)
      		conv5_ = brew.conv(model, relu4_, 'conv5_', dim_in=384, dim_out=256, kernel=3, stride=1, pad_t=1, pad_b=1, pad_l=1, pad_r=1) #legacy_pad=1)
    		# conv5_, output shape: {[256,4,5]}
    		relu5_ = brew.relu(model, conv5_, conv5_)
    		pool5_ = brew.max_pool(model, relu5_, 'pool5_', kernel=3, stride=2, pad_t=1, pad_b=0, pad_l=1, pad_r=1) #legacy_pad=1)
    		# pool5_, output shape: {[256,2,3]}
    		fc5_ = brew.fc(model, pool5_, 'fc5_', dim_in=256 * 2 * 3, dim_out=4096)
    		# fc5_, output shape: {[4096,1,1]}
    		relu6_ = brew.relu(model, fc5_, fc5_)
    		dropout6_ = brew.dropout(model, relu6_, 'dropout6_', ratio=0.5, is_test=False)
    		fc6_ = brew.fc(model, dropout6_, 'fc6_', dim_in=4096, dim_out=4096)
    		# fc6_, output shape: {[4096,1,1]}
    		relu7_ = brew.relu(model, fc6_, fc6_)
    		dropout7_ = brew.dropout(model, relu7_, 'dropout7_', ratio=0.5, is_test=False)
    		fc7_ = brew.fc(model, dropout7_, 'fc7_', dim_in=4096, dim_out=256)
    		# fc7_, output shape: {[256,1,1]}
    		relu8_ = brew.relu(model, fc7_, fc7_)
    		dropout8_ = brew.dropout(model, relu8_, 'dropout8_', ratio=0.5, is_test=False)
    		relu9_ = brew.relu(model, dropout8_, dropout8_)
    		fc9_ = brew.fc(model, relu9_, 'fc9_', dim_in=256, dim_out=14)
    		# fc9_, output shape: {[14,1,1]}

            # FIRST ATTEMPT. Error got:  Tensor type mismatch, caller expects elements to be float while tensor contains double Error from operator
            dist = model.net.SquaredL2Distance([label, fc9_], 'dist')
    		predictions = dist.AveragedLoss([], ['predictions'])

            '''
            # SECOND ATTEMPT: Error got:
    		label_float = model.Cast(label, None, to=core.DataType.FLOAT)
    		dist = model.net.SquaredL2Distance([label_float, fc9_], 'dist')
    		predictions = dist.AveragedLoss([], ['predictions'])
            '''

    		return predictions

    # this adds the loss and optimizer
    def add_training_operators(self, model, output, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum) :
    	with core.DeviceScope(device_opts):
    		xent = model.LabelCrossEntropy([output, label], 'xent')
    		loss = model.AveragedLoss(xent, "loss")

    		model.AddGradientOperators([loss])

    		if opt_type == 'adam':
    		    if policy == 'step':
    		        opt = optimizer.build_adam(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, beta1=beta1, beta2=beta2, epsilon=epsilon)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_adam(model, base_learning_rate=base_learning_rate, policy=policy, beta1=beta1, beta2=beta2, epsilon=epsilon)
    		    print("adam optimizer selected")
    		elif opt_type == 'sgd':
    		    if policy == 'step':
    		        opt = optimizer.build_sgd(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, gamma=gamma, momentum=momentum)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_sgd(model, base_learning_rate=base_learning_rate, policy=policy, gamma=gamma, momentum=momentum)
    		    print("sgd optimizer selected")
    		elif opt_type == 'rmsprop':
    		    if policy == 'step':
    		        opt = optimizer.build_rms_prop(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, decay=gamma, momentum=momentum, epsilon=epsilon)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_rms_prop(model, base_learning_rate=base_learning_rate, policy=policy, decay=gamma, momentum=momentum, epsilon=epsilon)
    		    print("rmsprop optimizer selected")
    		elif opt_type == 'adagrad':
    		    if policy == 'step':
    		        opt = optimizer.build_adagrad(model, base_learning_rate=base_learning_rate, policy=policy, stepsize=stepsize, decay=gamma, epsilon=epsilon)
    		    elif policy == 'fixed' or policy == 'inv':
    		        opt = optimizer.build_adagrad(model, base_learning_rate=base_learning_rate, policy=policy, decay=gamma, epsilon=epsilon)
    		    print("adagrad optimizer selected")

    def add_accuracy(self, model, output, label, device_opts, eval_metric):
        with core.DeviceScope(device_opts):
            if eval_metric == 'accuracy':
                accuracy = brew.accuracy(model, [output, label], "accuracy")
            elif eval_metric == 'top_k_accuracy':
                accuracy = brew.accuracy(model, [output, label], "accuracy", top_k=3)
            return accuracy

    def train(self, num_epoch=1000, batch_size=64, context='gpu', eval_metric='accuracy', opt_type='adam', base_learning_rate=0.001, weight_decay=0.001, policy='fixed', stepsize=1, epsilon=1E-8, beta1=0.9, beta2=0.999, gamma=0.999, momentum=0.9) :
        if context == 'cpu':
            device_opts = core.DeviceOption(caffe2_pb2.CPU, 0)
            print("CPU mode selected")
        elif context == 'gpu':
            device_opts = core.DeviceOption(caffe2_pb2.CUDA, 0)
            print("GPU mode selected")

    	workspace.ResetWorkspace(self._model_dir_)

    	arg_scope = {"order": "NHWC"}
    	# == Training model ==
    	train_model= model_helper.ModelHelper(name="train_net", arg_scope=arg_scope)
    	data, label = self.add_input(train_model, batch_size=batch_size, db=os.path.join(self._data_dir_, 'torcs-train-nchw-lmdb'), db_type='lmdb', device_opts=device_opts)
    	predictions = self.create_model(train_model, data, label, device_opts=device_opts)
    	self.add_training_operators(train_model, predictions, label, device_opts, opt_type, base_learning_rate, policy, stepsize, epsilon, beta1, beta2, gamma, momentum)
    	self.add_accuracy(train_model, predictions, label, device_opts, eval_metric)
    	with core.DeviceScope(device_opts):
    		brew.add_weight_decay(train_model, weight_decay)

    	# Initialize and create the training network
    	workspace.RunNetOnce(train_model.param_init_net)
    	workspace.CreateNet(train_model.net, overwrite=True)

    	# Main Training Loop
    	print("== Starting Training for " + str(num_epoch) + " epochs ==")
    	for i in range(num_epoch):
            workspace.RunNet(train_model.net)

            if i % 50 == 0:
            	print 'Iter ' + str(i) + ': ' + 'Loss ' + str(workspace.FetchBlob("loss")) + ' - ' + 'Accuracy ' + str(workspace.FetchBlob('accuracy'))
    	print("Training done")

    	# == Deployment model. ==
    	# We simply need the main AddModel part.
    	deploy_model = model_helper.ModelHelper(name="deploy_net", arg_scope=arg_scope, init_params=False)
    	self.create_model(deploy_model, "data", label, device_opts)

    	print("Saving deploy model")
    	self.save_net(self.INIT_NET, self.PREDICT_NET, deploy_model)

    def save_net(self, init_net_path, predict_net_path, model):

    	init_net, predict_net = mobile_exporter.Export(
    		workspace,
    		model.net,
    		model.params
    	)

        try:
            os.makedirs(self._model_dir_)
        except OSError:
            if not os.path.isdir(self._model_dir_):
                raise

    	print("Save the model to init_net.pb and predict_net.pb")
    	with open(predict_net_path, 'wb') as f:
    		f.write(model.net._net.SerializeToString())
    	with open(init_net_path, 'wb') as f:
    		f.write(init_net.SerializeToString())

    	print("Save the model to init_net.pbtxt and predict_net.pbtxt")

    	with open(init_net_path.replace('.pb','.pbtxt'), 'w') as f:
    		f.write(str(init_net))
    	with open(predict_net_path.replace('.pb','.pbtxt'), 'w') as f:
    		f.write(str(predict_net))
    	print("== Saved init_net and predict_net ==")

    def load_net(self, init_net_path, predict_net_path, device_opts):
        if not os.path.isfile(init_net_path):
            logging.error("Network loading failure. File '" + os.path.abspath(init_net_path) + "' does not exist.")
            sys.exit(1)
        elif not os.path.isfile(predict_net_path):
            logging.error("Network loading failure. File '" + os.path.abspath(predict_net_path) + "' does not exist.")
            sys.exit(1)

        init_def = caffe2_pb2.NetDef()
    	with open(init_net_path, 'rb') as f:
    		init_def.ParseFromString(f.read())
    		init_def.device_option.CopyFrom(device_opts)
    		workspace.RunNetOnce(init_def.SerializeToString())

    	net_def = caffe2_pb2.NetDef()
    	with open(predict_net_path, 'rb') as f:
    		net_def.ParseFromString(f.read())
    		net_def.device_option.CopyFrom(device_opts)
    		workspace.CreateNet(net_def.SerializeToString(), overwrite=True)
    	print("== Loaded init_net and predict_net ==")

System

  • PyTorch Version (e.g., 1.0): Caffer2 tag v0.4.0
  • OS (e.g., Linux): Ubuntu 16.04
  • How you installed PyTorch (conda, pip, source): Build from source (tag v0.4.0)
  • Build command you used (if compiling from source):
  • Python version: Python 2.7
  • CUDA/cuDNN version: 8.0/7.0.5
  • GPU models and configuration: GTX 1050
  • Any other relevant information:

This issue looks similar to the one reported on AIX with RNN sample script for caffe2 - Issue with caffe2 port on to AIX

System Information:

  • PyTorch Version (e.g., 1.0): Caffer2 tag v0.4.0
  • OS (e.g., Linux): AIX
  • How you installed PyTorch (conda, pip, source): Build from source (tag v0.4.0)
  • Build command you used (if compiling from source):
  • Python version: Python 2.7
  • CUDA/cuDNN version: NA
  • GPU models and configuration: NA

Any pointers on this will be helpful.