什么是 PyTorch Hub?
PyTorch Hub 是一个旨在促进研究可复现性和发现预训练模型的中央平台。它提供了一个简单统一的 API,用于探索和使用由社区和研究人员发布的各种模型。其主要目标是让开发人员和研究人员能够轻松加载和实验最先进的模型,而无需深入研究每个模型的实现细节。
主要特点
- 简单的 API: 使用
torch.hub.load()函数,通过一个命令即可从 hub 加载任何模型。 - 模型发现: 一个精选的模型集合,涵盖计算机视觉、自然语言处理等多种任务。
- 可复现性: hub 上的模型与其依赖项和预训练权重一起发布,确保结果可以轻松复现。
- 无缝集成: 直接在 PyTorch 生态系统内工作,方便对模型进行修改、微调和部署。
- 发布机制: 研究人员可以轻松地将自己的模型发布到 hub,使其工作能够被更广泛的受众所接触。
应用场景
- 迁移学习: 快速加载像 ResNet 或 BERT 这样的预训练模型,并在自定义数据集上针对特定任务进行微调。
- 快速原型开发: 利用强大的现成模型测试想法并构建概念验证应用。
- 基准测试: 在标准化任务上轻松比较不同模型的性能。
- 教育工具: 一个很好的资源,用于学习最先进的模型是如何构建和在实践中使用的。
入门指南
开始使用 PyTorch Hub 非常简单。您只需要安装 PyTorch。这是一个加载用于图像分类的预训练 ResNet18 模型的“Hello World”示例。
```python import torch
从 PyTorch Hub 加载预训练的 ResNet18 模型
model = torch.hub.load(‘pytorch/vision:v0.10.0’, ‘resnet18’, pretrained=True)
将模型设置为评估模式
model.eval()
创建一个虚拟输入张量(例如,用于一个 224x224 的图像)
在实际应用中,您会在这里加载并预处理您的图像
dummy_input = torch.randn(1, 3, 224, 224)
获取模型的预测
with torch.no_grad(): output = model(dummy_input)
输出包含每个类别的原始分数
print(output.shape)
预期输出: torch.Size([1, 1000])
定价
PyTorch Hub 是开源 PyTorch 框架的一个组成部分。在 PyTorch 的宽松 BSD 风格许可下,它完全免费,可用于学术和商业目的。