[조각글] Poetry로 jax, pytorch 설치

jax의 경우 공식 github를 보면 pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html로 설치하라고 되어있는데, poetry add 명령어에는 -f(--find-links)에 직접적으로 대응되는 인자가 없다. 대신 비슷한 역할을 하는 --source 인자는 있다.

jax 설치 과정에 대한 부연설명을 하자면, https://storage.googleapis.com/jax-releases/jax_cuda_releases.html와 같은 링크(물론, cpu/tpu 버전은 링크가 다르다)에서 jaxlib를 다운받고, jax를 pypi에서 다운받는 원리이다. 그래서 단순히 source add explicit으로 링크를 추가하고 add --source를 사용하는 것만으로는 한번에 jax와 jaxlib를 모두 설치할 수 없다. 그래서 jaxlib가 있는 url을 supplemental priority로 놓고, jax를 설치해야 한다.

아래 명령어처럼 입력하면 된다. pytorch의 경우 poetry 버전과 필요한 cpu/cuda 여부에 따라 바꿔서 실행하면 된다.

poetry source add -p supplemental jaxlib https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
poetry add "jax[cuda11_pip]"
poetry source add pypi
# 마지막 줄이 필요한 이유: Warning: In a future version of Poetry, PyPI will be disabled automatically if at least one custom source is configured with another priority than 'explicit'. In order to avoid a breaking change and make your pyproject.toml forward compatible, add PyPI explicitly via 'poetry source add pypi'. By the way, this has the advantage that you can set the priority of PyPI as with any other source.

# poetry<1.5에서는 secondary, 1.5부터는 explicit 사용
# pytorch url의 경우, cuda/cpu 중 필요한 것으로 변경
# cpu url: https://download.pytorch.org/whl/cpu
poetry source add -p explicit pytorch https://download.pytorch.org/whl/cu118
poetry add --source pytorch torch torchvision torchaudio