Timm 加载本地 huggingface 模型
最近使用 Timm 自动加载在线 hf-hub 模型时,由于服务器存在网络限制 huggingface 无法正常连接,导致无法加载模型以及权重。解决办法就是本地电脑下载,再上传到服务器。
以下载 huggingface.co/MahmoodLab/UNI 为例。
import timm
from timm.data import resolve_data_config
from timm.data.transforms_factory import create_transform
from huggingface_hub import loginlogin() # login with your User Access Token, found at https://huggingface.co/settings/tokens# pretrained=True needed to load UNI weights (and download weights for the first time)
# init_values need to be passed in to successfully load LayerScale parameters (e.g. - block.0.ls1.gamma)
model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True)
transform = create_transform(**resolve_data_config(model.pretrained_cfg, model=model))
model.eval()
本地下载
# 在可联网的机器运行,确保模型缓存
from huggingface_hub import snapshot_download# 指定存储路径
download_path = "D:/Research/pre_training_models"models = ["MahmoodLab/uni"
]
for repo in models:snapshot_download(repo_id=repo),
cache_dir=download_path
下载时可以存在一些获取 hf token 和对应模型库的邮箱认证等问题,可以自行 AI 获取解决步骤。
上传到服务器指定缓存目录
将文件复制到你的服务器 [用户名]/.cache/huggingface/hub
中,尝试复制到其他的路径发现 timm.create_model
无法正确识别,尽量还是放在该目录下。
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def load_model(model_name, device):if model_name == 'UNI':model = timm.create_model("hf-hub:MahmoodLab/uni", pretrained=True, init_values=1e-5, dynamic_img_size=True) # PMID:38504018else:raise NotImplementedError(f'Model {model_name} not implemented !')return model.to(device).eval()uni_model = load_model('UNI', device)
uni_transform = create_transform(**resolve_data_config(uni_model.pretrained_cfg, model=uni_model))