ZhengPeng7 commited on
Commit
4e60c70
1 Parent(s): f6b7155

Add inference endpoint feature in HF model page.

Browse files
Files changed (4) hide show
  1. README.md +46 -1
  2. birefnet.py +30 -27
  3. handler.py +132 -0
  4. requirements.txt +18 -0
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- library_name: birefnet
3
  tags:
4
  - background-removal
5
  - mask-generation
@@ -141,6 +141,51 @@ plt.show()
141
 
142
  ```
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  > This BiRefNet for standard dichotomous image segmentation (DIS) is trained on **DIS-TR** and validated on **DIS-TEs and DIS-VD**.
146
 
 
1
  ---
2
+ library_name: BiRefNet
3
  tags:
4
  - background-removal
5
  - mask-generation
 
141
 
142
  ```
143
 
144
+ ### 2. Use inference endpoint locally:
145
+ > You may need to click the *deploy* and set up the endpoint by yourself, which would make some costs.
146
+ ```
147
+ import requests
148
+ import base64
149
+ from io import BytesIO
150
+ from PIL import Image
151
+
152
+
153
+ YOUR_HF_TOKEN = 'xxx'
154
+ API_URL = "xxx"
155
+ headers = {
156
+ "Authorization": "Bearer {}".format(YOUR_HF_TOKEN)
157
+ }
158
+
159
+ def base64_to_bytes(base64_string):
160
+ # Remove the data URI prefix if present
161
+ if "data:image" in base64_string:
162
+ base64_string = base64_string.split(",")[1]
163
+
164
+ # Decode the Base64 string into bytes
165
+ image_bytes = base64.b64decode(base64_string)
166
+ return image_bytes
167
+
168
+ def bytes_to_base64(image_bytes):
169
+ # Create a BytesIO object to handle the image data
170
+ image_stream = BytesIO(image_bytes)
171
+
172
+ # Open the image using Pillow (PIL)
173
+ image = Image.open(image_stream)
174
+ return image
175
+
176
+ def query(payload):
177
+ response = requests.post(API_URL, headers=headers, json=payload)
178
+ return response.json()
179
+
180
+ output = query({
181
+ "inputs": "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg",
182
+ "parameters": {}
183
+ })
184
+
185
+ output_image = bytes_to_base64(base64_to_bytes(output))
186
+ output_image
187
+ ```
188
+
189
 
190
  > This BiRefNet for standard dichotomous image segmentation (DIS) is trained on **DIS-TR** and validated on **DIS-TEs and DIS-VD**.
191
 
birefnet.py CHANGED
@@ -7,7 +7,7 @@ import math
7
  class Config():
8
  def __init__(self) -> None:
9
  # PATH settings
10
- self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
11
 
12
  # TASK settings
13
  self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
@@ -615,6 +615,7 @@ from timm.models.layers import DropPath, to_2tuple, trunc_normal_
615
 
616
  # config = Config()
617
 
 
618
  class Mlp(nn.Module):
619
  """ Multilayer perceptron."""
620
 
@@ -739,7 +740,8 @@ class WindowAttention(nn.Module):
739
  attn = (q @ k.transpose(-2, -1))
740
 
741
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
742
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
 
743
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
744
  attn = attn + relative_position_bias.unsqueeze(0)
745
 
