File size: 1,609 Bytes
2ce40b4
 
ccfaaf5
2ce40b4
 
09d9f2a
2ce40b4
 
 
 
ccfaaf5
2ce40b4
 
ccfaaf5
 
 
 
 
 
 
 
 
2ce40b4
 
1fea6dc
8efe266
1fea6dc
8efe266
1fea6dc
c1e5327
 
1fea6dc
8efe266
 
1fea6dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
import torch.nn as nn
from transformers import PreTrainedModel, AutoConfig, AutoModel

class CustomModel(PreTrainedModel):
    config_class = AutoConfig  # Use AutoConfig to dynamically load the configuration class

    def __init__(self, config):
        super().__init__(config)
        # Implement your model architecture here
        self.encoder = AutoModel.from_config(config)  # Load the base model
        self.classifier = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, input_ids, attention_mask=None):
        # Pass inputs through the encoder
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        # Get the pooled output (e.g., CLS token for classification tasks)
        pooled_output = outputs.last_hidden_state[:, 0, :]
        # Pass through the classifier
        logits = self.classifier(pooled_output)
        return logits

    @classmethod
    def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
        try:
            # Load the configuration
            config = cls.config_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
            # Initialize the model with the configuration
            model = cls(config)
            # Optionally, you can load the state_dict here if needed
            # model.load_state_dict(torch.load(os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")))
            return model
        except Exception as e:
            print(f"Failed to load model from {pretrained_model_name_or_path}. Error: {e}")
            return None