Windows에서 JAX 컴파일하기

Windows에서 JAX 컴파일하기

Windows에서 JAX를 직접 컴파일해 GPU도 사용 가능하도록 설치하는 법을 다룬다.

컴파일 시간은 10700 CPU 기준 30분 정도 걸렸으니 컴파일 작업 걸어두고 밥이라도 먹고 오자.

기본적으로 아래의 공식 documentation을 참고하면 된다.

Building from source — JAX documentation
git clone https://github.com/google/jax
cd jax

C++ build tools 설치

chocolatey를 통해 C++ build tools를 설치한다.

choco install visualstudio2017-workload-vctools

CUDA, cudnn 설치

cuda(2021-12-28 기준 11.5 버전)을 설치한다. choco install cuda

cudnn을 설치한다. 구글에 cudnn 검색하고 설치하면 2021-12-28 기준 8.3.1 버전이 나온다. 해당 파일을 다운로드하고(로그인 필요) 압축 풀어서 C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.5같은 경로에 복붙.

https://developer.nvidia.com/rdp/cudnn-download

파이썬 가상환경 준비

anaconda3으로 python3.8을 준비한다. 2021-12-28 기준 컴파일하면 cp38 wheel만 나오니 참고하자.

bazel 설치

choco install bazel

msys2 설치

choco install msys2

이후 msys2에서

pacman -S patch coreutils

공식 튜토리얼에는 위 명령어를 실행하라고 나와 있긴 한데.. 근데 아래 나올 명령어를 git bash가 아니라 msys2에서 돌린다는 가정 하에 써있는 설명인가? 아무튼 잘 모르겠다.

jaxlib 컴파일, 설치

아래 명령어는 Git bash를 관리자 권한으로 열고 jax 경로에서 실행한다. 관리자 권한으로 안 열면 symlink 생성이 안 되더라..

python build/build.py --enable_cuda --cuda_version=11.5 --cudnn_version=8.3.1

pip install dist/*.whl

혹시나 해서 휠파일을 첨부해보는데, 저것만 가지고는 설치가 끝까지 잘 될지는 모르겠다.

jax 설치

pip install -e .  # installs jax

설치 확인

이렇게 jax.devices()를 입력했을 때 Gpu가 나오면 성공인 듯하다.