@@ -974,8 +976,9 @@ class BasicLayer(nn.Module):
974
  """
975
 
976
  # calculate attention mask for SW-MSA
977
- Hp = int(np.ceil(H / self.window_size)) * self.window_size
978
- Wp = int(np.ceil(W / self.window_size)) * self.window_size
 
979
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
980
  h_slices = (slice(0, -self.window_size),
981
  slice(-self.window_size, -self.shift_size),
@@ -1961,6 +1964,7 @@ import torch.nn as nn
1961
  import torch.nn.functional as F
1962
  from kornia.filters import laplacian
1963
  from transformers import PreTrainedModel
 
1964
 
1965
  # from config import Config
1966
  # from dataset import class_labels_TR_sorted
@@ -1974,13 +1978,24 @@ from transformers import PreTrainedModel
1974
  from .BiRefNet_config import BiRefNetConfig
1975
 
1976
 
 
 
 
 
 
 
 
 
 
 
 
 
1977
  class BiRefNet(
1978
  PreTrainedModel
1979
  ):
1980
  config_class = BiRefNetConfig
1981
  def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1982
- super(BiRefNet, self).__init__(config)
1983
- bb_pretrained = config.bb_pretrained
1984
  self.config = Config()
1985
  self.epoch = 1
1986
  self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
@@ -2124,18 +2139,6 @@ class Decoder(nn.Module):
2124
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2125
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2126
 
2127
- def get_patches_batch(self, x, p):
2128
- _size_h, _size_w = p.shape[2:]
2129
- patches_batch = []
2130
- for idx in range(x.shape[0]):
2131
- columns_x = torch.split(x[idx], split_size_or_sections=_size_w, dim=-1)
2132
- patches_x = []
2133
- for column_x in columns_x:
2134
- patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
2135
- patch_sample = torch.cat(patches_x, dim=1)
2136
- patches_batch.append(patch_sample)
2137
- return torch.cat(patches_batch, dim=0)
2138
-
2139
  def forward(self, features):
2140
  if self.training and self.config.out_ref:
2141
  outs_gdt_pred = []
@@ -2146,10 +2149,10 @@ class Decoder(nn.Module):
2146
  outs = []
2147
 
2148
  if self.config.dec_ipt:
2149
- patches_batch = self.get_patches_batch(x, x4) if self.split else x
2150
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2151
  p4 = self.decoder_block4(x4)
2152
- m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision else None
2153
  if self.config.out_ref:
2154
  p4_gdt = self.gdt_convs_4(p4)
2155
  if self.training:
@@ -2167,10 +2170,10 @@ class Decoder(nn.Module):
2167
  _p3 = _p4 + self.lateral_block4(x3)
2168
 
2169
  if self.config.dec_ipt:
2170
- patches_batch = self.get_patches_batch(x, _p3) if self.split else x
2171
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2172
  p3 = self.decoder_block3(_p3)
2173
- m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision else None
2174
  if self.config.out_ref:
2175
  p3_gdt = self.gdt_convs_3(p3)
2176
  if self.training:
@@ -2193,10 +2196,10 @@ class Decoder(nn.Module):
2193
  _p2 = _p3 + self.lateral_block3(x2)
2194
 
2195
  if self.config.dec_ipt:
2196
- patches_batch = self.get_patches_batch(x, _p2) if self.split else x
2197
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2198
  p2 = self.decoder_block2(_p2)
2199
- m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision else None
2200
  if self.config.out_ref:
2201
  p2_gdt = self.gdt_convs_2(p2)
2202
  if self.training:
@@ -2214,17 +2217,17 @@ class Decoder(nn.Module):
2214
  _p1 = _p2 + self.lateral_block2(x1)
2215
 
2216
  if self.config.dec_ipt:
2217
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2218
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2219
  _p1 = self.decoder_block1(_p1)
2220
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2221
 
2222
  if self.config.dec_ipt:
2223
- patches_batch = self.get_patches_batch(x, _p1) if self.split else x
2224
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2225
  p1_out = self.conv_out1(_p1)
2226
 
2227
- if self.config.ms_supervision:
2228
  outs.append(m4)
2229
  outs.append(m3)
2230
  outs.append(m2)
 
7
  class Config():
8
  def __init__(self) -> None:
9
  # PATH settings
10
+ self.sys_home_dir = os.path.expanduser('~') # Make up your file system as: SYS_HOME_DIR/codes/dis/BiRefNet, SYS_HOME_DIR/datasets/dis/xx, SYS_HOME_DIR/weights/xx
11
 
12
  # TASK settings
13
  self.task = ['DIS5K', 'COD', 'HRSOD', 'DIS5K+HRSOD+HRS10K', 'P3M-10k'][0]
 
615
 
616
  # config = Config()
617
 
618
+
619
  class Mlp(nn.Module):
620
  """ Multilayer perceptron."""
621
 
 
740
  attn = (q @ k.transpose(-2, -1))
741
 
742
  relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
743
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
744
+ ) # Wh*Ww, Wh*Ww, nH
745
  relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
746
  attn = attn + relative_position_bias.unsqueeze(0)
747
 
 
976
  """
