Compiling Pytorch with CUDA 9.1 inside a docker container

Hi

I’m trying to compile pytorch inside a docker container - it compiles successfully but when I use the image created to train a model - it crashes with the following stacktrace (this does not happen when I run it with the binary cu91/torch-0.3.1-cp36-cp36m-linux_x86_64.whl):

    model = model.cuda(cuda_device)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 216, in cuda
    return self._apply(lambda t: t.cuda(device))
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 146, in _apply
    module._apply(fn)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 146, in _apply
    module._apply(fn)
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 123, in _apply
    self.flatten_parameters()
  File "/opt/conda/lib/python3.6/site-packages/torch/nn/modules/rnn.py", line 111, in flatten_parameters
    params = rnn.get_parameters(fn, handle, fn.weight_buf)
  File "/opt/conda/lib/python3.6/site-packages/torch/backends/cudnn/rnn.py", line 165, in get_parameters
    assert filter_dim_a.prod() == filter_dim_a[0]
AssertionError

I’ve copied my docker file below for reference:

FROM nvidia/cuda:9.1-cudnn7-devel-ubuntu16.04

RUN apt-get update && apt-get install -y --no-install-recommends \
		build-essential \
		cmake \
		git \
		curl \
		vim \
		ca-certificates \
		libjpeg-dev \
		libpng-dev && \
	rm -rf /var/lib/apt/lists/

RUN curl -o ~/miniconda.sh -O  https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh  && \
	chmod +x ~/miniconda.sh && \
	~/miniconda.sh -b -p /opt/conda && \
	rm ~/miniconda.sh && \
	/opt/conda/bin/conda install numpy pyyaml scipy ipython mkl && \
	/opt/conda/bin/conda install -c soumith magma-cuda91 && \
	/opt/conda/bin/conda clean -ya
ENV PATH /opt/conda/bin:$PATH

WORKDIR /opt
RUN git clone --recursive --single-branch -b v0.3.1 https://github.com/pytorch/pytorch
WORKDIR /opt/pytorch
RUN CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" pip install -v .

Can anyone help me?

Thanks!