xref: /aosp_15_r20/external/pytorch/.ci/docker/common/install_triton.sh (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/bin/bash
2
3set -ex
4
5source "$(dirname "${BASH_SOURCE[0]}")/common_utils.sh"
6
7get_conda_version() {
8  as_jenkins conda list -n py_$ANACONDA_PYTHON_VERSION | grep -w $* | head -n 1 | awk '{print $2}'
9}
10
11conda_reinstall() {
12  as_jenkins conda install -q -n py_$ANACONDA_PYTHON_VERSION -y --force-reinstall $*
13}
14
15if [ -n "${XPU_VERSION}" ]; then
16  TRITON_REPO="https://github.com/intel/intel-xpu-backend-for-triton"
17  TRITON_TEXT_FILE="triton-xpu"
18else
19  TRITON_REPO="https://github.com/openai/triton"
20  TRITON_TEXT_FILE="triton"
21fi
22
23# The logic here is copied from .ci/pytorch/common_utils.sh
24TRITON_PINNED_COMMIT=$(get_pinned_commit ${TRITON_TEXT_FILE})
25
26if [ -n "${UBUNTU_VERSION}" ];then
27    apt update
28    apt-get install -y gpg-agent
29fi
30
31if [ -n "${CONDA_CMAKE}" ]; then
32  # Keep the current cmake and numpy version here, so we can reinstall them later
33  CMAKE_VERSION=$(get_conda_version cmake)
34  NUMPY_VERSION=$(get_conda_version numpy)
35fi
36
37if [ -z "${MAX_JOBS}" ]; then
38    export MAX_JOBS=$(nproc)
39fi
40
41# Git checkout triton
42mkdir /var/lib/jenkins/triton
43chown -R jenkins /var/lib/jenkins/triton
44chgrp -R jenkins /var/lib/jenkins/triton
45pushd /var/lib/jenkins/
46
47as_jenkins git clone ${TRITON_REPO} triton
48cd triton
49as_jenkins git checkout ${TRITON_PINNED_COMMIT}
50cd python
51
52# TODO: remove patch setup.py once we have a proper fix for https://github.com/triton-lang/triton/issues/4527
53as_jenkins sed -i -e 's/https:\/\/tritonlang.blob.core.windows.net\/llvm-builds/https:\/\/oaitriton.blob.core.windows.net\/public\/llvm-builds/g' setup.py
54
55if [ -n "${UBUNTU_VERSION}" ] && [ -n "${GCC_VERSION}" ] && [[ "${GCC_VERSION}" == "7" ]]; then
56  # Triton needs at least gcc-9 to build
57  apt-get install -y g++-9
58
59  CXX=g++-9 pip_install -e .
60elif [ -n "${UBUNTU_VERSION}" ] && [ -n "${CLANG_VERSION}" ]; then
61  # Triton needs <filesystem> which surprisingly is not available with clang-9 toolchain
62  add-apt-repository -y ppa:ubuntu-toolchain-r/test
63  apt-get install -y g++-9
64
65  CXX=g++-9 pip_install -e .
66else
67  pip_install -e .
68fi
69
70if [ -n "${CONDA_CMAKE}" ]; then
71  # TODO: This is to make sure that the same cmake and numpy version from install conda
72  # script is used. Without this step, the newer cmake version (3.25.2) downloaded by
73  # triton build step via pip will fail to detect conda MKL. Once that issue is fixed,
74  # this can be removed.
75  #
76  # The correct numpy version also needs to be set here because conda claims that it
77  # causes inconsistent environment.  Without this, conda will attempt to install the
78  # latest numpy version, which fails ASAN tests with the following import error: Numba
79  # needs NumPy 1.20 or less.
80  conda_reinstall cmake="${CMAKE_VERSION}"
81  # Note that we install numpy with pip as conda might not have the version we want
82  pip_install --force-reinstall numpy=="${NUMPY_VERSION}"
83fi
84