Upload folder using huggingface_hub
Browse files- config.json +4 -0
- mexma_siglip.py +126 -0
config.json
CHANGED
@@ -1,4 +1,8 @@
|
|
1 |
{
|
|
|
|
|
|
|
|
|
2 |
"architectures": [
|
3 |
"MexmaSigLIP"
|
4 |
],
|
|
|
1 |
{
|
2 |
+
"auto_map": {
|
3 |
+
"AutoConfig": "mexma_siglip.MexmaSigLIPConfig",
|
4 |
+
"AutoModel": "mexma_siglip.MexmaSigLIP"
|
5 |
+
},
|
6 |
"architectures": [
|
7 |
"MexmaSigLIP"
|
8 |
],
|
mexma_siglip.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Union
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from PIL import Image
|
8 |
+
from transformers import (
|
9 |
+
PretrainedConfig,
|
10 |
+
PreTrainedModel,
|
11 |
+
SiglipVisionConfig,
|
12 |
+
SiglipVisionModel,
|
13 |
+
XLMRobertaConfig,
|
14 |
+
XLMRobertaModel,
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
class MexmaSigLIPConfig(PretrainedConfig):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
optimized: bool = False,
|
22 |
+
**kwargs,
|
23 |
+
):
|
24 |
+
super().__init__(**kwargs)
|
25 |
+
self.optimized = optimized
|
26 |
+
|
27 |
+
|
28 |
+
class MLP(nn.Module):
|
29 |
+
def __init__(self, hidden_size: int, intermediate_size: int):
|
30 |
+
super().__init__()
|
31 |
+
self.fc1 = nn.Linear(hidden_size, intermediate_size)
|
32 |
+
self.fc2 = nn.Linear(intermediate_size, hidden_size)
|
33 |
+
|
34 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
35 |
+
hidden_states = self.fc1(hidden_states)
|
36 |
+
hidden_states = nn.SiLU()(hidden_states)
|
37 |
+
hidden_states = self.fc2(hidden_states)
|
38 |
+
return hidden_states
|
39 |
+
|
40 |
+
class MultiheadAttentionPoolingHead(nn.Module):
|
41 |
+
def __init__(self, hidden_size: int, out_hidden_size: int, num_attention_heads: int, layer_norm_eps: float, intermediate_size: int):
|
42 |
+
super().__init__()
|
43 |
+
|
44 |
+
self.probe = nn.Parameter(torch.randn(1, 1, hidden_size))
|
45 |
+
self.attention = torch.nn.MultiheadAttention(hidden_size, num_attention_heads, batch_first=True)
|
46 |
+
self.layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
47 |
+
self.mlp = MLP(hidden_size, intermediate_size)
|
48 |
+
self.projector = nn.Linear(hidden_size, out_hidden_size)
|
49 |
+
|
50 |
+
def forward(self, hidden_state):
|
51 |
+
batch_size = hidden_state.shape[0]
|
52 |
+
probe = self.probe.repeat(batch_size, 1, 1)
|
53 |
+
|
54 |
+
hidden_state = self.attention(probe, hidden_state, hidden_state)[0]
|
55 |
+
|
56 |
+
residual = hidden_state
|
57 |
+
hidden_state = self.layernorm(hidden_state)
|
58 |
+
hidden_state = residual + self.mlp(hidden_state)
|
59 |
+
hidden_state = self.projector(hidden_state)
|
60 |
+
return hidden_state[:, 0]
|
61 |
+
|
62 |
+
|
63 |
+
class MexmaSigLIP(PreTrainedModel):
|
64 |
+
config_class = MexmaSigLIPConfig
|
65 |
+
|
66 |
+
def __init__(self, config: MexmaSigLIPConfig):
|
67 |
+
super().__init__(config)
|
68 |
+
self.config = config
|
69 |
+
text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA")
|
70 |
+
if self.config.optimized:
|
71 |
+
text_config._attn_implementation = "sdpa"
|
72 |
+
self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False)
|
73 |
+
self.text_projector = MultiheadAttentionPoolingHead(1024, 1152, 16, 1e-5, 4304)
|
74 |
+
vision_congig = SiglipVisionConfig.from_pretrained(
|
75 |
+
"google/siglip2-so400m-patch16-512"
|
76 |
+
)
|
77 |
+
if self.config.optimized:
|
78 |
+
vision_congig._attn_implementation = "flash_attention_2"
|
79 |
+
vision_congig.torch_dtype = "bfloat16"
|
80 |
+
self.vision_model = SiglipVisionModel(vision_congig).vision_model
|
81 |
+
self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
82 |
+
self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10)
|
83 |
+
|
84 |
+
def forward(self, image_inputs, input_ids, attention_mask, normalize=False):
|
85 |
+
text_features = self.encode_texts(input_ids, attention_mask, normalize)
|
86 |
+
image_features = self.encode_images(image_inputs, normalize)
|
87 |
+
return {
|
88 |
+
"image_features": image_features,
|
89 |
+
"text_features": text_features,
|
90 |
+
"logit_scale": self.logit_scale,
|
91 |
+
"logit_bias": self.logit_bias,
|
92 |
+
}
|
93 |
+
|
94 |
+
def encode_images(
|
95 |
+
self,
|
96 |
+
pixel_values,
|
97 |
+
normalize=False,
|
98 |
+
):
|
99 |
+
features = self.vision_model(pixel_values).pooler_output
|
100 |
+
return F.normalize(features, dim=-1) if normalize else features
|
101 |
+
|
102 |
+
def encode_texts(
|
103 |
+
self,
|
104 |
+
input_ids,
|
105 |
+
attention_mask,
|
106 |
+
normalize=False,
|
107 |
+
):
|
108 |
+
features = self.text_model(
|
109 |
+
input_ids=input_ids, attention_mask=attention_mask
|
110 |
+
).last_hidden_state
|
111 |
+
features = self.text_projector(features)
|
112 |
+
return F.normalize(features, dim=-1) if normalize else features
|
113 |
+
|
114 |
+
def get_logits(
|
115 |
+
self,
|
116 |
+
input_ids,
|
117 |
+
attention_mask,
|
118 |
+
pixel_values,
|
119 |
+
):
|
120 |
+
image_features = self.encode_images(pixel_values, normalize=True)
|
121 |
+
text_features = self.encode_texts(input_ids, attention_mask, normalize=True)
|
122 |
+
image_logits = (
|
123 |
+
self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias
|
124 |
+
)
|
125 |
+
text_logits = image_logits.T
|
126 |
+
return image_logits, text_logits
|