1# syntax=docker/dockerfile:1 2 3# NOTE: Building this image require's docker version >= 23.0. 4# 5# For reference: 6# - https://docs.docker.com/build/dockerfile/frontend/#stable-channel 7 8ARG BASE_IMAGE=ubuntu:22.04 9ARG PYTHON_VERSION=3.11 10 11FROM ${BASE_IMAGE} as dev-base 12RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 13 build-essential \ 14 ca-certificates \ 15 ccache \ 16 cmake \ 17 curl \ 18 git \ 19 libjpeg-dev \ 20 libpng-dev && \ 21 rm -rf /var/lib/apt/lists/* 22RUN /usr/sbin/update-ccache-symlinks 23RUN mkdir /opt/ccache && ccache --set-config=cache_dir=/opt/ccache 24ENV PATH /opt/conda/bin:$PATH 25 26FROM dev-base as conda 27ARG PYTHON_VERSION=3.11 28# Automatically set by buildx 29ARG TARGETPLATFORM 30# translating Docker's TARGETPLATFORM into miniconda arches 31RUN case ${TARGETPLATFORM} in \ 32 "linux/arm64") MINICONDA_ARCH=aarch64 ;; \ 33 *) MINICONDA_ARCH=x86_64 ;; \ 34 esac && \ 35 curl -fsSL -v -o ~/miniconda.sh -O "https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-${MINICONDA_ARCH}.sh" 36COPY requirements.txt . 37# Manually invoke bash on miniconda script per https://github.com/conda/conda/issues/10431 38RUN chmod +x ~/miniconda.sh && \ 39 bash ~/miniconda.sh -b -p /opt/conda && \ 40 rm ~/miniconda.sh && \ 41 /opt/conda/bin/conda install -y python=${PYTHON_VERSION} cmake conda-build pyyaml numpy ipython && \ 42 /opt/conda/bin/python -mpip install -r requirements.txt && \ 43 /opt/conda/bin/conda clean -ya 44 45FROM dev-base as submodule-update 46WORKDIR /opt/pytorch 47COPY . . 48RUN git submodule update --init --recursive 49 50FROM conda as build 51ARG CMAKE_VARS 52WORKDIR /opt/pytorch 53COPY --from=conda /opt/conda /opt/conda 54COPY --from=submodule-update /opt/pytorch /opt/pytorch 55RUN make triton 56RUN --mount=type=cache,target=/opt/ccache \ 57 export eval ${CMAKE_VARS} && \ 58 TORCH_CUDA_ARCH_LIST="7.0 7.2 7.5 8.0 8.6 8.7 8.9 9.0 9.0a" TORCH_NVCC_FLAGS="-Xfatbin -compress-all" \ 59 CMAKE_PREFIX_PATH="$(dirname $(which conda))/../" \ 60 python setup.py install 61 62FROM conda as conda-installs 63ARG PYTHON_VERSION=3.11 64ARG CUDA_PATH=cu121 65ARG CUDA_CHANNEL=nvidia 66ARG INSTALL_CHANNEL=whl/nightly 67# Automatically set by buildx 68RUN /opt/conda/bin/conda update -y -n base -c defaults conda 69RUN /opt/conda/bin/conda install -y python=${PYTHON_VERSION} 70 71ARG TARGETPLATFORM 72 73# INSTALL_CHANNEL whl - release, whl/nightly - nightly, whle/test - test channels 74RUN case ${TARGETPLATFORM} in \ 75 "linux/arm64") pip install --extra-index-url https://download.pytorch.org/whl/cpu/ torch torchvision torchaudio ;; \ 76 *) pip install --index-url https://download.pytorch.org/${INSTALL_CHANNEL}/${CUDA_PATH#.}/ torch torchvision torchaudio ;; \ 77 esac && \ 78 /opt/conda/bin/conda clean -ya 79RUN /opt/conda/bin/pip install torchelastic 80RUN IS_CUDA=$(python -c 'import torch ; print(torch.cuda._is_compiled())'); \ 81 echo "Is torch compiled with cuda: ${IS_CUDA}"; \ 82 if test "${IS_CUDA}" != "True" -a ! -z "${CUDA_VERSION}"; then \ 83 exit 1; \ 84 fi 85 86FROM ${BASE_IMAGE} as official 87ARG PYTORCH_VERSION 88ARG TRITON_VERSION 89ARG TARGETPLATFORM 90ARG CUDA_VERSION 91LABEL com.nvidia.volumes.needed="nvidia_driver" 92RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \ 93 ca-certificates \ 94 libjpeg-dev \ 95 libpng-dev \ 96 && rm -rf /var/lib/apt/lists/* 97COPY --from=conda-installs /opt/conda /opt/conda 98RUN if test -n "${TRITON_VERSION}" -a "${TARGETPLATFORM}" != "linux/arm64"; then \ 99 DEBIAN_FRONTEND=noninteractive apt install -y --no-install-recommends gcc; \ 100 rm -rf /var/lib/apt/lists/*; \ 101 fi 102ENV PATH /opt/conda/bin:$PATH 103ENV NVIDIA_VISIBLE_DEVICES all 104ENV NVIDIA_DRIVER_CAPABILITIES compute,utility 105ENV LD_LIBRARY_PATH /usr/local/nvidia/lib:/usr/local/nvidia/lib64 106ENV PATH /usr/local/nvidia/bin:/usr/local/cuda/bin:$PATH 107ENV PYTORCH_VERSION ${PYTORCH_VERSION} 108WORKDIR /workspace 109 110FROM official as dev 111# Should override the already installed version from the official-image stage 112COPY --from=build /opt/conda /opt/conda 113