windows安装jax和jaxlib的教程(cuda)成功安装
本文你将解决3个问题:1、
jaxlib
没有安装的问题;2、python3.9
以上(不可忽略)、cuda12.1
(可忽略)以上配置要求不满足的问题;3、numpy版本太高的问题。
1、问题描述
当你直接pip install jax
或者conda install jax
后,执行以下代码检查是否错误:
import jax
print(jax.devices()) # 应输出类似 [gpu(id=0)]
总是会报错:ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.
出现该问题的原因是没有安装jaxlib。jaxlib只支持python3.9以上版本,且需要手动安装(直接用
pip install jaxlib
会报错)
ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
ERROR: No matching distribution found for jaxlib
2、解决办法
下面有2种情况,按照你的Windows电脑是否需要cuda来选择对应的教程。
情况1
,你不需要GPU加速,即不用显卡cuda,那么只需要执行以下2步:
1、在虚拟环境中,在python3.9及以上的版本安装jax库,如 pip install jax
或者conda install jax
,可以指定版本,这些就和一般的安装库那样。
2、下载jaxlib
的文件,并手动安装。在https://storage.googleapis.com/jax-releases/jax_releases.html 地址中,键盘快捷键"ctrl + F"
搜索"win"
找到对应python版本的jaxlib文件,jaxlib的版本自行测试吧。将其下载在本地任意文件夹中,然后像一般安装那样,在你的虚拟环境中安装此文件。
情况2
,你需要GPU加速,并且有自己的显卡cuda,而且已经配置了一个cuda11(或者以下的版本;如果你是cuda12及以上的版本,同样按照下面第2个步骤执行),那么只需要执行以下2步:
1、先安装cuda12(12.1以上的版本,必要的操作,不能跳过;无需卸载之前的cuda版本,多个版本的cuda可以共存),具体教程见以下两个教程(如果链接失效,请到我的csdn主页查找同名教程):
a. cuda 安装两个版本 https://blog.csdn.net/AdamCY888/article/details/147516608
b. 驱动支持的最高CUDA版本与实际安装的Runtime版本 https://blog.csdn.net/AdamCY888/article/details/147516543
(截图来自jax教程:https://jax.net.cn/en/latest/installation.html#installation)
2、上面步骤1确保你已经有一个12.1以上版本的cuda。
a. 下载jax:pip install -U "jax[cuda12]"
, 注意,引号不能省略,且建议不指定其jax
版本。
b. 接下来同前面情况1的步骤2一样,下载jaxlib
的whl
文件。自行对应相应的版本。
3、测试jax对应jaxlib的版本
由于并没有找到jax对应jaxlib的版本,于是就安装一个最低版本的jaxlib 0.4.13,按照其报错提示,来得到满足的版本。正确的对应关系是:jax 0.4.21
对应的 jaxlib 0.4.19
;如果安装的其它版本,也可以通过这个方法来解决。
RuntimeError: jaxlib is version 0.4.13, but this version of jax requires version >= 0.4.19.
于是,重新在 https://storage.googleapis.com/jax-releases/jax_releases.html 下载"jaxlib 0.4.19"
,并安装。
接下来进一步测试以下程序:
import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = jnp.arange(5.0)
print(selu(x))
报错:
A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.5 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.Traceback (most recent call last): File "d:\Anaconda\envs\jax_cuda12\lib\runpy.py", line 196, in _run_module_as_mainreturn _run_code(code, main_globals, None,File "d:\Anaconda\envs\jax_cuda12\lib\runpy.py", line 86, in _run_codeexec(code, run_globals)...
报错的原因是NumPy版本太高,需要降低版本。执行以下代码即可解决:
# 在虚拟环境中执行
conda activate jax_cuda12
pip uninstall numpy -y
pip install numpy==1.24.4 # 选择广泛兼容的1.x版本
4、安装成功!
import jax
print(jax.devices()) # 应输出类似 [gpu(id=0)]import jax.numpy as jnp
那么,接下来,请享受你的加速计算吧。
import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = jnp.arange(5.0)
print(selu(x))
联系我
如果你在Windows系统下安装jax
过程中,有任何困难,请留言或者私信,我将定期回复。
- jax备忘录 https://blog.csdn.net/AdamCY888/article/details/147402803