977
 
978
  # calculate attention mask for SW-MSA
979
+ # Turn int to torch.tensor for the compatiability with torch.compile in PyTorch 2.5.
980
+ Hp = torch.ceil(torch.tensor(H) / self.window_size).to(torch.int64) * self.window_size
981
+ Wp = torch.ceil(torch.tensor(W) / self.window_size).to(torch.int64) * self.window_size
982
  img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
983
  h_slices = (slice(0, -self.window_size),
984
  slice(-self.window_size, -self.shift_size),
 
1964
  import torch.nn.functional as F
1965
  from kornia.filters import laplacian
1966
  from transformers import PreTrainedModel
1967
+ from einops import rearrange
1968
 
1969
  # from config import Config
1970
  # from dataset import class_labels_TR_sorted
 
1978
  from .BiRefNet_config import BiRefNetConfig
1979
 
1980
 
1981
+ def image2patches(image, grid_h=2, grid_w=2, patch_ref=None, transformation='b c (hg h) (wg w) -> (b hg wg) c h w'):
1982
+ if patch_ref is not None:
1983
+ grid_h, grid_w = image.shape[-2] // patch_ref.shape[-2], image.shape[-1] // patch_ref.shape[-1]
1984
+ patches = rearrange(image, transformation, hg=grid_h, wg=grid_w)
1985
+ return patches
1986
+
1987
+ def patches2image(patches, grid_h=2, grid_w=2, patch_ref=None, transformation='(b hg wg) c h w -> b c (hg h) (wg w)'):
1988
+ if patch_ref is not None:
1989
+ grid_h, grid_w = patch_ref.shape[-2] // patches[0].shape[-2], patch_ref.shape[-1] // patches[0].shape[-1]
1990
+ image = rearrange(patches, transformation, hg=grid_h, wg=grid_w)
1991
+ return image
1992
+
1993
  class BiRefNet(
1994
  PreTrainedModel
1995
  ):
1996
  config_class = BiRefNetConfig
1997
  def __init__(self, bb_pretrained=True, config=BiRefNetConfig()):
1998
+ super(BiRefNet, self).__init__()
 
1999
  self.config = Config()
2000
  self.epoch = 1
2001
  self.bb = build_backbone(self.config.bb, pretrained=bb_pretrained)
 
2139
  self.gdt_convs_attn_3 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2140
  self.gdt_convs_attn_2 = nn.Sequential(nn.Conv2d(_N, 1, 1, 1, 0))
2141
 
 
 
 
 
 
 
 
 
 
 
 
 
2142
  def forward(self, features):
2143
  if self.training and self.config.out_ref:
2144
  outs_gdt_pred = []
 
2149
  outs = []
2150
 
2151
  if self.config.dec_ipt:
2152
+ patches_batch = image2patches(x, patch_ref=x4, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2153
  x4 = torch.cat((x4, self.ipt_blk5(F.interpolate(patches_batch, size=x4.shape[2:], mode='bilinear', align_corners=True))), 1)
2154
  p4 = self.decoder_block4(x4)
2155
+ m4 = self.conv_ms_spvn_4(p4) if self.config.ms_supervision and self.training else None
2156
  if self.config.out_ref:
2157
  p4_gdt = self.gdt_convs_4(p4)
2158
  if self.training:
 
2170
  _p3 = _p4 + self.lateral_block4(x3)
2171
 
2172
  if self.config.dec_ipt:
2173
+ patches_batch = image2patches(x, patch_ref=_p3, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2174
  _p3 = torch.cat((_p3, self.ipt_blk4(F.interpolate(patches_batch, size=x3.shape[2:], mode='bilinear', align_corners=True))), 1)
2175
  p3 = self.decoder_block3(_p3)
2176
+ m3 = self.conv_ms_spvn_3(p3) if self.config.ms_supervision and self.training else None
2177
  if self.config.out_ref:
2178
  p3_gdt = self.gdt_convs_3(p3)
2179
  if self.training:
 
2196
  _p2 = _p3 + self.lateral_block3(x2)
2197
 
2198
  if self.config.dec_ipt:
2199
+ patches_batch = image2patches(x, patch_ref=_p2, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2200
  _p2 = torch.cat((_p2, self.ipt_blk3(F.interpolate(patches_batch, size=x2.shape[2:], mode='bilinear', align_corners=True))), 1)
2201
  p2 = self.decoder_block2(_p2)
2202
+ m2 = self.conv_ms_spvn_2(p2) if self.config.ms_supervision and self.training else None
2203
  if self.config.out_ref:
2204
  p2_gdt = self.gdt_convs_2(p2)
2205
  if self.training:
 
2217
  _p1 = _p2 + self.lateral_block2(x1)
2218
 
2219
  if self.config.dec_ipt:
2220
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2221
  _p1 = torch.cat((_p1, self.ipt_blk2(F.interpolate(patches_batch, size=x1.shape[2:], mode='bilinear', align_corners=True))), 1)
2222
  _p1 = self.decoder_block1(_p1)
2223
  _p1 = F.interpolate(_p1, size=x.shape[2:], mode='bilinear', align_corners=True)
2224
 
2225
  if self.config.dec_ipt:
2226
+ patches_batch = image2patches(x, patch_ref=_p1, transformation='b c (hg h) (wg w) -> b (c hg wg) h w') if self.split else x
2227
  _p1 = torch.cat((_p1, self.ipt_blk1(F.interpolate(patches_batch, size=x.shape[2:], mode='bilinear', align_corners=True))), 1)
2228
  p1_out = self.conv_out1(_p1)
2229
 
2230
+ if self.config.ms_supervision and self.training:
2231
  outs.append(m4)
2232
  outs.append(m3)
2233
  outs.append(m2)
handler.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # These HF deployment codes refer to https://huggingface.co/not-lain/BiRefNet/raw/main/handler.py.
2
+ from typing import Dict, List, Any, Tuple
3
+ import os
4
+ import requests
5
+ from io import BytesIO
6
+ import cv2
7
+ import numpy as np
8
+ from PIL import Image
9
+ import torch
10
+ from torchvision import transforms
11
+ from transformers import AutoModelForImageSegmentation
12
+
13
+ torch.set_float32_matmul_precision(["high", "highest"][0])
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ ### image_proc.py
18
+ def refine_foreground(image, mask, r=90):
19
+ if mask.size != image.size:
20
+ mask = mask.resize(image.size)
21
+ image = np.array(image) / 255.0
22
+ mask = np.array(mask) / 255.0
23
+ estimated_foreground = FB_blur_fusion_foreground_estimator_2(image, mask, r=r)
24
+ image_masked = Image.fromarray((estimated_foreground * 255.0).astype(np.uint8))
25
+ return image_masked
26
+
27
+
28
+ def FB_blur_fusion_foreground_estimator_2(image, alpha, r=90):
29
+ # Thanks to the source: https://github.com/Photoroom/fast-foreground-estimation
30
+ alpha = alpha[:, :, None]
31
+ F, blur_B = FB_blur_fusion_foreground_estimator(image, image, image, alpha, r)
32
+ return FB_blur_fusion_foreground_estimator(image, F, blur_B, alpha, r=6)[0]
33
+
34
+
35
+ def FB_blur_fusion_foreground_estimator(image, F, B, alpha, r=90):
36
+ if isinstance(image, Image.Image):
37
+ image = np.array(image) / 255.0
38
+ blurred_alpha = cv2.blur(alpha, (r, r))[:, :, None]
39
+
40
+ blurred_FA = cv2.blur(F * alpha, (r, r))
41
+ blurred_F = blurred_FA / (blurred_alpha + 1e-5)
42
+
43
+ blurred_B1A = cv2.blur(B * (1 - alpha), (r, r))
44
+ blurred_B = blurred_B1A / ((1 - blurred_alpha) + 1e-5)
45
+ F = blurred_F + alpha * \
46
+ (image - alpha * blurred_F - (1 - alpha) * blurred_B)
47
+ F = np.clip(F, 0, 1)
48
+ return F, blurred_B
49
+
50
+
51
+ class ImagePreprocessor():
52
+ def __init__(self, resolution: Tuple[int, int] = (1024, 1024)) -> None:
53
+ self.transform_image = transforms.Compose([
54
+ transforms.Resize(resolution),
55
+ transforms.ToTensor(),
56
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
57
+ ])
58
+
59
+ def proc(self, image: Image.Image) -> torch.Tensor:
60
+ image = self.transform_image(image)
61
+ return image
62
+
63
+ usage_to_weights_file = {
64
+ 'General': 'BiRefNet',
65
+ 'General-Lite': 'BiRefNet_lite',
66
+ 'General-Lite-2K': 'BiRefNet_lite-2K',
67
+ 'General-reso_512': 'BiRefNet-reso_512',
68
+ 'Matting': 'BiRefNet-matting',
69
+ 'Portrait': 'BiRefNet-portrait',
70
+ 'DIS': 'BiRefNet-DIS5K',
71
+ 'HRSOD': 'BiRefNet-HRSOD',
72
+ 'COD': 'BiRefNet-COD',
73
+ 'DIS-TR_TEs': 'BiRefNet-DIS5K-TR_TEs',
74
+ 'General-legacy': 'BiRefNet-legacy'
75
+ }
76
+
77
+ # Choose the version of BiRefNet here.
78
+ usage = 'General'
79
+
80
+ # Set resolution
81
+ if usage in ['General-Lite-2K']:
82
+ resolution = (2560, 1440)
83
+ elif usage in ['General-reso_512']:
84
+ resolution = (512, 512)
85
+ else:
86
+ resolution = (1024, 1024)
87
+
88
+
89
+ class EndpointHandler():
90
+ def __init__(self, path=''):
91
+ self.birefnet = AutoModelForImageSegmentation.from_pretrained(
92
+ '/'.join(('zhengpeng7', usage_to_weights_file[usage])), trust_remote_code=True
93
+ )
94
+ self.birefnet.to(device)
95
+ self.birefnet.eval()
96
+
97
+ def __call__(self, data: Dict[str, Any]):
98
+ """
99
+ data args:
100
+ inputs (:obj: `str`)
101
+ date (:obj: `str`)
102
+ Return:
103
+ A :obj:`list` | `dict`: will be serialized and returned
104
+ """
105
+ print('data["inputs"] = ', data["inputs"])
106
+ image_src = data["inputs"]
107
+ if isinstance(image_src, str):
108
+ if os.path.isfile(image_src):
109
+ image_ori = Image.open(image_src)
110
+ else:
111
+ response = requests.get(image_src)
112
+ image_data = BytesIO(response.content)
113
+ image_ori = Image.open(image_data)
114
+ else:
115
+ image_ori = Image.fromarray(image_src)
116
+
117
+ image = image_ori.convert('RGB')
118
+ # Preprocess the image
119
+ image_preprocessor = ImagePreprocessor(resolution=tuple(resolution))
120
+ image_proc = image_preprocessor.proc(image)
121
+ image_proc = image_proc.unsqueeze(0)
122
+
123
+ # Prediction
124
+ with torch.no_grad():
125
+ preds = self.birefnet(image_proc.to(device))[-1].sigmoid().cpu()
126
+ pred = preds[0].squeeze()
127
+
128
+ # Show Results
129
+ pred_pil = transforms.ToPILImage()(pred)
130
+ image_masked = refine_foreground(image, pred_pil)
131
+ image_masked.putalpha(pred_pil.resize(image.size))
132
+ return image_masked
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu118
2
+ torch==2.0.1
3
+ --extra-index-url https://download.pytorch.org/whl/cu118
4
+ torchvision==0.15.2
5
+ numpy<2
6
+ opencv-python
7
+ timm
8
+ scipy
9
+ scikit-image
10
+ kornia
11
+ einops
12
+
13
+ tqdm
14
+ prettytable
15
+
16
+ transformers
17
+ huggingface-hub>0.25
18
+ accelerate