Zigeng commited on
Commit
8555086
1 Parent(s): 9f77655

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +223 -3
README.md CHANGED
@@ -1,10 +1,230 @@
1
  ---
2
  license: apache-2.0
3
  ---
 
 
 
 
 
 
4
  > **0.1% Data Makes Segment Anything Slim**
5
  > [Zigeng Chen](https://github.com/czg1225), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
6
  > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
7
  > Paper: [[Arxiv]](https://arxiv.org/abs/2312.05284)
8
- > GitHub: [[SlimSAM]](https://github.com/czg1225/SlimSAM)
9
- >
10
- > **SlimSAM** is a novel SAM compression method, which efficiently reuses pre-trained SAMs without the necessity for extensive retraining. This is achieved by the efficient reuse of pre-trained SAMs through a unified pruning-distillation framework. To enhance knowledge inheritance from the original SAM, we employ an innovative alternate slimming strategy that partitions the compression process into a progressive procedure. Diverging from prior pruning techniques, we meticulously prune and distill decoupled model structures in an alternating fashion. Furthermore, a novel label-free pruning criterion is also proposed to align the pruning objective with the optimization target, thereby boosting the post-distillation after pruning. SlimSAM achieves approaching performance while reducing the parameter counts to **0.9\% (5.7M)**, MACs to **0.8\% (21G)**, and requiring mere **0.1\% (10k)** of the training data when compared to the original SAM-H. Extensive experiments demonstrate that our method realize significant superior performance while utilizing over **10 times** less training data when compared to other SAM compression methods.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: apache-2.0
3
  ---
4
+ # SlimSAM: 0.1% Data Makes Segment Anything Slim
5
+ <div align="center">
6
+ <img src="images/paper/intro.PNG" width="66%">
7
+ <img src="images/paper/everything.PNG" width="100%">
8
+ </div>
9
+
10
  > **0.1% Data Makes Segment Anything Slim**
11
  > [Zigeng Chen](https://github.com/czg1225), [Gongfan Fang](https://fangggf.github.io/), [Xinyin Ma](https://horseee.github.io/), [Xinchao Wang](https://sites.google.com/site/sitexinchaowang/)
12
  > [Learning and Vision Lab](http://lv-nus.org/), National University of Singapore
13
  > Paper: [[Arxiv]](https://arxiv.org/abs/2312.05284)
14
+
15
+ ### Updates
16
+ * 🚀 **December 11, 2023**: Release the training code, inference code and pre-trained models for **SlimSAM**.
17
+
18
+ ## Introduction
19
+
20
+ <div align="center">
21
+ <img src="images/paper/process.PNG" width="100%">
22
+ </div>
23
+
24
+ **SlimSAM** is a novel SAM compression method, which efficiently reuses pre-trained SAMs without the necessity for extensive retraining. This is achieved by the efficient reuse of pre-trained SAMs through a unified pruning-distillation framework. To enhance knowledge inheritance from the original SAM, we employ an innovative alternate slimming strategy that partitions the compression process into a progressive procedure. Diverging from prior pruning techniques, we meticulously prune and distill decoupled model structures in an alternating fashion. Furthermore, a novel label-free pruning criterion is also proposed to align the pruning objective with the optimization target, thereby boosting the post-distillation after pruning.
25
+
26
+ ![Frame](images/paper/frame.PNG?raw=true)
27
+
28
+ SlimSAM achieves approaching performance while reducing the parameter counts to **0.9\% (5.7M)**, MACs to **0.8\% (21G)**, and requiring mere **0.1\% (10k)** of the training data when compared to the original SAM-H. Extensive experiments demonstrate that our method realize significant superior performance while utilizing over **10 times** less training data when compared to other SAM compression methods.
29
+
30
+ ## Visualization Results
31
+
32
+ Qualitative comparison of results obtained using point prompts, box prompts, and segment everything prompts are shown in the following section.
33
+
34
+ ### Segment Everything Prompts
35
+ <div align="center">
36
+ <img src="images/paper/everything2.PNG" width="100%">
37
+ </div>
38
+
39
+ ### Box Prompts and Point Prompts
40
+ <div align="center">
41
+ <img src="images/paper/prompt.PNG" width="100%">
42
+ </div>
43
+
44
+
45
+ ## Quantitative Results
46
+
47
+ We conducted a comprehensive comparison encompassing performance, efficiency, and training costs with other SAM compression methods and structural pruning methods.
48
+
49
+ ### Comparing with other SAM compression methods.
50
+ <div align="center">
51
+ <img src="images/paper/compare_tab1.PNG" width="100%">
52
+ </div>
53
+
54
+ ### Comparing with other structural pruning methods.
55
+ <div align="center">
56
+ <img src="images/paper/compare_tab2.PNG" width="50%">
57
+ </div>
58
+
59
+ ## Installation
60
+
61
+ The code requires `python>=3.8`, as well as `pytorch>=1.7` and `torchvision>=0.8`. Please follow the instructions [here](https://pytorch.org/get-started/locally/) to install both PyTorch and TorchVision dependencies. Installing both PyTorch and TorchVision with CUDA support is strongly recommended.
62
+
63
+
64
+ Install with
65
+
66
+ ```
67
+ pip install -e .
68
+ ```
69
+
70
+ The following optional dependencies are necessary for mask post-processing, saving masks in COCO format.
71
+
72
+ ```
73
+ pip install opencv-python pycocotools matplotlib
74
+ ```
75
+
76
+ ## Dataset
77
+ We use the original SA-1B dataset in our code. See [here](https://ai.facebook.com/datasets/segment-anything/) for an overview of the datastet. The dataset can be downloaded [here](https://ai.facebook.com/datasets/segment-anything-downloads/).
78
+
79
+ The download dataset should be saved as:
80
+
81
+ ```
82
+ <train_data_root>/
83
+ sa_xxxxxxx.jpg
84
+ sa_xxxxxxx.json
85
+ ......
86
+ <val_data_root>/
87
+ sa_xxxxxxx.jpg
88
+ sa_xxxxxxx.json
89
+ ......
90
+
91
+ ```
92
+
93
+
94
+ To decode a mask in COCO RLE format into binary:
95
+
96
+ ```
97
+ from pycocotools import mask as mask_utils
98
+ mask = mask_utils.decode(annotation["segmentation"])
99
+ ```
100
+
101
+ See [here](https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/mask.py) for more instructions to manipulate masks stored in RLE format.
102
+
103
+
104
+ ## <a name="Models"></a>Model Checkpoints
105
+
106
+ The base model of our method is available. To enhance collaboration with our dependency dectection algorithm, we have split the original image encoder's qkv layer into three distinct linear layers: q, k, and v.
107
+ <div align="center">
108
+ <img src="images/paper/split.PNG" width="70%">
109
+ </div>
110
+
111
+
112
+
113
+ Click the links below to download the checkpoints of orginal SAM-B.
114
+
115
+ - `SAM-B`: [SAM-B model.](https://drive.google.com/file/d/1CtcyOm4h9bXgBF8DEVWn3N7g9-3r4Xzz/view?usp=sharing)
116
+
117
+ The check points of our SlimSAM are avalable. We release two versions, which are SlimSAM-50 (pruning ratio = 50%) and SlimSAM-77 (pruning ratio = 77%).
118
+
119
+ Click the links below to download the checkpoints for the corresponding pruning ratio.
120
+
121
+ - `SlimSAM-50`: [SlimSAM-50 model.](https://drive.google.com/file/d/1iCN9IW0Su0Ud_fOFoQUnTdkC3bFveMND/view?usp=sharing)
122
+ - `SlimSAM-77`: [SlimSAM-77 model.](https://drive.google.com/file/d/1L7LB6gHDzR-3D63pH9acD9E0Ul9_wMF-/view)
123
+
124
+
125
+ These models can be instantiated by running
126
+
127
+ ```
128
+ import torch
129
+ SlimSAM_model = torch.load(<model_path>)
130
+ SlimSAM_model.image_encoder = SlimSAM_model.image_encoder.module
131
+
132
+ def forward(self, x):
133
+
134
+ x = self.patch_embed(x)
135
+ if self.pos_embed is not None:
136
+ x = x + self.pos_embed
137
+
138
+ for blk in self.blocks:
139
+ x,qkv_emb,mid_emb,x_emb = blk(x)
140
+
141
+ x = self.neck(x.permute(0, 3, 1, 2))
142
+
143
+ return x
144
+
145
+ import types
146
+ funcType = types.MethodType
147
+ SlimSAM_model.image_encoder.forward = funcType(forward, SlimSAM_model.image_encoder)
148
+ ```
149
+
150
+
151
+ ## <a name="Inference"></a>Inference
152
+
153
+ First download [SlimSAM-50 model](https://drive.google.com/file/d/1iCN9IW0Su0Ud_fOFoQUnTdkC3bFveMND/view?usp=sharing) or [SlimSAM-77 model](https://drive.google.com/file/d/1L7LB6gHDzR-3D63pH9acD9E0Ul9_wMF-/view) for inference
154
+
155
+
156
+ We provide detailed instructions in 'inference.py' on how to use a range of prompts, including 'point' and 'box' and 'everything', for inference purposes.
157
+
158
+ ```
159
+ CUDA_VISIBLE_DEVICES=0 python inference.py
160
+ ```
161
+
162
+ ## <a name="Train"></a>Train
163
+
164
+ First download a [SAM-B model](https://drive.google.com/file/d/1CtcyOm4h9bXgBF8DEVWn3N7g9-3r4Xzz/view?usp=sharing) into 'checkpoints/' as the base model.
165
+
166
+ ### Step1: Embedding Pruning + Bottleneck Aligning ###
167
+ The model after step1 is saved as 'checkpoints/vit_b_slim_step1_.pth'
168
+
169
+ ```
170
+ CUDA_VISIBLE_DEVICES=0 python prune_distill_step1.py --traindata_path <train_data_root> --valdata_path <val_data_root> --prune_ratio <pruning ratio> --epochs <training epochs>
171
+ ```
172
+
173
+ ### Step2: Bottleneck Pruning + Embedding Aligning ###
174
+ The model after step2 is saved as 'checkpoints/vit_b_slim_step2_.pth'
175
+
176
+ ```
177
+ CUDA_VISIBLE_DEVICES=0 python prune_distill_step2.py --traindata_path <train_data_root> --valdata_path <val_data_root> --prune_ratio <pruning ratio> --epochs <training epochs> --model_path 'checkpoints/vit_b_slim_step1_.pth'
178
+
179
+ ```
180
+
181
+ You can adjust the training settings to meet your specific requirements. While our method demonstrates impressive performance with just 10,000 training data, incorporating additional training data will further enhance the model's effectiveness
182
+
183
+ ## BibTex of our SlimSAM
184
+ If you use SlimSAM in your research, please use the following BibTeX entry. Thank you!
185
+
186
+ ```bibtex
187
+ @misc{chen202301,
188
+ title={0.1% Data Makes Segment Anything Slim},
189
+ author={Zigeng Chen and Gongfan Fang and Xinyin Ma and Xinchao Wang},
190
+ year={2023},
191
+ eprint={2312.05284},
192
+ archivePrefix={arXiv},
193
+ primaryClass={cs.CV}
194
+ }
195
+ ```
196
+
197
+ ## Acknowledgement
198
+
199
+ <details>
200
+ <summary>
201
+ <a href="https://github.com/facebookresearch/segment-anything">SAM</a> (Segment Anything) [<b>bib</b>]
202
+ </summary>
203
+
204
+ ```bibtex
205
+ @article{kirillov2023segany,
206
+ title={Segment Anything},
207
+ author={Kirillov, Alexander and Mintun, Eric and Ravi, Nikhila and Mao, Hanzi and Rolland, Chloe and Gustafson, Laura and Xiao, Tete and Whitehead, Spencer and Berg, Alexander C. and Lo, Wan-Yen and Doll{\'a}r, Piotr and Girshick, Ross},
208
+ journal={arXiv:2304.02643},
209
+ year={2023}
210
+ }
211
+ ```
212
+ </details>
213
+
214
+
215
+
216
+ <details>
217
+ <summary>
218
+ <a href="https://github.com/VainF/Torch-Pruning">Torch Pruning</a> (DepGraph: Towards Any Structural Pruning) [<b>bib</b>]
219
+ </summary>
220
+
221
+ ```bibtex
222
+ @inproceedings{fang2023depgraph,
223
+ title={Depgraph: Towards any structural pruning},
224
+ author={Fang, Gongfan and Ma, Xinyin and Song, Mingli and Mi, Michael Bi and Wang, Xinchao},
225
+ booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
226
+ pages={16091--16101},
227
+ year={2023}
228
+ }
229
+ ```
230
+ </details>