PyTorchModelHubMixin: Bridging the Gap for Custom AI Models on Hugging Face
Introduction
Hugging Face offers a series of useful features including free hosting of AI models, version control, monthly download metrics for users.
For people working with their own custom AI model architectures, accessing these features can pose a significant hurdle. This is where PyTorchModelHubMixin comes into play, seamlessly enabling the saving, loading, and pushing of these models to the hub.
In this blogpost we will investigate how to leverage PyTorchModelHubMixin to build your own Hugging Face integrated libraries using pxia as a reference.
PyTorchModelHubMixin: Key Features and Capabilities
While most famous AI models are already part of the transformers library, some new emerging models are using the mixin class including lerobot , ibm/biomed.sm.mv-te-84m , Salesforce/moirai-1.1-R-large, nvidia/domain-classifier , ZhengPeng7/BiRefNet, ...
Although the Transformers library offers an extensive suite of features, PyTorchModelHubMixin strikes a balance, offering a more curated set of functionalities in exchange for greater flexibility. For a detailed comparison, we've compiled a comprehensive table below using pxia as reference of what you can accomplish using PyTorchModelHubMixin
- โ part of the library
- ๐ก not part of the library but that feature can be added
- โ cannot be added to the library
feature | PyTorchModelHubMixin | pxia | transformers | comments |
---|---|---|---|---|
push_to_hub , save_pretrained , from_pretrained |
โ | โ | โ | both libraries have the same methods and parameters |
version control and experiment tracking | โ | โ | โ | If you are a researcher this might be the most interesting feature for you since you can push your model to the hub and load it using model.from_pretrained("repo_id", revision"branch_pr_or_sha") |
automatic model sharding | โ | โ | โ | huggingface has an upload limit of 50Gb per file so sharding is a crucial step in uploading your model weights to the hub, this is a predefined feature in PyTorchModelHubMixin |
monthly download metrics | โ | โ | โ | |
inference api support | โ | โ | โ | for security reasons and since there is no unified way to run every model, the inference api is not allowed for custom models in general |
custom input type | โ | โ | โ | transformers allows for input to be only in a dataclass format while the mixin allows for any type of input type, but the user needs to add the logic on how to serialize and deserialize the data in the coders parameter |
trainer API support | ๐ก | โ | โ | the user needs to make some slight changes to the forward method first |
code snippet | ๐ก | โ | โ | for mixin, you can add your code snippet to the hub by making a pr to huggingface.js to add a button at the top right corner to let people how to load your model (see pxia example here ) |
automodel support | ๐ก | โ | โ | If you have defined multiple AI models in the same repository and want, you might make use of the tags you added in the readme file to figure out which architecture you're using (see example), this is the easiest way I came up with since the transformers library stores the model architecture in a config.json file which is not easy to handle. |
pip integration | ๐ก | โ | โ | This can be done by setting a pip-compatible template, you can refer to the following minimalistic repo or pxia's implementation |
peft, quantization, etc ... | ๐ก | ๐ก | โ | In a sense the mixin allows users to add whatever features they want, the sky is the limit as long as the init method respects certain conditions defined below you can do whatever you want with the rest. |
Usage
PyTorchModelHubMixin usage is very simple, all you have to do is incorporate it into your model's class inheritance. Consider the following minimalistic example for reference :
from torch import nn
+from huggingface_hub import PyTorchModelHubMixin
class ANN(nn.Module,
+ PyTorchModelHubMixin):
def __init__(self,a,b):
super().__init__()
self.layer = nn.Linear(a,b, bias=False)
def forward(self,inputs):
return self.layer(inputs)
That is all you need to do :)
If you want to include some extra metadata that will be added to the readme file you can add them in the inheritance as defined in the documentation
class MyModel(
nn.Module,
PyTorchModelHubMixin,
library_name="keras-nlp",
repo_url="https://github.com/keras-team/keras-nlp",
docs_url="https://keras.io/keras_nlp/",
tags = ["demo","architecture","tensorflow"]
# ^ optional metadata to generate model card
):
Pushing and loading the model
By inheriting from PyTorchModelHubMixin
, your model will inherit three convenient methods: save_pretrained
, from_pretrained
, and push_to_hub
.
These methods function similarly to those found in the Transformers library, enabling seamless model sharing and loading. Let's dive into a practical example to illustrate how to use the mixin in a real-world scenario.
1- define your model architecture
from torch import nn
from huggingface_hub import PyTorchModelHubMixin
class ANN(nn.Module,
PyTorchModelHubMixin):
def __init__(self,a,b):
super().__init__()
self.layer = nn.Linear(a,b, bias=False)
def forward(self,inputs):
return self.layer(inputs)
2- initialize and push your model to the hub
you can also use save_pretrained
to save the model locally
model = ANN(1,2)
model.push_to_hub("repo_id",token=TOKEN)
3- load the model in here you can load the model using the ModelClass directly, there is no need to reinitialize your model again.
# No init parameters, no manual initialization or trying to remember the __init__ parameters
HF_model = ANN.from_pretrained("repo_id_or_path") # you can also pass your revision parameter here to load a specific sha,branch,pull of the model
from_pretrained
also works if the model is saved locally
You also might have noticed that in step (3) we are not passing any init parameters, this is because all of our parameters are captured in step (2) and saved in a config.json
file. In short, you only need to define your init parameters once.
Edge cases ๐ ๏ธ
In this section, we'll outline essential best practices and considerations for developing a Hugging Face integrated AI model.
To ensure your model functions as intended and integrates smoothly, keep the following two key recommendations in mind:
- Avoid loading any local files in the
__init__
method
import torch
from torch import nn
from huggingface_hub import PyTorchModelHubMixin
weights="folder/weights.pt"
class ANN(nn.Module,
PyTorchModelHubMixin):
def __init__(self,a,b):
super().__init__()
self.layer = nn.Linear(a,b, bias=False)
- self.load_state_dict(torch.load(weights))
def forward(self,inputs):
return self.layer(inputs)
model = ANN(1,2)
+model.load_state_dict(torch.load(weights))
- define a
coders
parameter if you're working with an unserializable input type
serializable
__init__
parameters can be found here
if your input parameters are not supported by the mixin you can add a logic for serialization and deserialization in the coders
parameter (this will define how we can store these parameters in the config.json
file)
import torch.nn as nn
from huggingface_hub import PyTorchModelHubMixin
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
default_conf = OmegaConf.create({"a": 2, "b": 1})
# converts parameter to python dict
+def serialize(x):
+ return OmegaConf.to_container(x)
# when loading the model we convert the init parameter back to its original type
+def deserialize(x):
+ return OmegaConf.create(x)
class ANN(
nn.Module,
PyTorchModelHubMixin,
+ coders={
+ DictConfig: (
+ lambda x: serialize(x),
+ lambda data: deserialize(data),
+ )
+ },
):
# infront of the parameter you need to define the input type beforehand so we can figure out
# which serialization logic goes to which parameter
def __init__(self,
- cfg = default_conf):
+ cfg:DictConfig = default_conf):
super().__init__()
self.layer = nn.Linear(cfg.a, cfg.b, bias=False)
def forward(self, inputs):
return self.layer(inputs)
Future work
For a comprehensive overview of the features mentioned in this blogpost refer to pxia. Pxia is an AI library designed as a template to aid users in developing their own projects through real-world examples.
If you'd like to contribute to features such as PEFT, quantization, or any other functionality you think would be valuable to users, please submit a pull request or open an issue in the GitHub repository.
If you consider this blogpost helpful consider upvoting โค๏ธใพ(โงโฝโฆ*)o