Spaces:
Runtime error
Runtime error
File size: 1,609 Bytes
2ce40b4 ccfaaf5 2ce40b4 09d9f2a 2ce40b4 ccfaaf5 2ce40b4 ccfaaf5 2ce40b4 1fea6dc 8efe266 1fea6dc 8efe266 1fea6dc c1e5327 1fea6dc 8efe266 1fea6dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 |
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoConfig, AutoModel
class CustomModel(PreTrainedModel):
config_class = AutoConfig # Use AutoConfig to dynamically load the configuration class
def __init__(self, config):
super().__init__(config)
# Implement your model architecture here
self.encoder = AutoModel.from_config(config) # Load the base model
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, input_ids, attention_mask=None):
# Pass inputs through the encoder
outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
# Get the pooled output (e.g., CLS token for classification tasks)
pooled_output = outputs.last_hidden_state[:, 0, :]
# Pass through the classifier
logits = self.classifier(pooled_output)
return logits
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
try:
# Load the configuration
config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
# Initialize the model with the configuration
model = cls(config)
# Optionally, you can load the state_dict here if needed
# model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")))
return model
except Exception as e:
print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
return None |