Using Hugging Face Transformer models without downloading the pretrained weights

nlp
code
Author

Hamish Dickson

Published

March 17, 2024

Imagine you have been given the task of classifying product descriptions into categories. You have the data and compute to do what you need and because you are awesome at your job you do this in record time.

The Simple Case

You will probably be using Hugging Face’s Transformers library for this and you will probably start off with a simple model which looks like the following:

import transformers

model_name = "microsoft/deberta-v3-xsmall"
where_my_weights_live = "my_save_location"

# create your tokenizer and model from the pretrained weights
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

# let's pretend there are 10 labels
model = transformers.AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=10)

# ... train your model ...

# finally save everything - this is all you need to do
tokenizer.save_pretrained(where_my_weights_live)
model.save_pretrained(where_my_weights_live)

Now it’s time to put this in production and the engineering team asks for you help to make sure this is done properly. What does this look like?

Well this is quite easy, you can basically just do this

import transformers

where_my_weights_live = "my_save_location"

tokenizer = transformers.AutoTokenizer.from_pretrained(where_my_weights_live)
model = transformers.AutoModelForSequenceClassification.from_pretrained(where_my_weights_live)

It’s easy to read, efficient and everyone is happy. Good job!

A More Realistic Example

In reality while the default models are very good you are likely to want to modify them. For example, imagine we want to insert some product features into our model. We would have to declare our own model in this case.

It could look a little like this

import torch
import transformers

class ProductClassifier(torch.nn.Module):
    def __init__(self, model_name, num_features, num_labels):
        super(ProductClassifier, self).__init__()
        
        self.model = transformers.AutoModel.from_pretrained(model_name)

        self.fc_features = torch.nn.Linear(n_features, n_features)

        model_hidden_size = self.model.config.hidden_size
        self.fc_out = torch.nn.Linear(model_hidden_size + num_features, num_labels)

    def forward(self, inputs, features):
        model_outputs = self.model(**inputs)
        model_outputs = model_outputs.last_hidden_state[:, 0, :]

        feature_outputs = self.fc_features(features)

        outputs = torch.cat([model_outputs, feature_outputs], dim=1)

        return self.fc_out(outputs)

You would have to save this a little differently than the default models, it would look a bit like this

# declare our models and train our code

where_my_weights_live = "my_save_location"

tokenizer.save_pretrained(where_my_weights_live)

torch.save(model.state_dict(), f"{where_my_weights_live}/pytorch_model.bin")

You can no longer just do model.save_pretrained instead we have to save the state_dict (ie the model weights).

Now, how do you use this?

Well, what’s the most obvious way to do this? It’s probably something like this?

model_name = ...   # same as what we trained
num_features = ... # same as what we trained
num_labels = ...   # same as what we trained

# pytorch doesn't save the model definition, so we have to declare it
class ProductClassifier(torch.nn.Module):
    def __init__(self, model_name, num_features, num_labels):
        super(ProductClassifier, self).__init__()
        
        self.model = transformers.AutoModel.from_pretrained(model_name)

        self.fc_features = torch.nn.Linear(n_features, n_features)

        model_hidden_size = self.model.config.hidden_size
        self.fc_out = torch.nn.Linear(model_hidden_size + num_features, num_labels)

    def forward(self, inputs, features):
        model_outputs = self.model(**inputs)
        model_outputs = model_outputs.last_hidden_state[:, 0, :]

        feature_outputs = self.fc_features(features)

        outputs = torch.cat([model_outputs, feature_outputs], dim=1)

        return self.fc_out(outputs)

# now we can load the weights
model = ProductClassifier(model_name, num_features, num_labels)
model.load_state_dict(torch.load(where_my_weights_live, map_location=torch.device('cpu')))
model.eval() # don't forget this!

This will probably work but something subtle happens — this will download the original model weights from Hugging Face!

self.model = transformers.AutoModel.from_pretrained(model_name)

This line in our is the problematic line. It will download the weights from hugging face!

This isn’t something that’s going to break your code but it’s going to make it very slow to initialise and also you are downloading some weights that you will never actually use.

Instead, can we initialise the model without downloading these weights?

Using from_config

We can do this

model_name = ...   # same as what we trained
num_features = ... # same as what we trained
num_labels = ...   # same as what we trained

# pytorch doesn't save the model definition, so we have to declare it
class ProductClassifier(torch.nn.Module):
    def __init__(self, model_name, num_features, num_labels):
        super(ProductClassifier, self).__init__()
        self.model_config = transformers.AutoConfig.from_pretrained(model_name)
        self.model = transformers.AutoModel.from_config(config=self.model_config)

        self.fc_features = torch.nn.Linear(n_features, n_features)

        model_hidden_size = self.model.config.hidden_size
        self.fc_out = torch.nn.Linear(model_hidden_size + num_features, num_labels)

    def forward(self, inputs, features):
        model_outputs = self.model(**inputs)
        model_outputs = model_outputs.last_hidden_state[:, 0, :]

        feature_outputs = self.fc_features(features)

        outputs = torch.cat([model_outputs, feature_outputs], dim=1)

        return self.fc_out(outputs)

# now we can load the weights
model = ProductClassifier(model_name, num_features, num_labels)
model.load_state_dict(torch.load(where_my_weights_live, map_location=torch.device('cpu')))
model.eval() # don't forget this!

Here instead of using AutoModel.from_pretrained we are using AutoModel.from_config this will create a model from the config without downloading the pretrained weights, this is much better

But in this case we are still going to the internet and downloading the config. It would be better to do this without going to internet

Ideally we would change our training setup so it looked more like this

import torch
import transformers

model_name = ...
num_features = ...
num_labels = ...

model_config = transformers.AutoConfig.from_pretrained(model_name)
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)

class ProductClassifier(torch.nn.Module):
    def __init__(self, model_config, num_features, num_labels):
        super(ProductClassifier, self).__init__()
        self.model = transformers.AutoModel.from_config(config=model_config)

        self.fc_features = torch.nn.Linear(n_features, n_features)

        model_hidden_size = self.model.config.hidden_size
        self.fc_out = torch.nn.Linear(model_hidden_size + num_features, num_labels)

    def forward(self, inputs, features):
        model_outputs = self.model(**inputs)
        model_outputs = model_outputs.last_hidden_state[:, 0, :]

        feature_outputs = self.fc_features(features)

        outputs = torch.cat([model_outputs, feature_outputs], dim=1)

        return self.fc_out(outputs)


model = ProductClassifier(model_config, num_features, num_labels)

# train the model

# now you have to save one more thing, the model config
model_config.save_pretrained(where_my_weights_live)
tokenizer.save_pretrained(where_my_weights_live)
torch.save(model.state_dict(), f"{where_my_weights_live}/pytorch_model.bin")

Finally our production code is just this!

where_my_weights_live = ...


tokenizer = transformers.AutoTokenizer.from_pretrained(where_my_weights_live)
model_config = transformers.AutoConfig.from_pretrained(where_my_weights_live)


model = ProductClassifier(model_config, num_features, num_labels)
model.load_state_dict(torch.load(where_my_weights_live, map_location=torch.device('cpu')))
model.eval() # don't forget this!

And now we can use our model without downloading anything!