Add mixin
Browse files- briarmbg.py +3 -1
- mixin.py +15 -0
briarmbg.py
CHANGED
@@ -2,6 +2,8 @@ import torch
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
|
|
|
|
5 |
class REBNCONV(nn.Module):
|
6 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
7 |
super(REBNCONV,self).__init__()
|
@@ -344,7 +346,7 @@ class myrebnconv(nn.Module):
|
|
344 |
return self.rl(self.bn(self.conv(x)))
|
345 |
|
346 |
|
347 |
-
class BriaRMBG(nn.Module):
|
348 |
|
349 |
def __init__(self,in_ch=3,out_ch=1):
|
350 |
super(BriaRMBG,self).__init__()
|
|
|
2 |
import torch.nn as nn
|
3 |
import torch.nn.functional as F
|
4 |
|
5 |
+
from huggingface_hub import PyTorchModelHubMixin
|
6 |
+
|
7 |
class REBNCONV(nn.Module):
|
8 |
def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1):
|
9 |
super(REBNCONV,self).__init__()
|
|
|
346 |
return self.rl(self.bn(self.conv(x)))
|
347 |
|
348 |
|
349 |
+
class BriaRMBG(nn.Module, PyTorchModelHubMixin):
|
350 |
|
351 |
def __init__(self,in_ch=3,out_ch=1):
|
352 |
super(BriaRMBG,self).__init__()
|
mixin.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from briarmbg import BriaRMBG
|
2 |
+
import torch
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
|
5 |
+
model_path = hf_hub_download("briaai/RMBG-1.4", 'model.pth')
|
6 |
+
|
7 |
+
net = BriaRMBG()
|
8 |
+
net.load_state_dict(torch.load(model_path, map_location="cpu"))
|
9 |
+
net.eval()
|
10 |
+
|
11 |
+
# push to hub
|
12 |
+
net.push_to_hub("nielsr/RMBG-1.4")
|
13 |
+
|
14 |
+
# reload
|
15 |
+
net = BriaRMBG.from_pretrained("nielsr/RMBG-1.4")
|