Spaces:
Running
on
Zero
Running
on
Zero
Migrated from GitHub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +7 -0
- LICENSE +201 -0
- ORIGINAL_README.md +166 -0
- assets/images/teaser.jpg +0 -0
- assets/videos/apt_exp_1_all.gif +3 -0
- assets/videos/apt_exp_2_all.gif +3 -0
- assets/videos/baodao_exp_1_all.gif +3 -0
- assets/videos/exp_1.gif +3 -0
- assets/videos/exp_2.gif +3 -0
- assets/videos/gf_exp1.gif +3 -0
- assets/videos/gf_exp1.mp4 +3 -0
- demo.ipynb +0 -0
- demo.py +98 -0
- demo/demo.py +98 -0
- demo/requirements.txt +10 -0
- projects/glamm/datasets/__init__.py +7 -0
- projects/glamm/datasets/collate_fns/glamm_collate_fn.py +136 -0
- projects/glamm/datasets/gcg_dataset.py +349 -0
- projects/glamm/datasets/refcoco_segm_dataset.py +195 -0
- projects/glamm/datasets/region_level_dataset.py +297 -0
- projects/glamm/datasets/semantic_seg_dataset.py +424 -0
- projects/glamm/datasets/utils/ade20k_classes.json +30 -0
- projects/glamm/datasets/utils/cocostuff_classes.txt +183 -0
- projects/glamm/datasets/utils/utils.py +131 -0
- projects/glamm/models/glamm.py +183 -0
- projects/glamm/models/region_encoder.py +359 -0
- projects/glamm/utils.py +280 -0
- projects/llava_sam2/configs/sa2va_4b.py +548 -0
- projects/llava_sam2/datasets/ChatUniVi_Dataset.py +389 -0
- projects/llava_sam2/datasets/GCG_Dataset.py +375 -0
- projects/llava_sam2/datasets/Grand_Dataset.py +241 -0
- projects/llava_sam2/datasets/MeVIS_Dataset.py +5 -0
- projects/llava_sam2/datasets/Osprey_Dataset.py +463 -0
- projects/llava_sam2/datasets/ReSAM2_Dataset.py +489 -0
- projects/llava_sam2/datasets/ReVOS_Dataset.py +602 -0
- projects/llava_sam2/datasets/RefCOCO_Dataset.py +338 -0
- projects/llava_sam2/datasets/RefYoutubeVOS_Dataset.py +47 -0
- projects/llava_sam2/datasets/__init__.py +15 -0
- projects/llava_sam2/datasets/collect_fns.py +206 -0
- projects/llava_sam2/datasets/encode_fn.py +144 -0
- projects/llava_sam2/datasets/gcg_process.py +297 -0
- projects/llava_sam2/datasets/grand_process.py +110 -0
- projects/llava_sam2/datasets/utils.py +58 -0
- projects/llava_sam2/datasets/vqa_dataset.py +509 -0
- projects/llava_sam2/deepspeed_zero2_sam2.json +24 -0
- projects/llava_sam2/gradio/app.py +151 -0
- projects/llava_sam2/gradio/app_utils.py +293 -0
- projects/llava_sam2/models/__init__.py +3 -0
- projects/llava_sam2/models/extension/__init__.py +1 -0
- projects/llava_sam2/models/extension/sam2_base.py +281 -0
.gitattributes
CHANGED
@@ -33,3 +33,10 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/videos/apt_exp_1_all.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/videos/apt_exp_2_all.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/videos/baodao_exp_1_all.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/videos/exp_1.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/videos/exp_2.gif filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/videos/gf_exp1.gif filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/videos/gf_exp1.mp4 filter=lfs diff=lfs merge=lfs -text
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
ORIGINAL_README.md
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos
|
2 |
+
|
3 |
+
[\[🏠 Sa2VA\]](https://lxtgh.github.io/project/sa2va) [\[📜 arXiv\]](https://arxiv.org/abs/2501.04001) [\[🤗 HuggingFace\]](https://huggingface.co/collections/ByteDance/sa2va-model-zoo-677e3084d71b5f108d00e093) [\[🎥 Introduction\]]() [\[🧑💻 GitHub\]](https://github.com/magic-research/Sa2VA) [\[Online Demo (Sa2VA-4B)\]](https://5512470799b6b35fbc.gradio.live/)
|
4 |
+
|
5 |
+
|
6 |
+
[**Haobo Yuan**](https://yuanhaobo.me/)<sup>1*</sup> · [**Xiangtai Li**](https://scholar.google.com/citations?user=NmHgX-wAAAAJ)<sup>2*†</sup> · [**Tao Zhang**](https://zhang-tao-whu.github.io/)<sup>2,3*</sup> · [**Zilong Huang**](http://speedinghzl.github.io/)<sup>2</sup> · [**Shilin Xu**](https://xushilin1.github.io/)<sup>4</sup> ·[**Shunping Ji**](https://scholar.google.com/citations?user=FjoRmF4AAAAJ&hl=en)<sup>3</sup> ·[**Yunhai Tong**](https://scholar.google.com/citations?user=T4gqdPkAAAAJ&hl=zh-CN)<sup>4</sup> ·
|
7 |
+
|
8 |
+
[**Lu Qi**](https://luqi.info/)<sup>2</sup> · [**Jiashi Feng**](https://sites.google.com/site/jshfeng/)<sup>2</sup> · [**Ming-Hsuan Yang**](https://faculty.ucmerced.edu/mhyang/)<sup>1</sup>
|
9 |
+
|
10 |
+
<sup>1</sup>UC Merced    <sup>2</sup>ByteDance Seed    <sup>3</sup>WHU    <sup>4</sup>PKU
|
11 |
+
|
12 |
+
† project lead * the first three authors equally contribute to the work.
|
13 |
+
|
14 |
+
![Teaser](assets/images/teaser.jpg)
|
15 |
+
|
16 |
+
## Overiew
|
17 |
+
This repository contains the code for the paper "Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos".
|
18 |
+
|
19 |
+
Sa2VA is the the first unified model for dense grounded understanding of both images and videos. Unlike existing multi-modal large language models, which are often limited to specific modalities and tasks, Sa2VA supports a wide range of image and video tasks, including referring segmentation and conversation, with minimal one-shot instruction tuning. Sa2VA combines SAM-2, a foundation video segmentation model, with LLaVA, an advanced vision-language model, and unifies text, image, and video into a shared LLM token space.
|
20 |
+
|
21 |
+
## Model Zoo
|
22 |
+
We provide the following models:
|
23 |
+
| Model Name | Base MLLM | Language Part | HF Link |
|
24 |
+
|:----------:|:-----------------------------------------------------------------:|:-----------------------------------------------------------------------------:|:----------------------------------------------------:|
|
25 |
+
| Sa2VA-1B | [InternVL2.0-1B](https://huggingface.co/OpenGVLab/InternVL2-1B) | [Qwen2-0.5B-Instruct](https://huggingface.co/Qwen/Qwen2-0.5B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-1B) |
|
26 |
+
| Sa2VA-4B | [InternVL2.5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B) | [Qwen2.5-3B-Instruct](https://huggingface.co/Qwen/Qwen2.5-3B-Instruct) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-4B) |
|
27 |
+
| Sa2VA-8B | [InternVL2.5-8B](https://huggingface.co/OpenGVLab/InternVL2_5-8B) | [internlm2_5-7b-chat](https://huggingface.co/internlm/internlm2_5-7b-chat) | [🤗 link](https://huggingface.co/ByteDance/Sa2VA-8B) |
|
28 |
+
|
29 |
+
## Gradio Demos
|
30 |
+
|
31 |
+
We provide a script that implements interactive chat using gradio, which requires installing `gradio==4.42.0`. You can try it to quickly build a chat interface locally.
|
32 |
+
```shell
|
33 |
+
PYTHONPATH=. python projects/llava_sam2/gradio/app.py ByteDance/Sa2VA-4B
|
34 |
+
```
|
35 |
+
|
36 |
+
## Quick Start
|
37 |
+
|
38 |
+
Our Sa2VA model is available on 🤗HuggingFace. With very few steps, you can try it with your own data. You can install the `demo/requirements.txt` to avoid training-only packages.
|
39 |
+
|
40 |
+
|
41 |
+
**Option1 - scripts:**
|
42 |
+
|
43 |
+
Supposing you have a folder (`PATH_TO_FOLDER`) that contains images of a video, you can use the following script to chat with the Sa2VA model or segment the objects in the videos.
|
44 |
+
|
45 |
+
```bash
|
46 |
+
> cd scripts
|
47 |
+
> python demo.py PATH_TO_FOLDER --model_path ByteDance/Sa2VA-8B --work-dir OUTPUT_DIR --text "<image>Please describe the video content."
|
48 |
+
```
|
49 |
+
|
50 |
+
If the output contains the segmentation results, the results will be saved to `OUTPUT_DIR`.
|
51 |
+
|
52 |
+
**Option2 - Jupter Notebook:**
|
53 |
+
|
54 |
+
Please refer to `demo.ipynb`.
|
55 |
+
|
56 |
+
## Demo
|
57 |
+
|
58 |
+
<details open>
|
59 |
+
<summary>Demo 1</summary>
|
60 |
+
Input Video (Source: La La Land 2016):
|
61 |
+
|
62 |
+
![Error](assets/videos/exp_1.gif)
|
63 |
+
|
64 |
+
Instruction: "Please segment the girl wearing the yellow dress."
|
65 |
+
</details>
|
66 |
+
|
67 |
+
<details open>
|
68 |
+
<summary>Demo 2</summary>
|
69 |
+
Input Video (Source: La La Land 2016):
|
70 |
+
|
71 |
+
![Error](assets/videos/exp_2.gif)
|
72 |
+
|
73 |
+
Instruction: "Please segment the main character."
|
74 |
+
</details>
|
75 |
+
|
76 |
+
|
77 |
+
<details open>
|
78 |
+
<summary>Demo 3</summary>
|
79 |
+
Input Video (Source: Internet):
|
80 |
+
|
81 |
+
![Error](assets/videos/apt_exp_1_all.gif)
|
82 |
+
|
83 |
+
Instruction: "Please segment the person wearing sun glasses."
|
84 |
+
</details>
|
85 |
+
|
86 |
+
|
87 |
+
<details open>
|
88 |
+
<summary>Demo 4</summary>
|
89 |
+
Input Video (Source: Internet):
|
90 |
+
|
91 |
+
![Error](assets/videos/apt_exp_2_all.gif)
|
92 |
+
|
93 |
+
Instruction: "Instruction: "Please segment the singing girl."
|
94 |
+
</details>
|
95 |
+
|
96 |
+
<details open>
|
97 |
+
<summary>Demo 5</summary>
|
98 |
+
Input Video:
|
99 |
+
|
100 |
+
![Error](assets/videos/gf_exp1.gif)
|
101 |
+
|
102 |
+
Instruction: "What is the atmosphere of the scene?"
|
103 |
+
|
104 |
+
Answer: "The scene has a dark and mysterious atmosphere, with the men dressed in suits and ties, and the dimly lit room."
|
105 |
+
</details>
|
106 |
+
|
107 |
+
|
108 |
+
## Training
|
109 |
+
<details open>
|
110 |
+
<summary>Installation</summary>
|
111 |
+
|
112 |
+
1. Please install the python and pytorch first:
|
113 |
+
```bash
|
114 |
+
> conda create -n vlm python=3.10
|
115 |
+
> conda activate vlm
|
116 |
+
> conda install pytorch==2.3.1 torchvision==0.18.1 pytorch-cuda=12.1 cuda -c pytorch -c "nvidia/label/cuda-12.1.0" -c "nvidia/label/cuda-12.1.1"
|
117 |
+
```
|
118 |
+
|
119 |
+
2. Install mmcv:
|
120 |
+
```bash
|
121 |
+
> pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu121/torch2.3/index.html
|
122 |
+
```
|
123 |
+
|
124 |
+
3. Install other dependencies:
|
125 |
+
```bash
|
126 |
+
> pip install -r requirements.txt
|
127 |
+
```
|
128 |
+
</details>
|
129 |
+
|
130 |
+
<details open>
|
131 |
+
<summary>Pretrained Model Preparation</summary>
|
132 |
+
|
133 |
+
You are expected to download the following pretrained models and place them in the `./pretrained` directory:
|
134 |
+
- [sam2_hiera_large.pt](https://huggingface.co/facebook/sam2-hiera-large)
|
135 |
+
- [InternVL2_5-4B](https://huggingface.co/OpenGVLab/InternVL2_5-4B)
|
136 |
+
|
137 |
+
</details>
|
138 |
+
|
139 |
+
<details open>
|
140 |
+
<summary>Data Preparation</summary>
|
141 |
+
|
142 |
+
(TODO) Please download the training datasets and place them in the `data` directory. The download link is [here](https://huggingface.co/datasets/Dense-World/Sa2VA-Training).
|
143 |
+
|
144 |
+
</details>
|
145 |
+
|
146 |
+
|
147 |
+
<details open>
|
148 |
+
<summary>Training Script</summary>
|
149 |
+
|
150 |
+
Please run the following script to train:
|
151 |
+
```bash
|
152 |
+
> bash tools/dist.sh train projects/llava_sam2/configs/sa2va_4b.py 8
|
153 |
+
```
|
154 |
+
</details>
|
155 |
+
|
156 |
+
|
157 |
+
## References
|
158 |
+
If you find this repository useful, please consider referring the following paper:
|
159 |
+
```
|
160 |
+
@article{sa2va,
|
161 |
+
title={Sa2VA: Marrying SAM2 with LLaVA for Dense Grounded Understanding of Images and Videos},
|
162 |
+
author={Yuan, Haobo and Li, Xiangtai and Zhang, Tao and Huang, Zilong and Xu, Shilin and Ji, Shunping and Tong, Yunhai and Qi, Lu and Feng, Jiashi and Yang, Ming-Hsuan},
|
163 |
+
journal={arXiv},
|
164 |
+
year={2025}
|
165 |
+
}
|
166 |
+
```
|
assets/images/teaser.jpg
ADDED
assets/videos/apt_exp_1_all.gif
ADDED
Git LFS Details
|
assets/videos/apt_exp_2_all.gif
ADDED
Git LFS Details
|
assets/videos/baodao_exp_1_all.gif
ADDED
Git LFS Details
|
assets/videos/exp_1.gif
ADDED
Git LFS Details
|
assets/videos/exp_2.gif
ADDED
Git LFS Details
|
assets/videos/gf_exp1.gif
ADDED
Git LFS Details
|
assets/videos/gf_exp1.mp4
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:272f4246fbb62aa690811e01d5f8aecaac3d157cc01a9859de79675ee5d4f7cf
|
3 |
+
size 15332128
|
demo.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
demo.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
try:
|
9 |
+
from mmengine.visualization import Visualizer
|
10 |
+
except ImportError:
|
11 |
+
Visualizer = None
|
12 |
+
print("Warning: mmengine is not installed, visualization is disabled.")
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
|
17 |
+
parser.add_argument('image_folder', help='Path to image file')
|
18 |
+
parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
|
19 |
+
parser.add_argument('--work-dir', default=None, help='The dir to save results.')
|
20 |
+
parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
|
21 |
+
parser.add_argument('--select', type=int, default=-1)
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
def visualize(pred_mask, image_path, work_dir):
|
27 |
+
visualizer = Visualizer()
|
28 |
+
img = cv2.imread(image_path)
|
29 |
+
visualizer.set_image(img)
|
30 |
+
visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
|
31 |
+
visual_result = visualizer.get_image()
|
32 |
+
|
33 |
+
output_path = os.path.join(work_dir, os.path.basename(image_path))
|
34 |
+
cv2.imwrite(output_path, visual_result)
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
cfg = parse_args()
|
38 |
+
model_path = cfg.model_path
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_path,
|
41 |
+
torch_dtype="auto",
|
42 |
+
device_map="auto",
|
43 |
+
trust_remote_code=True
|
44 |
+
)
|
45 |
+
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
47 |
+
model_path,
|
48 |
+
trust_remote_code=True
|
49 |
+
)
|
50 |
+
|
51 |
+
image_files = []
|
52 |
+
image_paths = []
|
53 |
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
|
54 |
+
for filename in sorted(list(os.listdir(cfg.image_folder))):
|
55 |
+
if os.path.splitext(filename)[1].lower() in image_extensions:
|
56 |
+
image_files.append(filename)
|
57 |
+
image_paths.append(os.path.join(cfg.image_folder, filename))
|
58 |
+
|
59 |
+
vid_frames = []
|
60 |
+
for img_path in image_paths:
|
61 |
+
img = Image.open(img_path).convert('RGB')
|
62 |
+
vid_frames.append(img)
|
63 |
+
|
64 |
+
|
65 |
+
if cfg.select > 0:
|
66 |
+
img_frame = vid_frames[cfg.select - 1]
|
67 |
+
|
68 |
+
print(f"Selected frame {cfg.select}")
|
69 |
+
print(f"The input is:\n{cfg.text}")
|
70 |
+
result = model.predict_forward(
|
71 |
+
image=img_frame,
|
72 |
+
text=cfg.text,
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
print(f"The input is:\n{cfg.text}")
|
77 |
+
result = model.predict_forward(
|
78 |
+
video=vid_frames,
|
79 |
+
text=cfg.text,
|
80 |
+
tokenizer=tokenizer,
|
81 |
+
)
|
82 |
+
|
83 |
+
prediction = result['prediction']
|
84 |
+
print(f"The output is:\n{prediction}")
|
85 |
+
|
86 |
+
if '[SEG]' in prediction and Visualizer is not None:
|
87 |
+
_seg_idx = 0
|
88 |
+
pred_masks = result['prediction_masks'][_seg_idx]
|
89 |
+
for frame_idx in range(len(vid_frames)):
|
90 |
+
pred_mask = pred_masks[frame_idx]
|
91 |
+
if cfg.work_dir:
|
92 |
+
os.makedirs(cfg.work_dir, exist_ok=True)
|
93 |
+
visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
|
94 |
+
else:
|
95 |
+
os.makedirs('./temp_visualize_results', exist_ok=True)
|
96 |
+
visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
|
97 |
+
else:
|
98 |
+
pass
|
demo/demo.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
6 |
+
|
7 |
+
import cv2
|
8 |
+
try:
|
9 |
+
from mmengine.visualization import Visualizer
|
10 |
+
except ImportError:
|
11 |
+
Visualizer = None
|
12 |
+
print("Warning: mmengine is not installed, visualization is disabled.")
|
13 |
+
|
14 |
+
|
15 |
+
def parse_args():
|
16 |
+
parser = argparse.ArgumentParser(description='Video Reasoning Segmentation')
|
17 |
+
parser.add_argument('image_folder', help='Path to image file')
|
18 |
+
parser.add_argument('--model_path', default="ByteDance/Sa2VA-8B")
|
19 |
+
parser.add_argument('--work-dir', default=None, help='The dir to save results.')
|
20 |
+
parser.add_argument('--text', type=str, default="<image>Please describe the video content.")
|
21 |
+
parser.add_argument('--select', type=int, default=-1)
|
22 |
+
args = parser.parse_args()
|
23 |
+
return args
|
24 |
+
|
25 |
+
|
26 |
+
def visualize(pred_mask, image_path, work_dir):
|
27 |
+
visualizer = Visualizer()
|
28 |
+
img = cv2.imread(image_path)
|
29 |
+
visualizer.set_image(img)
|
30 |
+
visualizer.draw_binary_masks(pred_mask, colors='g', alphas=0.4)
|
31 |
+
visual_result = visualizer.get_image()
|
32 |
+
|
33 |
+
output_path = os.path.join(work_dir, os.path.basename(image_path))
|
34 |
+
cv2.imwrite(output_path, visual_result)
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
cfg = parse_args()
|
38 |
+
model_path = cfg.model_path
|
39 |
+
model = AutoModelForCausalLM.from_pretrained(
|
40 |
+
model_path,
|
41 |
+
torch_dtype="auto",
|
42 |
+
device_map="auto",
|
43 |
+
trust_remote_code=True
|
44 |
+
)
|
45 |
+
|
46 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
47 |
+
model_path,
|
48 |
+
trust_remote_code=True
|
49 |
+
)
|
50 |
+
|
51 |
+
image_files = []
|
52 |
+
image_paths = []
|
53 |
+
image_extensions = {".jpg", ".jpeg", ".png", ".bmp", ".gif", ".tiff"}
|
54 |
+
for filename in sorted(list(os.listdir(cfg.image_folder))):
|
55 |
+
if os.path.splitext(filename)[1].lower() in image_extensions:
|
56 |
+
image_files.append(filename)
|
57 |
+
image_paths.append(os.path.join(cfg.image_folder, filename))
|
58 |
+
|
59 |
+
vid_frames = []
|
60 |
+
for img_path in image_paths:
|
61 |
+
img = Image.open(img_path).convert('RGB')
|
62 |
+
vid_frames.append(img)
|
63 |
+
|
64 |
+
|
65 |
+
if cfg.select > 0:
|
66 |
+
img_frame = vid_frames[cfg.select - 1]
|
67 |
+
|
68 |
+
print(f"Selected frame {cfg.select}")
|
69 |
+
print(f"The input is:\n{cfg.text}")
|
70 |
+
result = model.predict_forward(
|
71 |
+
image=img_frame,
|
72 |
+
text=cfg.text,
|
73 |
+
tokenizer=tokenizer,
|
74 |
+
)
|
75 |
+
else:
|
76 |
+
print(f"The input is:\n{cfg.text}")
|
77 |
+
result = model.predict_forward(
|
78 |
+
video=vid_frames,
|
79 |
+
text=cfg.text,
|
80 |
+
tokenizer=tokenizer,
|
81 |
+
)
|
82 |
+
|
83 |
+
prediction = result['prediction']
|
84 |
+
print(f"The output is:\n{prediction}")
|
85 |
+
|
86 |
+
if '[SEG]' in prediction and Visualizer is not None:
|
87 |
+
_seg_idx = 0
|
88 |
+
pred_masks = result['prediction_masks'][_seg_idx]
|
89 |
+
for frame_idx in range(len(vid_frames)):
|
90 |
+
pred_mask = pred_masks[frame_idx]
|
91 |
+
if cfg.work_dir:
|
92 |
+
os.makedirs(cfg.work_dir, exist_ok=True)
|
93 |
+
visualize(pred_mask, image_paths[frame_idx], cfg.work_dir)
|
94 |
+
else:
|
95 |
+
os.makedirs('./temp_visualize_results', exist_ok=True)
|
96 |
+
visualize(pred_mask, image_paths[frame_idx], './temp_visualize_results')
|
97 |
+
else:
|
98 |
+
pass
|
demo/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.1
|
2 |
+
torchvision==0.18.1
|
3 |
+
transformers==4.42.3
|
4 |
+
opencv-python-headless<4.10
|
5 |
+
peft<0.14.0
|
6 |
+
timm==1.0.9
|
7 |
+
einops==0.8.0
|
8 |
+
flash_attn
|
9 |
+
sentencepiece==0.2.0
|
10 |
+
mmengine<1
|
projects/glamm/datasets/__init__.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .semantic_seg_dataset import SemanticSegDataset, ADE20kSemanticSegDataset, \
|
2 |
+
COCOStuffSemanticSegDataset, PascalPartSemanticSegDataset, PacoSemanticSegDataset
|
3 |
+
from .gcg_dataset import GCGDataset, GranDfGCGDataset, RefCOCOgGCGDataset, OpenPsgGCGDataset, Flickr30kGCGDataset
|
4 |
+
from .region_level_dataset import RefCocoGRegionDataset, VisualGenomeRegionDataset
|
5 |
+
from .refcoco_segm_dataset import ReferSegmDataset
|
6 |
+
from .utils.utils import *
|
7 |
+
from .collate_fns.glamm_collate_fn import glamm_collate_fn
|
projects/glamm/datasets/collate_fns/glamm_collate_fn.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Sequence
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from torch.nn.utils.rnn import pad_sequence
|
5 |
+
|
6 |
+
from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
|
7 |
+
pad_for_sequence_parallel)
|
8 |
+
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
|
9 |
+
|
10 |
+
|
11 |
+
def glamm_collate_fn(instances: Sequence[Dict],
|
12 |
+
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
|
13 |
+
return_hf_format: bool = False,
|
14 |
+
use_varlen_attn: bool = False):
|
15 |
+
seq_parallel_world_size = get_sequence_parallel_world_size()
|
16 |
+
|
17 |
+
input_ids, labels = [], []
|
18 |
+
has_image = any(inst.get('pixel_values') is not None for inst in instances)
|
19 |
+
has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
|
20 |
+
has_mask = any(inst.get('masks') is not None for inst in instances)
|
21 |
+
has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
|
22 |
+
has_points = any(inst.get('points') is not None for inst in instances)
|
23 |
+
|
24 |
+
if use_varlen_attn:
|
25 |
+
position_ids, cumulative_len = [], []
|
26 |
+
assert len(instances) == 1, (
|
27 |
+
f'If utilizing varlen attention, the batch size should be'
|
28 |
+
f' set to 1, but got {len(instances)}')
|
29 |
+
assert not has_image, 'Currently, it is not configured to '
|
30 |
+
'accommodate the use of varlen Attention in multimodal training'
|
31 |
+
|
32 |
+
if has_image:
|
33 |
+
pixel_values = []
|
34 |
+
if has_grounding_image:
|
35 |
+
grounding_pixel_values = []
|
36 |
+
if has_mask:
|
37 |
+
object_masks = []
|
38 |
+
if has_bboxes:
|
39 |
+
object_bboxes = []
|
40 |
+
if has_points:
|
41 |
+
prompt_points = []
|
42 |
+
|
43 |
+
for example in instances:
|
44 |
+
input_ids.append(torch.LongTensor(example['input_ids']))
|
45 |
+
labels.append(torch.LongTensor(example['labels']))
|
46 |
+
if use_varlen_attn:
|
47 |
+
cumulative_len.append(torch.IntTensor(example['cumulative_len']))
|
48 |
+
position_ids.append(torch.LongTensor(example['position_ids']))
|
49 |
+
|
50 |
+
if has_image:
|
51 |
+
pixel_values.append(example['pixel_values'])
|
52 |
+
if has_grounding_image:
|
53 |
+
grounding_pixel_values.append(example['g_pixel_values'])
|
54 |
+
if has_mask:
|
55 |
+
if 'masks' in example.keys() and example['masks'] is not None:
|
56 |
+
object_masks.append(example['masks'])
|
57 |
+
if has_bboxes:
|
58 |
+
if 'bboxes' in example.keys() and example['bboxes'] is not None:
|
59 |
+
object_bboxes.append(example['bboxes'])
|
60 |
+
if has_points:
|
61 |
+
if 'points' in example.keys() and example['points'] is not None:
|
62 |
+
prompt_points.append(example['points'])
|
63 |
+
|
64 |
+
ori_length = [len(ids) for ids in input_ids]
|
65 |
+
if len(instances) > 1:
|
66 |
+
input_ids = pad_sequence(
|
67 |
+
input_ids, batch_first=True, padding_value=pad_index)
|
68 |
+
labels = pad_sequence(
|
69 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX)
|
70 |
+
else:
|
71 |
+
input_ids = torch.stack(input_ids)
|
72 |
+
labels = torch.stack(labels)
|
73 |
+
|
74 |
+
if use_varlen_attn:
|
75 |
+
assert input_ids.size(1) % seq_parallel_world_size == 0
|
76 |
+
attention_mask = None
|
77 |
+
position_ids = torch.stack(position_ids, dim=0)
|
78 |
+
else:
|
79 |
+
# Some tokenizers have the same eos token and pad token, so input_ids
|
80 |
+
# cannot be masked directly based on the pad token id.
|
81 |
+
attention_mask = torch.zeros_like(input_ids).bool()
|
82 |
+
for i, length in enumerate(ori_length):
|
83 |
+
attention_mask[i, :length] = True
|
84 |
+
|
85 |
+
bs, seq_len = input_ids.shape
|
86 |
+
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
|
87 |
+
|
88 |
+
if seq_parallel_world_size > 1:
|
89 |
+
input_ids = pad_for_sequence_parallel(input_ids, pad_index)
|
90 |
+
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
|
91 |
+
position_ids = pad_for_sequence_parallel(position_ids, 0)
|
92 |
+
if attention_mask is not None:
|
93 |
+
attention_mask = pad_for_sequence_parallel(attention_mask, 0)
|
94 |
+
|
95 |
+
if use_varlen_attn:
|
96 |
+
max_seqlen = (
|
97 |
+
cumulative_len[0][1:] - # noqa: W504
|
98 |
+
cumulative_len[0][:-1]).max().item()
|
99 |
+
data_dict = {
|
100 |
+
'input_ids': input_ids,
|
101 |
+
'cumulative_len': cumulative_len,
|
102 |
+
'position_ids': position_ids,
|
103 |
+
'labels': labels,
|
104 |
+
'max_seqlen': max_seqlen
|
105 |
+
}
|
106 |
+
else:
|
107 |
+
data_dict = {
|
108 |
+
'input_ids': input_ids,
|
109 |
+
'attention_mask': attention_mask,
|
110 |
+
'position_ids': position_ids,
|
111 |
+
'labels': labels
|
112 |
+
}
|
113 |
+
|
114 |
+
if has_image:
|
115 |
+
if all(x.shape == pixel_values[0].shape for x in pixel_values):
|
116 |
+
pixel_values = torch.stack(pixel_values, dim=0)
|
117 |
+
data_dict['pixel_values'] = pixel_values
|
118 |
+
|
119 |
+
if has_grounding_image:
|
120 |
+
# if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
|
121 |
+
# grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
|
122 |
+
data_dict['g_pixel_values'] = grounding_pixel_values
|
123 |
+
|
124 |
+
if has_mask:
|
125 |
+
data_dict['masks'] = object_masks
|
126 |
+
|
127 |
+
if has_bboxes:
|
128 |
+
data_dict['bboxes'] = object_bboxes
|
129 |
+
|
130 |
+
if has_points:
|
131 |
+
data_dict['points'] = prompt_points
|
132 |
+
|
133 |
+
if return_hf_format:
|
134 |
+
return data_dict
|
135 |
+
else:
|
136 |
+
return {'data': data_dict, 'data_samples': None}
|
projects/glamm/datasets/gcg_dataset.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from mmengine import print_log
|
10 |
+
from mmengine.config import Config, ConfigDict
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from pycocotools.coco import COCO
|
16 |
+
from pycocotools import mask as mask_utils
|
17 |
+
|
18 |
+
from xtuner.registry import BUILDER
|
19 |
+
|
20 |
+
from xtuner.dataset.utils import encode_fn
|
21 |
+
from xtuner.dataset.map_fns import llava_map_fn
|
22 |
+
|
23 |
+
from projects.glamm.datasets.utils.utils import expand2square
|
24 |
+
|
25 |
+
from projects.glamm.datasets.utils.utils import GCG_QUESTIONS, ANSWER_LIST
|
26 |
+
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
27 |
+
class GCGDataset(Dataset):
|
28 |
+
def __init__(self,
|
29 |
+
image_folder,
|
30 |
+
image_processor,
|
31 |
+
data_path=None,
|
32 |
+
tokenizer=None,
|
33 |
+
template_map_fn=None,
|
34 |
+
max_length=2048,
|
35 |
+
pad_image_to_square=False,
|
36 |
+
repeats=1,
|
37 |
+
num_classes_per_sample=3,
|
38 |
+
extra_image_processor=None):
|
39 |
+
super().__init__()
|
40 |
+
self.question_templates = GCG_QUESTIONS
|
41 |
+
if extra_image_processor is not None:
|
42 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
43 |
+
self.num_classes_per_sample = num_classes_per_sample
|
44 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
45 |
+
|
46 |
+
self.tokenizer.add_tokens(
|
47 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
48 |
+
)
|
49 |
+
reg_tokens = ['<bbox>', '<point>']
|
50 |
+
segmentation_tokens = ['[SEG]']
|
51 |
+
phrase_tokens = ['<p>', '</p>']
|
52 |
+
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
|
53 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
54 |
+
|
55 |
+
self.max_length = max_length
|
56 |
+
self.template_map_fn = BUILDER.build(template_map_fn)
|
57 |
+
|
58 |
+
self.text_data = self.json_file_preprocess(data_path, image_folder)
|
59 |
+
self.image_folder = image_folder
|
60 |
+
|
61 |
+
self.image_processor = BUILDER.build(image_processor)
|
62 |
+
size = self.image_processor.crop_size
|
63 |
+
|
64 |
+
if isinstance(size, dict):
|
65 |
+
self.image_w, self.image_h = size['width'], size['height']
|
66 |
+
elif isinstance(size, int):
|
67 |
+
self.image_h, self.image_w = size, size
|
68 |
+
else:
|
69 |
+
self.image_w, self.image_h = size
|
70 |
+
|
71 |
+
self.pad_image_to_square = pad_image_to_square
|
72 |
+
self.repeats = repeats
|
73 |
+
|
74 |
+
def json_file_preprocess(self, data_path, image_folder=None):
|
75 |
+
with open(data_path, 'r') as f:
|
76 |
+
json_data = json.load(f)
|
77 |
+
return json_data
|
78 |
+
|
79 |
+
@property
|
80 |
+
def modality_length(self):
|
81 |
+
length_list = []
|
82 |
+
for data_dict in self.text_data:
|
83 |
+
cur_len = 100
|
84 |
+
length_list.append(cur_len)
|
85 |
+
return length_list * self.repeats
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.text_data) * self.repeats
|
89 |
+
|
90 |
+
def real_len(self):
|
91 |
+
return len(self.text_data)
|
92 |
+
|
93 |
+
def _parse_annotations(self, ann_info):
|
94 |
+
image_path = os.path.join(self.image_folder, ann_info['file_name'])
|
95 |
+
image = Image.open(image_path).convert('RGB')
|
96 |
+
if hasattr(self, 'extra_image_processor'):
|
97 |
+
g_image = np.array(image) # for grounding
|
98 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
99 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
100 |
+
ann_info['g_pixel_values'] = g_pixel_values
|
101 |
+
|
102 |
+
width, height = image.size
|
103 |
+
if self.pad_image_to_square:
|
104 |
+
image = expand2square(
|
105 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
106 |
+
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
107 |
+
ann_info['pixel_values'] = image
|
108 |
+
|
109 |
+
caption = ann_info['caption'].strip('"').strip()
|
110 |
+
masks, phrases, tokens_positive = [], [], []
|
111 |
+
for word, grounding in ann_info["groundings"].items():
|
112 |
+
phrases.append(word)
|
113 |
+
tokens_positive.append(grounding["token_positives"])
|
114 |
+
|
115 |
+
# Convert segmentation to binary mask
|
116 |
+
binary_mask = np.zeros((height, width), dtype=np.uint8)
|
117 |
+
for rle in grounding["rle_masks"]:
|
118 |
+
m = mask_utils.decode(rle).astype(np.uint8)
|
119 |
+
binary_mask += m.squeeze()
|
120 |
+
masks.append(binary_mask)
|
121 |
+
|
122 |
+
def sort_by_start_index(items, order):
|
123 |
+
return [items[i] for i in order]
|
124 |
+
|
125 |
+
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
|
126 |
+
masks = sort_by_start_index(masks, phrase_order)
|
127 |
+
phrases = sort_by_start_index(phrases, phrase_order)
|
128 |
+
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
|
129 |
+
|
130 |
+
ann_info.update({
|
131 |
+
'image_path': image_path,
|
132 |
+
'caption': caption,
|
133 |
+
'masks': masks,
|
134 |
+
'phrases': phrases,
|
135 |
+
'tokens_positive': tokens_positive,
|
136 |
+
})
|
137 |
+
return ann_info
|
138 |
+
|
139 |
+
def create_conversation(self, caption, tokens_positive):
|
140 |
+
question = random.choice(self.question_templates).strip()
|
141 |
+
|
142 |
+
# Prepare caption with tags
|
143 |
+
def tag_caption(caption, tokens):
|
144 |
+
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
|
145 |
+
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
|
146 |
+
return caption
|
147 |
+
|
148 |
+
detailed_answer = tag_caption(caption, tokens_positive)
|
149 |
+
|
150 |
+
question = 'The <image> provides an overview of the picture.\n' + question
|
151 |
+
conversation = [{'input': question, 'output': detailed_answer}]
|
152 |
+
return conversation
|
153 |
+
|
154 |
+
def __getitem__(self, index):
|
155 |
+
index = index % self.real_len()
|
156 |
+
data_dict = {}
|
157 |
+
ann_info = copy.deepcopy(self.text_data[index])
|
158 |
+
ann_info = self._parse_annotations(ann_info)
|
159 |
+
|
160 |
+
data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values')
|
161 |
+
data_dict['pixel_values'] = ann_info.pop('pixel_values')
|
162 |
+
if len(ann_info['masks']) == 0:
|
163 |
+
return self.__getitem__(0)
|
164 |
+
data_dict['masks'] = torch.from_numpy(np.stack(ann_info['masks'], axis=0))
|
165 |
+
|
166 |
+
conversation = self.create_conversation(ann_info['caption'], ann_info['tokens_positive'])
|
167 |
+
data_dict['conversation'] = conversation
|
168 |
+
|
169 |
+
result = self.template_map_fn(data_dict)
|
170 |
+
data_dict.update(result)
|
171 |
+
|
172 |
+
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
|
173 |
+
data_dict.update(result)
|
174 |
+
|
175 |
+
return data_dict
|
176 |
+
|
177 |
+
class GranDfGCGDataset(GCGDataset):
|
178 |
+
pass
|
179 |
+
class RefCOCOgGCGDataset(GCGDataset):
|
180 |
+
def json_file_preprocess(self, data_path, image_folder=None):
|
181 |
+
with open(data_path, 'r') as f:
|
182 |
+
json_data = json.load(f)
|
183 |
+
return [list(line.values())[0] for line in json_data]
|
184 |
+
|
185 |
+
def _parse_annotations(self, ann_info):
|
186 |
+
image_path = os.path.join(self.image_folder, ann_info['img_file_name'])
|
187 |
+
image = Image.open(image_path).convert('RGB')
|
188 |
+
if hasattr(self, 'extra_image_processor'):
|
189 |
+
g_image = np.array(image) # for grounding
|
190 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
191 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
192 |
+
ann_info['g_pixel_values'] = g_pixel_values
|
193 |
+
|
194 |
+
width, height = image.size
|
195 |
+
if self.pad_image_to_square:
|
196 |
+
image = expand2square(
|
197 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
198 |
+
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
199 |
+
ann_info['pixel_values'] = image
|
200 |
+
|
201 |
+
caption = ann_info['caption'].strip('"').strip().lower()
|
202 |
+
masks, phrases, tokens_positive = [], [], []
|
203 |
+
for detail in ann_info['refs']:
|
204 |
+
phrase = detail['sentence']
|
205 |
+
if phrase.lower() in caption:
|
206 |
+
phrases.append(phrase)
|
207 |
+
index = caption.find(phrase)
|
208 |
+
end_index = index + len(phrase) if index != -1 else -1
|
209 |
+
tokens_positive.append([index, end_index])
|
210 |
+
|
211 |
+
binary_mask = np.zeros((height, width), dtype=np.uint8)
|
212 |
+
for seg in detail["segmentation"]:
|
213 |
+
rles = mask_utils.frPyObjects([seg], height, width)
|
214 |
+
m = mask_utils.decode(rles)
|
215 |
+
m = m.astype(np.uint8)
|
216 |
+
binary_mask += m.squeeze()
|
217 |
+
masks.append(binary_mask)
|
218 |
+
|
219 |
+
def sort_by_start_index(items, order):
|
220 |
+
return [items[i] for i in order]
|
221 |
+
|
222 |
+
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
|
223 |
+
masks = sort_by_start_index(masks, phrase_order)
|
224 |
+
phrases = sort_by_start_index(phrases, phrase_order)
|
225 |
+
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
|
226 |
+
|
227 |
+
ann_info.update({
|
228 |
+
'image_path': image_path,
|
229 |
+
'caption': caption,
|
230 |
+
'masks': masks,
|
231 |
+
'phrases': phrases,
|
232 |
+
'tokens_positive': tokens_positive,
|
233 |
+
})
|
234 |
+
return ann_info
|
235 |
+
|
236 |
+
class OpenPsgGCGDataset(GCGDataset):
|
237 |
+
pass
|
238 |
+
|
239 |
+
class Flickr30kGCGDataset(GCGDataset):
|
240 |
+
|
241 |
+
def json_file_preprocess(self, data_path, image_folder=None):
|
242 |
+
def filter_images(data_infos, min_size):
|
243 |
+
return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
|
244 |
+
|
245 |
+
self.coco = COCO(data_path)
|
246 |
+
self.image_ids = self.coco.getImgIds()
|
247 |
+
data_infos = []
|
248 |
+
total_ann_ids = []
|
249 |
+
removed_img_count = 0
|
250 |
+
for img_id in self.image_ids:
|
251 |
+
info = self.coco.loadImgs([img_id])[0]
|
252 |
+
if len(info['caption'].split(' ')) < 3:
|
253 |
+
removed_img_count += 1
|
254 |
+
continue
|
255 |
+
info['filename'] = info['file_name'].split('_')[-1]
|
256 |
+
info['height'] = int(info['height'])
|
257 |
+
info['width'] = int(info['width'])
|
258 |
+
data_infos.append(info)
|
259 |
+
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
|
260 |
+
total_ann_ids.extend(ann_ids)
|
261 |
+
assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
|
262 |
+
print(f'Removed {removed_img_count} images.')
|
263 |
+
data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
|
264 |
+
|
265 |
+
return data_infos
|
266 |
+
|
267 |
+
def _parse_annotations(self, img_info):
|
268 |
+
ann_ids = self.coco.getAnnIds(imgIds=img_info['id'])
|
269 |
+
ann_info = self.coco.loadAnns(ann_ids)
|
270 |
+
|
271 |
+
annotations = {'phrases': [], 'caption': img_info['caption'], 'masks': [], 'tokens_positive': []}
|
272 |
+
image_path = os.path.join(self.image_folder, img_info['file_name'])
|
273 |
+
image = Image.open(image_path).convert('RGB')
|
274 |
+
if hasattr(self, 'extra_image_processor'):
|
275 |
+
g_image = np.array(image) # for grounding
|
276 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
277 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
278 |
+
annotations['g_pixel_values'] = g_pixel_values
|
279 |
+
|
280 |
+
width, height = image.size
|
281 |
+
if self.pad_image_to_square:
|
282 |
+
image = expand2square(
|
283 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
284 |
+
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
285 |
+
annotations['pixel_values'] = image
|
286 |
+
|
287 |
+
for ann in ann_info:
|
288 |
+
if ann.get('ignore', False):
|
289 |
+
continue
|
290 |
+
x1, y1, w, h = ann['bbox']
|
291 |
+
inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
|
292 |
+
inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
|
293 |
+
if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
|
294 |
+
continue
|
295 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
296 |
+
tokens_positive = ann['tokens_positive']
|
297 |
+
phrase = [img_info['caption'][span[0]:span[1]] for span in tokens_positive]
|
298 |
+
annotations['phrases'].append(phrase[0])
|
299 |
+
annotations['tokens_positive'].append(tokens_positive[0])
|
300 |
+
|
301 |
+
rle = ann['sam_mask']
|
302 |
+
mask_decoded = mask_utils.decode(rle).astype(np.uint8)
|
303 |
+
annotations['masks'].append(mask_decoded)
|
304 |
+
|
305 |
+
def sort_by_start_index(items, order):
|
306 |
+
return [items[i] for i in order]
|
307 |
+
|
308 |
+
phrase_order = sorted(range(len(annotations['tokens_positive'])), key=lambda x: annotations['tokens_positive'][x][0])
|
309 |
+
annotations['masks'] = sort_by_start_index(annotations['masks'], phrase_order)
|
310 |
+
annotations['phrases'] = sort_by_start_index(annotations['phrases'], phrase_order)
|
311 |
+
annotations['tokens_positive'] = sort_by_start_index(annotations['tokens_positive'], phrase_order)
|
312 |
+
|
313 |
+
return annotations
|
314 |
+
|
315 |
+
if __name__ == '__main__':
|
316 |
+
from transformers import CLIPImageProcessor, AutoTokenizer
|
317 |
+
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
|
318 |
+
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
|
319 |
+
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
|
320 |
+
|
321 |
+
tokenizer = dict(
|
322 |
+
type=AutoTokenizer.from_pretrained,
|
323 |
+
pretrained_model_name_or_path=llm_name_or_path)
|
324 |
+
image_processor = dict(
|
325 |
+
type=CLIPImageProcessor.from_pretrained,
|
326 |
+
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
|
327 |
+
extra_image_processor = dict(
|
328 |
+
type=ResizeLongestSide,
|
329 |
+
target_length=1024,
|
330 |
+
)
|
331 |
+
from xtuner.utils.templates import PROMPT_TEMPLATE
|
332 |
+
prompt_template = PROMPT_TEMPLATE.vicuna
|
333 |
+
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
|
334 |
+
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
|
335 |
+
dataset = Flickr30kGCGDataset(
|
336 |
+
image_folder='data/flickr30k/flickr30k-images/',
|
337 |
+
image_processor=image_processor,
|
338 |
+
data_path='./data/GranDf/annotations/train/flickr_mergedGT_GCG_train.json',
|
339 |
+
tokenizer=tokenizer,
|
340 |
+
template_map_fn=dict(
|
341 |
+
type=template_map_fn_factory, template=prompt_template),
|
342 |
+
max_length=2048,
|
343 |
+
pad_image_to_square=True,
|
344 |
+
repeats=1,
|
345 |
+
num_classes_per_sample=3,
|
346 |
+
extra_image_processor=extra_image_processor)
|
347 |
+
|
348 |
+
for i in range(1000):
|
349 |
+
print(dataset[i])
|
projects/glamm/datasets/refcoco_segm_dataset.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from mmengine import print_log
|
10 |
+
from mmengine.config import Config, ConfigDict
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from pycocotools.coco import COCO
|
16 |
+
from pycocotools import mask as mask_utils
|
17 |
+
|
18 |
+
from xtuner.registry import BUILDER
|
19 |
+
|
20 |
+
from xtuner.dataset.utils import encode_fn
|
21 |
+
from xtuner.dataset.map_fns import llava_map_fn
|
22 |
+
|
23 |
+
from projects.glamm.datasets.utils.utils import expand2square
|
24 |
+
|
25 |
+
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
|
26 |
+
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
27 |
+
|
28 |
+
from third_parts.mmdet.datasets.refcoco import RefCocoDataset
|
29 |
+
|
30 |
+
|
31 |
+
class ReferSegmDataset(RefCocoDataset):
|
32 |
+
def __init__(self,
|
33 |
+
data_root,
|
34 |
+
ann_file=None,
|
35 |
+
split_file=None,
|
36 |
+
image_processor=None,
|
37 |
+
extra_image_processor=None,
|
38 |
+
data_prefix=dict(img_path='train2014/'),
|
39 |
+
tokenizer=None,
|
40 |
+
template_map_fn=None,
|
41 |
+
max_length=2048,
|
42 |
+
pad_image_to_square=False,
|
43 |
+
num_classes_per_sample=3):
|
44 |
+
super().__init__(
|
45 |
+
data_root=data_root,
|
46 |
+
data_prefix=data_prefix,
|
47 |
+
pipeline=None,
|
48 |
+
ann_file=ann_file,
|
49 |
+
split_file=split_file,
|
50 |
+
)
|
51 |
+
self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
|
52 |
+
|
53 |
+
self.question_templates = SEG_QUESTIONS
|
54 |
+
if extra_image_processor is not None:
|
55 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
56 |
+
self.num_classes_per_sample = num_classes_per_sample
|
57 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
58 |
+
|
59 |
+
self.tokenizer.add_tokens(
|
60 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
61 |
+
)
|
62 |
+
reg_tokens = ['<bbox>', '<point>']
|
63 |
+
segmentation_tokens = ['[SEG]']
|
64 |
+
phrase_tokens = ['<p>', '</p>']
|
65 |
+
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
|
66 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
67 |
+
|
68 |
+
self.max_length = max_length
|
69 |
+
self.template_map_fn = BUILDER.build(template_map_fn)
|
70 |
+
|
71 |
+
self.image_processor = BUILDER.build(image_processor)
|
72 |
+
size = self.image_processor.crop_size
|
73 |
+
if isinstance(size, dict):
|
74 |
+
self.image_w, self.image_h = size['width'], size['height']
|
75 |
+
self.pad_image_to_square = pad_image_to_square
|
76 |
+
|
77 |
+
@property
|
78 |
+
def modality_length(self):
|
79 |
+
import pickle
|
80 |
+
length_list = []
|
81 |
+
for idx in range(len(self)):
|
82 |
+
length_list.append(100)
|
83 |
+
# for idx in range(len(self)):
|
84 |
+
# if self.serialize_data:
|
85 |
+
# start_addr = 0 if idx == 0 else self.data_address[idx - 1].item()
|
86 |
+
# end_addr = self.data_address[idx].item()
|
87 |
+
# bytes = memoryview(
|
88 |
+
# self.data_bytes[start_addr:end_addr]) # type: ignore
|
89 |
+
# data_dict = pickle.loads(bytes)
|
90 |
+
# else:
|
91 |
+
# data_dict = copy.deepcopy(self.data_list[idx])
|
92 |
+
return length_list
|
93 |
+
|
94 |
+
def _parse_annotations(self, ann_info):
|
95 |
+
image_path = ann_info['img_path']
|
96 |
+
image = Image.open(image_path).convert('RGB')
|
97 |
+
if hasattr(self, 'extra_image_processor'):
|
98 |
+
g_image = np.array(image) # for grounding
|
99 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
100 |
+
g_pixel_values = torch.from_numpy(
|
101 |
+
g_image).permute(2, 0, 1).contiguous()
|
102 |
+
ann_info['g_pixel_values'] = g_pixel_values
|
103 |
+
|
104 |
+
width, height = image.size
|
105 |
+
if self.pad_image_to_square:
|
106 |
+
image = expand2square(
|
107 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
108 |
+
image = self.image_processor.preprocess(
|
109 |
+
image, return_tensors='pt')['pixel_values'][0]
|
110 |
+
ann_info['pixel_values'] = image
|
111 |
+
|
112 |
+
masks, phrases = [], []
|
113 |
+
instances, text = ann_info['instances'], ann_info['text']
|
114 |
+
index = np.random.choice(range(len(instances)), min(
|
115 |
+
len(instances), self.num_classes_per_sample))
|
116 |
+
for idx in index:
|
117 |
+
inst = instances[idx]
|
118 |
+
phrase = text[idx].lower()
|
119 |
+
phrases.append(phrase)
|
120 |
+
binary_mask = np.zeros((height, width), dtype=np.uint8)
|
121 |
+
for seg in inst["mask"]:
|
122 |
+
rles = mask_utils.frPyObjects([seg], height, width)
|
123 |
+
m = mask_utils.decode(rles)
|
124 |
+
m = m.astype(np.uint8)
|
125 |
+
binary_mask += m.squeeze()
|
126 |
+
masks.append(binary_mask)
|
127 |
+
|
128 |
+
ann_info.update({
|
129 |
+
'masks': masks,
|
130 |
+
'phrases': phrases,
|
131 |
+
})
|
132 |
+
return ann_info
|
133 |
+
|
134 |
+
def __getitem__(self, idx):
|
135 |
+
data_dict = {}
|
136 |
+
ann_info = super().__getitem__(idx)
|
137 |
+
ann_info = self._parse_annotations(ann_info)
|
138 |
+
|
139 |
+
data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values')
|
140 |
+
data_dict['pixel_values'] = ann_info.pop('pixel_values')
|
141 |
+
if len(ann_info['masks']) == 0:
|
142 |
+
return self.__getitem__(0)
|
143 |
+
data_dict['masks'] = torch.from_numpy(
|
144 |
+
np.stack(ann_info['masks'], axis=0))
|
145 |
+
|
146 |
+
conversation = []
|
147 |
+
for i, phrase in enumerate(ann_info['phrases']):
|
148 |
+
question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
|
149 |
+
conversation.append(
|
150 |
+
{'input': question, 'output': random.choice(ANSWER_LIST)})
|
151 |
+
|
152 |
+
data_dict['conversation'] = conversation
|
153 |
+
result = self.template_map_fn(data_dict)
|
154 |
+
data_dict.update(result)
|
155 |
+
|
156 |
+
result = encode_fn(data_dict, tokenizer=self.tokenizer,
|
157 |
+
max_length=self.max_length, with_image_token=True)
|
158 |
+
data_dict.update(result)
|
159 |
+
|
160 |
+
return data_dict
|
161 |
+
|
162 |
+
if __name__ == '__main__':
|
163 |
+
from transformers import CLIPImageProcessor, AutoTokenizer
|
164 |
+
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
|
165 |
+
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
|
166 |
+
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
|
167 |
+
|
168 |
+
tokenizer = dict(
|
169 |
+
type=AutoTokenizer.from_pretrained,
|
170 |
+
pretrained_model_name_or_path=llm_name_or_path)
|
171 |
+
image_processor = dict(
|
172 |
+
type=CLIPImageProcessor.from_pretrained,
|
173 |
+
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
|
174 |
+
extra_image_processor = dict(
|
175 |
+
type=ResizeLongestSide,
|
176 |
+
target_length=1024,
|
177 |
+
)
|
178 |
+
from xtuner.utils.templates import PROMPT_TEMPLATE
|
179 |
+
prompt_template = PROMPT_TEMPLATE.vicuna
|
180 |
+
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
|
181 |
+
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
|
182 |
+
|
183 |
+
dataset = ReferSegmDataset(
|
184 |
+
tokenizer=tokenizer,
|
185 |
+
image_processor=image_processor,
|
186 |
+
template_map_fn=dict(
|
187 |
+
type=template_map_fn_factory, template=prompt_template),
|
188 |
+
extra_image_processor=extra_image_processor,
|
189 |
+
data_root='data/coco/',
|
190 |
+
data_prefix=dict(img_path='train2014/'),
|
191 |
+
ann_file='refcoco+/instances.json',
|
192 |
+
split_file='refcoco+/refs(unc).p',
|
193 |
+
)
|
194 |
+
for i in range(1000):
|
195 |
+
dataset[i]
|
projects/glamm/datasets/region_level_dataset.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from mmengine import print_log
|
10 |
+
from mmengine.config import Config, ConfigDict
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from pycocotools.coco import COCO
|
16 |
+
from pycocotools import mask as mask_utils
|
17 |
+
|
18 |
+
from xtuner.registry import BUILDER
|
19 |
+
|
20 |
+
from xtuner.dataset.utils import encode_fn
|
21 |
+
from xtuner.dataset.map_fns import llava_map_fn
|
22 |
+
|
23 |
+
from projects.glamm.datasets.utils.utils import expand2square
|
24 |
+
|
25 |
+
from projects.glamm.datasets.utils.utils import ANSWER_LIST, REGION_QUESTIONS
|
26 |
+
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
27 |
+
|
28 |
+
|
29 |
+
class RegionDataset(Dataset):
|
30 |
+
def __init__(self,
|
31 |
+
image_folder,
|
32 |
+
image_processor,
|
33 |
+
data_path=None,
|
34 |
+
tokenizer=None,
|
35 |
+
template_map_fn=None,
|
36 |
+
max_length=2048,
|
37 |
+
pad_image_to_square=False,
|
38 |
+
repeats=1,
|
39 |
+
num_classes_per_sample=3,
|
40 |
+
extra_image_processor=None):
|
41 |
+
super().__init__()
|
42 |
+
|
43 |
+
self.begin_str = f"""{DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n"""
|
44 |
+
self.question_templates = REGION_QUESTIONS
|
45 |
+
|
46 |
+
if extra_image_processor is not None:
|
47 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
48 |
+
self.num_classes_per_sample = num_classes_per_sample
|
49 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
50 |
+
|
51 |
+
self.tokenizer.add_tokens(
|
52 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
53 |
+
)
|
54 |
+
reg_tokens = ['<bbox>', '<point>']
|
55 |
+
segmentation_tokens = ['[SEG]']
|
56 |
+
phrase_tokens = ['<p>', '</p>']
|
57 |
+
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
|
58 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
59 |
+
|
60 |
+
self.max_length = max_length
|
61 |
+
self.template_map_fn = BUILDER.build(template_map_fn)
|
62 |
+
|
63 |
+
self.text_data = self._load_annotations(data_path, image_folder)
|
64 |
+
self.image_folder = image_folder
|
65 |
+
|
66 |
+
self.image_processor = BUILDER.build(image_processor)
|
67 |
+
size = self.image_processor.crop_size
|
68 |
+
|
69 |
+
if isinstance(size, dict):
|
70 |
+
self.image_w, self.image_h = size['width'], size['height']
|
71 |
+
elif isinstance(size, int):
|
72 |
+
self.image_h, self.image_w = size, size
|
73 |
+
else:
|
74 |
+
self.image_w, self.image_h = size
|
75 |
+
|
76 |
+
self.pad_image_to_square = pad_image_to_square
|
77 |
+
self.repeats = repeats
|
78 |
+
|
79 |
+
def _load_annotations(self, data_path, image_folder=None):
|
80 |
+
self.coco = COCO(data_path)
|
81 |
+
img_ids = self.coco.getImgIds()
|
82 |
+
data_infos = []
|
83 |
+
for img_id in img_ids:
|
84 |
+
info = self.coco.loadImgs([img_id])[0]
|
85 |
+
info['filename'] = info['file_name'].split('_')[-1]
|
86 |
+
info['height'] = int(info['height'])
|
87 |
+
info['width'] = int(info['width'])
|
88 |
+
if min(info['height'], info['width']) < 32:
|
89 |
+
continue
|
90 |
+
data_infos.append(info)
|
91 |
+
return data_infos
|
92 |
+
|
93 |
+
@property
|
94 |
+
def modality_length(self):
|
95 |
+
length_list = []
|
96 |
+
for data_dict in self.text_data:
|
97 |
+
cur_len = 100
|
98 |
+
length_list.append(cur_len)
|
99 |
+
return length_list * self.repeats
|
100 |
+
|
101 |
+
def __len__(self):
|
102 |
+
return len(self.text_data) * self.repeats
|
103 |
+
|
104 |
+
def real_len(self):
|
105 |
+
return len(self.text_data)
|
106 |
+
|
107 |
+
def region_processor(self, orig_size, post_size, bboxes, labels):
|
108 |
+
orig_h, orig_w = orig_size
|
109 |
+
post_h, post_w = post_size
|
110 |
+
y_scale = post_h / orig_h
|
111 |
+
x_scale = post_w / orig_w
|
112 |
+
shuffle_ids = torch.randperm(len(labels))[:self.num_classes_per_sample]
|
113 |
+
selected_bboxes = bboxes[shuffle_ids]
|
114 |
+
|
115 |
+
# Ensure selected_bboxes is two-dimensional
|
116 |
+
if len(selected_bboxes.shape) == 1:
|
117 |
+
selected_bboxes = np.expand_dims(selected_bboxes, axis=0)
|
118 |
+
|
119 |
+
selected_labels = [labels[i] for i in shuffle_ids]
|
120 |
+
selected_bboxes[:, [0, 2]] *= x_scale
|
121 |
+
selected_bboxes[:, [1, 3]] *= y_scale
|
122 |
+
selected_bboxes = torch.tensor(
|
123 |
+
selected_bboxes, dtype=torch.float32) / post_h
|
124 |
+
return selected_bboxes, selected_labels
|
125 |
+
|
126 |
+
def _parse_annotations(self, img_info):
|
127 |
+
data_dict = {}
|
128 |
+
bboxes, captions = [], []
|
129 |
+
ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
|
130 |
+
image_path = os.path.join(self.image_folder, img_info['file_name'])
|
131 |
+
image = Image.open(image_path).convert('RGB')
|
132 |
+
if hasattr(self, 'extra_image_processor'):
|
133 |
+
g_image = np.array(image) # for grounding
|
134 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
135 |
+
g_pixel_values = torch.from_numpy(
|
136 |
+
g_image).permute(2, 0, 1).contiguous()
|
137 |
+
data_dict['g_pixel_values'] = g_pixel_values
|
138 |
+
|
139 |
+
orig_w, orig_h = image.size
|
140 |
+
if self.pad_image_to_square:
|
141 |
+
image = expand2square(
|
142 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
143 |
+
image = self.image_processor.preprocess(
|
144 |
+
image, return_tensors='pt')['pixel_values'][0]
|
145 |
+
post_h, post_w = image.shape[1:3]
|
146 |
+
data_dict['pixel_values'] = image
|
147 |
+
|
148 |
+
for ann in ann_info:
|
149 |
+
if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
|
150 |
+
continue
|
151 |
+
x1, y1, w, h = ann['bbox']
|
152 |
+
inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
|
153 |
+
inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
|
154 |
+
if inter_w * inter_h == 0:
|
155 |
+
continue
|
156 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
157 |
+
|
158 |
+
if bbox:
|
159 |
+
bboxes.append(bbox)
|
160 |
+
captions.append(img_info['caption'])
|
161 |
+
|
162 |
+
if len(bboxes) == 0:
|
163 |
+
return self.__getitem__(0)
|
164 |
+
|
165 |
+
bboxes = np.array(bboxes, dtype=np.float32)
|
166 |
+
seg_map = img_info['file_name'].replace('jpg', 'png')
|
167 |
+
bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
|
168 |
+
|
169 |
+
data_dict['bboxes'] = bboxes
|
170 |
+
data_dict['captions'] = captions
|
171 |
+
data_dict['seg_map'] = seg_map
|
172 |
+
return data_dict
|
173 |
+
|
174 |
+
def create_conversation(self, captions):
|
175 |
+
questions = []
|
176 |
+
answers = []
|
177 |
+
for i, label in enumerate(captions):
|
178 |
+
question = random.choice(self.question_templates).strip().replace('<region>', f'region{i + 1} <bbox>')
|
179 |
+
questions.append(question)
|
180 |
+
answers.append(label)
|
181 |
+
|
182 |
+
conversation = []
|
183 |
+
for i, (question, answer) in enumerate(zip(questions, answers)):
|
184 |
+
if i == 0:
|
185 |
+
question = self.begin_str + question
|
186 |
+
conversation.append({'input': question, 'output': answer})
|
187 |
+
return conversation
|
188 |
+
|
189 |
+
def __getitem__(self, index):
|
190 |
+
index = index % self.real_len()
|
191 |
+
data_dict = {}
|
192 |
+
ann_info = copy.deepcopy(self.text_data[index])
|
193 |
+
ann_info = self._parse_annotations(ann_info)
|
194 |
+
|
195 |
+
data_dict['g_pixel_values'] = ann_info.pop('g_pixel_values', None)
|
196 |
+
data_dict['pixel_values'] = ann_info.pop('pixel_values')
|
197 |
+
data_dict['bboxes'] = ann_info.pop('bboxes', None)
|
198 |
+
|
199 |
+
conversation = self.create_conversation(ann_info['captions'])
|
200 |
+
data_dict['conversation'] = conversation
|
201 |
+
|
202 |
+
result = self.template_map_fn(data_dict)
|
203 |
+
data_dict.update(result)
|
204 |
+
|
205 |
+
result = encode_fn(data_dict, tokenizer=self.tokenizer,
|
206 |
+
max_length=self.max_length, with_image_token=True)
|
207 |
+
data_dict.update(result)
|
208 |
+
|
209 |
+
return data_dict
|
210 |
+
|
211 |
+
class RefCocoGRegionDataset(RegionDataset):
|
212 |
+
pass
|
213 |
+
|
214 |
+
class VisualGenomeRegionDataset(RegionDataset):
|
215 |
+
def _parse_annotations(self, img_info):
|
216 |
+
data_dict = {}
|
217 |
+
bboxes, captions = [], []
|
218 |
+
ann_info = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_info['id']))
|
219 |
+
image_path = os.path.join(self.image_folder, img_info['file_name'])
|
220 |
+
image = Image.open(image_path).convert('RGB')
|
221 |
+
if hasattr(self, 'extra_image_processor'):
|
222 |
+
g_image = np.array(image) # for grounding
|
223 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
224 |
+
g_pixel_values = torch.from_numpy(
|
225 |
+
g_image).permute(2, 0, 1).contiguous()
|
226 |
+
data_dict['g_pixel_values'] = g_pixel_values
|
227 |
+
|
228 |
+
orig_w, orig_h = image.size
|
229 |
+
if self.pad_image_to_square:
|
230 |
+
image = expand2square(
|
231 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
232 |
+
image = self.image_processor.preprocess(
|
233 |
+
image, return_tensors='pt')['pixel_values'][0]
|
234 |
+
post_h, post_w = image.shape[1:3]
|
235 |
+
data_dict['pixel_values'] = image
|
236 |
+
|
237 |
+
for ann in ann_info:
|
238 |
+
if ann.get('ignore', False) or ann['area'] <= 0 or ann['bbox'][2] < 1 or ann['bbox'][3] < 1:
|
239 |
+
continue
|
240 |
+
x1, y1, w, h = ann['bbox']
|
241 |
+
inter_w = max(0, min(x1 + w, orig_w) - max(x1, 0))
|
242 |
+
inter_h = max(0, min(y1 + h, orig_h) - max(y1, 0))
|
243 |
+
if inter_w * inter_h == 0:
|
244 |
+
continue
|
245 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
246 |
+
|
247 |
+
if bbox:
|
248 |
+
bboxes.append(bbox)
|
249 |
+
captions.append(ann['caption'].strip())
|
250 |
+
|
251 |
+
if len(bboxes) == 0:
|
252 |
+
return self.__getitem__(0)
|
253 |
+
|
254 |
+
bboxes = np.array(bboxes, dtype=np.float32)
|
255 |
+
seg_map = img_info['file_name'].replace('jpg', 'png')
|
256 |
+
bboxes, captions = self.region_processor((orig_h, orig_w), (post_h, post_w), bboxes, captions)
|
257 |
+
|
258 |
+
data_dict['bboxes'] = bboxes
|
259 |
+
data_dict['captions'] = captions
|
260 |
+
data_dict['seg_map'] = seg_map
|
261 |
+
return data_dict
|
262 |
+
|
263 |
+
if __name__ == '__main__':
|
264 |
+
from transformers import CLIPImageProcessor, AutoTokenizer
|
265 |
+
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
|
266 |
+
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
|
267 |
+
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
|
268 |
+
|
269 |
+
tokenizer = dict(
|
270 |
+
type=AutoTokenizer.from_pretrained,
|
271 |
+
pretrained_model_name_or_path=llm_name_or_path)
|
272 |
+
image_processor = dict(
|
273 |
+
type=CLIPImageProcessor.from_pretrained,
|
274 |
+
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
|
275 |
+
extra_image_processor = dict(
|
276 |
+
type=ResizeLongestSide,
|
277 |
+
target_length=1024,
|
278 |
+
)
|
279 |
+
from xtuner.utils.templates import PROMPT_TEMPLATE
|
280 |
+
prompt_template = PROMPT_TEMPLATE.vicuna
|
281 |
+
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
|
282 |
+
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
|
283 |
+
dataset = VisualGenomeRegionDataset(
|
284 |
+
image_folder='./data/visual_genome/images',
|
285 |
+
image_processor=image_processor,
|
286 |
+
data_path='data/visual_genome/train.json',
|
287 |
+
tokenizer=tokenizer,
|
288 |
+
template_map_fn=dict(
|
289 |
+
type=template_map_fn_factory, template=prompt_template),
|
290 |
+
max_length=2048,
|
291 |
+
pad_image_to_square=False,
|
292 |
+
repeats=1,
|
293 |
+
num_classes_per_sample=3,
|
294 |
+
extra_image_processor=None)
|
295 |
+
|
296 |
+
for i in range(1000):
|
297 |
+
print(dataset[i])
|
projects/glamm/datasets/semantic_seg_dataset.py
ADDED
@@ -0,0 +1,424 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from mmengine import print_log
|
10 |
+
from mmengine.config import Config, ConfigDict
|
11 |
+
from PIL import Image
|
12 |
+
from torch.utils.data import Dataset
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from pycocotools.coco import COCO
|
16 |
+
|
17 |
+
from xtuner.registry import BUILDER
|
18 |
+
|
19 |
+
from xtuner.dataset.utils import encode_fn
|
20 |
+
from xtuner.dataset.map_fns import llava_map_fn
|
21 |
+
|
22 |
+
from projects.glamm.datasets.utils.utils import expand2square
|
23 |
+
|
24 |
+
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
|
25 |
+
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
26 |
+
|
27 |
+
|
28 |
+
class SemanticSegDataset(Dataset):
|
29 |
+
def __init__(self,
|
30 |
+
image_folder,
|
31 |
+
image_processor,
|
32 |
+
data_path=None,
|
33 |
+
tokenizer=None,
|
34 |
+
offline_processed_text_folder=None,
|
35 |
+
max_dataset_length=None,
|
36 |
+
dataset_map_fn=None,
|
37 |
+
template_map_fn=None,
|
38 |
+
max_length=2048,
|
39 |
+
pad_image_to_square=False,
|
40 |
+
num_proc=8,
|
41 |
+
lazy=False,
|
42 |
+
repeats=1,
|
43 |
+
gcg_format=False,
|
44 |
+
num_classes_per_sample=3,
|
45 |
+
extra_image_processor=None):
|
46 |
+
super().__init__()
|
47 |
+
self.gcg_format = gcg_format
|
48 |
+
if extra_image_processor is not None:
|
49 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
50 |
+
self.num_classes_per_sample = num_classes_per_sample
|
51 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
52 |
+
|
53 |
+
self.tokenizer.add_tokens(
|
54 |
+
[DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True
|
55 |
+
)
|
56 |
+
reg_tokens = ['<bbox>', '<point>']
|
57 |
+
segmentation_tokens = ['[SEG]']
|
58 |
+
phrase_tokens = ['<p>', '</p>']
|
59 |
+
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
|
60 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
61 |
+
|
62 |
+
assert offline_processed_text_folder or (data_path and tokenizer)
|
63 |
+
self.lazy = lazy
|
64 |
+
|
65 |
+
self.max_length = max_length
|
66 |
+
self.dataset_map_fn = dataset_map_fn
|
67 |
+
self.template_map_fn = template_map_fn
|
68 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
69 |
+
_type = self.template_map_fn['type']
|
70 |
+
del self.template_map_fn['type']
|
71 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
72 |
+
|
73 |
+
if offline_processed_text_folder and data_path:
|
74 |
+
print_log(
|
75 |
+
'Both `offline_processed_text_folder` and '
|
76 |
+
'`data_path` are set, and we load dataset from'
|
77 |
+
'`offline_processed_text_folder` '
|
78 |
+
f'({offline_processed_text_folder})',
|
79 |
+
logger='current',
|
80 |
+
level=logging.WARNING)
|
81 |
+
|
82 |
+
if offline_processed_text_folder is not None:
|
83 |
+
raise NotImplementedError
|
84 |
+
else:
|
85 |
+
self.image_label_datas = self.json_file_preprocess(data_path, image_folder)
|
86 |
+
|
87 |
+
self.image_folder = image_folder
|
88 |
+
|
89 |
+
if isinstance(image_processor, dict) or isinstance(image_processor, Config) or isinstance(image_processor, ConfigDict):
|
90 |
+
self.image_processor = BUILDER.build(image_processor)
|
91 |
+
else:
|
92 |
+
self.image_processor = image_processor
|
93 |
+
|
94 |
+
size = self.image_processor.crop_size
|
95 |
+
|
96 |
+
if isinstance(size, dict):
|
97 |
+
self.image_w, self.image_h = size['width'], size['height']
|
98 |
+
elif isinstance(size, int):
|
99 |
+
self.image_h, self.image_w = size, size
|
100 |
+
else:
|
101 |
+
self.image_w, self.image_h = size
|
102 |
+
|
103 |
+
self.pad_image_to_square = pad_image_to_square
|
104 |
+
self.down_ratio = 1
|
105 |
+
self.repeats = repeats
|
106 |
+
|
107 |
+
def json_file_preprocess(self, data_path, image_folder):
|
108 |
+
# ade20k
|
109 |
+
with open(data_path, 'r') as file:
|
110 |
+
ade20k_classes = json.load(file)
|
111 |
+
ade20k_image_dir = image_folder
|
112 |
+
ade20k_images = [os.path.join(ade20k_image_dir, img) for img in os.listdir(ade20k_image_dir) if
|
113 |
+
img.endswith('.jpg')]
|
114 |
+
ade20k_labels = [img.replace(".jpg", ".png").replace(
|
115 |
+
"images", "annotations") for img in ade20k_images]
|
116 |
+
self.classes = np.array(ade20k_classes)
|
117 |
+
|
118 |
+
ret = []
|
119 |
+
for image, label in zip(ade20k_images, ade20k_labels):
|
120 |
+
ret.append({"image": image, "label": label})
|
121 |
+
return ret
|
122 |
+
|
123 |
+
def __len__(self):
|
124 |
+
return len(self.image_label_datas) * self.repeats
|
125 |
+
|
126 |
+
@property
|
127 |
+
def modality_length(self):
|
128 |
+
length_list = []
|
129 |
+
for data_dict in self.image_label_datas:
|
130 |
+
length_list.append(100)
|
131 |
+
length_list = length_list * self.repeats
|
132 |
+
return length_list
|
133 |
+
|
134 |
+
def real_len(self):
|
135 |
+
return len(self.image_label_datas)
|
136 |
+
|
137 |
+
def decode_mask(self, label_path):
|
138 |
+
label = np.array(Image.open(label_path))
|
139 |
+
|
140 |
+
# ade20k
|
141 |
+
label = np.where(label == 0, 255, label - 1)
|
142 |
+
unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
|
143 |
+
if not unique_labels:
|
144 |
+
return None, None
|
145 |
+
|
146 |
+
selected_labels = np.random.choice(unique_labels, min(
|
147 |
+
len(unique_labels), self.num_classes_per_sample), replace=False)
|
148 |
+
label = torch.from_numpy(label).long()
|
149 |
+
masks = torch.stack([label == class_id for class_id in selected_labels], dim=0)
|
150 |
+
return masks, selected_labels
|
151 |
+
|
152 |
+
def __getitem__(self, index):
|
153 |
+
index = index % self.real_len()
|
154 |
+
data_dict = copy.deepcopy(self.image_label_datas[index])
|
155 |
+
|
156 |
+
assert 'image' in data_dict.keys()
|
157 |
+
if data_dict.get('image', None) is not None:
|
158 |
+
image_file = data_dict['image']
|
159 |
+
image = Image.open(image_file).convert('RGB')
|
160 |
+
if hasattr(self, 'extra_image_processor'):
|
161 |
+
g_image = np.array(image) # for grounding
|
162 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
163 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
164 |
+
data_dict['g_pixel_values'] = g_pixel_values
|
165 |
+
|
166 |
+
ori_width, ori_height = image.size
|
167 |
+
if self.pad_image_to_square:
|
168 |
+
image = expand2square(image, tuple(int(x * 255)
|
169 |
+
for x in self.image_processor.image_mean))
|
170 |
+
image = self.image_processor.preprocess(
|
171 |
+
image, return_tensors='pt')['pixel_values'][0]
|
172 |
+
data_dict['pixel_values'] = image
|
173 |
+
|
174 |
+
# process and get masks
|
175 |
+
data_dict['masks'], class_id = self.decode_mask(data_dict['label'])
|
176 |
+
if class_id is None:
|
177 |
+
return self.__getitem__(0)
|
178 |
+
|
179 |
+
if self.gcg_format:
|
180 |
+
pass
|
181 |
+
else:
|
182 |
+
conversation = []
|
183 |
+
for i, c_id in enumerate(class_id):
|
184 |
+
question = random.choice(SEG_QUESTIONS).format(
|
185 |
+
class_name=self.classes[c_id].lower())
|
186 |
+
if i == 0:
|
187 |
+
question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question
|
188 |
+
conversation.append(
|
189 |
+
{'input': question, 'output': random.choice(ANSWER_LIST)})
|
190 |
+
|
191 |
+
data_dict.update({'conversation': conversation})
|
192 |
+
else:
|
193 |
+
if hasattr(self.image_processor, 'crop_size'):
|
194 |
+
crop_size = self.image_processor.crop_size
|
195 |
+
else:
|
196 |
+
crop_size = self.image_processor.size
|
197 |
+
data_dict['pixel_values'] = torch.zeros(3, crop_size['height'],
|
198 |
+
crop_size['width'])
|
199 |
+
data_dict['masks'] = None
|
200 |
+
|
201 |
+
if self.lazy:
|
202 |
+
result = self.template_map_fn(data_dict)
|
203 |
+
data_dict.update(result)
|
204 |
+
|
205 |
+
result = encode_fn(data_dict, tokenizer=self.tokenizer,
|
206 |
+
max_length=self.max_length, with_image_token=True)
|
207 |
+
data_dict.update(result)
|
208 |
+
|
209 |
+
return data_dict
|
210 |
+
|
211 |
+
class ADE20kSemanticSegDataset(SemanticSegDataset):
|
212 |
+
def __init__(self,
|
213 |
+
image_folder,
|
214 |
+
image_processor,
|
215 |
+
data_path=None,
|
216 |
+
tokenizer=None,
|
217 |
+
offline_processed_text_folder=None,
|
218 |
+
max_dataset_length=None,
|
219 |
+
dataset_map_fn=None,
|
220 |
+
template_map_fn=None,
|
221 |
+
max_length=2048,
|
222 |
+
pad_image_to_square=False,
|
223 |
+
num_proc=8,
|
224 |
+
lazy=False,
|
225 |
+
repeats=1,
|
226 |
+
gcg_format=False,
|
227 |
+
num_classes_per_sample=3,
|
228 |
+
extra_image_processor=None):
|
229 |
+
super().__init__(
|
230 |
+
image_folder=image_folder,
|
231 |
+
image_processor=image_processor,
|
232 |
+
data_path=data_path,
|
233 |
+
tokenizer=tokenizer,
|
234 |
+
offline_processed_text_folder=offline_processed_text_folder,
|
235 |
+
max_dataset_length=max_dataset_length,
|
236 |
+
dataset_map_fn=dataset_map_fn,
|
237 |
+
template_map_fn=template_map_fn,
|
238 |
+
max_length=max_length,
|
239 |
+
pad_image_to_square=pad_image_to_square,
|
240 |
+
num_proc=num_proc,
|
241 |
+
lazy=lazy,
|
242 |
+
repeats=repeats,
|
243 |
+
gcg_format=gcg_format,
|
244 |
+
num_classes_per_sample=num_classes_per_sample,
|
245 |
+
extra_image_processor=extra_image_processor,
|
246 |
+
)
|
247 |
+
|
248 |
+
class COCOStuffSemanticSegDataset(SemanticSegDataset):
|
249 |
+
def __init__(self,
|
250 |
+
image_folder,
|
251 |
+
image_processor,
|
252 |
+
data_path=None,
|
253 |
+
tokenizer=None,
|
254 |
+
offline_processed_text_folder=None,
|
255 |
+
max_dataset_length=None,
|
256 |
+
dataset_map_fn=None,
|
257 |
+
template_map_fn=None,
|
258 |
+
max_length=2048,
|
259 |
+
pad_image_to_square=False,
|
260 |
+
num_proc=8,
|
261 |
+
lazy=False,
|
262 |
+
repeats=1,
|
263 |
+
label_path=None,
|
264 |
+
gcg_format=False,
|
265 |
+
num_classes_per_sample=3,
|
266 |
+
extra_image_processor=None):
|
267 |
+
self.label_path = label_path
|
268 |
+
super().__init__(
|
269 |
+
image_folder=image_folder,
|
270 |
+
image_processor=image_processor,
|
271 |
+
data_path=data_path,
|
272 |
+
tokenizer=tokenizer,
|
273 |
+
offline_processed_text_folder=offline_processed_text_folder,
|
274 |
+
max_dataset_length=max_dataset_length,
|
275 |
+
dataset_map_fn=dataset_map_fn,
|
276 |
+
template_map_fn=template_map_fn,
|
277 |
+
max_length=max_length,
|
278 |
+
pad_image_to_square=pad_image_to_square,
|
279 |
+
num_proc=num_proc,
|
280 |
+
lazy=lazy,
|
281 |
+
repeats=repeats,
|
282 |
+
gcg_format=gcg_format,
|
283 |
+
num_classes_per_sample=num_classes_per_sample,
|
284 |
+
extra_image_processor=extra_image_processor,
|
285 |
+
)
|
286 |
+
self.cocostuff_class2index = {c: i for i, c in enumerate(self.classes)}
|
287 |
+
|
288 |
+
def json_file_preprocess(self, data_path, image_folder):
|
289 |
+
# coco stuff
|
290 |
+
assert self.label_path is not None
|
291 |
+
with open(data_path, 'r') as file:
|
292 |
+
cocostuff_classes = [line.strip().split(": ")[-1]
|
293 |
+
for line in file.readlines()[1:]]
|
294 |
+
coco_stuff_image_dir = image_folder
|
295 |
+
coco_stuff_label_dir = self.label_path
|
296 |
+
coco_stuff_labels = glob.glob(
|
297 |
+
os.path.join(coco_stuff_label_dir, "*.png"))
|
298 |
+
|
299 |
+
coco_stuff_images = [label.replace(".png", ".jpg").replace(coco_stuff_label_dir, coco_stuff_image_dir)
|
300 |
+
for label in coco_stuff_labels]
|
301 |
+
|
302 |
+
self.classes = np.array(cocostuff_classes)
|
303 |
+
|
304 |
+
ret = []
|
305 |
+
for image, label in zip(coco_stuff_images, coco_stuff_labels):
|
306 |
+
ret.append({"image": image, "label": label})
|
307 |
+
return ret
|
308 |
+
|
309 |
+
def decode_mask(self, label_path):
|
310 |
+
label = np.array(Image.open(label_path))
|
311 |
+
|
312 |
+
# coco stuff
|
313 |
+
ignored_classes = [index for class_name,
|
314 |
+
index in self.cocostuff_class2index.items() if "-" in class_name]
|
315 |
+
label = np.where(np.isin(label, ignored_classes), 255, label)
|
316 |
+
|
317 |
+
unique_labels = [lbl for lbl in np.unique(label) if lbl != 255]
|
318 |
+
if not unique_labels:
|
319 |
+
print("No valid label !!!")
|
320 |
+
return None, None
|
321 |
+
|
322 |
+
# only choose 1
|
323 |
+
selected_labels = np.random.choice(unique_labels, min(
|
324 |
+
len(unique_labels), self.num_classes_per_sample), replace=False)
|
325 |
+
|
326 |
+
label = torch.from_numpy(label).long()
|
327 |
+
masks = torch.stack(
|
328 |
+
[label == class_id for class_id in selected_labels], dim=0)
|
329 |
+
return masks, selected_labels
|
330 |
+
|
331 |
+
class PascalPartSemanticSegDataset(SemanticSegDataset):
|
332 |
+
|
333 |
+
def json_file_preprocess(self, data_path, image_folder):
|
334 |
+
self.coco_api = COCO(data_path)
|
335 |
+
img_ids = self.coco_api.getImgIds()
|
336 |
+
all_classes = self.coco_api.loadCats(self.coco_api.getCatIds())
|
337 |
+
class_map_pascal_part = {}
|
338 |
+
for cat in all_classes:
|
339 |
+
cat_main, cat_part = cat["name"].strip().split(":")
|
340 |
+
name = (cat_main, cat_part)
|
341 |
+
class_map_pascal_part[cat["id"]] = name
|
342 |
+
self.classes = class_map_pascal_part
|
343 |
+
return img_ids
|
344 |
+
|
345 |
+
def __getitem__(self, index):
|
346 |
+
index = index % self.real_len()
|
347 |
+
img_id = self.image_label_datas[index]
|
348 |
+
img_info = self.coco_api.loadImgs([img_id])[0]
|
349 |
+
file_name = img_info["file_name"]
|
350 |
+
data_dict = {}
|
351 |
+
|
352 |
+
image_file = os.path.join(self.image_folder, file_name)
|
353 |
+
image = Image.open(image_file).convert('RGB')
|
354 |
+
|
355 |
+
if hasattr(self, 'extra_image_processor'):
|
356 |
+
g_image = np.array(image) # for grounding
|
357 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
358 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
359 |
+
data_dict['g_pixel_values'] = g_pixel_values
|
360 |
+
|
361 |
+
if self.pad_image_to_square:
|
362 |
+
image = expand2square(
|
363 |
+
image, tuple(int(x * 255) for x in self.image_processor.image_mean))
|
364 |
+
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
|
365 |
+
data_dict['pixel_values'] = image
|
366 |
+
|
367 |
+
annotation_ids = self.coco_api.getAnnIds(imgIds=img_info["id"])
|
368 |
+
annotations = self.coco_api.loadAnns(annotation_ids)
|
369 |
+
|
370 |
+
if not annotations:
|
371 |
+
return self.__getitem__(0)
|
372 |
+
|
373 |
+
sampled_anns = np.random.choice(annotations, min(
|
374 |
+
len(annotations), self.num_classes_per_sample), replace=False)
|
375 |
+
|
376 |
+
conversation = []
|
377 |
+
for i, ann in enumerate(sampled_anns):
|
378 |
+
cat_id = ann['category_id']
|
379 |
+
sampled_cls = self.classes[cat_id]
|
380 |
+
if isinstance(sampled_cls, tuple):
|
381 |
+
obj, part = sampled_cls
|
382 |
+
name = f"{obj} {part}" if random.random() < 0.5 else f"the {part} of the {obj}"
|
383 |
+
else:
|
384 |
+
name = sampled_cls
|
385 |
+
question = random.choice(SEG_QUESTIONS).format(class_name=name)
|
386 |
+
if i == 0:
|
387 |
+
question = f"""The {DEFAULT_IMAGE_TOKEN} provides an overview of the picture.\n""" + question
|
388 |
+
conversation.append(
|
389 |
+
{'input': question, 'output': random.choice(ANSWER_LIST)})
|
390 |
+
|
391 |
+
masks = [self.coco_api.annToMask(ann) for ann in sampled_anns]
|
392 |
+
masks = np.stack(masks, axis=0)
|
393 |
+
masks = torch.from_numpy(masks)
|
394 |
+
|
395 |
+
data_dict['masks'] = masks
|
396 |
+
data_dict['conversation'] = conversation
|
397 |
+
|
398 |
+
if self.lazy:
|
399 |
+
result = self.template_map_fn(data_dict)
|
400 |
+
data_dict.update(result)
|
401 |
+
|
402 |
+
result = encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
|
403 |
+
data_dict.update(result)
|
404 |
+
|
405 |
+
return data_dict
|
406 |
+
|
407 |
+
class PacoSemanticSegDataset(PascalPartSemanticSegDataset):
|
408 |
+
def json_file_preprocess(self, data_path, image_folder):
|
409 |
+
self.coco_api = COCO(data_path)
|
410 |
+
all_classes = self.coco_api.loadCats(self.coco_api.getCatIds())
|
411 |
+
class_map_paco = {}
|
412 |
+
for cat in all_classes:
|
413 |
+
cat_split = cat["name"].strip().split(":")
|
414 |
+
if len(cat_split) == 1:
|
415 |
+
name = cat_split[0].split("_(")[0]
|
416 |
+
else:
|
417 |
+
assert len(cat_split) == 2
|
418 |
+
obj, part = cat_split
|
419 |
+
obj = obj.split("_(")[0]
|
420 |
+
part = part.split("_(")[0]
|
421 |
+
name = (obj, part)
|
422 |
+
class_map_paco[cat["id"]] = name
|
423 |
+
self.classes = class_map_paco
|
424 |
+
return self.coco_api.getImgIds()
|
projects/glamm/datasets/utils/ade20k_classes.json
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
"wall", "building", "sky", "floor", "tree", "ceiling", "road",
|
3 |
+
"bed", "windowpane", "grass", "cabinet", "sidewalk",
|
4 |
+
"person", "earth", "door", "table", "mountain", "plant",
|
5 |
+
"curtain", "chair", "car", "water", "painting", "sofa",
|
6 |
+
"shelf", "house", "sea", "mirror", "rug", "field", "armchair",
|
7 |
+
"seat", "fence", "desk", "rock", "wardrobe", "lamp",
|
8 |
+
"bathtub", "railing", "cushion", "base", "box", "column",
|
9 |
+
"signboard", "chest of drawers", "counter", "sand", "sink",
|
10 |
+
"skyscraper", "fireplace", "refrigerator", "grandstand",
|
11 |
+
"path", "stairs", "runway", "case", "pool table", "pillow",
|
12 |
+
"screen door", "stairway", "river", "bridge", "bookcase",
|
13 |
+
"blind", "coffee table", "toilet", "flower", "book", "hill",
|
14 |
+
"bench", "countertop", "stove", "palm", "kitchen island",
|
15 |
+
"computer", "swivel chair", "boat", "bar", "arcade machine",
|
16 |
+
"hovel", "bus", "towel", "light", "truck", "tower",
|
17 |
+
"chandelier", "awning", "streetlight", "booth",
|
18 |
+
"television receiver", "airplane", "dirt track", "apparel",
|
19 |
+
"pole", "land", "bannister", "escalator", "ottoman", "bottle",
|
20 |
+
"buffet", "poster", "stage", "van", "ship", "fountain",
|
21 |
+
"conveyer belt", "canopy", "washer", "plaything",
|
22 |
+
"swimming pool", "stool", "barrel", "basket", "waterfall",
|
23 |
+
"tent", "bag", "minibike", "cradle", "oven", "ball", "food",
|
24 |
+
"step", "tank", "trade name", "microwave", "pot", "animal",
|
25 |
+
"bicycle", "lake", "dishwasher", "screen", "blanket",
|
26 |
+
"sculpture", "hood", "sconce", "vase", "traffic light",
|
27 |
+
"tray", "ashcan", "fan", "pier", "crt screen", "plate",
|
28 |
+
"monitor", "bulletin board", "shower", "radiator", "glass",
|
29 |
+
"clock", "flag"
|
30 |
+
]
|
projects/glamm/datasets/utils/cocostuff_classes.txt
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
0: unlabeled
|
2 |
+
1: person
|
3 |
+
2: bicycle
|
4 |
+
3: car
|
5 |
+
4: motorcycle
|
6 |
+
5: airplane
|
7 |
+
6: bus
|
8 |
+
7: train
|
9 |
+
8: truck
|
10 |
+
9: boat
|
11 |
+
10: traffic light
|
12 |
+
11: fire hydrant
|
13 |
+
12: street sign
|
14 |
+
13: stop sign
|
15 |
+
14: parking meter
|
16 |
+
15: bench
|
17 |
+
16: bird
|
18 |
+
17: cat
|
19 |
+
18: dog
|
20 |
+
19: horse
|
21 |
+
20: sheep
|
22 |
+
21: cow
|
23 |
+
22: elephant
|
24 |
+
23: bear
|
25 |
+
24: zebra
|
26 |
+
25: giraffe
|
27 |
+
26: hat
|
28 |
+
27: backpack
|
29 |
+
28: umbrella
|
30 |
+
29: shoe
|
31 |
+
30: eye glasses
|
32 |
+
31: handbag
|
33 |
+
32: tie
|
34 |
+
33: suitcase
|
35 |
+
34: frisbee
|
36 |
+
35: skis
|
37 |
+
36: snowboard
|
38 |
+
37: sports ball
|
39 |
+
38: kite
|
40 |
+
39: baseball bat
|
41 |
+
40: baseball glove
|
42 |
+
41: skateboard
|
43 |
+
42: surfboard
|
44 |
+
43: tennis racket
|
45 |
+
44: bottle
|
46 |
+
45: plate
|
47 |
+
46: wine glass
|
48 |
+
47: cup
|
49 |
+
48: fork
|
50 |
+
49: knife
|
51 |
+
50: spoon
|
52 |
+
51: bowl
|
53 |
+
52: banana
|
54 |
+
53: apple
|
55 |
+
54: sandwich
|
56 |
+
55: orange
|
57 |
+
56: broccoli
|
58 |
+
57: carrot
|
59 |
+
58: hot dog
|
60 |
+
59: pizza
|
61 |
+
60: donut
|
62 |
+
61: cake
|
63 |
+
62: chair
|
64 |
+
63: couch
|
65 |
+
64: potted plant
|
66 |
+
65: bed
|
67 |
+
66: mirror
|
68 |
+
67: dining table
|
69 |
+
68: window
|
70 |
+
69: desk
|
71 |
+
70: toilet
|
72 |
+
71: door
|
73 |
+
72: tv
|
74 |
+
73: laptop
|
75 |
+
74: mouse
|
76 |
+
75: remote
|
77 |
+
76: keyboard
|
78 |
+
77: cell phone
|
79 |
+
78: microwave
|
80 |
+
79: oven
|
81 |
+
80: toaster
|
82 |
+
81: sink
|
83 |
+
82: refrigerator
|
84 |
+
83: blender
|
85 |
+
84: book
|
86 |
+
85: clock
|
87 |
+
86: vase
|
88 |
+
87: scissors
|
89 |
+
88: teddy bear
|
90 |
+
89: hair drier
|
91 |
+
90: toothbrush
|
92 |
+
91: hair brush
|
93 |
+
92: banner
|
94 |
+
93: blanket
|
95 |
+
94: branch
|
96 |
+
95: bridge
|
97 |
+
96: building-other
|
98 |
+
97: bush
|
99 |
+
98: cabinet
|
100 |
+
99: cage
|
101 |
+
100: cardboard
|
102 |
+
101: carpet
|
103 |
+
102: ceiling-other
|
104 |
+
103: ceiling-tile
|
105 |
+
104: cloth
|
106 |
+
105: clothes
|
107 |
+
106: clouds
|
108 |
+
107: counter
|
109 |
+
108: cupboard
|
110 |
+
109: curtain
|
111 |
+
110: desk-stuff
|
112 |
+
111: dirt
|
113 |
+
112: door-stuff
|
114 |
+
113: fence
|
115 |
+
114: floor-marble
|
116 |
+
115: floor-other
|
117 |
+
116: floor-stone
|
118 |
+
117: floor-tile
|
119 |
+
118: floor-wood
|
120 |
+
119: flower
|
121 |
+
120: fog
|
122 |
+
121: food-other
|
123 |
+
122: fruit
|
124 |
+
123: furniture-other
|
125 |
+
124: grass
|
126 |
+
125: gravel
|
127 |
+
126: ground-other
|
128 |
+
127: hill
|
129 |
+
128: house
|
130 |
+
129: leaves
|
131 |
+
130: light
|
132 |
+
131: mat
|
133 |
+
132: metal
|
134 |
+
133: mirror-stuff
|
135 |
+
134: moss
|
136 |
+
135: mountain
|
137 |
+
136: mud
|
138 |
+
137: napkin
|
139 |
+
138: net
|
140 |
+
139: paper
|
141 |
+
140: pavement
|
142 |
+
141: pillow
|
143 |
+
142: plant-other
|
144 |
+
143: plastic
|
145 |
+
144: platform
|
146 |
+
145: playingfield
|
147 |
+
146: railing
|
148 |
+
147: railroad
|
149 |
+
148: river
|
150 |
+
149: road
|
151 |
+
150: rock
|
152 |
+
151: roof
|
153 |
+
152: rug
|
154 |
+
153: salad
|
155 |
+
154: sand
|
156 |
+
155: sea
|
157 |
+
156: shelf
|
158 |
+
157: sky
|
159 |
+
158: skyscraper
|
160 |
+
159: snow
|
161 |
+
160: solid-other
|
162 |
+
161: stairs
|
163 |
+
162: stone
|
164 |
+
163: straw
|
165 |
+
164: structural-other
|
166 |
+
165: table
|
167 |
+
166: tent
|
168 |
+
167: textile-other
|
169 |
+
168: towel
|
170 |
+
169: tree
|
171 |
+
170: vegetable
|
172 |
+
171: wall-brick
|
173 |
+
172: wall-concrete
|
174 |
+
173: wall-other
|
175 |
+
174: wall-panel
|
176 |
+
175: wall-stone
|
177 |
+
176: wall-tile
|
178 |
+
177: wall-wood
|
179 |
+
178: water-other
|
180 |
+
179: waterdrops
|
181 |
+
180: window-blind
|
182 |
+
181: window-other
|
183 |
+
182: wood
|
projects/glamm/datasets/utils/utils.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
|
3 |
+
|
4 |
+
|
5 |
+
def expand2square(pil_img, background_color):
|
6 |
+
width, height = pil_img.size
|
7 |
+
if width == height:
|
8 |
+
return pil_img
|
9 |
+
elif width > height:
|
10 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
11 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
12 |
+
return result
|
13 |
+
else:
|
14 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
15 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
16 |
+
return result
|
17 |
+
|
18 |
+
CAPTION_QUESTIONS = [
|
19 |
+
'Could you please give me a detailed description of the image?',
|
20 |
+
'Can you provide a thorough description of the this image?',
|
21 |
+
'Please provide a thorough description of the this image',
|
22 |
+
'Please provide a thorough description of the this image.',
|
23 |
+
'Please describe in detail the contents of the image.',
|
24 |
+
'Please describe in detail the contents of the image',
|
25 |
+
'Could you give a comprehensive explanation of what can be found within this picture?',
|
26 |
+
'Could you give me an elaborate explanation of this picture?',
|
27 |
+
'Could you provide me with a detailed analysis of this photo?',
|
28 |
+
'Could you please give me a detailed description of the image?',
|
29 |
+
'Can you provide a thorough description of the this image?',
|
30 |
+
'Please describe in detail the contents of the image',
|
31 |
+
'Please describe in detail the contents of the image.',
|
32 |
+
'Can you give a comprehensive explanation of this photo',
|
33 |
+
'Please provide an elaborate explanation of this picture.',
|
34 |
+
'Please provide an elaborate explanation of this picture',
|
35 |
+
'Could you provide me with a detailed analysis of this photo',
|
36 |
+
]
|
37 |
+
|
38 |
+
REGION_QUESTIONS = [
|
39 |
+
'Can you provide me with a detailed description of the region in the picture marked by <region>?',
|
40 |
+
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
|
41 |
+
'What can you tell me about the region indicated by <region> in the image?',
|
42 |
+
"I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
|
43 |
+
'Could you describe the region shown as <region> in the picture in great detail?',
|
44 |
+
'What details can you give me about the region outlined by <region> in the photo?',
|
45 |
+
'Please provide me with a comprehensive description of the region marked with <region> in the image.',
|
46 |
+
'Can you give me a detailed account of the region labeled as <region> in the picture?',
|
47 |
+
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
|
48 |
+
'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
|
49 |
+
'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
|
50 |
+
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
|
51 |
+
'What can you tell me about the region indicated by <region> in the image, exactly?',
|
52 |
+
"I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
|
53 |
+
'Could you describe the region shown as <region> in the picture in great detail, please?',
|
54 |
+
'What details can you give me about the region outlined by <region> in the photo, please?',
|
55 |
+
'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
|
56 |
+
'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
|
57 |
+
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
|
58 |
+
'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
|
59 |
+
]
|
60 |
+
|
61 |
+
REGION_GROUP_QUESTIONS = [
|
62 |
+
'Could you please give me a detailed description of these areas <region>?',
|
63 |
+
'Can you provide a thorough description of the regions <region> in this image?',
|
64 |
+
'Please describe in detail the contents of the boxed areas <region>.',
|
65 |
+
'Could you give a comprehensive explanation of what can be found within <region> in the picture?',
|
66 |
+
'Could you give me an elaborate explanation of the <region> regions in this picture?',
|
67 |
+
'Can you provide a comprehensive description of the areas identified by <region> in this photo?',
|
68 |
+
'Help me understand the specific locations labeled <region> in this picture in detail, please.',
|
69 |
+
'What is the detailed information about the areas marked by <region> in this image?',
|
70 |
+
'Could you provide me with a detailed analysis of the regions designated <region> in this photo?',
|
71 |
+
'What are the specific features of the areas marked <region> in this picture that you can describe in detail?',
|
72 |
+
'Could you elaborate on the regions identified by <region> in this image?',
|
73 |
+
'What can you tell me about the areas labeled <region> in this picture?',
|
74 |
+
'Can you provide a thorough analysis of the specific locations designated <region> in this photo?',
|
75 |
+
'I am interested in learning more about the regions marked <region> in this image. Can you provide me with more information?',
|
76 |
+
'Could you please provide a detailed description of the areas identified by <region> in this photo?',
|
77 |
+
'What is the significance of the regions labeled <region> in this picture?',
|
78 |
+
'I would like to know more about the specific locations designated <region> in this image. Can you provide me with more information?',
|
79 |
+
'Can you provide a detailed breakdown of the regions marked <region> in this photo?',
|
80 |
+
'What specific features can you tell me about the areas identified by <region> in this picture?',
|
81 |
+
'Could you please provide a comprehensive explanation of the locations labeled <region> in this image?',
|
82 |
+
'Can you provide a detailed account of the regions designated <region> in this photo?',
|
83 |
+
'I am curious about the areas marked <region> in this picture. Can you provide me with a detailed analysis?',
|
84 |
+
'What important details can you tell me about the specific locations identified by <region> in this image?',
|
85 |
+
'Could you please provide a detailed description of the regions labeled <region> in this photo?',
|
86 |
+
'What can you tell me about the features of the areas designated <region> in this picture?',
|
87 |
+
'Can you provide a comprehensive overview of the regions marked <region> in this image?',
|
88 |
+
'I would like to know more about the specific locations identified by <region> in this photo. Can you provide me with more information?',
|
89 |
+
'What is the detailed information you have on the areas labeled <region> in this picture?',
|
90 |
+
'Could you provide me with a thorough analysis of the regions designated <region> in this image?',
|
91 |
+
'Can you provide a detailed explanation of the specific locations marked by <region> in this photo?'
|
92 |
+
]
|
93 |
+
|
94 |
+
GCG_QUESTIONS = [
|
95 |
+
'Could you please give me a detailed description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
|
96 |
+
'Can you provide a thorough description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
|
97 |
+
'Please describe in detail the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
|
98 |
+
'Could you give a comprehensive explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
|
99 |
+
'Could you give me an elaborate explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
|
100 |
+
'Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
|
101 |
+
]
|
102 |
+
|
103 |
+
SEG_QUESTIONS = [
|
104 |
+
"Can you segment the {class_name} in this image?",
|
105 |
+
"Please segment {class_name} in this image.",
|
106 |
+
"What is {class_name} in this image? Please respond with segmentation mask.",
|
107 |
+
"What is {class_name} in this image? Please output segmentation mask.",
|
108 |
+
|
109 |
+
"Can you segment the {class_name} in this image",
|
110 |
+
"Please segment {class_name} in this image",
|
111 |
+
"What is {class_name} in this image? Please respond with segmentation mask",
|
112 |
+
"What is {class_name} in this image? Please output segmentation mask",
|
113 |
+
|
114 |
+
"Could you provide a segmentation mask for the {class_name} in this image?",
|
115 |
+
"Please identify and segment the {class_name} in this image.",
|
116 |
+
"Where is the {class_name} in this picture? Please respond with a segmentation mask.",
|
117 |
+
"Can you highlight the {class_name} in this image with a segmentation mask?",
|
118 |
+
|
119 |
+
"Could you provide a segmentation mask for the {class_name} in this image",
|
120 |
+
"Please identify and segment the {class_name} in this image",
|
121 |
+
"Where is the {class_name} in this picture? Please respond with a segmentation mask",
|
122 |
+
"Can you highlight the {class_name} in this image with a segmentation mask",
|
123 |
+
]
|
124 |
+
|
125 |
+
ANSWER_LIST = [
|
126 |
+
"It is [SEG].",
|
127 |
+
"Sure, [SEG].",
|
128 |
+
"Sure, it is [SEG].",
|
129 |
+
"Sure, the segmentation result is [SEG].",
|
130 |
+
"[SEG].",
|
131 |
+
]
|
projects/glamm/models/glamm.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from xtuner.registry import BUILDER
|
5 |
+
from xtuner.model.utils import LoadWoInit, guess_load_checkpoint
|
6 |
+
from xtuner.model.llava import LLaVAModel
|
7 |
+
|
8 |
+
from mmengine.model import BaseModel
|
9 |
+
from mmengine import print_log
|
10 |
+
|
11 |
+
from projects.glamm.utils import prepare_inputs_labels_for_multimodal
|
12 |
+
from projects.glamm.utils import DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
13 |
+
|
14 |
+
|
15 |
+
class GLaMM(LLaVAModel):
|
16 |
+
def __init__(self,
|
17 |
+
use_activation_checkpointing=True,
|
18 |
+
tokenizer=None,
|
19 |
+
grounding_encoder=None,
|
20 |
+
region_encoder=None,
|
21 |
+
loss_mask=None,
|
22 |
+
loss_dice=None,
|
23 |
+
*args, **kwargs):
|
24 |
+
super(GLaMM, self).__init__(
|
25 |
+
*args, use_activation_checkpointing=use_activation_checkpointing, **kwargs)
|
26 |
+
|
27 |
+
self.use_activation_checkpointing = use_activation_checkpointing
|
28 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
29 |
+
self._add_special_tokens()
|
30 |
+
|
31 |
+
self.grounding_encoder = BUILDER.build(grounding_encoder)
|
32 |
+
self.grounding_encoder.requires_grad_(False)
|
33 |
+
self.grounding_encoder.mask_decoder.requires_grad_(True)
|
34 |
+
|
35 |
+
if region_encoder is not None:
|
36 |
+
self.region_encoder = BUILDER.build(region_encoder)
|
37 |
+
|
38 |
+
in_dim = self.config.hidden_size
|
39 |
+
out_dim = self.grounding_encoder.mask_decoder.transformer_dim
|
40 |
+
self.text_hidden_fcs = nn.Sequential(
|
41 |
+
nn.Linear(in_dim, in_dim), nn.ReLU(inplace=True),
|
42 |
+
nn.Linear(in_dim, out_dim), nn.Dropout(0.0)
|
43 |
+
)
|
44 |
+
|
45 |
+
self.loss_mask = BUILDER.build(loss_mask)
|
46 |
+
self.loss_dice = BUILDER.build(loss_dice)
|
47 |
+
|
48 |
+
def _add_special_tokens(self):
|
49 |
+
reg_tokens = ['<im_start>', '<im_end>', '<bbox>', '<point>']
|
50 |
+
segmentation_tokens = ['[SEG]']
|
51 |
+
phrase_tokens = ['<p>', '</p>']
|
52 |
+
special_tokens = reg_tokens + segmentation_tokens + phrase_tokens
|
53 |
+
num_new_tokens = self.tokenizer.add_tokens(
|
54 |
+
special_tokens, special_tokens=True)
|
55 |
+
if num_new_tokens > 0:
|
56 |
+
self.llm.resize_token_embeddings(len(self.tokenizer))
|
57 |
+
input_embeddings = self.llm.get_input_embeddings().weight.data
|
58 |
+
output_embeddings = self.llm.get_output_embeddings().weight.data
|
59 |
+
|
60 |
+
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(
|
61 |
+
dim=0, keepdim=True)
|
62 |
+
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(
|
63 |
+
dim=0, keepdim=True)
|
64 |
+
|
65 |
+
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
66 |
+
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
67 |
+
|
68 |
+
self.seg_token_idx = self.tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
69 |
+
self.bop_token_idx = self.tokenizer("<p>", add_special_tokens=False).input_ids[0]
|
70 |
+
self.eop_token_idx = self.tokenizer("</p>", add_special_tokens=False).input_ids[0]
|
71 |
+
self.bbox_token_idx = self.tokenizer("<bbox>", add_special_tokens=False).input_ids[0]
|
72 |
+
|
73 |
+
if self.use_activation_checkpointing or self.use_llm_lora or not self.freeze_llm:
|
74 |
+
self.llm.enable_input_require_grads()
|
75 |
+
|
76 |
+
def forward(self, data, data_samples=None, mode='loss'):
|
77 |
+
if 'pixel_values' in data:
|
78 |
+
visual_outputs = self.visual_encoder(
|
79 |
+
data['pixel_values'].to(self.visual_encoder.dtype),
|
80 |
+
output_hidden_states=True)
|
81 |
+
pixel_values = self.projector(
|
82 |
+
visual_outputs.hidden_states[self.visual_select_layer][:, 1:])
|
83 |
+
data['pixel_values'] = pixel_values
|
84 |
+
bboxes = data.pop('bboxes', None)
|
85 |
+
if bboxes is not None:
|
86 |
+
select_hidden_state_layer = -2
|
87 |
+
num_level_reg_features = 4
|
88 |
+
mlvl_reg_features = visual_outputs.hidden_states[select_hidden_state_layer::-3]
|
89 |
+
mlvl_reg_features = mlvl_reg_features[::-1]
|
90 |
+
mlvl_reg_features = mlvl_reg_features[-num_level_reg_features:]
|
91 |
+
mlvl_reg_features = [item[:, 1:] for item in mlvl_reg_features]
|
92 |
+
mlvl_reg_features = self.region_encoder(mlvl_reg_features, bboxes)
|
93 |
+
data = prepare_inputs_labels_for_multimodal(llm=self.llm, **data)
|
94 |
+
|
95 |
+
if bboxes is not None:
|
96 |
+
inputs_embeds = data['inputs_embeds']
|
97 |
+
for i, reg_feat in enumerate(mlvl_reg_features):
|
98 |
+
reg_mask = data['new_input_ids'][i] == self.bbox_token_idx
|
99 |
+
inputs_embeds[i][reg_mask] = reg_feat
|
100 |
+
data['inputs_embeds'] = inputs_embeds
|
101 |
+
|
102 |
+
if mode == 'loss':
|
103 |
+
return self.compute_loss(data, data_samples)
|
104 |
+
elif mode == 'predict':
|
105 |
+
return self.predict(data, data_samples)
|
106 |
+
elif mode == 'tensor':
|
107 |
+
return self._forward(data, data_samples)
|
108 |
+
else:
|
109 |
+
raise NotImplementedError
|
110 |
+
|
111 |
+
def compute_loss(self, data, data_samples=None):
|
112 |
+
g_pixel_values = data.pop('g_pixel_values', None)
|
113 |
+
gt_masks = data.pop('masks', None)
|
114 |
+
new_input_ids = data.pop('new_input_ids', None)
|
115 |
+
|
116 |
+
output = self.llm(output_hidden_states=True, **data)
|
117 |
+
if gt_masks is None:
|
118 |
+
return {'llm_loss': output.loss}
|
119 |
+
|
120 |
+
resize_list = [pixel.shape[-2:] for pixel in g_pixel_values]
|
121 |
+
ori_size_list = [mask.shape[-2:] for mask in gt_masks]
|
122 |
+
g_pixel_values = torch.stack([
|
123 |
+
self.grounding_encoder.preprocess(pixel) for pixel in g_pixel_values
|
124 |
+
])
|
125 |
+
image_embeddings = self.grounding_encoder.image_encoder(g_pixel_values)
|
126 |
+
|
127 |
+
seg_token_mask = new_input_ids == self.seg_token_idx
|
128 |
+
hidden_states = output.hidden_states
|
129 |
+
hidden_states = self.text_hidden_fcs(hidden_states[-1])
|
130 |
+
pred_embeddings = hidden_states[seg_token_mask]
|
131 |
+
|
132 |
+
seg_token_counts = seg_token_mask.int().sum(-1)
|
133 |
+
pred_embeddings_list = torch.split(pred_embeddings, seg_token_counts.tolist(), dim=0)
|
134 |
+
|
135 |
+
pred_masks = self._generate_and_postprocess_masks(
|
136 |
+
pred_embeddings_list, image_embeddings, resize_list, ori_size_list)
|
137 |
+
|
138 |
+
bs = len(pred_masks)
|
139 |
+
loss_mask, loss_dice = 0, 0
|
140 |
+
for i in range(bs):
|
141 |
+
pred_mask = pred_masks[i]
|
142 |
+
gt_mask = gt_masks[i]
|
143 |
+
|
144 |
+
sam_loss_mask = self.loss_mask(pred_mask, gt_mask)
|
145 |
+
sam_loss_dice = self.loss_dice(pred_mask, gt_mask)
|
146 |
+
accuracy = torch.eq((pred_mask.sigmoid() > 0.5), gt_mask).to(pred_mask).mean()
|
147 |
+
loss_mask += sam_loss_mask
|
148 |
+
loss_dice += sam_loss_dice
|
149 |
+
|
150 |
+
|
151 |
+
loss_dict = {
|
152 |
+
'loss_mask': loss_mask / bs,
|
153 |
+
'loss_dice': loss_dice / bs,
|
154 |
+
'accuracy': accuracy,
|
155 |
+
'llm_loss': output.loss,
|
156 |
+
}
|
157 |
+
return loss_dict
|
158 |
+
|
159 |
+
|
160 |
+
def _generate_and_postprocess_masks(self, pred_embeddings, image_embeddings, resize_list=None, orig_size_list=None, infer=False):
|
161 |
+
pred_masks = []
|
162 |
+
for i, pred_embedding in enumerate(pred_embeddings):
|
163 |
+
sparse_embeddings, dense_embeddings = self.grounding_encoder.prompt_encoder(
|
164 |
+
points=None, boxes=None, masks=None, text_embeds=pred_embedding.unsqueeze(1)
|
165 |
+
)
|
166 |
+
sparse_embeddings = sparse_embeddings.to(pred_embedding.dtype)
|
167 |
+
low_res_masks, _ = self.grounding_encoder.mask_decoder(
|
168 |
+
image_embeddings=image_embeddings[i].unsqueeze(0),
|
169 |
+
image_pe=self.grounding_encoder.prompt_encoder.get_dense_pe(),
|
170 |
+
sparse_prompt_embeddings=sparse_embeddings, dense_prompt_embeddings=dense_embeddings,
|
171 |
+
multimask_output=False, )
|
172 |
+
|
173 |
+
pred_mask = self.grounding_encoder.postprocess_masks(
|
174 |
+
low_res_masks, input_size=resize_list[i], original_size=orig_size_list[i], )
|
175 |
+
pred_masks.append(pred_mask[:, 0])
|
176 |
+
return pred_masks
|
177 |
+
|
178 |
+
def predict(self, data):
|
179 |
+
pass
|
180 |
+
|
181 |
+
def _forward(self, data, dta_samples=None):
|
182 |
+
outputs = self.llm(**data)
|
183 |
+
return outputs
|
projects/glamm/models/region_encoder.py
ADDED
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABCMeta, abstractmethod
|
2 |
+
from typing import List, Optional, Tuple
|
3 |
+
from torch import Tensor
|
4 |
+
|
5 |
+
import math
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
|
10 |
+
from mmcv import ops
|
11 |
+
from mmcv.cnn import ConvModule, Linear
|
12 |
+
from mmengine.model import BaseModule
|
13 |
+
|
14 |
+
class BaseRoIExtractor(BaseModule, metaclass=ABCMeta):
|
15 |
+
"""Base class for RoI extractor.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
roi_layer (:obj:`ConfigDict` or dict): Specify RoI layer type and
|
19 |
+
arguments.
|
20 |
+
out_channels (int): Output channels of RoI layers.
|
21 |
+
featmap_strides (list[int]): Strides of input feature maps.
|
22 |
+
init_cfg (:obj:`ConfigDict` or dict or list[:obj:`ConfigDict` or \
|
23 |
+
dict], optional): Initialization config dict. Defaults to None.
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self,
|
27 |
+
roi_layer,
|
28 |
+
out_channels: int,
|
29 |
+
featmap_strides: List[int],
|
30 |
+
init_cfg=None) -> None:
|
31 |
+
super().__init__(init_cfg=init_cfg)
|
32 |
+
self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
|
33 |
+
self.out_channels = out_channels
|
34 |
+
self.featmap_strides = featmap_strides
|
35 |
+
|
36 |
+
@property
|
37 |
+
def num_inputs(self) -> int:
|
38 |
+
"""int: Number of input feature maps."""
|
39 |
+
return len(self.featmap_strides)
|
40 |
+
|
41 |
+
def build_roi_layers(self, layer_cfg,
|
42 |
+
featmap_strides: List[int]) -> nn.ModuleList:
|
43 |
+
"""Build RoI operator to extract feature from each level feature map.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
layer_cfg (:obj:`ConfigDict` or dict): Dictionary to construct and
|
47 |
+
config RoI layer operation. Options are modules under
|
48 |
+
``mmcv/ops`` such as ``RoIAlign``.
|
49 |
+
featmap_strides (list[int]): The stride of input feature map w.r.t
|
50 |
+
to the original image size, which would be used to scale RoI
|
51 |
+
coordinate (original image coordinate system) to feature
|
52 |
+
coordinate system.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
:obj:`nn.ModuleList`: The RoI extractor modules for each level
|
56 |
+
feature map.
|
57 |
+
"""
|
58 |
+
|
59 |
+
cfg = layer_cfg.copy()
|
60 |
+
layer_type = cfg.pop('type')
|
61 |
+
if isinstance(layer_type, str):
|
62 |
+
assert hasattr(ops, layer_type)
|
63 |
+
layer_cls = getattr(ops, layer_type)
|
64 |
+
else:
|
65 |
+
layer_cls = layer_type
|
66 |
+
roi_layers = nn.ModuleList(
|
67 |
+
[layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
|
68 |
+
return roi_layers
|
69 |
+
|
70 |
+
def roi_rescale(self, rois: Tensor, scale_factor: float) -> Tensor:
|
71 |
+
"""Scale RoI coordinates by scale factor.
|
72 |
+
|
73 |
+
Args:
|
74 |
+
rois (Tensor): RoI (Region of Interest), shape (n, 5)
|
75 |
+
scale_factor (float): Scale factor that RoI will be multiplied by.
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
Tensor: Scaled RoI.
|
79 |
+
"""
|
80 |
+
|
81 |
+
cx = (rois[:, 1] + rois[:, 3]) * 0.5
|
82 |
+
cy = (rois[:, 2] + rois[:, 4]) * 0.5
|
83 |
+
w = rois[:, 3] - rois[:, 1]
|
84 |
+
h = rois[:, 4] - rois[:, 2]
|
85 |
+
new_w = w * scale_factor
|
86 |
+
new_h = h * scale_factor
|
87 |
+
x1 = cx - new_w * 0.5
|
88 |
+
x2 = cx + new_w * 0.5
|
89 |
+
y1 = cy - new_h * 0.5
|
90 |
+
y2 = cy + new_h * 0.5
|
91 |
+
new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
|
92 |
+
return new_rois
|
93 |
+
|
94 |
+
@abstractmethod
|
95 |
+
def forward(self,
|
96 |
+
feats: Tuple[Tensor],
|
97 |
+
rois: Tensor,
|
98 |
+
roi_scale_factor: Optional[float] = None) -> Tensor:
|
99 |
+
"""Extractor ROI feats.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
feats (Tuple[Tensor]): Multi-scale features.
|
103 |
+
rois (Tensor): RoIs with the shape (n, 5) where the first
|
104 |
+
column indicates batch id of each RoI.
|
105 |
+
roi_scale_factor (Optional[float]): RoI scale factor.
|
106 |
+
Defaults to None.
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
Tensor: RoI feature.
|
110 |
+
"""
|
111 |
+
pass
|
112 |
+
|
113 |
+
|
114 |
+
class MLVLFuseModule(nn.Module):
|
115 |
+
def __init__(self, input_dims=1024, embed_dims=1024, num_levels=3, num_fuse=4):
|
116 |
+
super(MLVLFuseModule, self).__init__()
|
117 |
+
self.embed_dims = embed_dims
|
118 |
+
self.num_levels = num_levels
|
119 |
+
self.num_fuse = num_fuse
|
120 |
+
self.input_dims = input_dims
|
121 |
+
self.shuffle_channles = embed_dims // 4
|
122 |
+
|
123 |
+
# contains the tuple of level indices that will do the interaction
|
124 |
+
self.fuse_lvl_list = []
|
125 |
+
num_levels = self.num_levels
|
126 |
+
for lvl in range(num_levels):
|
127 |
+
top_lvl = min(lvl + 1, num_levels - 1)
|
128 |
+
dow_lvl = max(lvl - 1, 0)
|
129 |
+
tar_lvl = lvl
|
130 |
+
self.fuse_lvl_list.append((tar_lvl, top_lvl, dow_lvl))
|
131 |
+
|
132 |
+
self.remain_chs = self.embed_dims - self.shuffle_channles * 2
|
133 |
+
self._init_layers()
|
134 |
+
|
135 |
+
def generate_coordinate(self, featmap_sizes, device='cuda'):
|
136 |
+
|
137 |
+
x_range = torch.linspace(-1, 1, featmap_sizes[-1], device=device)
|
138 |
+
y_range = torch.linspace(-1, 1, featmap_sizes[-2], device=device)
|
139 |
+
y, x = torch.meshgrid(y_range, x_range)
|
140 |
+
y = y.expand([featmap_sizes[0], 1, -1, -1])
|
141 |
+
x = x.expand([featmap_sizes[0], 1, -1, -1])
|
142 |
+
coord_feat = torch.cat([x, y], 1)
|
143 |
+
|
144 |
+
return coord_feat
|
145 |
+
|
146 |
+
def _init_layers(self):
|
147 |
+
self.input_conv = nn.ModuleList([nn.Conv2d(self.input_dims + 2,
|
148 |
+
self.embed_dims, 1)
|
149 |
+
for _ in range(self.num_levels)])
|
150 |
+
self.fuse_convs = nn.ModuleList()
|
151 |
+
for i in range(self.num_fuse):
|
152 |
+
self.fuse_convs.append(
|
153 |
+
ConvModule(self.embed_dims,
|
154 |
+
self.embed_dims,
|
155 |
+
3,
|
156 |
+
stride=1,
|
157 |
+
padding=3 // 2,
|
158 |
+
conv_cfg=None,
|
159 |
+
norm_cfg=dict(type='GN',
|
160 |
+
num_groups=64,
|
161 |
+
requires_grad=True)
|
162 |
+
))
|
163 |
+
|
164 |
+
def init_weights(self):
|
165 |
+
pass
|
166 |
+
|
167 |
+
def _single_shuffle(self, inputs, conv_module):
|
168 |
+
if not isinstance(conv_module, (nn.ModuleList, list)):
|
169 |
+
conv_module = [conv_module]
|
170 |
+
for single_conv_m in conv_module:
|
171 |
+
fused_inputs = []
|
172 |
+
for fuse_lvl_tuple in self.fuse_lvl_list:
|
173 |
+
tar_lvl, top_lvl, dow_lvl = fuse_lvl_tuple
|
174 |
+
tar_input = inputs[tar_lvl]
|
175 |
+
top_input = inputs[top_lvl]
|
176 |
+
down_input = inputs[dow_lvl]
|
177 |
+
remain = tar_input[:, :self.remain_chs]
|
178 |
+
from_top = top_input[:, self.remain_chs:][:, self.shuffle_channles:]
|
179 |
+
from_top = F.interpolate(from_top.to(torch.float32),
|
180 |
+
size=tar_input.shape[-2:],
|
181 |
+
mode='bilinear',
|
182 |
+
align_corners=True)
|
183 |
+
from_down = down_input[:, self.remain_chs:][:, :self.shuffle_channles]
|
184 |
+
from_down = F.interpolate(from_down.to(torch.float32),
|
185 |
+
size=tar_input.shape[-2:],
|
186 |
+
mode='bilinear',
|
187 |
+
align_corners=True)
|
188 |
+
fused_inputs.append(
|
189 |
+
torch.cat([remain, from_top.to(remain.dtype), from_down.to(remain.dtype)], dim=1))
|
190 |
+
fused_inputs = [single_conv_m(item) for item in fused_inputs]
|
191 |
+
inputs = fused_inputs
|
192 |
+
return inputs
|
193 |
+
|
194 |
+
def forward(self, inputs, ):
|
195 |
+
feat_size = [item.shape for item in inputs]
|
196 |
+
new_inputs = []
|
197 |
+
for feat, single_feat_size in zip(inputs, feat_size):
|
198 |
+
coord_feat = self.generate_coordinate(
|
199 |
+
single_feat_size, device=inputs[0].device)
|
200 |
+
# feat = torch.cat([feat, coord_feat], dim=1)
|
201 |
+
feat = torch.cat([feat, coord_feat.to(feat.dtype)], dim=1)
|
202 |
+
new_inputs.append(feat)
|
203 |
+
inputs = new_inputs
|
204 |
+
|
205 |
+
inputs = [self.input_conv[lvl](item)
|
206 |
+
for lvl, item in enumerate(inputs)]
|
207 |
+
|
208 |
+
for conv_m in self.fuse_convs:
|
209 |
+
inputs = self._single_shuffle(inputs, [conv_m])
|
210 |
+
return inputs
|
211 |
+
|
212 |
+
|
213 |
+
class MlvlRoIExtractor(BaseRoIExtractor):
|
214 |
+
def __init__(self,
|
215 |
+
roi_layer,
|
216 |
+
out_channels,
|
217 |
+
featmap_strides,
|
218 |
+
embed_dims=1024,
|
219 |
+
stride=1,
|
220 |
+
norm_init=True,
|
221 |
+
fuse_level=3,
|
222 |
+
finest_scale=56,
|
223 |
+
init_cfg=None):
|
224 |
+
super(MlvlRoIExtractor, self).__init__(roi_layer, out_channels,
|
225 |
+
featmap_strides, init_cfg)
|
226 |
+
self.embed_dims = embed_dims
|
227 |
+
self.finest_scale = finest_scale
|
228 |
+
self.fuse_level = fuse_level
|
229 |
+
self.norm_init = norm_init
|
230 |
+
|
231 |
+
self.pconvs = nn.ModuleList(
|
232 |
+
nn.Conv2d(self.embed_dims, self.embed_dims, 3, stride=1, padding=1)
|
233 |
+
for _ in range(self.fuse_level))
|
234 |
+
self.pos_embedd = nn.Sequential(
|
235 |
+
nn.Linear(4, 256),
|
236 |
+
nn.ReLU(inplace=True),
|
237 |
+
nn.LayerNorm(256),
|
238 |
+
nn.Linear(256, 1024),
|
239 |
+
nn.ReLU(inplace=True),
|
240 |
+
nn.LayerNorm(1024),
|
241 |
+
)
|
242 |
+
self.updims = nn.Linear(1024, 4096)
|
243 |
+
|
244 |
+
self.flatten_linear = nn.Linear(
|
245 |
+
self.embed_dims * self.roi_layers[0].output_size[0] ** 2, 1024)
|
246 |
+
|
247 |
+
self.norm_init_weights()
|
248 |
+
|
249 |
+
# self.dtype = torch.float32
|
250 |
+
def norm_init_weights(self):
|
251 |
+
pass
|
252 |
+
|
253 |
+
def forward(self, feats, rois, roi_scale_factor=None):
|
254 |
+
"""Forward function."""
|
255 |
+
num_imgs = len(rois)
|
256 |
+
# feats = [item for item in feats]
|
257 |
+
batch_rois = torch.cat(rois, dim=0).to(feats[0].dtype)
|
258 |
+
pos_embedd = self.pos_embedd(batch_rois)
|
259 |
+
out_size = self.roi_layers[0].output_size
|
260 |
+
num_levels = len(feats)
|
261 |
+
if feats[0].dim() == 3:
|
262 |
+
h = w = int(math.sqrt(feats[0].shape[1]))
|
263 |
+
assert h == 16
|
264 |
+
assert w == 16
|
265 |
+
b, c = feats[0].shape[0], feats[0].shape[-1]
|
266 |
+
feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2)
|
267 |
+
for item in feats]
|
268 |
+
new_rois = []
|
269 |
+
for img_id, single_img_roi in enumerate(rois):
|
270 |
+
# rescale to original img scale
|
271 |
+
single_img_roi = single_img_roi * 224
|
272 |
+
|
273 |
+
roi_img_id = single_img_roi.new_ones(len(single_img_roi)) * img_id
|
274 |
+
single_img_roi = torch.cat(
|
275 |
+
[roi_img_id[:, None], single_img_roi], dim=1)
|
276 |
+
new_rois.append(single_img_roi)
|
277 |
+
rois = torch.cat(new_rois)
|
278 |
+
|
279 |
+
roi_feats = feats[0].new_zeros(self.fuse_level,
|
280 |
+
rois.size(0), self.out_channels, *out_size)
|
281 |
+
|
282 |
+
for i in range(num_levels):
|
283 |
+
if len(rois) > 0:
|
284 |
+
rois_ = rois
|
285 |
+
ori_dtype = feats[i].dtype
|
286 |
+
roi_feats_t = self.roi_layers[i](feats[i].to(
|
287 |
+
torch.float32), rois_.to(torch.float32))
|
288 |
+
|
289 |
+
roi_feats[i] = roi_feats_t.to(ori_dtype)
|
290 |
+
|
291 |
+
else:
|
292 |
+
roi_feats += sum(
|
293 |
+
x.view(-1)[0]
|
294 |
+
for x in self.parameters()) * 0. + feats[i].sum() * 0.
|
295 |
+
|
296 |
+
fuse_roi_feats = []
|
297 |
+
for i in range(self.fuse_level):
|
298 |
+
fuse_roi_feats.append(self.pconvs[i](roi_feats[i]))
|
299 |
+
|
300 |
+
fuse_roi_feats = sum(fuse_roi_feats)
|
301 |
+
fuse_roi_feats = F.relu(fuse_roi_feats)
|
302 |
+
fuse_roi_feats = fuse_roi_feats.flatten(1, -1)
|
303 |
+
fuse_roi_feats = self.flatten_linear(fuse_roi_feats)
|
304 |
+
fuse_roi_feats = fuse_roi_feats + pos_embedd
|
305 |
+
fuse_roi_feats = self.updims(fuse_roi_feats)
|
306 |
+
query_feats = []
|
307 |
+
for i in range(num_imgs):
|
308 |
+
mask = rois[:, 0] == i
|
309 |
+
query_feats.append(fuse_roi_feats[mask])
|
310 |
+
|
311 |
+
return query_feats
|
312 |
+
|
313 |
+
|
314 |
+
class MLVLROIQueryModule(nn.Module):
|
315 |
+
def __init__(self, embed_dims=1024, out_dims=4096,
|
316 |
+
num_levels=3):
|
317 |
+
super(MLVLROIQueryModule, self).__init__()
|
318 |
+
self.mlvl_fuse = MLVLFuseModule(input_dims=embed_dims,
|
319 |
+
embed_dims=embed_dims,
|
320 |
+
num_levels=num_levels,
|
321 |
+
num_fuse=5)
|
322 |
+
strids = [14 / 8, 14 / 4, 14 / 2, 14]
|
323 |
+
assert len(strids) == num_levels
|
324 |
+
bbox_roi_extractor = dict(roi_layer=dict(type='RoIAlign',
|
325 |
+
output_size=14,
|
326 |
+
sampling_ratio=2),
|
327 |
+
out_channels=embed_dims,
|
328 |
+
embed_dims=embed_dims,
|
329 |
+
fuse_level=num_levels,
|
330 |
+
featmap_strides=strids)
|
331 |
+
|
332 |
+
self.roi_align = MlvlRoIExtractor(**bbox_roi_extractor)
|
333 |
+
|
334 |
+
def forward(self, mlvl_feats, bboxes):
|
335 |
+
if mlvl_feats[0].dim() == 3:
|
336 |
+
h = w = int(math.sqrt(mlvl_feats[0].shape[1]))
|
337 |
+
assert h == 24
|
338 |
+
assert w == 24
|
339 |
+
b, c = mlvl_feats[0].shape[0], mlvl_feats[0].shape[-1]
|
340 |
+
mlvl_feats = [item.reshape(b, h, w, c).permute(0, 3, 1, 2) for item in mlvl_feats]
|
341 |
+
base_shape = mlvl_feats[0].shape[-2:]
|
342 |
+
num_level = len(mlvl_feats)
|
343 |
+
to_shape = [(base_shape[0] * 2 ** level, base_shape[1] * 2 ** level)
|
344 |
+
for level in range(num_level)]
|
345 |
+
to_shape = to_shape[::-1]
|
346 |
+
for level in range(num_level):
|
347 |
+
feat = mlvl_feats[level]
|
348 |
+
shape = to_shape[level]
|
349 |
+
# feat = feat
|
350 |
+
# mlvl_feats[level] = F.interpolate(feat, size=shape, mode='bilinear', align_corners=True)
|
351 |
+
# todo: temporary fix for "upsample_bilinear2d_out_frame" not implemented for 'BFloat16'
|
352 |
+
feat = feat.to(torch.float32)
|
353 |
+
mlvl_feats[level] = F.interpolate(
|
354 |
+
feat, size=shape, mode='bilinear', align_corners=True)
|
355 |
+
mlvl_feats[level] = mlvl_feats[level].to(torch.bfloat16)
|
356 |
+
|
357 |
+
mlvl_feats = self.mlvl_fuse(mlvl_feats)
|
358 |
+
|
359 |
+
return self.roi_align(mlvl_feats, bboxes)
|
projects/glamm/utils.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.distributed as dist
|
6 |
+
|
7 |
+
from transformers import PreTrainedModel
|
8 |
+
from typing import List, Optional
|
9 |
+
|
10 |
+
|
11 |
+
IGNORE_INDEX = -100
|
12 |
+
IMAGE_TOKEN_INDEX = -200
|
13 |
+
|
14 |
+
DEFAULT_EOS_TOKEN = '</s>'
|
15 |
+
DEFAULT_BOS_TOKEN = '<s>'
|
16 |
+
DEFAULT_UNK_TOKEN = '<unk>'
|
17 |
+
|
18 |
+
DEFAULT_IMAGE_TOKEN = "<image>"
|
19 |
+
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
|
20 |
+
DEFAULT_IM_START_TOKEN = "<im_start>"
|
21 |
+
DEFAULT_IM_END_TOKEN = "<im_end>"
|
22 |
+
DEFAULT_BBOX_TOKEN = "<bbox>"
|
23 |
+
|
24 |
+
|
25 |
+
|
26 |
+
# Modified from https://github.com/haotian-liu/LLaVA/blob/82fc5e0e5f4393a4c26851fa32c69ab37ea3b146/llava/model/llava_arch.py#L99 # noqa: E501
|
27 |
+
def prepare_inputs_labels_for_multimodal(
|
28 |
+
llm: PreTrainedModel,
|
29 |
+
input_ids: torch.LongTensor = None,
|
30 |
+
position_ids: Optional[torch.LongTensor] = None,
|
31 |
+
attention_mask: Optional[torch.Tensor] = None,
|
32 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
33 |
+
labels: Optional[torch.LongTensor] = None,
|
34 |
+
pixel_values: Optional[torch.FloatTensor] = None,
|
35 |
+
**kwargs):
|
36 |
+
if pixel_values is None:
|
37 |
+
kwargs.update({
|
38 |
+
'input_ids': input_ids,
|
39 |
+
'position_ids': position_ids,
|
40 |
+
'attention_mask': attention_mask,
|
41 |
+
'past_key_values': past_key_values,
|
42 |
+
'inputs_embeds': None,
|
43 |
+
'labels': labels
|
44 |
+
})
|
45 |
+
return kwargs
|
46 |
+
|
47 |
+
_labels = labels
|
48 |
+
_position_ids = position_ids
|
49 |
+
_attention_mask = attention_mask
|
50 |
+
if attention_mask is None:
|
51 |
+
attention_mask = torch.ones_like(input_ids, dtype=torch.bool)
|
52 |
+
else:
|
53 |
+
attention_mask = attention_mask.bool()
|
54 |
+
if position_ids is None:
|
55 |
+
position_ids = torch.arange(0, input_ids.shape[1], dtype=torch.long, device=input_ids.device)
|
56 |
+
if labels is None:
|
57 |
+
labels = torch.full_like(input_ids, IGNORE_INDEX)
|
58 |
+
|
59 |
+
# remove the padding using attention_mask -- TODO: double check
|
60 |
+
input_ids = [
|
61 |
+
cur_input_ids[cur_attention_mask]
|
62 |
+
for cur_input_ids, cur_attention_mask in zip(input_ids, attention_mask)
|
63 |
+
]
|
64 |
+
labels = [
|
65 |
+
cur_labels[cur_attention_mask]
|
66 |
+
for cur_labels, cur_attention_mask in zip(labels, attention_mask)
|
67 |
+
]
|
68 |
+
|
69 |
+
new_inputs_embeds = []
|
70 |
+
new_labels = []
|
71 |
+
new_input_ids = []
|
72 |
+
cur_image_idx = 0
|
73 |
+
for batch_idx, cur_input_ids in enumerate(input_ids):
|
74 |
+
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
|
75 |
+
if num_images == 0:
|
76 |
+
cur_pixel_values = pixel_values[cur_image_idx]
|
77 |
+
cur_inputs_embeds_1 = llm.get_input_embeddings()(cur_input_ids)
|
78 |
+
cur_inputs_embeds = torch.cat([cur_inputs_embeds_1, cur_pixel_values[0:0]], dim=0)
|
79 |
+
new_inputs_embeds.append(cur_inputs_embeds)
|
80 |
+
new_labels.append(labels[batch_idx])
|
81 |
+
new_input_ids.append(cur_input_ids)
|
82 |
+
cur_image_idx += 1
|
83 |
+
continue
|
84 |
+
|
85 |
+
image_token_indices = [-1] + torch.where(cur_input_ids == IMAGE_TOKEN_INDEX)[0].tolist() + [cur_input_ids.shape[0]]
|
86 |
+
cur_input_ids_noim = []
|
87 |
+
cur_labels = labels[batch_idx]
|
88 |
+
cur_labels_noim = []
|
89 |
+
for i in range(len(image_token_indices) - 1):
|
90 |
+
cur_input_ids_noim.append(cur_input_ids[image_token_indices[i] + 1:image_token_indices[i + 1]])
|
91 |
+
cur_labels_noim.append(cur_labels[image_token_indices[i] + 1:image_token_indices[i + 1]])
|
92 |
+
|
93 |
+
split_sizes = [x.shape[0] for x in cur_labels_noim]
|
94 |
+
cur_inputs_embeds = llm.get_input_embeddings()(torch.cat(cur_input_ids_noim))
|
95 |
+
cur_inputs_embeds_no_im = torch.split(cur_inputs_embeds, split_sizes, dim=0)
|
96 |
+
cur_new_inputs_embeds = []
|
97 |
+
cur_new_labels = []
|
98 |
+
cur_new_input_ids = []
|
99 |
+
|
100 |
+
for i in range(num_images + 1):
|
101 |
+
cur_new_inputs_embeds.append(cur_inputs_embeds_no_im[i])
|
102 |
+
cur_new_labels.append(cur_labels_noim[i])
|
103 |
+
cur_new_input_ids.append(cur_input_ids_noim[i])
|
104 |
+
if i < num_images:
|
105 |
+
cur_pixel_values = pixel_values[cur_image_idx]
|
106 |
+
cur_image_idx += 1
|
107 |
+
cur_new_inputs_embeds.append(cur_pixel_values)
|
108 |
+
cur_new_labels.append(torch.full((cur_pixel_values.shape[0], ), IGNORE_INDEX, device=cur_labels.device, dtype=cur_labels.dtype))
|
109 |
+
cur_new_input_ids.append(torch.full((cur_pixel_values.shape[0], ), IMAGE_TOKEN_INDEX, device=cur_input_ids.device, dtype=cur_input_ids.dtype))
|
110 |
+
|
111 |
+
cur_new_inputs_embeds = torch.cat(cur_new_inputs_embeds)
|
112 |
+
cur_new_labels = torch.cat(cur_new_labels)
|
113 |
+
cur_new_input_ids = torch.cat(cur_new_input_ids)
|
114 |
+
|
115 |
+
new_inputs_embeds.append(cur_new_inputs_embeds)
|
116 |
+
new_labels.append(cur_new_labels)
|
117 |
+
new_input_ids.append(cur_new_input_ids)
|
118 |
+
|
119 |
+
# Combine them
|
120 |
+
max_len = max(x.shape[0] for x in new_inputs_embeds)
|
121 |
+
batch_size = len(new_inputs_embeds)
|
122 |
+
|
123 |
+
new_inputs_embeds_padded = []
|
124 |
+
new_labels_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_labels[0].dtype, device=new_labels[0].device)
|
125 |
+
new_input_ids_padded = torch.full((batch_size, max_len), IGNORE_INDEX, dtype=new_input_ids[0].dtype, device=new_input_ids[0].device)
|
126 |
+
attention_mask = torch.zeros((batch_size, max_len), dtype=attention_mask.dtype, device=attention_mask.device)
|
127 |
+
position_ids = torch.zeros((batch_size, max_len), dtype=position_ids.dtype, device=position_ids.device)
|
128 |
+
|
129 |
+
for i, (cur_new_embed, cur_new_labels, cur_new_input_ids) in enumerate(zip(new_inputs_embeds, new_labels, new_input_ids)):
|
130 |
+
cur_len = cur_new_embed.shape[0]
|
131 |
+
new_inputs_embeds_padded.append(torch.cat((cur_new_embed, torch.zeros((max_len - cur_len, cur_new_embed.shape[1]), dtype=cur_new_embed.dtype, device=cur_new_embed.device)), dim=0))
|
132 |
+
if cur_len > 0:
|
133 |
+
new_labels_padded[i, :cur_len] = cur_new_labels
|
134 |
+
new_input_ids_padded[i, :cur_len] = cur_new_input_ids
|
135 |
+
attention_mask[i, :cur_len] = True
|
136 |
+
position_ids[i, :cur_len] = torch.arange(0, cur_len, dtype=position_ids.dtype, device=position_ids.device)
|
137 |
+
|
138 |
+
new_inputs_embeds = torch.stack(new_inputs_embeds_padded, dim=0)
|
139 |
+
|
140 |
+
if _labels is None:
|
141 |
+
new_labels = None
|
142 |
+
else:
|
143 |
+
new_labels = new_labels_padded
|
144 |
+
|
145 |
+
new_input_ids = new_input_ids_padded
|
146 |
+
|
147 |
+
if _attention_mask is None:
|
148 |
+
attention_mask = None
|
149 |
+
else:
|
150 |
+
attention_mask = attention_mask.to(dtype=_attention_mask.dtype)
|
151 |
+
|
152 |
+
if _position_ids is None:
|
153 |
+
position_ids = None
|
154 |
+
|
155 |
+
kwargs.update({
|
156 |
+
'input_ids': None,
|
157 |
+
'position_ids': position_ids,
|
158 |
+
'attention_mask': attention_mask,
|
159 |
+
'past_key_values': past_key_values,
|
160 |
+
'inputs_embeds': new_inputs_embeds,
|
161 |
+
'labels': new_labels,
|
162 |
+
'new_input_ids': new_input_ids
|
163 |
+
})
|
164 |
+
return kwargs
|
165 |
+
|
166 |
+
class Summary(Enum):
|
167 |
+
NONE = 0
|
168 |
+
AVERAGE = 1
|
169 |
+
SUM = 2
|
170 |
+
COUNT = 3
|
171 |
+
|
172 |
+
|
173 |
+
class AverageMeter(object):
|
174 |
+
"""Computes and stores the average and current value"""
|
175 |
+
|
176 |
+
def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
|
177 |
+
self.name = name
|
178 |
+
self.fmt = fmt
|
179 |
+
self.summary_type = summary_type
|
180 |
+
self.reset()
|
181 |
+
|
182 |
+
def reset(self):
|
183 |
+
self.val = 0
|
184 |
+
self.avg = 0
|
185 |
+
self.sum = 0
|
186 |
+
self.count = 0
|
187 |
+
|
188 |
+
def update(self, val, n=1):
|
189 |
+
self.val = val
|
190 |
+
self.sum += val * n
|
191 |
+
self.count += n
|
192 |
+
self.avg = self.sum / self.count
|
193 |
+
|
194 |
+
def all_reduce(self):
|
195 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
196 |
+
if isinstance(self.sum, np.ndarray):
|
197 |
+
total = torch.tensor(
|
198 |
+
self.sum.tolist()
|
199 |
+
+ [
|
200 |
+
self.count,
|
201 |
+
],
|
202 |
+
dtype=torch.float32,
|
203 |
+
device=device,
|
204 |
+
)
|
205 |
+
else:
|
206 |
+
total = torch.tensor(
|
207 |
+
[self.sum, self.count], dtype=torch.float32, device=device
|
208 |
+
)
|
209 |
+
|
210 |
+
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
|
211 |
+
if total.shape[0] > 2:
|
212 |
+
self.sum, self.count = total[:-1].cpu().numpy(), total[-1].cpu().item()
|
213 |
+
else:
|
214 |
+
self.sum, self.count = total.tolist()
|
215 |
+
self.avg = self.sum / (self.count + 1e-5)
|
216 |
+
|
217 |
+
def __str__(self):
|
218 |
+
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
|
219 |
+
return fmtstr.format(**self.__dict__)
|
220 |
+
|
221 |
+
def summary(self):
|
222 |
+
fmtstr = ""
|
223 |
+
if self.summary_type is Summary.NONE:
|
224 |
+
fmtstr = ""
|
225 |
+
elif self.summary_type is Summary.AVERAGE:
|
226 |
+
fmtstr = "{name} {avg:.3f}"
|
227 |
+
elif self.summary_type is Summary.SUM:
|
228 |
+
fmtstr = "{name} {sum:.3f}"
|
229 |
+
elif self.summary_type is Summary.COUNT:
|
230 |
+
fmtstr = "{name} {count:.3f}"
|
231 |
+
else:
|
232 |
+
raise ValueError("invalid summary type %r" % self.summary_type)
|
233 |
+
|
234 |
+
return fmtstr.format(**self.__dict__)
|
235 |
+
|
236 |
+
|
237 |
+
def intersectionAndUnionGPU(output, target, K, ignore_index=255):
|
238 |
+
# 'K' classes, output and target sizes are N or N * L or N * H * W, each value in range 0 to K - 1.
|
239 |
+
assert output.dim() in [1, 2, 3]
|
240 |
+
assert output.shape == target.shape
|
241 |
+
output = output.view(-1)
|
242 |
+
target = target.view(-1)
|
243 |
+
output[target == ignore_index] = ignore_index
|
244 |
+
intersection = output[output == target]
|
245 |
+
area_intersection = torch.histc(intersection, bins=K, min=0, max=K - 1)
|
246 |
+
area_output = torch.histc(output, bins=K, min=0, max=K - 1)
|
247 |
+
area_target = torch.histc(target, bins=K, min=0, max=K - 1)
|
248 |
+
area_union = area_output + area_target - area_intersection
|
249 |
+
return area_intersection, area_union, area_target
|
250 |
+
|
251 |
+
|
252 |
+
class ProgressMeter(object):
|
253 |
+
def __init__(self, num_batches, meters, prefix=""):
|
254 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
255 |
+
self.meters = meters
|
256 |
+
self.prefix = prefix
|
257 |
+
|
258 |
+
def display(self, batch):
|
259 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
260 |
+
entries += [str(meter) for meter in self.meters]
|
261 |
+
print("\t".join(entries))
|
262 |
+
|
263 |
+
def display_summary(self):
|
264 |
+
entries = [" *"]
|
265 |
+
entries += [meter.summary() for meter in self.meters]
|
266 |
+
print(" ".join(entries))
|
267 |
+
|
268 |
+
def _get_batch_fmtstr(self, num_batches):
|
269 |
+
num_digits = len(str(num_batches // 1))
|
270 |
+
fmt = "{:" + str(num_digits) + "d}"
|
271 |
+
return "[" + fmt + "/" + fmt.format(num_batches) + "]"
|
272 |
+
|
273 |
+
|
274 |
+
def dict_to_cuda(input_dict):
|
275 |
+
for k, v in input_dict.items():
|
276 |
+
if isinstance(input_dict[k], torch.Tensor):
|
277 |
+
input_dict[k] = v.cuda(non_blocking=True)
|
278 |
+
elif isinstance(v, list) and len(v) > 0:
|
279 |
+
input_dict[k] = [ele.cuda(non_blocking=True) if isinstance(ele, torch.Tensor) else ele for ele in v]
|
280 |
+
return input_dict
|
projects/llava_sam2/configs/sa2va_4b.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, IterTimerHook,
|
2 |
+
LoggerHook, ParamSchedulerHook)
|
3 |
+
from mmengine.optim import AmpOptimWrapper, CosineAnnealingLR, LinearLR
|
4 |
+
from torch.optim import AdamW
|
5 |
+
from transformers import AutoTokenizer
|
6 |
+
|
7 |
+
from xtuner.dataset import ConcatDataset
|
8 |
+
from xtuner.dataset.samplers import LengthGroupedSampler
|
9 |
+
from xtuner.engine.hooks import DatasetInfoHook
|
10 |
+
from xtuner.engine.runner import TrainLoop
|
11 |
+
from xtuner.utils import PROMPT_TEMPLATE
|
12 |
+
from xtuner.dataset.map_fns import template_map_fn_factory
|
13 |
+
|
14 |
+
from third_parts.mmdet.models.losses import DiceLoss, CrossEntropyLoss
|
15 |
+
from peft import LoraConfig
|
16 |
+
|
17 |
+
from projects.llava_sam2.models.internvl import InternVL_Slowfast
|
18 |
+
|
19 |
+
from projects.llava_sam2.models import VideoLLaVASAMModel, SAM2TrainRunner, VideoLLaVASAMModel_zero3
|
20 |
+
from projects.llava_sam2.datasets import VideoReVOSDataset, VideoMeVISDataset, VideoRefYoutubeVOSDataset, video_lisa_collate_fn, VideoSAM2Dataset
|
21 |
+
from projects.llava_sam2.datasets import VideoChatUniViDataset
|
22 |
+
from projects.llava_sam2.datasets import RefCOCOgGCGDataset, OpenPsgGCGDataset, FlickrGCGDataset, GranDfGCGDataset, OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
|
23 |
+
from projects.llava_sam2.datasets import LLaVADataset
|
24 |
+
from projects.llava_sam2.datasets import ReferSegmDataset
|
25 |
+
from projects.llava_sam2.models.preprocess.image_resize import DirectResize
|
26 |
+
|
27 |
+
#######################################################################
|
28 |
+
# PART 1 Settings #
|
29 |
+
#######################################################################
|
30 |
+
# Model
|
31 |
+
path = './pretrained/InternVL2_5-4B'
|
32 |
+
pretrained_pth = None
|
33 |
+
|
34 |
+
# Data
|
35 |
+
prompt_template = PROMPT_TEMPLATE.phi3_chat
|
36 |
+
max_length = 8192
|
37 |
+
|
38 |
+
# Scheduler & Optimizer
|
39 |
+
batch_size = 2 # per_device
|
40 |
+
accumulative_counts = 4
|
41 |
+
dataloader_num_workers = 4
|
42 |
+
max_epochs = 1
|
43 |
+
optim_type = AdamW
|
44 |
+
# official 1024 -> 4e-5
|
45 |
+
# lr = 1e-6
|
46 |
+
lr = 4e-5
|
47 |
+
betas = (0.9, 0.999)
|
48 |
+
weight_decay = 0.05
|
49 |
+
max_norm = 1 # grad clip
|
50 |
+
warmup_ratio = 0.05
|
51 |
+
|
52 |
+
# Save
|
53 |
+
save_steps = 1000
|
54 |
+
save_total_limit = 2 # Maximum checkpoints to keep (-1 means unlimited)
|
55 |
+
|
56 |
+
special_tokens = ['[SEG]', '<p>', '</p>', '<vp>', '</vp>']
|
57 |
+
|
58 |
+
tokenizer = dict(
|
59 |
+
type=AutoTokenizer.from_pretrained,
|
60 |
+
pretrained_model_name_or_path=path,
|
61 |
+
trust_remote_code=True,
|
62 |
+
padding_side='right')
|
63 |
+
|
64 |
+
extra_image_processor = dict(
|
65 |
+
type=DirectResize,
|
66 |
+
target_length=1024,
|
67 |
+
)
|
68 |
+
#######################################################################
|
69 |
+
# PART 2 Model & Tokenizer & Image Processor #
|
70 |
+
#######################################################################
|
71 |
+
model = dict(
|
72 |
+
type=VideoLLaVASAMModel_zero3,
|
73 |
+
special_tokens=special_tokens,
|
74 |
+
frozen_sam2_decoder=False,
|
75 |
+
mllm=dict(
|
76 |
+
type=InternVL_Slowfast,
|
77 |
+
model_path=path,
|
78 |
+
freeze_llm=True,
|
79 |
+
freeze_visual_encoder=True,
|
80 |
+
llm_lora=dict(
|
81 |
+
type=LoraConfig,
|
82 |
+
r=128,
|
83 |
+
lora_alpha=256,
|
84 |
+
lora_dropout=0.05,
|
85 |
+
bias='none',
|
86 |
+
task_type='CAUSAL_LM'),
|
87 |
+
special_tokens=special_tokens,
|
88 |
+
),
|
89 |
+
tokenizer=tokenizer,
|
90 |
+
grounding_encoder=dict(
|
91 |
+
type=SAM2TrainRunner,
|
92 |
+
),
|
93 |
+
loss_mask=dict(
|
94 |
+
type=CrossEntropyLoss,
|
95 |
+
use_sigmoid=True,
|
96 |
+
reduction='mean',
|
97 |
+
loss_weight=2.0),
|
98 |
+
loss_dice=dict(
|
99 |
+
type=DiceLoss,
|
100 |
+
use_sigmoid=True,
|
101 |
+
activate=True,
|
102 |
+
reduction='mean',
|
103 |
+
naive_dice=True,
|
104 |
+
eps=1.0,
|
105 |
+
loss_weight=0.5),
|
106 |
+
pretrained_pth=pretrained_pth,
|
107 |
+
loss_sample_points=True,
|
108 |
+
# loss_sample_points=False,
|
109 |
+
bs=batch_size,
|
110 |
+
)
|
111 |
+
|
112 |
+
#######################################################################
|
113 |
+
# PART 3 Dataset & Dataloader #
|
114 |
+
#######################################################################
|
115 |
+
|
116 |
+
|
117 |
+
VIDEO_DATAS = './data/video_datas/'
|
118 |
+
IMG_DATAS = './data/image_datas/'
|
119 |
+
|
120 |
+
############### video res
|
121 |
+
data_root_revos = './data/video_datas/revos/'
|
122 |
+
video_revos_image_folder = data_root_revos
|
123 |
+
video_revos_expression_file = data_root_revos + 'meta_expressions_train_.json'
|
124 |
+
video_revos_mask_file = data_root_revos + 'mask_dict.json'
|
125 |
+
|
126 |
+
data_root_mevis = './data/video_datas/mevis/train/'
|
127 |
+
video_mevis_image_folder = data_root_mevis + 'JPEGImages'
|
128 |
+
video_mevis_expression_file = data_root_mevis + 'meta_expressions.json'
|
129 |
+
video_mevis_mask_file = data_root_mevis + 'mask_dict.json'
|
130 |
+
|
131 |
+
data_root_refytvos = './data/video_datas/rvos/'
|
132 |
+
video_refytvos_image_folder = data_root_refytvos + 'train/JPEGImages/'
|
133 |
+
video_refytvos_expression_file = data_root_refytvos + 'meta_expressions/train/meta_expressions.json'
|
134 |
+
video_refytvos_mask_file = data_root_refytvos + 'mask_dict.pkl'
|
135 |
+
|
136 |
+
video_revos_dataset = dict(
|
137 |
+
type=VideoReVOSDataset,
|
138 |
+
image_folder=video_revos_image_folder,
|
139 |
+
expression_file=video_revos_expression_file,
|
140 |
+
mask_file=video_revos_mask_file,
|
141 |
+
tokenizer=tokenizer,
|
142 |
+
template_map_fn=dict(
|
143 |
+
type=template_map_fn_factory, template=prompt_template),
|
144 |
+
max_length=max_length,
|
145 |
+
lazy=True,
|
146 |
+
repeats=10,
|
147 |
+
special_tokens=special_tokens,
|
148 |
+
extra_image_processor=extra_image_processor,
|
149 |
+
sampled_frames=5,
|
150 |
+
)
|
151 |
+
|
152 |
+
video_mevis_dataset = dict(
|
153 |
+
type=VideoMeVISDataset,
|
154 |
+
image_folder=video_mevis_image_folder,
|
155 |
+
expression_file=video_mevis_expression_file,
|
156 |
+
mask_file=video_mevis_mask_file,
|
157 |
+
tokenizer=tokenizer,
|
158 |
+
template_map_fn=dict(
|
159 |
+
type=template_map_fn_factory, template=prompt_template),
|
160 |
+
max_length=max_length,
|
161 |
+
lazy=True,
|
162 |
+
repeats=4,
|
163 |
+
special_tokens=special_tokens,
|
164 |
+
extra_image_processor=extra_image_processor,
|
165 |
+
sampled_frames=5,
|
166 |
+
)
|
167 |
+
|
168 |
+
video_refytvos_dataset = dict(
|
169 |
+
type=VideoRefYoutubeVOSDataset,
|
170 |
+
image_folder=video_refytvos_image_folder,
|
171 |
+
expression_file=video_refytvos_expression_file,
|
172 |
+
mask_file=video_refytvos_mask_file,
|
173 |
+
tokenizer=tokenizer,
|
174 |
+
template_map_fn=dict(
|
175 |
+
type=template_map_fn_factory, template=prompt_template),
|
176 |
+
max_length=max_length,
|
177 |
+
lazy=True,
|
178 |
+
repeats=4,
|
179 |
+
special_tokens=special_tokens,
|
180 |
+
extra_image_processor=extra_image_processor,
|
181 |
+
sampled_frames=5,
|
182 |
+
)
|
183 |
+
|
184 |
+
################### Video chat
|
185 |
+
data_root_video_chatunivi = VIDEO_DATAS + 'video_vlm/video_chat/'
|
186 |
+
video_chatunivi_image_folder = data_root_video_chatunivi + 'Activity_Videos/'
|
187 |
+
video_chatunivi_json_file = data_root_video_chatunivi+ 'video_chat.json'
|
188 |
+
|
189 |
+
video_qa_dataset = dict(
|
190 |
+
type=VideoChatUniViDataset,
|
191 |
+
image_folder=video_chatunivi_image_folder,
|
192 |
+
json_file=video_chatunivi_json_file,
|
193 |
+
tokenizer=tokenizer,
|
194 |
+
template_map_fn=dict(
|
195 |
+
type=template_map_fn_factory, template=prompt_template),
|
196 |
+
max_length=max_length,
|
197 |
+
lazy=True,
|
198 |
+
repeats=1,
|
199 |
+
special_tokens=special_tokens,
|
200 |
+
extra_image_processor=extra_image_processor,
|
201 |
+
sampled_frames=5,
|
202 |
+
)
|
203 |
+
|
204 |
+
################## image chat
|
205 |
+
llava_vqa_dataset = dict(
|
206 |
+
type=LLaVADataset,
|
207 |
+
tokenizer=tokenizer,
|
208 |
+
data_path='data/llava_data/LLaVA-Instruct-150K/llava_v1_5_mix665k.json',
|
209 |
+
prompt_template=prompt_template,
|
210 |
+
special_tokens=special_tokens,
|
211 |
+
image_folder='data/llava_data/llava_images/',
|
212 |
+
)
|
213 |
+
|
214 |
+
################## image res
|
215 |
+
refcoco_segm_dataset=dict(
|
216 |
+
type=ReferSegmDataset,
|
217 |
+
tokenizer=tokenizer,
|
218 |
+
special_tokens=special_tokens,
|
219 |
+
extra_image_processor=extra_image_processor,
|
220 |
+
data_root='data/ref_seg/refcoco',
|
221 |
+
data_prefix=dict(img_path='coco2014/train2014/'),
|
222 |
+
ann_file='instances.json',
|
223 |
+
split_file='refs(unc).p',
|
224 |
+
prompt_template=prompt_template,
|
225 |
+
num_classes_per_sample=5,
|
226 |
+
max_length=max_length,
|
227 |
+
)
|
228 |
+
refcoco_plus_segm_dataset=dict(
|
229 |
+
type=ReferSegmDataset,
|
230 |
+
tokenizer=tokenizer,
|
231 |
+
special_tokens=special_tokens,
|
232 |
+
extra_image_processor=extra_image_processor,
|
233 |
+
data_root='data/ref_seg/refcoco+',
|
234 |
+
data_prefix=dict(img_path='coco2014/train2014/'),
|
235 |
+
ann_file='instances.json',
|
236 |
+
split_file='refs(unc).p',
|
237 |
+
prompt_template=prompt_template,
|
238 |
+
num_classes_per_sample=5,
|
239 |
+
max_length=max_length,
|
240 |
+
)
|
241 |
+
refcocog_segm_dataset=dict(
|
242 |
+
type=ReferSegmDataset,
|
243 |
+
tokenizer=tokenizer,
|
244 |
+
special_tokens=special_tokens,
|
245 |
+
extra_image_processor=extra_image_processor,
|
246 |
+
data_root='data/ref_seg/refcocog',
|
247 |
+
data_prefix=dict(img_path='coco2014/train2014/'),
|
248 |
+
ann_file='instances.json',
|
249 |
+
split_file='refs(umd).p',
|
250 |
+
prompt_template=prompt_template,
|
251 |
+
num_classes_per_sample=5,
|
252 |
+
max_length=max_length,
|
253 |
+
)
|
254 |
+
|
255 |
+
# image gcg datas
|
256 |
+
glamm_data_root = './data/glamm_data/'
|
257 |
+
|
258 |
+
refcocog_image_path = glamm_data_root + 'images/coco2014/train2014/'
|
259 |
+
refcocog_ann_file = glamm_data_root + 'annotations/RefCOCOg_GCG_train.json'
|
260 |
+
|
261 |
+
grandf_image_path = glamm_data_root + 'images/grandf/train/'
|
262 |
+
grandf_ann_file = glamm_data_root + 'annotations/GranDf_HA_GCG_train.json'
|
263 |
+
|
264 |
+
flickr_image_path = glamm_data_root + 'images/flickr30k/Flickr30K/'
|
265 |
+
flickr_ann_file = glamm_data_root + 'annotations/flickr_mergedGT_GCG_train.json'
|
266 |
+
|
267 |
+
psg_image_path = glamm_data_root + 'images/coco2017/'
|
268 |
+
psg_ann_file = glamm_data_root + 'annotations/OpenPsgGCG_train.json'
|
269 |
+
|
270 |
+
glamm_refcocog_dataset = dict(
|
271 |
+
type=RefCOCOgGCGDataset,
|
272 |
+
image_folder=refcocog_image_path,
|
273 |
+
data_path=refcocog_ann_file,
|
274 |
+
tokenizer=tokenizer,
|
275 |
+
max_length=max_length,
|
276 |
+
special_tokens=special_tokens,
|
277 |
+
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
|
278 |
+
extra_image_processor=extra_image_processor,
|
279 |
+
lazy=True,
|
280 |
+
repeats=1,
|
281 |
+
)
|
282 |
+
|
283 |
+
glamm_grandf_dataset = dict(
|
284 |
+
type=GranDfGCGDataset,
|
285 |
+
data_path=grandf_ann_file,
|
286 |
+
image_folder=grandf_image_path,
|
287 |
+
tokenizer=tokenizer,
|
288 |
+
max_length=max_length,
|
289 |
+
special_tokens=special_tokens,
|
290 |
+
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
|
291 |
+
extra_image_processor=extra_image_processor,
|
292 |
+
lazy=True,
|
293 |
+
repeats=10,
|
294 |
+
)
|
295 |
+
|
296 |
+
glamm_psg_dataset = dict(
|
297 |
+
type=OpenPsgGCGDataset,
|
298 |
+
data_path=psg_ann_file,
|
299 |
+
image_folder=psg_image_path,
|
300 |
+
tokenizer=tokenizer,
|
301 |
+
max_length=max_length,
|
302 |
+
special_tokens=special_tokens,
|
303 |
+
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
|
304 |
+
extra_image_processor=extra_image_processor,
|
305 |
+
lazy=True,
|
306 |
+
repeats=1,
|
307 |
+
)
|
308 |
+
|
309 |
+
glamm_flickr_dataset = dict(
|
310 |
+
type=FlickrGCGDataset,
|
311 |
+
data_path=flickr_ann_file,
|
312 |
+
image_folder=flickr_image_path,
|
313 |
+
tokenizer=tokenizer,
|
314 |
+
max_length=max_length,
|
315 |
+
special_tokens=special_tokens,
|
316 |
+
template_map_fn=dict(type=template_map_fn_factory, template=prompt_template),
|
317 |
+
extra_image_processor=extra_image_processor,
|
318 |
+
lazy=True,
|
319 |
+
repeats=1,
|
320 |
+
)
|
321 |
+
|
322 |
+
# sam2 data
|
323 |
+
data_sam2_folder = VIDEO_DATAS + 'segmentation_datasets/sam_v_full/'
|
324 |
+
data_sam2_expression_file = './whole_pesudo_cap_v3/sam_v_final_v3.json'
|
325 |
+
|
326 |
+
video_sam2_dataset = dict(
|
327 |
+
type=VideoSAM2Dataset,
|
328 |
+
sam2_folder=data_sam2_folder,
|
329 |
+
expression_file=data_sam2_expression_file,
|
330 |
+
tokenizer=tokenizer,
|
331 |
+
template_map_fn=dict(
|
332 |
+
type=template_map_fn_factory, template=prompt_template),
|
333 |
+
max_length=max_length,
|
334 |
+
lazy=True,
|
335 |
+
repeats=4,
|
336 |
+
special_tokens=special_tokens,
|
337 |
+
extra_image_processor=extra_image_processor,
|
338 |
+
sampled_frames=5,
|
339 |
+
select_number=5,
|
340 |
+
)
|
341 |
+
|
342 |
+
# osprey
|
343 |
+
data_osprey_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_conversation.json'
|
344 |
+
data_osprey_image_folders = [
|
345 |
+
IMG_DATAS+ 'coco/train2014/',
|
346 |
+
IMG_DATAS + 'coco/val2014/',
|
347 |
+
IMG_DATAS + 'coco/train2017/',
|
348 |
+
IMG_DATAS + 'coco/val2017/',
|
349 |
+
]
|
350 |
+
|
351 |
+
image_osprey_dataset = dict(
|
352 |
+
type=OspreyDataset,
|
353 |
+
image_folder=data_osprey_image_folders,
|
354 |
+
data_path=data_osprey_file,
|
355 |
+
tokenizer=tokenizer,
|
356 |
+
template_map_fn=dict(
|
357 |
+
type=template_map_fn_factory, template=prompt_template),
|
358 |
+
max_length=max_length,
|
359 |
+
lazy=True,
|
360 |
+
repeats=1,
|
361 |
+
special_tokens=special_tokens,
|
362 |
+
)
|
363 |
+
|
364 |
+
data_osprey_detail_description_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_detail_description.json'
|
365 |
+
image_osprey_description_dataset = dict(
|
366 |
+
type=OspreyDescriptionDataset,
|
367 |
+
image_folder=data_osprey_image_folders,
|
368 |
+
data_path=data_osprey_detail_description_file,
|
369 |
+
tokenizer=tokenizer,
|
370 |
+
template_map_fn=dict(
|
371 |
+
type=template_map_fn_factory, template=prompt_template),
|
372 |
+
max_length=max_length,
|
373 |
+
lazy=True,
|
374 |
+
repeats=1,
|
375 |
+
special_tokens=special_tokens,
|
376 |
+
)
|
377 |
+
|
378 |
+
data_osprey_short_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_short_form.json'
|
379 |
+
image_osprey_short_dataset = dict(
|
380 |
+
type=OspreyShortDescriptionDataset,
|
381 |
+
image_folder=data_osprey_image_folders,
|
382 |
+
data_path=data_osprey_short_file,
|
383 |
+
tokenizer=tokenizer,
|
384 |
+
template_map_fn=dict(
|
385 |
+
type=template_map_fn_factory, template=prompt_template),
|
386 |
+
max_length=max_length,
|
387 |
+
lazy=True,
|
388 |
+
repeats=1,
|
389 |
+
special_tokens=special_tokens,
|
390 |
+
)
|
391 |
+
|
392 |
+
data_osprey_part_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_part_level.json'
|
393 |
+
image_osprey_part_dataset = dict(
|
394 |
+
type=OspreyDataset,
|
395 |
+
image_folder=data_osprey_image_folders,
|
396 |
+
data_path=data_osprey_part_file,
|
397 |
+
tokenizer=tokenizer,
|
398 |
+
template_map_fn=dict(
|
399 |
+
type=template_map_fn_factory, template=prompt_template),
|
400 |
+
max_length=max_length,
|
401 |
+
lazy=True,
|
402 |
+
repeats=1,
|
403 |
+
special_tokens=special_tokens,
|
404 |
+
)
|
405 |
+
|
406 |
+
data_osprey_positive_neg_file = VIDEO_DATAS + 'osprey-724k/Osprey-724K/osprey_lvis_positive_negative.json'
|
407 |
+
image_osprey_positive_neg_dataset = dict(
|
408 |
+
type=OspreyDataset,
|
409 |
+
image_folder=data_osprey_image_folders,
|
410 |
+
data_path=data_osprey_positive_neg_file,
|
411 |
+
tokenizer=tokenizer,
|
412 |
+
template_map_fn=dict(
|
413 |
+
type=template_map_fn_factory, template=prompt_template),
|
414 |
+
max_length=max_length,
|
415 |
+
lazy=True,
|
416 |
+
repeats=1,
|
417 |
+
special_tokens=special_tokens,
|
418 |
+
)
|
419 |
+
|
420 |
+
train_dataset = dict(
|
421 |
+
type=ConcatDataset, datasets=[
|
422 |
+
# sem seg
|
423 |
+
# semantic_seg_ade20k_dataset,
|
424 |
+
# ref seg
|
425 |
+
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
|
426 |
+
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
|
427 |
+
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
|
428 |
+
refcoco_segm_dataset, refcoco_plus_segm_dataset, refcocog_segm_dataset,
|
429 |
+
# image qa
|
430 |
+
llava_vqa_dataset,
|
431 |
+
# video res
|
432 |
+
video_mevis_dataset, video_revos_dataset, video_refytvos_dataset,
|
433 |
+
# video chat
|
434 |
+
video_qa_dataset,
|
435 |
+
# sam2 pesudo
|
436 |
+
video_sam2_dataset,
|
437 |
+
# gcg data
|
438 |
+
glamm_psg_dataset,
|
439 |
+
glamm_grandf_dataset,
|
440 |
+
glamm_flickr_dataset,
|
441 |
+
glamm_refcocog_dataset,
|
442 |
+
# visual prompt
|
443 |
+
image_osprey_dataset, image_osprey_description_dataset,
|
444 |
+
image_osprey_part_dataset, image_osprey_short_dataset,
|
445 |
+
image_osprey_positive_neg_dataset,
|
446 |
+
]
|
447 |
+
)
|
448 |
+
train_dataloader = dict(
|
449 |
+
batch_size=batch_size,
|
450 |
+
num_workers=dataloader_num_workers,
|
451 |
+
dataset=train_dataset,
|
452 |
+
sampler=dict(
|
453 |
+
type=LengthGroupedSampler,
|
454 |
+
length_property='modality_length',
|
455 |
+
per_device_batch_size=batch_size * accumulative_counts),
|
456 |
+
collate_fn=dict(type=video_lisa_collate_fn)
|
457 |
+
)
|
458 |
+
|
459 |
+
#######################################################################
|
460 |
+
# PART 4 Scheduler & Optimizer #
|
461 |
+
#######################################################################
|
462 |
+
# optimizer
|
463 |
+
optim_wrapper = dict(
|
464 |
+
type=AmpOptimWrapper,
|
465 |
+
optimizer=dict(
|
466 |
+
type=optim_type, lr=lr, betas=betas, weight_decay=weight_decay),
|
467 |
+
clip_grad=dict(max_norm=max_norm, error_if_nonfinite=False),
|
468 |
+
accumulative_counts=accumulative_counts,
|
469 |
+
loss_scale='dynamic',
|
470 |
+
dtype='bfloat16'
|
471 |
+
)
|
472 |
+
|
473 |
+
# learning policy
|
474 |
+
# More information: https://github.com/open-mmlab/mmengine/blob/main/docs/en/tutorials/param_scheduler.md # noqa: E501
|
475 |
+
param_scheduler = [
|
476 |
+
dict(
|
477 |
+
type=LinearLR,
|
478 |
+
start_factor=1e-5,
|
479 |
+
by_epoch=True,
|
480 |
+
begin=0,
|
481 |
+
end=warmup_ratio * max_epochs,
|
482 |
+
convert_to_iter_based=True),
|
483 |
+
dict(
|
484 |
+
type=CosineAnnealingLR,
|
485 |
+
eta_min=0.0,
|
486 |
+
by_epoch=True,
|
487 |
+
begin=warmup_ratio * max_epochs,
|
488 |
+
end=max_epochs,
|
489 |
+
convert_to_iter_based=True)
|
490 |
+
]
|
491 |
+
|
492 |
+
# train, val, test setting
|
493 |
+
train_cfg = dict(type=TrainLoop, max_epochs=max_epochs)
|
494 |
+
|
495 |
+
#######################################################################
|
496 |
+
# PART 5 Runtime #
|
497 |
+
#######################################################################
|
498 |
+
# Log the dialogue periodically during the training process, optional
|
499 |
+
custom_hooks = [
|
500 |
+
# dict(type=DatasetInfoHook, tokenizer=tokenizer),
|
501 |
+
]
|
502 |
+
|
503 |
+
# configure default hooks
|
504 |
+
default_hooks = dict(
|
505 |
+
# record the time of every iteration.
|
506 |
+
timer=dict(type=IterTimerHook),
|
507 |
+
# print log every 10 iterations.
|
508 |
+
logger=dict(type=LoggerHook, log_metric_by_epoch=False, interval=10),
|
509 |
+
# enable the parameter scheduler.
|
510 |
+
param_scheduler=dict(type=ParamSchedulerHook),
|
511 |
+
# save checkpoint per `save_steps`.
|
512 |
+
checkpoint=dict(
|
513 |
+
type=CheckpointHook,
|
514 |
+
save_optimizer=False,
|
515 |
+
by_epoch=False,
|
516 |
+
interval=save_steps,
|
517 |
+
max_keep_ckpts=save_total_limit),
|
518 |
+
# set sampler seed in distributed evrionment.
|
519 |
+
sampler_seed=dict(type=DistSamplerSeedHook),
|
520 |
+
)
|
521 |
+
|
522 |
+
# configure environment
|
523 |
+
env_cfg = dict(
|
524 |
+
# whether to enable cudnn benchmark
|
525 |
+
cudnn_benchmark=False,
|
526 |
+
# set multi process parameters
|
527 |
+
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
528 |
+
# set distributed parameters
|
529 |
+
dist_cfg=dict(backend='nccl'),
|
530 |
+
)
|
531 |
+
|
532 |
+
# set visualizer
|
533 |
+
visualizer = None
|
534 |
+
|
535 |
+
# set log level
|
536 |
+
log_level = 'INFO'
|
537 |
+
|
538 |
+
# load from which checkpoint
|
539 |
+
load_from = None
|
540 |
+
|
541 |
+
# whether to resume training from the loaded checkpoint
|
542 |
+
resume = False
|
543 |
+
|
544 |
+
# Defaults to use random seed and disable `deterministic`
|
545 |
+
randomness = dict(seed=None, deterministic=False)
|
546 |
+
|
547 |
+
# set log processor
|
548 |
+
log_processor = dict(by_epoch=False)
|
projects/llava_sam2/datasets/ChatUniVi_Dataset.py
ADDED
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset as HFDataset
|
7 |
+
from datasets import DatasetDict, load_from_disk
|
8 |
+
from mmengine import print_log
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from xtuner.registry import BUILDER
|
14 |
+
from xtuner.dataset.huggingface import build_origin_dataset
|
15 |
+
import copy
|
16 |
+
from .encode_fn import video_lisa_encode_fn
|
17 |
+
import json
|
18 |
+
import cv2
|
19 |
+
import torchvision.transforms as T
|
20 |
+
from torchvision.transforms.functional import InterpolationMode
|
21 |
+
from decord import VideoReader, cpu
|
22 |
+
|
23 |
+
|
24 |
+
def _get_rawvideo_dec(video_path, select_frames=5):
|
25 |
+
|
26 |
+
if os.path.exists(video_path):
|
27 |
+
vreader = VideoReader(video_path, ctx=cpu(0))
|
28 |
+
elif os.path.exists(video_path.replace('mkv', 'mp4')):
|
29 |
+
vreader = VideoReader(video_path.replace('mkv', 'mp4'), ctx=cpu(0))
|
30 |
+
else:
|
31 |
+
print(video_path)
|
32 |
+
raise FileNotFoundError
|
33 |
+
|
34 |
+
fps = vreader.get_avg_fps()
|
35 |
+
f_start = 0
|
36 |
+
f_end = len(vreader) - 1
|
37 |
+
num_frames = f_end - f_start + 1
|
38 |
+
assert num_frames > 0, f'num_frames: {num_frames}, f_start: {f_start}, f_end: {f_end}, fps: {fps}, video_path: {video_path}'
|
39 |
+
# T x 3 x H x W
|
40 |
+
if num_frames <= select_frames:
|
41 |
+
sample_pos = range(f_start, f_end + 1)
|
42 |
+
else:
|
43 |
+
split_point = np.linspace(0, num_frames, num=select_frames+1, dtype=int)
|
44 |
+
sample_pos = [np.random.randint(split_point[i], split_point[i+1]) for i in range(select_frames)]
|
45 |
+
patch_images = [Image.fromarray(f) for f in vreader.get_batch(sample_pos).asnumpy()]
|
46 |
+
return patch_images
|
47 |
+
|
48 |
+
|
49 |
+
class VideoChatUniViDataset(Dataset):
|
50 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
51 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
52 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
53 |
+
IMG_START_TOKEN = '<img>'
|
54 |
+
IMG_END_TOKEN = '</img>'
|
55 |
+
|
56 |
+
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
|
57 |
+
FAST_IMG_START_TOKEN = '<fast_img>'
|
58 |
+
FAST_IMG_END_TOKEN = '</fast_img>'
|
59 |
+
|
60 |
+
def __init__(self,
|
61 |
+
image_folder,
|
62 |
+
json_file,
|
63 |
+
extra_image_processor=None,
|
64 |
+
tokenizer=None,
|
65 |
+
sampled_frames=10,
|
66 |
+
offline_processed_text_folder=None,
|
67 |
+
template_map_fn=None,
|
68 |
+
max_length=2048,
|
69 |
+
lazy=True,
|
70 |
+
repeats=1,
|
71 |
+
special_tokens=None,
|
72 |
+
use_fast=False,
|
73 |
+
n_fast_images=50,
|
74 |
+
fast_pool_size=4,
|
75 |
+
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
|
76 |
+
preprocessor=None,
|
77 |
+
):
|
78 |
+
assert lazy is True
|
79 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
80 |
+
self.sampled_frames = sampled_frames
|
81 |
+
assert offline_processed_text_folder or (json_file and tokenizer)
|
82 |
+
self.lazy = lazy
|
83 |
+
|
84 |
+
self.max_length = max_length
|
85 |
+
|
86 |
+
self.template_map_fn = template_map_fn
|
87 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
88 |
+
_type = self.template_map_fn['type']
|
89 |
+
del self.template_map_fn['type']
|
90 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
91 |
+
|
92 |
+
if offline_processed_text_folder and json_file:
|
93 |
+
print_log(
|
94 |
+
'Both `offline_processed_text_folder` and '
|
95 |
+
'`data_path` are set, and we load dataset from'
|
96 |
+
'`offline_processed_text_folder` '
|
97 |
+
f'({offline_processed_text_folder})',
|
98 |
+
logger='current',
|
99 |
+
level=logging.WARNING)
|
100 |
+
|
101 |
+
if offline_processed_text_folder is not None:
|
102 |
+
raise NotImplementedError
|
103 |
+
else:
|
104 |
+
json_datas = self.json_file_preprocess(json_file)
|
105 |
+
self.json_datas = json_datas
|
106 |
+
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
|
107 |
+
if self.lazy:
|
108 |
+
self.text_data = build_origin_dataset(json_data, 'train')
|
109 |
+
else:
|
110 |
+
raise NotImplementedError
|
111 |
+
|
112 |
+
self.image_folder = image_folder
|
113 |
+
if extra_image_processor is not None:
|
114 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
115 |
+
|
116 |
+
self.arch_type = arch_type
|
117 |
+
if self.arch_type == 'qwen':
|
118 |
+
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
|
119 |
+
self.IMG_START_TOKEN = '<|vision_start|>'
|
120 |
+
self.IMG_END_TOKEN = '<|vision_end|>'
|
121 |
+
elif self.arch_type == 'llava':
|
122 |
+
self.IMG_CONTEXT_TOKEN = '<image>'
|
123 |
+
self.IMG_START_TOKEN = ''
|
124 |
+
self.IMG_END_TOKEN = ''
|
125 |
+
self.repeats = repeats
|
126 |
+
|
127 |
+
self._system = ''
|
128 |
+
|
129 |
+
self.downsample_ratio = 0.5
|
130 |
+
if self.arch_type == 'llava':
|
131 |
+
self.downsample_ratio = 1
|
132 |
+
self.image_size = 448
|
133 |
+
if self.arch_type == 'llava':
|
134 |
+
self.image_size = 336
|
135 |
+
patch_size = 14
|
136 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
137 |
+
if self.arch_type == 'qwen':
|
138 |
+
self.patch_token = 1
|
139 |
+
|
140 |
+
if preprocessor is None:
|
141 |
+
self.transformer = T.Compose([
|
142 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
143 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
144 |
+
T.ToTensor(),
|
145 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
146 |
+
])
|
147 |
+
self.preprocessor = None
|
148 |
+
else:
|
149 |
+
self.transformer = None
|
150 |
+
self.preprocessor = BUILDER.build(preprocessor)
|
151 |
+
|
152 |
+
self.arch_type = arch_type
|
153 |
+
|
154 |
+
if special_tokens is not None:
|
155 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
156 |
+
|
157 |
+
self.use_fast = use_fast
|
158 |
+
self.n_fast_images = n_fast_images
|
159 |
+
self.fast_pool_size = fast_pool_size
|
160 |
+
|
161 |
+
# for visualization debug
|
162 |
+
self.save_folder = './work_dirs/video_debug/'
|
163 |
+
self.cur_number = 0
|
164 |
+
|
165 |
+
print("Video Chat dataset, include {} items.".format(len(self.text_data)))
|
166 |
+
|
167 |
+
def __len__(self):
|
168 |
+
return len(self.text_data) * self.repeats
|
169 |
+
|
170 |
+
@property
|
171 |
+
def modality_length(self):
|
172 |
+
length_list = []
|
173 |
+
for data_dict in self.text_data:
|
174 |
+
cur_len = 10000
|
175 |
+
length_list.append(cur_len)
|
176 |
+
return length_list
|
177 |
+
|
178 |
+
def real_len(self):
|
179 |
+
return len(self.text_data)
|
180 |
+
|
181 |
+
def json_file_preprocess(self, json_file):
|
182 |
+
# prepare expression annotation files
|
183 |
+
with open(json_file, 'r') as f:
|
184 |
+
json_datas = json.load(f)
|
185 |
+
return json_datas
|
186 |
+
|
187 |
+
def dataset_map_fn(self, data_dict, select_k=5):
|
188 |
+
assert 'video' in data_dict
|
189 |
+
# video
|
190 |
+
video_file = data_dict['video']
|
191 |
+
video_file = os.path.join(self.image_folder, video_file)
|
192 |
+
images = _get_rawvideo_dec(video_file, select_frames=select_k)
|
193 |
+
if self.use_fast:
|
194 |
+
fast_images = _get_rawvideo_dec(video_file, select_frames=self.n_fast_images)
|
195 |
+
else:
|
196 |
+
fast_images = None
|
197 |
+
|
198 |
+
conversation = data_dict['conversations']
|
199 |
+
|
200 |
+
# prepare text
|
201 |
+
if self.use_fast:
|
202 |
+
text_dict = self.prepare_text(
|
203 |
+
select_k, conversation, num_image_tokens=self.patch_token,
|
204 |
+
n_fast_images=len(fast_images),
|
205 |
+
)
|
206 |
+
else:
|
207 |
+
text_dict = self.prepare_text(
|
208 |
+
select_k, conversation, num_image_tokens=self.patch_token,
|
209 |
+
)
|
210 |
+
|
211 |
+
|
212 |
+
ret = {'images': images, 'conversation': text_dict['conversation'], 'fast_images': fast_images}
|
213 |
+
return ret
|
214 |
+
|
215 |
+
def prepare_text(self, n_frames, conversation, num_image_tokens=256, n_fast_images=0):
|
216 |
+
|
217 |
+
if self.use_fast:
|
218 |
+
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
|
219 |
+
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
|
220 |
+
f'{self.FAST_IMG_END_TOKEN}' + '\n'
|
221 |
+
else:
|
222 |
+
fast_frame_token_str = ''
|
223 |
+
|
224 |
+
frame_token_str = f'{self.IMG_START_TOKEN}' \
|
225 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
226 |
+
f'{self.IMG_END_TOKEN}'
|
227 |
+
|
228 |
+
questions = []
|
229 |
+
answers = []
|
230 |
+
|
231 |
+
for conv in conversation:
|
232 |
+
if conv['from'] == 'human':
|
233 |
+
questions.append(conv['value'].replace('<image>', ''))
|
234 |
+
else:
|
235 |
+
answers.append(conv['value'])
|
236 |
+
assert len(questions) == len(answers)
|
237 |
+
|
238 |
+
qa_list = []
|
239 |
+
for i, (question, answer) in enumerate(zip(questions, answers)):
|
240 |
+
if i == 0:
|
241 |
+
frame_tokens = frame_token_str + '\n'
|
242 |
+
# frame_tokens = '=' + ' '
|
243 |
+
frame_tokens = frame_tokens * n_frames
|
244 |
+
frame_tokens = frame_tokens.strip()
|
245 |
+
frame_tokens = fast_frame_token_str + frame_tokens
|
246 |
+
qa_list.append(
|
247 |
+
{'from': 'human', 'value': frame_tokens + question}
|
248 |
+
)
|
249 |
+
else:
|
250 |
+
qa_list.append(
|
251 |
+
{'from': 'human', 'value': question}
|
252 |
+
)
|
253 |
+
qa_list.append(
|
254 |
+
{'from': 'gpt', 'value': answer}
|
255 |
+
)
|
256 |
+
|
257 |
+
input = ''
|
258 |
+
conversation = []
|
259 |
+
for msg in qa_list:
|
260 |
+
if msg['from'] == 'human':
|
261 |
+
input += msg['value']
|
262 |
+
elif msg['from'] == 'gpt':
|
263 |
+
conversation.append({'input': input, 'output': msg['value']})
|
264 |
+
input = ''
|
265 |
+
else:
|
266 |
+
raise NotImplementedError
|
267 |
+
|
268 |
+
# add system information
|
269 |
+
conversation[0].update({'system': self._system})
|
270 |
+
return {'conversation': conversation}
|
271 |
+
|
272 |
+
def __getitem__(self, index):
|
273 |
+
index = index % self.real_len()
|
274 |
+
selected_data_dict = copy.deepcopy(self.text_data[index])
|
275 |
+
data_dict = self.dataset_map_fn(selected_data_dict, select_k=self.sampled_frames)
|
276 |
+
|
277 |
+
|
278 |
+
assert 'images' in data_dict.keys()
|
279 |
+
if self.use_fast:
|
280 |
+
assert 'fast_images' in data_dict.keys()
|
281 |
+
pixel_values = []
|
282 |
+
num_video_tokens = None
|
283 |
+
num_frame_tokens = None
|
284 |
+
if data_dict.get('images', None) is not None:
|
285 |
+
frames_files = data_dict['images']
|
286 |
+
for frame_image in frames_files:
|
287 |
+
frame_image = frame_image.convert('RGB')
|
288 |
+
ori_width, ori_height = frame_image.size
|
289 |
+
|
290 |
+
if self.preprocessor is not None:
|
291 |
+
pass
|
292 |
+
else:
|
293 |
+
frame_image = self.transformer(frame_image)
|
294 |
+
pixel_values.append(frame_image)
|
295 |
+
|
296 |
+
if self.preprocessor is not None:
|
297 |
+
if self.arch_type == 'qwen':
|
298 |
+
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
|
299 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
300 |
+
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
|
301 |
+
num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
|
302 |
+
num_frames = _data_dict['image_grid_thw'].shape[0]
|
303 |
+
num_video_tokens = num_frame_tokens * num_frames
|
304 |
+
elif self.arch_type == 'llava':
|
305 |
+
_data_dict = self.preprocessor(pixel_values, do_resize=True,
|
306 |
+
size=(self.image_size, self.image_size))
|
307 |
+
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
|
308 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
309 |
+
else:
|
310 |
+
raise NotImplementedError
|
311 |
+
data_dict.update(_data_dict)
|
312 |
+
else:
|
313 |
+
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
|
314 |
+
data_dict['pixel_values'] = pixel_values
|
315 |
+
else:
|
316 |
+
data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
|
317 |
+
data_dict['masks'] = None
|
318 |
+
|
319 |
+
if num_video_tokens is not None:
|
320 |
+
assert self.patch_token == 1
|
321 |
+
input_str = data_dict['conversation'][0]['input']
|
322 |
+
input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
|
323 |
+
assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
|
324 |
+
data_dict['conversation'][0]['input'] = input_str
|
325 |
+
|
326 |
+
result = self.template_map_fn(data_dict)
|
327 |
+
data_dict.update(result)
|
328 |
+
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
|
329 |
+
data_dict.update(result)
|
330 |
+
|
331 |
+
# for fast branch
|
332 |
+
if self.use_fast:
|
333 |
+
fast_pixel_values = []
|
334 |
+
frames_files = data_dict['fast_images']
|
335 |
+
for frame_image in frames_files:
|
336 |
+
frame_image = frame_image.convert('RGB')
|
337 |
+
ori_width, ori_height = frame_image.size
|
338 |
+
|
339 |
+
frame_image = self.transformer(frame_image)
|
340 |
+
fast_pixel_values.append(frame_image)
|
341 |
+
|
342 |
+
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
|
343 |
+
data_dict['fast_pixel_values'] = fast_pixel_values
|
344 |
+
|
345 |
+
|
346 |
+
# # for debug
|
347 |
+
# self.visualization_debug(data_dict)
|
348 |
+
# if self.cur_number < 10:
|
349 |
+
# return self[random.randint(0, len(self))]
|
350 |
+
|
351 |
+
data_dict['type'] = 'video'
|
352 |
+
return data_dict
|
353 |
+
|
354 |
+
def visualization_debug(self, data_dict):
|
355 |
+
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
|
356 |
+
if not os.path.exists(save_folder):
|
357 |
+
os.mkdir(save_folder)
|
358 |
+
self.cur_number += 1
|
359 |
+
|
360 |
+
# images
|
361 |
+
|
362 |
+
show_images = []
|
363 |
+
|
364 |
+
pixel_values = data_dict['pixel_values']
|
365 |
+
save_folder_image = os.path.join(save_folder, 'image')
|
366 |
+
if not os.path.exists(save_folder_image):
|
367 |
+
os.mkdir(save_folder_image)
|
368 |
+
for i_image, image_pixel_value in enumerate(pixel_values):
|
369 |
+
# print(image_pixel_value.shape)
|
370 |
+
image_pixel_value[0] = image_pixel_value[0] * 0.2686
|
371 |
+
image_pixel_value[1] = image_pixel_value[1] * 0.2613
|
372 |
+
image_pixel_value[2] = image_pixel_value[2] * 0.2757
|
373 |
+
image_pixel_value[0] = image_pixel_value[0] + 0.4814
|
374 |
+
image_pixel_value[1] = image_pixel_value[1] + 0.4578
|
375 |
+
image_pixel_value[2] = image_pixel_value[2] + 0.4082
|
376 |
+
image_pixel_value = image_pixel_value * 255
|
377 |
+
image_pixel_value = image_pixel_value.permute(1, 2, 0)
|
378 |
+
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
|
379 |
+
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
|
380 |
+
# print(image_pixel_value.shape)
|
381 |
+
show_images.append(image_pixel_value)
|
382 |
+
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
|
383 |
+
|
384 |
+
# text
|
385 |
+
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
|
386 |
+
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
|
387 |
+
json.dump([input_text], f)
|
388 |
+
|
389 |
+
return
|
projects/llava_sam2/datasets/GCG_Dataset.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from datasets import Dataset as HFDataset
|
6 |
+
from datasets import DatasetDict, load_from_disk
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from pycocotools import mask
|
10 |
+
import numpy as np
|
11 |
+
import copy
|
12 |
+
|
13 |
+
from xtuner.registry import BUILDER
|
14 |
+
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
|
15 |
+
import torchvision.transforms as T
|
16 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
17 |
+
from torchvision.transforms.functional import InterpolationMode
|
18 |
+
from .encode_fn import video_lisa_encode_fn
|
19 |
+
from .utils import dynamic_preprocess
|
20 |
+
|
21 |
+
from .gcg_process import glamm_openpsg_map_fn, glamm_flickr_map_fn, glamm_granf_map_fn, glamm_refcocog_map_fn
|
22 |
+
|
23 |
+
class GCGDataset(Dataset):
|
24 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
25 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
26 |
+
IMG_START_TOKEN = '<img>'
|
27 |
+
IMG_END_TOKEN = '</img>'
|
28 |
+
|
29 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
30 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
31 |
+
def __init__(self,
|
32 |
+
image_folder,
|
33 |
+
data_path=None,
|
34 |
+
tokenizer=None,
|
35 |
+
max_length=8196,
|
36 |
+
special_tokens=None,
|
37 |
+
template_map_fn=None,
|
38 |
+
extra_image_processor=None,
|
39 |
+
lazy=True,
|
40 |
+
repeats=1,
|
41 |
+
single_image_mode=False,
|
42 |
+
):
|
43 |
+
super().__init__()
|
44 |
+
assert lazy
|
45 |
+
self.lazy = lazy
|
46 |
+
self.max_length = max_length
|
47 |
+
|
48 |
+
json_data = self.json_file_preprocess(data_path)
|
49 |
+
json_data = DatasetDict({'train': HFDataset.from_list(json_data)})
|
50 |
+
self.text_data = build_origin_dataset(json_data, 'train')
|
51 |
+
|
52 |
+
self.image_folder = image_folder
|
53 |
+
|
54 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
55 |
+
if special_tokens is not None:
|
56 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
57 |
+
|
58 |
+
self.template_map_fn = template_map_fn
|
59 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
60 |
+
_type = self.template_map_fn['type']
|
61 |
+
del self.template_map_fn['type']
|
62 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
63 |
+
|
64 |
+
if extra_image_processor is not None:
|
65 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
66 |
+
|
67 |
+
self.repeats = repeats
|
68 |
+
|
69 |
+
self._system = ''
|
70 |
+
|
71 |
+
self.min_dynamic_patch = 1
|
72 |
+
self.max_dynamic_patch = 12
|
73 |
+
self.downsample_ratio = 0.5
|
74 |
+
self.image_size = 448
|
75 |
+
self.use_thumbnail = True
|
76 |
+
patch_size = 14
|
77 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
78 |
+
|
79 |
+
self.transformer = T.Compose([
|
80 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
81 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
82 |
+
T.ToTensor(),
|
83 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
84 |
+
])
|
85 |
+
|
86 |
+
if special_tokens is not None:
|
87 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
88 |
+
|
89 |
+
self.single_image_mode = single_image_mode
|
90 |
+
|
91 |
+
def json_file_preprocess(self, data_path):
|
92 |
+
with open(data_path, 'r') as f:
|
93 |
+
json_data = json.load(f)
|
94 |
+
return json_data
|
95 |
+
|
96 |
+
@property
|
97 |
+
def modality_length(self):
|
98 |
+
length_list = []
|
99 |
+
for data_dict in self.text_data:
|
100 |
+
if self.lazy:
|
101 |
+
cur_len = 100
|
102 |
+
else:
|
103 |
+
cur_len = len(data_dict['input_ids'])
|
104 |
+
if data_dict.get('image', None) is None:
|
105 |
+
cur_len = -cur_len
|
106 |
+
length_list.append(cur_len)
|
107 |
+
return length_list * self.repeats
|
108 |
+
|
109 |
+
def __len__(self):
|
110 |
+
return len(self.text_data) * self.repeats
|
111 |
+
|
112 |
+
def real_len(self):
|
113 |
+
return len(self.text_data)
|
114 |
+
|
115 |
+
def decode_mask(self, object_masks, ori_height, ori_width):
|
116 |
+
binary_masks = []
|
117 |
+
for object_mask in object_masks:
|
118 |
+
binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
|
119 |
+
for seg in object_mask:
|
120 |
+
rles = mask.frPyObjects([seg], ori_height, ori_width)
|
121 |
+
m = mask.decode(rles)
|
122 |
+
m = m.astype(np.uint8)
|
123 |
+
binary_mask += m.squeeze()
|
124 |
+
|
125 |
+
binary_masks.append(binary_mask)
|
126 |
+
if len(binary_masks) == 0:
|
127 |
+
return None
|
128 |
+
masks = np.stack(binary_masks, axis=0)
|
129 |
+
masks = torch.from_numpy(masks)
|
130 |
+
return masks
|
131 |
+
|
132 |
+
def dataset_map_fn(self, data_dict):
|
133 |
+
data_dict = glamm_refcocog_map_fn(data_dict)
|
134 |
+
return data_dict
|
135 |
+
|
136 |
+
def replace_image_str(self, data_dict, image_str):
|
137 |
+
data_dict['conversation'][0]['input'] = \
|
138 |
+
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
|
139 |
+
return data_dict
|
140 |
+
|
141 |
+
def __getitem__(self, index):
|
142 |
+
|
143 |
+
index = index % self.real_len()
|
144 |
+
data_dict = copy.deepcopy(self.text_data[index])
|
145 |
+
|
146 |
+
# parse datasets
|
147 |
+
result = self.dataset_map_fn(data_dict)
|
148 |
+
data_dict.update(result)
|
149 |
+
|
150 |
+
# process image
|
151 |
+
image_file = data_dict['image']
|
152 |
+
image = Image.open(os.path.join(self.image_folder,
|
153 |
+
image_file)).convert('RGB')
|
154 |
+
ori_width, ori_height = image.size
|
155 |
+
if hasattr(self, 'extra_image_processor'):
|
156 |
+
g_image = np.array(image) # for grounding
|
157 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
158 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
159 |
+
data_dict['g_pixel_values'] = g_pixel_values
|
160 |
+
|
161 |
+
if self.single_image_mode:
|
162 |
+
images = [image]
|
163 |
+
else:
|
164 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
165 |
+
self.max_dynamic_patch,
|
166 |
+
self.image_size, self.use_thumbnail)
|
167 |
+
pixel_values = [self.transformer(image) for image in images]
|
168 |
+
pixel_values = torch.stack(pixel_values)
|
169 |
+
data_dict['pixel_values'] = pixel_values
|
170 |
+
|
171 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
172 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
173 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
174 |
+
f'{self.IMG_END_TOKEN}'
|
175 |
+
|
176 |
+
data_dict = self.replace_image_str(data_dict, image_token_str)
|
177 |
+
|
178 |
+
result = self.template_map_fn(data_dict)
|
179 |
+
data_dict.update(result)
|
180 |
+
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
|
181 |
+
with_image_token=True)
|
182 |
+
data_dict.update(result)
|
183 |
+
# process mask
|
184 |
+
data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
|
185 |
+
|
186 |
+
if data_dict['masks'] is None:
|
187 |
+
return self.__getitem__(0)
|
188 |
+
|
189 |
+
return data_dict
|
190 |
+
|
191 |
+
class RefCOCOgGCGDataset(GCGDataset):
|
192 |
+
def __init__(self,
|
193 |
+
image_folder,
|
194 |
+
data_path=None,
|
195 |
+
tokenizer=None,
|
196 |
+
max_length=8196,
|
197 |
+
special_tokens=None,
|
198 |
+
template_map_fn=None,
|
199 |
+
extra_image_processor=None,
|
200 |
+
lazy=True,
|
201 |
+
repeats=1,
|
202 |
+
single_image_mode=False,
|
203 |
+
):
|
204 |
+
super().__init__(
|
205 |
+
image_folder=image_folder,
|
206 |
+
data_path=data_path,
|
207 |
+
tokenizer=tokenizer,
|
208 |
+
max_length=max_length,
|
209 |
+
special_tokens=special_tokens,
|
210 |
+
template_map_fn=template_map_fn,
|
211 |
+
extra_image_processor=extra_image_processor,
|
212 |
+
lazy=lazy,
|
213 |
+
repeats=repeats,
|
214 |
+
single_image_mode=single_image_mode,
|
215 |
+
)
|
216 |
+
|
217 |
+
def json_file_preprocess(self, data_path):
|
218 |
+
json_data = json.load(open(data_path))
|
219 |
+
|
220 |
+
# convert {id: dict} to dict(..., id=xx)
|
221 |
+
for idx in range(len(json_data)):
|
222 |
+
id = list(json_data[idx].keys())[0]
|
223 |
+
json_data[idx] = json_data[idx][id]
|
224 |
+
json_data[idx].update({'id': id})
|
225 |
+
return json_data
|
226 |
+
|
227 |
+
class GranDfGCGDataset(GCGDataset):
|
228 |
+
def __init__(self,
|
229 |
+
image_folder,
|
230 |
+
data_path=None,
|
231 |
+
tokenizer=None,
|
232 |
+
max_length=8196,
|
233 |
+
special_tokens=None,
|
234 |
+
template_map_fn=None,
|
235 |
+
extra_image_processor=None,
|
236 |
+
lazy=True,
|
237 |
+
repeats=1,
|
238 |
+
single_image_mode=False,
|
239 |
+
):
|
240 |
+
super().__init__(
|
241 |
+
image_folder=image_folder,
|
242 |
+
data_path=data_path,
|
243 |
+
tokenizer=tokenizer,
|
244 |
+
max_length=max_length,
|
245 |
+
special_tokens=special_tokens,
|
246 |
+
template_map_fn=template_map_fn,
|
247 |
+
extra_image_processor=extra_image_processor,
|
248 |
+
lazy=lazy,
|
249 |
+
repeats=repeats,
|
250 |
+
single_image_mode=single_image_mode,
|
251 |
+
)
|
252 |
+
|
253 |
+
def dataset_map_fn(self, data_dict):
|
254 |
+
data_dict = glamm_granf_map_fn(data_dict)
|
255 |
+
return data_dict
|
256 |
+
|
257 |
+
def decode_mask(self, object_masks, ori_height, ori_width):
|
258 |
+
binary_masks = []
|
259 |
+
for object_mask in object_masks:
|
260 |
+
binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
|
261 |
+
|
262 |
+
for rle in object_mask:
|
263 |
+
m = mask.decode(rle).astype(np.uint8)
|
264 |
+
binary_mask += m.squeeze()
|
265 |
+
|
266 |
+
binary_masks.append(binary_mask)
|
267 |
+
if len(binary_masks) == 0:
|
268 |
+
return None
|
269 |
+
masks = np.stack(binary_masks, axis=0)
|
270 |
+
masks = torch.from_numpy(masks)
|
271 |
+
return masks
|
272 |
+
|
273 |
+
class OpenPsgGCGDataset(GranDfGCGDataset):
|
274 |
+
def __init__(self,
|
275 |
+
image_folder,
|
276 |
+
data_path=None,
|
277 |
+
tokenizer=None,
|
278 |
+
max_length=8196,
|
279 |
+
special_tokens=None,
|
280 |
+
template_map_fn=None,
|
281 |
+
extra_image_processor=None,
|
282 |
+
lazy=True,
|
283 |
+
repeats=1,
|
284 |
+
single_image_mode=False,
|
285 |
+
):
|
286 |
+
super().__init__(
|
287 |
+
image_folder=image_folder,
|
288 |
+
data_path=data_path,
|
289 |
+
tokenizer=tokenizer,
|
290 |
+
max_length=max_length,
|
291 |
+
special_tokens=special_tokens,
|
292 |
+
template_map_fn=template_map_fn,
|
293 |
+
extra_image_processor=extra_image_processor,
|
294 |
+
lazy=lazy,
|
295 |
+
repeats=repeats,
|
296 |
+
single_image_mode=single_image_mode,
|
297 |
+
)
|
298 |
+
def dataset_map_fn(self, data_dict):
|
299 |
+
data_dict = glamm_openpsg_map_fn(data_dict)
|
300 |
+
return data_dict
|
301 |
+
|
302 |
+
|
303 |
+
class FlickrGCGDataset(GCGDataset):
|
304 |
+
def __init__(self,
|
305 |
+
image_folder,
|
306 |
+
data_path=None,
|
307 |
+
tokenizer=None,
|
308 |
+
max_length=8196,
|
309 |
+
special_tokens=None,
|
310 |
+
template_map_fn=None,
|
311 |
+
extra_image_processor=None,
|
312 |
+
lazy=True,
|
313 |
+
repeats=1,
|
314 |
+
single_image_mode=False,
|
315 |
+
):
|
316 |
+
super().__init__(
|
317 |
+
image_folder=image_folder,
|
318 |
+
data_path=data_path,
|
319 |
+
tokenizer=tokenizer,
|
320 |
+
max_length=max_length,
|
321 |
+
special_tokens=special_tokens,
|
322 |
+
template_map_fn=template_map_fn,
|
323 |
+
extra_image_processor=extra_image_processor,
|
324 |
+
lazy=lazy,
|
325 |
+
repeats=repeats,
|
326 |
+
single_image_mode=single_image_mode,
|
327 |
+
)
|
328 |
+
|
329 |
+
def dataset_map_fn(self, data_dict):
|
330 |
+
data_dict = glamm_flickr_map_fn(data_dict)
|
331 |
+
return data_dict
|
332 |
+
|
333 |
+
def json_file_preprocess(self, data_path):
|
334 |
+
def filter_images(data_infos, min_size):
|
335 |
+
return [i for i, info in enumerate(data_infos) if min(info['width'], info['height']) >= min_size]
|
336 |
+
|
337 |
+
# convert {id: dict} to dict(..., id=xx)
|
338 |
+
from pycocotools.coco import COCO
|
339 |
+
self.coco = COCO(data_path)
|
340 |
+
self.image_ids = self.coco.getImgIds()
|
341 |
+
data_infos = []
|
342 |
+
total_ann_ids = []
|
343 |
+
removed_img_count = 0
|
344 |
+
for img_id in self.image_ids:
|
345 |
+
info = self.coco.loadImgs([img_id])[0]
|
346 |
+
if len(info['caption'].split(' ')) < 3:
|
347 |
+
removed_img_count += 1
|
348 |
+
continue
|
349 |
+
info['filename'] = info['file_name'].split('_')[-1]
|
350 |
+
info['height'] = int(info['height'])
|
351 |
+
info['width'] = int(info['width'])
|
352 |
+
data_infos.append(info)
|
353 |
+
ann_ids = self.coco.getAnnIds(imgIds=[img_id])
|
354 |
+
total_ann_ids.extend(ann_ids)
|
355 |
+
assert len(set(total_ann_ids)) == len(total_ann_ids), f"Non-unique annotation IDs in '{data_path}'!"
|
356 |
+
print(f'Removed {removed_img_count} images.')
|
357 |
+
data_infos = [data_infos[i] for i in filter_images(data_infos, min_size=32)]
|
358 |
+
|
359 |
+
# obtain_annotations
|
360 |
+
for data_info in data_infos:
|
361 |
+
ann_ids = self.coco.getAnnIds(imgIds=data_info['id'])
|
362 |
+
ann_info = self.coco.loadAnns(ann_ids)
|
363 |
+
data_info.update({'ann_info': ann_info})
|
364 |
+
return data_infos
|
365 |
+
|
366 |
+
def decode_mask(self, object_masks, ori_height, ori_width):
|
367 |
+
binary_masks = []
|
368 |
+
for object_mask in object_masks:
|
369 |
+
binary_mask = mask.decode(object_mask).astype(np.uint8)
|
370 |
+
binary_masks.append(binary_mask)
|
371 |
+
if len(binary_masks) == 0:
|
372 |
+
return None
|
373 |
+
masks = np.stack(binary_masks, axis=0)
|
374 |
+
masks = torch.from_numpy(masks)
|
375 |
+
return masks
|
projects/llava_sam2/datasets/Grand_Dataset.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset as HFDataset
|
7 |
+
from datasets import DatasetDict, load_from_disk
|
8 |
+
from PIL import Image
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from pycocotools import mask
|
11 |
+
import numpy as np
|
12 |
+
import copy
|
13 |
+
|
14 |
+
from xtuner.registry import BUILDER
|
15 |
+
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
|
16 |
+
import torchvision.transforms as T
|
17 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
18 |
+
from torchvision.transforms.functional import InterpolationMode
|
19 |
+
from .encode_fn import video_lisa_encode_fn
|
20 |
+
from .utils import dynamic_preprocess
|
21 |
+
|
22 |
+
from .grand_process import glamm_grand_map_fn
|
23 |
+
|
24 |
+
class GranDDataset(Dataset):
|
25 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
26 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
27 |
+
IMG_START_TOKEN = '<img>'
|
28 |
+
IMG_END_TOKEN = '</img>'
|
29 |
+
|
30 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
31 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
32 |
+
def __init__(self,
|
33 |
+
image_folder,
|
34 |
+
json_folder=None,
|
35 |
+
tokenizer=None,
|
36 |
+
max_length=8196,
|
37 |
+
special_tokens=None,
|
38 |
+
template_map_fn=None,
|
39 |
+
extra_image_processor=None,
|
40 |
+
lazy=True,
|
41 |
+
repeats=1,
|
42 |
+
single_image_mode=False,
|
43 |
+
image_list_save_path='./work_dirs/grand_image.json',
|
44 |
+
json_list_save_path='./work_dirs/grand_jsons.json',
|
45 |
+
):
|
46 |
+
super().__init__()
|
47 |
+
assert lazy
|
48 |
+
self.lazy = lazy
|
49 |
+
self.max_length = max_length
|
50 |
+
|
51 |
+
self.image_list_save_path = image_list_save_path
|
52 |
+
self.json_list_save_path = json_list_save_path
|
53 |
+
|
54 |
+
json_files, image_path_dict = self.json_file_preprocess(image_folder, json_folder)
|
55 |
+
self.json_data = json_files
|
56 |
+
self.image_path_dict = image_path_dict
|
57 |
+
|
58 |
+
self.image_folder = image_folder
|
59 |
+
|
60 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
61 |
+
if special_tokens is not None:
|
62 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
63 |
+
|
64 |
+
self.template_map_fn = template_map_fn
|
65 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
66 |
+
_type = self.template_map_fn['type']
|
67 |
+
del self.template_map_fn['type']
|
68 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
69 |
+
|
70 |
+
if extra_image_processor is not None:
|
71 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
72 |
+
|
73 |
+
self.repeats = repeats
|
74 |
+
|
75 |
+
self._system = ''
|
76 |
+
|
77 |
+
self.min_dynamic_patch = 1
|
78 |
+
self.max_dynamic_patch = 12
|
79 |
+
self.downsample_ratio = 0.5
|
80 |
+
self.image_size = 448
|
81 |
+
self.use_thumbnail = True
|
82 |
+
patch_size = 14
|
83 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
84 |
+
|
85 |
+
self.transformer = T.Compose([
|
86 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
87 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
88 |
+
T.ToTensor(),
|
89 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
90 |
+
])
|
91 |
+
|
92 |
+
if special_tokens is not None:
|
93 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
94 |
+
|
95 |
+
self.single_image_mode = single_image_mode
|
96 |
+
|
97 |
+
def json_file_preprocess(self, image_folder, json_folder):
|
98 |
+
|
99 |
+
# list jsons
|
100 |
+
print("Processing GRAND json files !!!")
|
101 |
+
if os.path.exists(self.json_list_save_path):
|
102 |
+
with open(self.json_list_save_path, 'r') as f:
|
103 |
+
json_files = json.load(f)
|
104 |
+
else:
|
105 |
+
json_files = os.listdir(json_folder)
|
106 |
+
_json_files = []
|
107 |
+
for _file in json_files:
|
108 |
+
if '.json' in _file:
|
109 |
+
_json_files.append(os.path.join(json_folder, _file))
|
110 |
+
json_files = _json_files
|
111 |
+
with open(self.json_list_save_path, 'w') as f:
|
112 |
+
json.dump(json_files, f)
|
113 |
+
print(f"Finished, {len(json_files)} json files !")
|
114 |
+
|
115 |
+
# list images
|
116 |
+
print("Processing GRAND image files !!!")
|
117 |
+
if os.path.exists(self.image_list_save_path):
|
118 |
+
with open(self.image_list_save_path, 'r') as f:
|
119 |
+
image_path_dict = json.load(f)
|
120 |
+
else:
|
121 |
+
sub_folders = os.listdir(image_folder)
|
122 |
+
_sub_folders = []
|
123 |
+
for folder_name in sub_folders:
|
124 |
+
if 'sa_00' in folder_name:
|
125 |
+
_sub_folders.append(folder_name)
|
126 |
+
sub_folders = _sub_folders
|
127 |
+
sub_folders = [os.path.join(image_folder, folder_name) for folder_name in sub_folders]
|
128 |
+
|
129 |
+
image_path_dict = {}
|
130 |
+
for sub_folder in sub_folders:
|
131 |
+
files = os.listdir(sub_folder)
|
132 |
+
for _file in files:
|
133 |
+
if '.jpg' in _file:
|
134 |
+
image_path_dict[_file] = os.path.join(sub_folder, _file)
|
135 |
+
|
136 |
+
with open(self.image_list_save_path, 'w') as f:
|
137 |
+
json.dump(image_path_dict, f)
|
138 |
+
print(f"Finished, {len(image_path_dict)} image files !")
|
139 |
+
|
140 |
+
return json_files, image_path_dict
|
141 |
+
|
142 |
+
@property
|
143 |
+
def modality_length(self):
|
144 |
+
length_list = [10000] * len(self.json_data)
|
145 |
+
return length_list * self.repeats
|
146 |
+
|
147 |
+
def __len__(self):
|
148 |
+
return len(self.json_data) * self.repeats
|
149 |
+
|
150 |
+
def real_len(self):
|
151 |
+
return len(self.json_data)
|
152 |
+
|
153 |
+
def decode_mask(self, object_masks, ori_height, ori_width):
|
154 |
+
binary_masks = []
|
155 |
+
for object_mask in object_masks:
|
156 |
+
binary_mask = np.zeros((ori_height, ori_width), dtype=np.uint8)
|
157 |
+
for seg in object_mask:
|
158 |
+
m = mask.decode(seg)
|
159 |
+
m = m.astype(np.uint8)
|
160 |
+
binary_mask += m.squeeze()
|
161 |
+
|
162 |
+
binary_masks.append(binary_mask)
|
163 |
+
if len(binary_masks) == 0:
|
164 |
+
return None
|
165 |
+
masks = np.stack(binary_masks, axis=0)
|
166 |
+
masks = torch.from_numpy(masks)
|
167 |
+
return masks
|
168 |
+
|
169 |
+
def dataset_map_fn(self, data_dict):
|
170 |
+
data_dict = glamm_grand_map_fn(data_dict)
|
171 |
+
return data_dict
|
172 |
+
|
173 |
+
def replace_image_str(self, data_dict, image_str):
|
174 |
+
data_dict['conversation'][0]['input'] = \
|
175 |
+
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
|
176 |
+
return data_dict
|
177 |
+
|
178 |
+
def __getitem__(self, index):
|
179 |
+
|
180 |
+
index = index % self.real_len()
|
181 |
+
json_file_path = self.json_data[index]
|
182 |
+
with open(json_file_path, 'r') as f:
|
183 |
+
json_dict = json.load(f)
|
184 |
+
|
185 |
+
image_name = list(json_dict.keys())[0]
|
186 |
+
|
187 |
+
if image_name not in self.image_path_dict.keys():
|
188 |
+
return self.__getitem__(random.randint(0, len(self.json_data) - 1))
|
189 |
+
image_path = self.image_path_dict[image_name]
|
190 |
+
|
191 |
+
json_dict = json_dict[image_name]
|
192 |
+
# parse datasets
|
193 |
+
result = self.dataset_map_fn(json_dict)
|
194 |
+
json_dict.update(result)
|
195 |
+
data_dict = json_dict
|
196 |
+
|
197 |
+
data_dict['image'] = image_path
|
198 |
+
|
199 |
+
# process image
|
200 |
+
image_file = data_dict['image']
|
201 |
+
try:
|
202 |
+
image = Image.open(os.path.join(self.image_folder,
|
203 |
+
image_file)).convert('RGB')
|
204 |
+
except:
|
205 |
+
return self.__getitem__(random.randint(0, len(self.json_data) - 1))
|
206 |
+
ori_width, ori_height = image.size
|
207 |
+
if hasattr(self, 'extra_image_processor'):
|
208 |
+
g_image = np.array(image) # for grounding
|
209 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
210 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
211 |
+
data_dict['g_pixel_values'] = g_pixel_values
|
212 |
+
|
213 |
+
if self.single_image_mode:
|
214 |
+
images = [image]
|
215 |
+
else:
|
216 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
217 |
+
self.max_dynamic_patch,
|
218 |
+
self.image_size, self.use_thumbnail)
|
219 |
+
pixel_values = [self.transformer(image) for image in images]
|
220 |
+
pixel_values = torch.stack(pixel_values)
|
221 |
+
data_dict['pixel_values'] = pixel_values
|
222 |
+
|
223 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
224 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
225 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
226 |
+
f'{self.IMG_END_TOKEN}'
|
227 |
+
|
228 |
+
data_dict = self.replace_image_str(data_dict, image_token_str)
|
229 |
+
|
230 |
+
result = self.template_map_fn(data_dict)
|
231 |
+
data_dict.update(result)
|
232 |
+
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
|
233 |
+
with_image_token=True)
|
234 |
+
data_dict.update(result)
|
235 |
+
# process mask
|
236 |
+
data_dict['masks'] = self.decode_mask(data_dict['masks'], ori_height=ori_height, ori_width=ori_width)
|
237 |
+
|
238 |
+
if data_dict['masks'] is None:
|
239 |
+
return self.__getitem__(random.randint(0, len(self.json_data) - 1))
|
240 |
+
|
241 |
+
return data_dict
|
projects/llava_sam2/datasets/MeVIS_Dataset.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ReVOS_Dataset import VideoReVOSDataset
|
2 |
+
|
3 |
+
|
4 |
+
class VideoMeVISDataset(VideoReVOSDataset):
|
5 |
+
pass
|
projects/llava_sam2/datasets/Osprey_Dataset.py
ADDED
@@ -0,0 +1,463 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from datasets import Dataset as HFDataset
|
6 |
+
from datasets import DatasetDict, load_from_disk
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
from pycocotools import mask as maskUtils
|
10 |
+
import numpy as np
|
11 |
+
import copy
|
12 |
+
|
13 |
+
from xtuner.registry import BUILDER
|
14 |
+
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
|
15 |
+
import torchvision.transforms as T
|
16 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
17 |
+
from torchvision.transforms.functional import InterpolationMode
|
18 |
+
from .encode_fn import video_lisa_encode_fn
|
19 |
+
from .utils import dynamic_preprocess
|
20 |
+
|
21 |
+
import random
|
22 |
+
|
23 |
+
import torch.nn.functional as F
|
24 |
+
|
25 |
+
class OspreyDataset(Dataset):
|
26 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
27 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
28 |
+
IMG_START_TOKEN = '<img>'
|
29 |
+
IMG_END_TOKEN = '</img>'
|
30 |
+
|
31 |
+
LIMIT = ''
|
32 |
+
|
33 |
+
VP_START_TOKEN = '<vp>'
|
34 |
+
VP_END_TOKEN = '</vp>'
|
35 |
+
|
36 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
37 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
38 |
+
def __init__(self,
|
39 |
+
image_folder,
|
40 |
+
data_path=None,
|
41 |
+
tokenizer=None,
|
42 |
+
max_length=8196,
|
43 |
+
special_tokens=None,
|
44 |
+
template_map_fn=None,
|
45 |
+
extra_image_processor=None,
|
46 |
+
lazy=True,
|
47 |
+
repeats=1,
|
48 |
+
single_image_mode=False,
|
49 |
+
):
|
50 |
+
super().__init__()
|
51 |
+
assert lazy
|
52 |
+
self.lazy = lazy
|
53 |
+
self.max_length = max_length
|
54 |
+
|
55 |
+
json_data = self.json_file_preprocess(data_path)
|
56 |
+
self.text_data = json_data
|
57 |
+
|
58 |
+
self.image_folder = image_folder
|
59 |
+
|
60 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
61 |
+
if special_tokens is not None:
|
62 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
63 |
+
|
64 |
+
self.template_map_fn = template_map_fn
|
65 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
66 |
+
_type = self.template_map_fn['type']
|
67 |
+
del self.template_map_fn['type']
|
68 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
69 |
+
|
70 |
+
if extra_image_processor is not None:
|
71 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
72 |
+
|
73 |
+
self.repeats = repeats
|
74 |
+
|
75 |
+
self._system = ''
|
76 |
+
|
77 |
+
self.min_dynamic_patch = 1
|
78 |
+
self.max_dynamic_patch = 12
|
79 |
+
self.downsample_ratio = 0.5
|
80 |
+
self.image_size = 448
|
81 |
+
self.use_thumbnail = True
|
82 |
+
patch_size = 14
|
83 |
+
self.patch_size = patch_size
|
84 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
85 |
+
|
86 |
+
self.transformer = T.Compose([
|
87 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
88 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
89 |
+
T.ToTensor(),
|
90 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
91 |
+
])
|
92 |
+
|
93 |
+
if special_tokens is not None:
|
94 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
95 |
+
|
96 |
+
self.single_image_mode = single_image_mode
|
97 |
+
|
98 |
+
def json_file_preprocess(self, data_path):
|
99 |
+
with open(data_path, 'r') as f:
|
100 |
+
json_data = json.load(f)
|
101 |
+
return json_data
|
102 |
+
|
103 |
+
@property
|
104 |
+
def modality_length(self):
|
105 |
+
length_list = []
|
106 |
+
for data_dict in self.text_data:
|
107 |
+
if self.lazy:
|
108 |
+
cur_len = 100
|
109 |
+
else:
|
110 |
+
cur_len = len(data_dict['input_ids'])
|
111 |
+
if data_dict.get('image', None) is None:
|
112 |
+
cur_len = -cur_len
|
113 |
+
length_list.append(cur_len)
|
114 |
+
return length_list * self.repeats
|
115 |
+
|
116 |
+
def __len__(self):
|
117 |
+
return len(self.text_data) * self.repeats
|
118 |
+
|
119 |
+
def real_len(self):
|
120 |
+
return len(self.text_data)
|
121 |
+
|
122 |
+
def annToMask(self, mask_ann, h, w):
|
123 |
+
if isinstance(mask_ann, list):
|
124 |
+
rles = maskUtils.frPyObjects(mask_ann, h, w)
|
125 |
+
rle = maskUtils.merge(rles)
|
126 |
+
elif isinstance(mask_ann['counts'], list):
|
127 |
+
# uncompressed RLE
|
128 |
+
rle = maskUtils.frPyObjects(mask_ann, h, w)
|
129 |
+
else:
|
130 |
+
# rle
|
131 |
+
rle = mask_ann
|
132 |
+
mask = maskUtils.decode(rle)
|
133 |
+
return mask
|
134 |
+
|
135 |
+
def decode_mask(self, object_masks, ori_height, ori_width):
|
136 |
+
binary_masks = []
|
137 |
+
for object_mask in object_masks:
|
138 |
+
binary_mask = self.annToMask(object_mask, ori_height, ori_width)
|
139 |
+
binary_masks.append(binary_mask)
|
140 |
+
if len(binary_masks) == 0:
|
141 |
+
return None
|
142 |
+
masks = np.stack(binary_masks, axis=0)
|
143 |
+
masks = torch.from_numpy(masks)
|
144 |
+
return masks
|
145 |
+
|
146 |
+
def _process_conversation(self, converations, n_regions, region_pixels):
|
147 |
+
start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
|
148 |
+
for i in range(n_regions):
|
149 |
+
start_region_str = start_region_str + \
|
150 |
+
f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
|
151 |
+
if i == n_regions - 1:
|
152 |
+
start_region_str = start_region_str + '.\n'
|
153 |
+
else:
|
154 |
+
start_region_str = start_region_str + ', '
|
155 |
+
|
156 |
+
for i, item in enumerate(converations):
|
157 |
+
item['value'] = item['value'].replace('<', '').replace('>', '')
|
158 |
+
if item['from'] == 'human':
|
159 |
+
item['value'] = item['value'] + self.LIMIT
|
160 |
+
# first conv process
|
161 |
+
if i == 0:
|
162 |
+
assert item['from'] == "human"
|
163 |
+
item['value'] = start_region_str + item['value']
|
164 |
+
|
165 |
+
messages = converations
|
166 |
+
input = ''
|
167 |
+
|
168 |
+
conversation = []
|
169 |
+
while messages and messages[0]['from'] == 'gpt':
|
170 |
+
# Skip the first one if it is from gpt
|
171 |
+
messages = messages[1:]
|
172 |
+
for msg in messages:
|
173 |
+
if msg['from'] == 'human':
|
174 |
+
if DEFAULT_IMAGE_TOKEN in msg['value']:
|
175 |
+
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
|
176 |
+
'').strip()
|
177 |
+
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
|
178 |
+
msg['value'] = msg['value'].strip()
|
179 |
+
input += msg['value']
|
180 |
+
|
181 |
+
elif msg['from'] == 'gpt':
|
182 |
+
conversation.append({'input': input, 'output': msg['value']})
|
183 |
+
input = ''
|
184 |
+
else:
|
185 |
+
raise NotImplementedError
|
186 |
+
|
187 |
+
return conversation
|
188 |
+
|
189 |
+
def _get_region_infos(self, masks):
|
190 |
+
# masks tensor, (n_obj, h, w)
|
191 |
+
masks = F.interpolate(
|
192 |
+
masks.unsqueeze(0),
|
193 |
+
size=(int(self.image_size // self.patch_size * self.downsample_ratio),
|
194 |
+
int(self.image_size // self.patch_size * self.downsample_ratio)),
|
195 |
+
mode='nearest').squeeze(0)
|
196 |
+
region_pixels = []
|
197 |
+
for mask in masks:
|
198 |
+
region_pixels.append(mask.bool().to(torch.int64).sum())
|
199 |
+
return masks, region_pixels
|
200 |
+
|
201 |
+
def dataset_map_fn(self, data_dict):
|
202 |
+
file_name = data_dict['file_name'] # image file name
|
203 |
+
conversations = data_dict['conversations']
|
204 |
+
masks = [anno["segmentation"] for anno in data_dict["annotation"]]
|
205 |
+
height = data_dict['height']
|
206 |
+
width = data_dict['width']
|
207 |
+
_ret = {}
|
208 |
+
|
209 |
+
_ret['image'] = file_name
|
210 |
+
_ret['height'] = height
|
211 |
+
_ret['width'] = width
|
212 |
+
|
213 |
+
masks = self.decode_mask(masks, height, width)
|
214 |
+
masks, region_pixels = self._get_region_infos(masks)
|
215 |
+
|
216 |
+
if masks is None:
|
217 |
+
return None
|
218 |
+
|
219 |
+
conversations = self._process_conversation(conversations, len(masks), region_pixels)
|
220 |
+
_ret['conversation'] = conversations
|
221 |
+
_ret['prompt_masks'] = masks
|
222 |
+
return _ret
|
223 |
+
|
224 |
+
def replace_image_str(self, data_dict, image_str):
|
225 |
+
data_dict['conversation'][0]['input'] = \
|
226 |
+
data_dict['conversation'][0]['input'].replace(DEFAULT_IMAGE_TOKEN, image_str)
|
227 |
+
return data_dict
|
228 |
+
|
229 |
+
def __getitem__(self, index):
|
230 |
+
|
231 |
+
index = index % self.real_len()
|
232 |
+
data_dict = copy.deepcopy(self.text_data[index])
|
233 |
+
|
234 |
+
# parse datasets
|
235 |
+
result = self.dataset_map_fn(data_dict) # {'image', 'height', 'width', 'conversation', 'masks'}
|
236 |
+
if result is None or result['prompt_masks'] is None:
|
237 |
+
return self.__getitem__(0)
|
238 |
+
|
239 |
+
data_dict = result
|
240 |
+
|
241 |
+
# process image
|
242 |
+
image_file = data_dict['image']
|
243 |
+
if isinstance(self.image_folder, list):
|
244 |
+
for image_folder in self.image_folder:
|
245 |
+
image_path = os.path.join(image_folder, image_file)
|
246 |
+
if os.path.exists(image_path):
|
247 |
+
image = Image.open(image_path).convert('RGB')
|
248 |
+
break
|
249 |
+
else:
|
250 |
+
image = Image.open(os.path.join(self.image_folder,
|
251 |
+
image_file)).convert('RGB')
|
252 |
+
ori_width, ori_height = image.size
|
253 |
+
|
254 |
+
if self.single_image_mode:
|
255 |
+
images = [image]
|
256 |
+
else:
|
257 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
258 |
+
self.max_dynamic_patch,
|
259 |
+
self.image_size, self.use_thumbnail)
|
260 |
+
vp_overall_mask = torch.Tensor([False] * (len(images) - 1) + [True])
|
261 |
+
data_dict['vp_overall_mask'] = vp_overall_mask
|
262 |
+
|
263 |
+
pixel_values = [self.transformer(image) for image in images]
|
264 |
+
pixel_values = torch.stack(pixel_values)
|
265 |
+
data_dict['pixel_values'] = pixel_values
|
266 |
+
|
267 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
268 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
269 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
270 |
+
f'{self.IMG_END_TOKEN}'
|
271 |
+
|
272 |
+
data_dict = self.replace_image_str(data_dict, image_token_str)
|
273 |
+
|
274 |
+
result = self.template_map_fn(data_dict)
|
275 |
+
data_dict.update(result)
|
276 |
+
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length,
|
277 |
+
with_image_token=True)
|
278 |
+
data_dict.update(result)
|
279 |
+
# process mask
|
280 |
+
# data_dict['prompt_masks'] = data_dict['prompt_masks']
|
281 |
+
|
282 |
+
if data_dict['prompt_masks'] is None:
|
283 |
+
return self.__getitem__(0)
|
284 |
+
|
285 |
+
return data_dict
|
286 |
+
|
287 |
+
|
288 |
+
DETAILED_QUESTIONS = [
|
289 |
+
'Can you provide me with a detailed description of the region in the picture marked by <region>?',
|
290 |
+
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail?",
|
291 |
+
'What can you tell me about the region indicated by <region> in the image?',
|
292 |
+
"I'd like to know more about the area in the photo labeled <region>. Can you give me a detailed description?",
|
293 |
+
'Could you describe the region shown as <region> in the picture in great detail?',
|
294 |
+
'What details can you give me about the region outlined by <region> in the photo?',
|
295 |
+
'Please provide me with a comprehensive description of the region marked with <region> in the image.',
|
296 |
+
'Can you give me a detailed account of the region labeled as <region> in the picture?',
|
297 |
+
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail?",
|
298 |
+
'What is the region outlined by <region> in the picture like? Could you give me a detailed description?',
|
299 |
+
'Can you provide me with a detailed description of the region in the picture marked by <region>, please?',
|
300 |
+
"I'm curious about the region represented by <region> in the picture. Could you describe it in detail, please?",
|
301 |
+
'What can you tell me about the region indicated by <region> in the image, exactly?',
|
302 |
+
"I'd like to know more about the area in the photo labeled <region>, please. Can you give me a detailed description?",
|
303 |
+
'Could you describe the region shown as <region> in the picture in great detail, please?',
|
304 |
+
'What details can you give me about the region outlined by <region> in the photo, please?',
|
305 |
+
'Please provide me with a comprehensive description of the region marked with <region> in the image, please.',
|
306 |
+
'Can you give me a detailed account of the region labeled as <region> in the picture, please?',
|
307 |
+
"I'm interested in learning more about the region represented by <region> in the photo. Can you describe it in detail, please?",
|
308 |
+
'What is the region outlined by <region> in the picture like, please? Could you give me a detailed description?',
|
309 |
+
'Please describe the region <region> in the image in detail.',
|
310 |
+
'Can you offer a thorough analysis of the region <region> in the image?',
|
311 |
+
'Could you elaborate on the region highlighted by <region> in the picture provided?',
|
312 |
+
'Please share more information about the zone emphasized with <region> in the photo.',
|
313 |
+
'What insights can you give about the area denoted by <region> in the image presented?',
|
314 |
+
'Can you share a comprehensive rundown of the region denoted by <region> in the presented image?',
|
315 |
+
"I'd like to know more about the region highlighted by <region> in the picture provided.",
|
316 |
+
'Work through the important details of the area <region> in the image.',
|
317 |
+
'Illustrate the area represented by <region> through a descriptive explanation.',
|
318 |
+
'Examine the region <region> closely and share its details.'
|
319 |
+
]
|
320 |
+
|
321 |
+
class OspreyDescriptionDataset(OspreyDataset):
|
322 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
323 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
324 |
+
IMG_START_TOKEN = '<img>'
|
325 |
+
IMG_END_TOKEN = '</img>'
|
326 |
+
|
327 |
+
VP_START_TOKEN = '<vp>'
|
328 |
+
VP_END_TOKEN = '</vp>'
|
329 |
+
|
330 |
+
LIMIT=''
|
331 |
+
|
332 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
333 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
334 |
+
def __init__(self,
|
335 |
+
image_folder,
|
336 |
+
data_path=None,
|
337 |
+
tokenizer=None,
|
338 |
+
max_length=8196,
|
339 |
+
special_tokens=None,
|
340 |
+
template_map_fn=None,
|
341 |
+
extra_image_processor=None,
|
342 |
+
lazy=True,
|
343 |
+
repeats=1,
|
344 |
+
single_image_mode=False,
|
345 |
+
):
|
346 |
+
super(OspreyDescriptionDataset, self).__init__(
|
347 |
+
image_folder=image_folder,
|
348 |
+
data_path=data_path,
|
349 |
+
tokenizer=tokenizer,
|
350 |
+
max_length=max_length,
|
351 |
+
special_tokens=special_tokens,
|
352 |
+
template_map_fn=template_map_fn,
|
353 |
+
extra_image_processor=extra_image_processor,
|
354 |
+
lazy=lazy,
|
355 |
+
repeats=repeats,
|
356 |
+
single_image_mode=single_image_mode,
|
357 |
+
)
|
358 |
+
|
359 |
+
def dataset_map_fn(self, data_dict):
|
360 |
+
file_name = data_dict['file_name'] # image file name
|
361 |
+
descriptions = data_dict['description']
|
362 |
+
masks = [anno["segmentation"] for anno in data_dict["annotation"]]
|
363 |
+
height = data_dict['height']
|
364 |
+
width = data_dict['width']
|
365 |
+
_ret = {}
|
366 |
+
|
367 |
+
_ret['image'] = file_name
|
368 |
+
_ret['height'] = height
|
369 |
+
_ret['width'] = width
|
370 |
+
|
371 |
+
masks = self.decode_mask(masks, height, width)
|
372 |
+
masks, region_pixels = self._get_region_infos(masks)
|
373 |
+
|
374 |
+
if masks is None:
|
375 |
+
return None
|
376 |
+
|
377 |
+
conversations = self._process_conversation(descriptions, len(masks), region_pixels)
|
378 |
+
_ret['conversation'] = conversations
|
379 |
+
_ret['prompt_masks'] = masks
|
380 |
+
return _ret
|
381 |
+
|
382 |
+
def _process_conversation(self, descriptions, n_regions, region_pixels):
|
383 |
+
start_region_str = '<image> There are {} part regions in the picture: '.format(n_regions)
|
384 |
+
for i in range(n_regions):
|
385 |
+
start_region_str = start_region_str + \
|
386 |
+
f"region{i+1}" + self.VP_START_TOKEN + self.IMG_CONTEXT_TOKEN * region_pixels[i] + self.VP_END_TOKEN
|
387 |
+
if i == n_regions - 1:
|
388 |
+
start_region_str = start_region_str + '.\n'
|
389 |
+
else:
|
390 |
+
start_region_str = start_region_str + ', '
|
391 |
+
|
392 |
+
converations = []
|
393 |
+
for i, item in enumerate(descriptions):
|
394 |
+
question = random.choice(DETAILED_QUESTIONS).strip().replace('<region>', f"region{i+1}") + self.LIMIT
|
395 |
+
answer = item.replace('<', '').replace('>', '')
|
396 |
+
# first conv process
|
397 |
+
if i == 0:
|
398 |
+
question = start_region_str + question
|
399 |
+
converations.append({'from': 'human', 'value': question})
|
400 |
+
converations.append({'from': 'gpt', 'value': answer})
|
401 |
+
|
402 |
+
messages = converations
|
403 |
+
input = ''
|
404 |
+
|
405 |
+
conversation = []
|
406 |
+
while messages and messages[0]['from'] == 'gpt':
|
407 |
+
# Skip the first one if it is from gpt
|
408 |
+
messages = messages[1:]
|
409 |
+
for msg in messages:
|
410 |
+
if msg['from'] == 'human':
|
411 |
+
if DEFAULT_IMAGE_TOKEN in msg['value']:
|
412 |
+
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
|
413 |
+
'').strip()
|
414 |
+
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
|
415 |
+
msg['value'] = msg['value'].strip()
|
416 |
+
input += msg['value']
|
417 |
+
|
418 |
+
elif msg['from'] == 'gpt':
|
419 |
+
conversation.append({'input': input, 'output': msg['value']})
|
420 |
+
input = ''
|
421 |
+
else:
|
422 |
+
raise NotImplementedError
|
423 |
+
return conversation
|
424 |
+
|
425 |
+
|
426 |
+
class OspreyShortDescriptionDataset(OspreyDataset):
|
427 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
428 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
429 |
+
IMG_START_TOKEN = '<img>'
|
430 |
+
IMG_END_TOKEN = '</img>'
|
431 |
+
|
432 |
+
VP_START_TOKEN = '<vp>'
|
433 |
+
VP_END_TOKEN = '</vp>'
|
434 |
+
|
435 |
+
LIMIT = ' Answer the question using a single word or phrase.'
|
436 |
+
|
437 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
438 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
439 |
+
|
440 |
+
def __init__(self,
|
441 |
+
image_folder,
|
442 |
+
data_path=None,
|
443 |
+
tokenizer=None,
|
444 |
+
max_length=8196,
|
445 |
+
special_tokens=None,
|
446 |
+
template_map_fn=None,
|
447 |
+
extra_image_processor=None,
|
448 |
+
lazy=True,
|
449 |
+
repeats=1,
|
450 |
+
single_image_mode=False,
|
451 |
+
):
|
452 |
+
super(OspreyShortDescriptionDataset, self).__init__(
|
453 |
+
image_folder=image_folder,
|
454 |
+
data_path=data_path,
|
455 |
+
tokenizer=tokenizer,
|
456 |
+
max_length=max_length,
|
457 |
+
special_tokens=special_tokens,
|
458 |
+
template_map_fn=template_map_fn,
|
459 |
+
extra_image_processor=extra_image_processor,
|
460 |
+
lazy=lazy,
|
461 |
+
repeats=repeats,
|
462 |
+
single_image_mode=single_image_mode,
|
463 |
+
)
|
projects/llava_sam2/datasets/ReSAM2_Dataset.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
from datasets import Dataset as HFDataset
|
5 |
+
from datasets import DatasetDict, load_from_disk
|
6 |
+
from mmengine import print_log
|
7 |
+
from PIL import Image
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from xtuner.registry import BUILDER
|
12 |
+
from xtuner.dataset.huggingface import process_hf_dataset, build_origin_dataset
|
13 |
+
import copy
|
14 |
+
from .encode_fn import video_lisa_encode_fn
|
15 |
+
import json
|
16 |
+
import random
|
17 |
+
import pycocotools.mask as maskUtils
|
18 |
+
import cv2
|
19 |
+
import torchvision.transforms as T
|
20 |
+
from torchvision.transforms.functional import InterpolationMode
|
21 |
+
|
22 |
+
SEG_QUESTIONS = [
|
23 |
+
"Please segment the object according to the description: {class_name}",
|
24 |
+
]
|
25 |
+
|
26 |
+
SEG_QUESTIONS_SHORT = [
|
27 |
+
"Can you segment the {class_name} in this image?",
|
28 |
+
"Please segment {class_name} in this image.",
|
29 |
+
"What is {class_name} in this image? Please respond with segmentation mask.",
|
30 |
+
"What is {class_name} in this image? Please output segmentation mask.",
|
31 |
+
|
32 |
+
"Can you segment the {class_name} in this image",
|
33 |
+
"Please segment {class_name} in this image",
|
34 |
+
"What is {class_name} in this image? Please respond with segmentation mask",
|
35 |
+
"What is {class_name} in this image? Please output segmentation mask",
|
36 |
+
|
37 |
+
"Could you provide a segmentation mask for the {class_name} in this image?",
|
38 |
+
"Please identify and segment the {class_name} in this image.",
|
39 |
+
"Where is the {class_name} in this picture? Please respond with a segmentation mask.",
|
40 |
+
"Can you highlight the {class_name} in this image with a segmentation mask?",
|
41 |
+
|
42 |
+
"Could you provide a segmentation mask for the {class_name} in this image",
|
43 |
+
"Please identify and segment the {class_name} in this image",
|
44 |
+
"Where is the {class_name} in this picture? Please respond with a segmentation mask",
|
45 |
+
"Can you highlight the {class_name} in this image with a segmentation mask",
|
46 |
+
]
|
47 |
+
|
48 |
+
ANSWER_LIST = [
|
49 |
+
"It is [SEG].",
|
50 |
+
"Sure, [SEG].",
|
51 |
+
"Sure, it is [SEG].",
|
52 |
+
"Sure, the segmentation result is [SEG].",
|
53 |
+
"[SEG].",
|
54 |
+
]
|
55 |
+
|
56 |
+
class VideoSAM2Dataset(Dataset):
|
57 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
58 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
59 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
60 |
+
IMG_START_TOKEN = '<img>'
|
61 |
+
IMG_END_TOKEN = '</img>'
|
62 |
+
|
63 |
+
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
|
64 |
+
FAST_IMG_START_TOKEN = '<fast_img>'
|
65 |
+
FAST_IMG_END_TOKEN = '</fast_img>'
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
sam2_folder,
|
69 |
+
expression_file,
|
70 |
+
extra_image_processor=None,
|
71 |
+
tokenizer=None,
|
72 |
+
select_number=5,
|
73 |
+
sampled_frames=5,
|
74 |
+
offline_processed_text_folder=None,
|
75 |
+
template_map_fn=None,
|
76 |
+
max_length=8196,
|
77 |
+
lazy=True,
|
78 |
+
repeats=1,
|
79 |
+
special_tokens=None,
|
80 |
+
use_fast=False,
|
81 |
+
n_fast_images=50,
|
82 |
+
fast_pool_size=4,
|
83 |
+
mode='long',
|
84 |
+
frame_contiguous_sample=False,
|
85 |
+
):
|
86 |
+
assert mode in ['long', 'long_short', 'short']
|
87 |
+
self.mode = mode
|
88 |
+
self.cur_mode = mode
|
89 |
+
assert lazy is True
|
90 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
91 |
+
self.select_number = select_number
|
92 |
+
self.sampled_frames = sampled_frames
|
93 |
+
assert offline_processed_text_folder or (expression_file and tokenizer)
|
94 |
+
self.lazy = lazy
|
95 |
+
|
96 |
+
self.max_length = max_length
|
97 |
+
|
98 |
+
self.template_map_fn = template_map_fn
|
99 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
100 |
+
_type = self.template_map_fn['type']
|
101 |
+
del self.template_map_fn['type']
|
102 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
103 |
+
|
104 |
+
if offline_processed_text_folder and expression_file:
|
105 |
+
print_log(
|
106 |
+
'Both `offline_processed_text_folder` and '
|
107 |
+
'`data_path` are set, and we load dataset from'
|
108 |
+
'`offline_processed_text_folder` '
|
109 |
+
f'({offline_processed_text_folder})',
|
110 |
+
logger='current',
|
111 |
+
level=logging.WARNING)
|
112 |
+
|
113 |
+
if offline_processed_text_folder is not None:
|
114 |
+
raise NotImplementedError
|
115 |
+
else:
|
116 |
+
video_ids, anno_dict = self.json_file_preprocess(expression_file)
|
117 |
+
if self.lazy:
|
118 |
+
self.video_ids = video_ids
|
119 |
+
self.anno_dict = anno_dict
|
120 |
+
else:
|
121 |
+
raise NotImplementedError
|
122 |
+
|
123 |
+
self.sam2_folder = sam2_folder
|
124 |
+
if extra_image_processor is not None:
|
125 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
126 |
+
self.down_ratio = 1
|
127 |
+
self.repeats = repeats
|
128 |
+
|
129 |
+
self._system = ''
|
130 |
+
|
131 |
+
self.downsample_ratio = 0.5
|
132 |
+
self.image_size = 448
|
133 |
+
patch_size = 14
|
134 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
135 |
+
|
136 |
+
self.transformer = T.Compose([
|
137 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
138 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
139 |
+
T.ToTensor(),
|
140 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
141 |
+
])
|
142 |
+
|
143 |
+
if special_tokens is not None:
|
144 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
145 |
+
|
146 |
+
self.use_fast = use_fast
|
147 |
+
self.n_fast_images = n_fast_images
|
148 |
+
self.fast_pool_size = fast_pool_size
|
149 |
+
|
150 |
+
self.frame_contiguous_sample = frame_contiguous_sample
|
151 |
+
|
152 |
+
# for visualization debug
|
153 |
+
self.save_folder = './work_dirs/video_debug/'
|
154 |
+
self.cur_number = 0
|
155 |
+
|
156 |
+
print("Video res dataset (ref-sam2), include {} items.".format(len(self.video_ids)))
|
157 |
+
|
158 |
+
def __len__(self):
|
159 |
+
return len(self.video_ids) * self.repeats
|
160 |
+
|
161 |
+
@property
|
162 |
+
def modality_length(self):
|
163 |
+
length_list = []
|
164 |
+
for data_dict in self.video_ids:
|
165 |
+
cur_len = 20000
|
166 |
+
length_list.append(cur_len)
|
167 |
+
return length_list
|
168 |
+
|
169 |
+
def real_len(self):
|
170 |
+
return len(self.video_ids)
|
171 |
+
|
172 |
+
def json_file_preprocess(self, expression_file):
|
173 |
+
# prepare expression annotation files
|
174 |
+
with open(expression_file, 'r') as f:
|
175 |
+
expression_datas = json.load(f)
|
176 |
+
|
177 |
+
video_ids = list(expression_datas.keys())
|
178 |
+
return video_ids, expression_datas
|
179 |
+
|
180 |
+
def dataset_map_fn(self, objects_expression_infos, n_frames, n_fast_frames=0):
|
181 |
+
# prepare text
|
182 |
+
if self.mode == 'long':
|
183 |
+
expressions = [object_info['formated'] for object_info in objects_expression_infos]
|
184 |
+
self.cur_mode = self.mode
|
185 |
+
elif self.mode == 'short':
|
186 |
+
expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps'])-1)] for object_info in objects_expression_infos]
|
187 |
+
self.cur_mode = self.mode
|
188 |
+
else:
|
189 |
+
if random.random() < 0.5:
|
190 |
+
expressions = [object_info['formated'] for object_info in objects_expression_infos]
|
191 |
+
self.cur_mode = 'long'
|
192 |
+
else:
|
193 |
+
expressions = [object_info['short_caps'][random.randint(0, len(object_info['short_caps']) - 1)] for
|
194 |
+
object_info in objects_expression_infos]
|
195 |
+
self.cur_mode = 'short'
|
196 |
+
text_dict = self.prepare_text(n_frames, expressions, num_image_tokens=self.patch_token,
|
197 |
+
n_fast_frames=n_fast_frames)
|
198 |
+
ret = {'conversation': text_dict['conversation']}
|
199 |
+
return ret
|
200 |
+
|
201 |
+
def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_frames=0):
|
202 |
+
|
203 |
+
if self.use_fast:
|
204 |
+
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
|
205 |
+
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_frames * self.fast_pool_size * self.fast_pool_size}' \
|
206 |
+
f'{self.FAST_IMG_END_TOKEN}' + '\n'
|
207 |
+
else:
|
208 |
+
fast_frame_token_str = ''
|
209 |
+
|
210 |
+
frame_token_str = f'{self.IMG_START_TOKEN}' \
|
211 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
212 |
+
f'{self.IMG_END_TOKEN}'
|
213 |
+
|
214 |
+
questions = []
|
215 |
+
answers = []
|
216 |
+
for i, exp in enumerate(expressions):
|
217 |
+
if self.cur_mode == 'short':
|
218 |
+
question_template = random.choice(SEG_QUESTIONS_SHORT)
|
219 |
+
exp = exp.replace("A ", '')
|
220 |
+
else:
|
221 |
+
question_template = random.choice(SEG_QUESTIONS)
|
222 |
+
questions.append(question_template.format(class_name=exp))
|
223 |
+
answers.append(random.choice(ANSWER_LIST))
|
224 |
+
qa_list = []
|
225 |
+
for i, (question, answer) in enumerate(zip(questions, answers)):
|
226 |
+
if i == 0:
|
227 |
+
frame_tokens = frame_token_str + '\n'
|
228 |
+
# frame_tokens = '=' + ' '
|
229 |
+
frame_tokens = frame_tokens * n_frames
|
230 |
+
frame_tokens = frame_tokens.strip()
|
231 |
+
frame_tokens = fast_frame_token_str + frame_tokens
|
232 |
+
qa_list.append(
|
233 |
+
{'from': 'human', 'value': frame_tokens + question}
|
234 |
+
)
|
235 |
+
else:
|
236 |
+
qa_list.append(
|
237 |
+
{'from': 'human', 'value': question}
|
238 |
+
)
|
239 |
+
qa_list.append(
|
240 |
+
{'from': 'gpt', 'value': answer}
|
241 |
+
)
|
242 |
+
|
243 |
+
input = ''
|
244 |
+
conversation = []
|
245 |
+
for msg in qa_list:
|
246 |
+
if msg['from'] == 'human':
|
247 |
+
input += msg['value']
|
248 |
+
elif msg['from'] == 'gpt':
|
249 |
+
conversation.append({'input': input, 'output': msg['value']})
|
250 |
+
input = ''
|
251 |
+
else:
|
252 |
+
raise NotImplementedError
|
253 |
+
|
254 |
+
# add system information
|
255 |
+
conversation[0].update({'system': self._system})
|
256 |
+
return {'conversation': conversation}
|
257 |
+
|
258 |
+
def __getitem__(self, index):
|
259 |
+
index = index % self.real_len()
|
260 |
+
video_id = self.video_ids[index]
|
261 |
+
expression_dict = self.anno_dict[video_id]
|
262 |
+
object_ids = list(expression_dict['objects'].keys())
|
263 |
+
|
264 |
+
video_path = os.path.join(self.sam2_folder, expression_dict['video_path'])
|
265 |
+
anno_path = os.path.join(self.sam2_folder, expression_dict['anno_path'])
|
266 |
+
|
267 |
+
video_frames = get_video_frames(video_path)
|
268 |
+
|
269 |
+
if self.use_fast:
|
270 |
+
# sample fast branch
|
271 |
+
fast_interval = len(video_frames) / (self.n_fast_images + 1e-4)
|
272 |
+
sampled_fast_frame_idxs = [min(int(i * fast_interval), len(video_frames) - 1) for i in range(self.n_fast_images)]
|
273 |
+
fast_video_frames = [video_frames[_idx] for _idx in sampled_fast_frame_idxs]
|
274 |
+
else:
|
275 |
+
fast_video_frames = None
|
276 |
+
|
277 |
+
video_frames = video_frames[::4]
|
278 |
+
|
279 |
+
# mask annotation
|
280 |
+
with open(anno_path, 'r') as f:
|
281 |
+
mask_data = json.load(f)
|
282 |
+
masklents = decode_masklet(mask_data['masklet'])
|
283 |
+
|
284 |
+
n_frames = len(masklents)
|
285 |
+
n_objects = len(object_ids)
|
286 |
+
|
287 |
+
# sample object
|
288 |
+
if n_objects > self.select_number:
|
289 |
+
selected_indexes = np.random.choice(n_objects, self.select_number)
|
290 |
+
else:
|
291 |
+
selected_indexes = np.random.choice(n_objects, self.select_number, replace=True)
|
292 |
+
|
293 |
+
selected_object_ids = [object_ids[_idx] for _idx in selected_indexes]
|
294 |
+
objects_expression_infos = [expression_dict['objects'][_idx] for _idx in selected_object_ids]
|
295 |
+
_masklents = []
|
296 |
+
for _mask in masklents:
|
297 |
+
_mask_selected = []
|
298 |
+
for _idx in selected_object_ids:
|
299 |
+
_mask_selected.append(_mask[:, :, int(_idx)])
|
300 |
+
_mask_selected = np.stack(_mask_selected, axis=2)
|
301 |
+
_masklents.append(_mask_selected)
|
302 |
+
masklents = _masklents
|
303 |
+
|
304 |
+
# sample video frames
|
305 |
+
# prepare images, random select k frames
|
306 |
+
if n_frames > self.sampled_frames + 1:
|
307 |
+
if self.frame_contiguous_sample and random.random() < 0.5:
|
308 |
+
# do contiguous sample
|
309 |
+
selected_start_frame = np.random.choice(n_frames - self.sampled_frames, 1, replace=False)
|
310 |
+
selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(self.sampled_frames)]
|
311 |
+
else:
|
312 |
+
selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=False)
|
313 |
+
else:
|
314 |
+
selected_frame_indexes = np.random.choice(n_frames, self.sampled_frames, replace=True)
|
315 |
+
selected_frame_indexes.sort()
|
316 |
+
|
317 |
+
video_frames = [video_frames[_idx] for _idx in selected_frame_indexes]
|
318 |
+
masklents = [masklents[_idx] for _idx in selected_frame_indexes]
|
319 |
+
|
320 |
+
data_dict = self.dataset_map_fn(objects_expression_infos, len(video_frames), n_fast_frames=self.n_fast_images)
|
321 |
+
result = self.template_map_fn(data_dict)
|
322 |
+
data_dict.update(result)
|
323 |
+
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length, with_image_token=True)
|
324 |
+
data_dict.update(result)
|
325 |
+
|
326 |
+
pixel_values = []
|
327 |
+
extra_pixel_values = []
|
328 |
+
for frame in video_frames:
|
329 |
+
frame = frame[:, :, ::-1]
|
330 |
+
frame_image = Image.fromarray(frame).convert('RGB')
|
331 |
+
ori_width, ori_height = frame_image.size
|
332 |
+
if self.extra_image_processor is not None:
|
333 |
+
g_image = np.array(frame_image) # for grounding
|
334 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
335 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
336 |
+
extra_pixel_values.append(g_pixel_values)
|
337 |
+
|
338 |
+
frame_image = self.transformer(frame_image)
|
339 |
+
pixel_values.append(frame_image)
|
340 |
+
|
341 |
+
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
|
342 |
+
data_dict['pixel_values'] = pixel_values
|
343 |
+
if self.extra_image_processor is not None:
|
344 |
+
data_dict['g_pixel_values'] = extra_pixel_values
|
345 |
+
|
346 |
+
# for fast branch
|
347 |
+
if self.use_fast:
|
348 |
+
fast_pixel_values = []
|
349 |
+
for frame_image in fast_video_frames:
|
350 |
+
frame = frame_image[:, :, ::-1]
|
351 |
+
frame_image = Image.fromarray(frame).convert('RGB')
|
352 |
+
ori_width, ori_height = frame_image.size
|
353 |
+
|
354 |
+
frame_image = self.transformer(frame_image)
|
355 |
+
fast_pixel_values.append(frame_image)
|
356 |
+
|
357 |
+
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
|
358 |
+
data_dict['fast_pixel_values'] = fast_pixel_values
|
359 |
+
|
360 |
+
# process and get masks
|
361 |
+
masklents = np.stack(masklents, axis=0) # (n_frames, h, w, n_obj)
|
362 |
+
masklents = torch.from_numpy(masklents).permute(3, 0, 1, 2)
|
363 |
+
masklents = masklents.flatten(0, 1)
|
364 |
+
# print('sam2-mask_shape:', masklents.shape)
|
365 |
+
# print('sam2-pixel_values:', data_dict['pixel_values'].shape)
|
366 |
+
# print('sam2-g_pixel_values:', len(data_dict['g_pixel_values']), ', ', data_dict['g_pixel_values'][0].shape)
|
367 |
+
data_dict['masks'] = masklents
|
368 |
+
data_dict['type'] = 'video'
|
369 |
+
return data_dict
|
370 |
+
|
371 |
+
def visualization_debug(self, data_dict):
|
372 |
+
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
|
373 |
+
if not os.path.exists(save_folder):
|
374 |
+
os.mkdir(save_folder)
|
375 |
+
self.cur_number += 1
|
376 |
+
|
377 |
+
# images
|
378 |
+
|
379 |
+
show_images = []
|
380 |
+
|
381 |
+
pixel_values = data_dict['pixel_values']
|
382 |
+
save_folder_image = os.path.join(save_folder, 'image')
|
383 |
+
if not os.path.exists(save_folder_image):
|
384 |
+
os.mkdir(save_folder_image)
|
385 |
+
for i_image, image_pixel_value in enumerate(pixel_values):
|
386 |
+
# print(image_pixel_value.shape)
|
387 |
+
image_pixel_value[0] = image_pixel_value[0] * 0.2686
|
388 |
+
image_pixel_value[1] = image_pixel_value[1] * 0.2613
|
389 |
+
image_pixel_value[2] = image_pixel_value[2] * 0.2757
|
390 |
+
image_pixel_value[0] = image_pixel_value[0] + 0.4814
|
391 |
+
image_pixel_value[1] = image_pixel_value[1] + 0.4578
|
392 |
+
image_pixel_value[2] = image_pixel_value[2] + 0.4082
|
393 |
+
image_pixel_value = image_pixel_value * 255
|
394 |
+
image_pixel_value = image_pixel_value.permute(1, 2, 0)
|
395 |
+
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
|
396 |
+
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
|
397 |
+
# print(image_pixel_value.shape)
|
398 |
+
show_images.append(image_pixel_value)
|
399 |
+
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
|
400 |
+
|
401 |
+
# text
|
402 |
+
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
|
403 |
+
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
|
404 |
+
json.dump([input_text], f)
|
405 |
+
|
406 |
+
# masks
|
407 |
+
save_folder_mask = os.path.join(save_folder, 'mask')
|
408 |
+
if not os.path.exists(save_folder_mask):
|
409 |
+
os.mkdir(save_folder_mask)
|
410 |
+
n_frames = len(pixel_values)
|
411 |
+
masks = data_dict['masks']
|
412 |
+
_, h, w = masks.shape
|
413 |
+
masks = masks.reshape(-1, n_frames, h, w)
|
414 |
+
for i_obj, obj_masks in enumerate(masks):
|
415 |
+
save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
|
416 |
+
if not os.path.exists(save_folder_mask_obj_folder):
|
417 |
+
os.mkdir(save_folder_mask_obj_folder)
|
418 |
+
for i_frame, f_mask in enumerate(obj_masks):
|
419 |
+
f_mask = f_mask.numpy()
|
420 |
+
f_mask = f_mask * 255
|
421 |
+
f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
|
422 |
+
f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
|
423 |
+
f_mask = f_mask.astype(np.uint8)
|
424 |
+
cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
|
425 |
+
return
|
426 |
+
|
427 |
+
def get_video_frames(video_path):
|
428 |
+
cap = cv2.VideoCapture(video_path)
|
429 |
+
|
430 |
+
if not cap.isOpened():
|
431 |
+
print("Error: Cannot open video file.")
|
432 |
+
return
|
433 |
+
|
434 |
+
frames = []
|
435 |
+
|
436 |
+
frame_id = 0
|
437 |
+
while True:
|
438 |
+
ret, frame = cap.read()
|
439 |
+
|
440 |
+
if not ret:
|
441 |
+
break
|
442 |
+
|
443 |
+
frames.append(frame)
|
444 |
+
|
445 |
+
frame_id += 1
|
446 |
+
|
447 |
+
cap.release()
|
448 |
+
return frames
|
449 |
+
|
450 |
+
|
451 |
+
def images_to_video(frames, video_name, fps=6):
|
452 |
+
height, width, layers = frames[0].shape
|
453 |
+
|
454 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
455 |
+
video = cv2.VideoWriter(video_name, fourcc, fps, (width, height))
|
456 |
+
|
457 |
+
for frame in frames:
|
458 |
+
video.write(frame)
|
459 |
+
|
460 |
+
# cv2.destroyAllWindows()
|
461 |
+
video.release()
|
462 |
+
return
|
463 |
+
|
464 |
+
def decode_masklet(masklet):
|
465 |
+
masks = []
|
466 |
+
for _rle in masklet:
|
467 |
+
mask = maskUtils.decode(_rle)
|
468 |
+
masks.append(mask)
|
469 |
+
return masks
|
470 |
+
|
471 |
+
def draw_mask(image, mask):
|
472 |
+
obj_mask = mask * 255
|
473 |
+
obj_mask = np.stack([obj_mask * 1, obj_mask * 0, obj_mask * 0], axis=2)
|
474 |
+
obj_mask = obj_mask * 0.5 + copy.deepcopy(image) * 0.5
|
475 |
+
obj_mask = obj_mask.astype(np.uint8)
|
476 |
+
return obj_mask
|
477 |
+
|
478 |
+
def add_mask2images(frames, masklets):
|
479 |
+
show_videos = []
|
480 |
+
for i_frames, (frame, masks) in enumerate(zip(frames, masklets)):
|
481 |
+
if i_frames == 0:
|
482 |
+
n_obj = masks.shape[-1]
|
483 |
+
for i_obj in range(n_obj):
|
484 |
+
show_videos.append([])
|
485 |
+
|
486 |
+
n_obj = masks.shape[-1]
|
487 |
+
for i_obj in range(n_obj):
|
488 |
+
show_videos[i_obj].append(draw_mask(copy.deepcopy(frame), masks[:, :, i_obj]))
|
489 |
+
return show_videos
|
projects/llava_sam2/datasets/ReVOS_Dataset.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
from typing import Literal
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from datasets import Dataset as HFDataset
|
7 |
+
from datasets import DatasetDict
|
8 |
+
from mmengine import print_log
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
from xtuner.registry import BUILDER
|
14 |
+
from xtuner.dataset.huggingface import build_origin_dataset
|
15 |
+
import copy
|
16 |
+
|
17 |
+
from .encode_fn import video_lisa_encode_fn
|
18 |
+
import json
|
19 |
+
import random
|
20 |
+
import pycocotools.mask as maskUtils
|
21 |
+
import cv2
|
22 |
+
import torchvision.transforms as T
|
23 |
+
from torchvision.transforms.functional import InterpolationMode
|
24 |
+
|
25 |
+
SEG_QUESTIONS = [
|
26 |
+
"Can you segment the {class_name} in this image?",
|
27 |
+
"Please segment {class_name} in this image.",
|
28 |
+
"What is {class_name} in this image? Please respond with segmentation mask.",
|
29 |
+
"What is {class_name} in this image? Please output segmentation mask.",
|
30 |
+
|
31 |
+
"Can you segment the {class_name} in this image",
|
32 |
+
"Please segment {class_name} in this image",
|
33 |
+
"What is {class_name} in this image? Please respond with segmentation mask",
|
34 |
+
"What is {class_name} in this image? Please output segmentation mask",
|
35 |
+
|
36 |
+
"Could you provide a segmentation mask for the {class_name} in this image?",
|
37 |
+
"Please identify and segment the {class_name} in this image.",
|
38 |
+
"Where is the {class_name} in this picture? Please respond with a segmentation mask.",
|
39 |
+
"Can you highlight the {class_name} in this image with a segmentation mask?",
|
40 |
+
|
41 |
+
"Could you provide a segmentation mask for the {class_name} in this image",
|
42 |
+
"Please identify and segment the {class_name} in this image",
|
43 |
+
"Where is the {class_name} in this picture? Please respond with a segmentation mask",
|
44 |
+
"Can you highlight the {class_name} in this image with a segmentation mask",
|
45 |
+
]
|
46 |
+
|
47 |
+
ANSWER_LIST = [
|
48 |
+
"It is [SEG].",
|
49 |
+
"Sure, [SEG].",
|
50 |
+
"Sure, it is [SEG].",
|
51 |
+
"Sure, the segmentation result is [SEG].",
|
52 |
+
"[SEG].",
|
53 |
+
]
|
54 |
+
|
55 |
+
class VideoReVOSDataset(Dataset):
|
56 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
57 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
58 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
59 |
+
IMG_START_TOKEN = '<img>'
|
60 |
+
IMG_END_TOKEN = '</img>'
|
61 |
+
|
62 |
+
FAST_IMG_CONTEXT_TOKEN = '<FAST_IMG_CONTEXT>'
|
63 |
+
FAST_IMG_START_TOKEN = '<fast_img>'
|
64 |
+
FAST_IMG_END_TOKEN = '</fast_img>'
|
65 |
+
|
66 |
+
def __init__(self,
|
67 |
+
image_folder,
|
68 |
+
expression_file,
|
69 |
+
mask_file,
|
70 |
+
extra_image_processor=None,
|
71 |
+
tokenizer=None,
|
72 |
+
select_number=5,
|
73 |
+
sampled_frames=10,
|
74 |
+
offline_processed_text_folder=None,
|
75 |
+
template_map_fn=None,
|
76 |
+
max_length=2048,
|
77 |
+
lazy=True,
|
78 |
+
repeats=1,
|
79 |
+
special_tokens=None,
|
80 |
+
frame_contiguous_sample=False,
|
81 |
+
use_fast=False,
|
82 |
+
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
|
83 |
+
preprocessor=None,
|
84 |
+
# only work if use_fast = True
|
85 |
+
n_fast_images=50,
|
86 |
+
fast_pool_size=4,
|
87 |
+
fast_token_after_question=False,
|
88 |
+
):
|
89 |
+
assert lazy is True
|
90 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
91 |
+
self.select_number = select_number
|
92 |
+
self.sampled_frames = sampled_frames
|
93 |
+
assert offline_processed_text_folder or (expression_file and tokenizer)
|
94 |
+
self.lazy = lazy
|
95 |
+
|
96 |
+
self.max_length = max_length
|
97 |
+
|
98 |
+
self.template_map_fn = template_map_fn
|
99 |
+
if isinstance(self.template_map_fn, dict) and self.lazy:
|
100 |
+
_type = self.template_map_fn['type']
|
101 |
+
del self.template_map_fn['type']
|
102 |
+
self.template_map_fn = _type(**self.template_map_fn)
|
103 |
+
|
104 |
+
if offline_processed_text_folder and expression_file:
|
105 |
+
print_log(
|
106 |
+
'Both `offline_processed_text_folder` and '
|
107 |
+
'`data_path` are set, and we load dataset from'
|
108 |
+
'`offline_processed_text_folder` '
|
109 |
+
f'({offline_processed_text_folder})',
|
110 |
+
logger='current',
|
111 |
+
level=logging.WARNING)
|
112 |
+
|
113 |
+
self.arch_type = arch_type
|
114 |
+
if self.arch_type == 'qwen':
|
115 |
+
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
|
116 |
+
self.IMG_START_TOKEN = '<|vision_start|>'
|
117 |
+
self.IMG_END_TOKEN = '<|vision_end|>'
|
118 |
+
elif self.arch_type == 'llava':
|
119 |
+
self.IMG_CONTEXT_TOKEN = '<image>'
|
120 |
+
self.IMG_START_TOKEN = ''
|
121 |
+
self.IMG_END_TOKEN = ''
|
122 |
+
|
123 |
+
|
124 |
+
if offline_processed_text_folder is not None:
|
125 |
+
raise NotImplementedError
|
126 |
+
else:
|
127 |
+
vid2metaid, metas, mask_dict = self.json_file_preprocess(expression_file, mask_file)
|
128 |
+
self.vid2metaid = vid2metaid
|
129 |
+
self.videos = list(self.vid2metaid.keys())
|
130 |
+
self.mask_dict = mask_dict
|
131 |
+
self.json_datas = metas
|
132 |
+
json_datas = metas
|
133 |
+
json_data = DatasetDict({'train': HFDataset.from_list(json_datas)})
|
134 |
+
if self.lazy:
|
135 |
+
self.text_data = build_origin_dataset(json_data, 'train')
|
136 |
+
else:
|
137 |
+
raise NotImplementedError
|
138 |
+
|
139 |
+
self.image_folder = image_folder
|
140 |
+
if extra_image_processor is not None:
|
141 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
142 |
+
self.down_ratio = 1
|
143 |
+
self.repeats = repeats
|
144 |
+
|
145 |
+
self._system = ''
|
146 |
+
|
147 |
+
self.downsample_ratio = 0.5
|
148 |
+
if self.arch_type == 'llava':
|
149 |
+
self.downsample_ratio = 1
|
150 |
+
self.image_size = 448
|
151 |
+
if self.arch_type == 'llava':
|
152 |
+
self.image_size = 336
|
153 |
+
patch_size = 14
|
154 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
155 |
+
if self.arch_type == 'qwen':
|
156 |
+
self.patch_token = 1
|
157 |
+
|
158 |
+
if preprocessor is None:
|
159 |
+
self.transformer = T.Compose([
|
160 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
161 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
162 |
+
T.ToTensor(),
|
163 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
164 |
+
])
|
165 |
+
self.preprocessor = None
|
166 |
+
else:
|
167 |
+
self.transformer = None
|
168 |
+
self.preprocessor = BUILDER.build(preprocessor)
|
169 |
+
|
170 |
+
if special_tokens is not None:
|
171 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
172 |
+
|
173 |
+
self.use_fast = use_fast
|
174 |
+
self.n_fast_images = n_fast_images
|
175 |
+
self.fast_pool_size = fast_pool_size
|
176 |
+
|
177 |
+
self.frame_contiguous_sample = frame_contiguous_sample
|
178 |
+
|
179 |
+
# for visualization debug
|
180 |
+
self.save_folder = './work_dirs/video_debug/'
|
181 |
+
self.cur_number = 0
|
182 |
+
|
183 |
+
# exist_thr
|
184 |
+
self.exist_thr = 8
|
185 |
+
self.fast_token_after_question = fast_token_after_question
|
186 |
+
if self.fast_token_after_question:
|
187 |
+
assert self.use_fast
|
188 |
+
|
189 |
+
print("Video res dataset, include {} items.".format(len(self.vid2metaid)))
|
190 |
+
|
191 |
+
def __len__(self):
|
192 |
+
return len(self.vid2metaid) * self.repeats
|
193 |
+
|
194 |
+
@property
|
195 |
+
def modality_length(self):
|
196 |
+
length_list = []
|
197 |
+
for data_dict in self.vid2metaid:
|
198 |
+
cur_len = 10000
|
199 |
+
length_list.append(cur_len)
|
200 |
+
return length_list
|
201 |
+
|
202 |
+
def real_len(self):
|
203 |
+
return len(self.vid2metaid)
|
204 |
+
|
205 |
+
def json_file_preprocess(self, expression_file, mask_file):
|
206 |
+
# prepare expression annotation files
|
207 |
+
with open(expression_file, 'r') as f:
|
208 |
+
expression_datas = json.load(f)['videos']
|
209 |
+
|
210 |
+
metas = []
|
211 |
+
anno_count = 0 # serve as anno_id
|
212 |
+
vid2metaid = {}
|
213 |
+
for vid_name in expression_datas:
|
214 |
+
vid_express_data = expression_datas[vid_name]
|
215 |
+
|
216 |
+
vid_frames = sorted(vid_express_data['frames'])
|
217 |
+
vid_len = len(vid_frames)
|
218 |
+
|
219 |
+
exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
|
220 |
+
for exp_id in exp_id_list:
|
221 |
+
exp_dict = vid_express_data['expressions'][exp_id]
|
222 |
+
meta = {}
|
223 |
+
meta['video'] = vid_name
|
224 |
+
meta['exp'] = exp_dict['exp'] # str
|
225 |
+
meta['mask_anno_id'] = exp_dict['anno_id']
|
226 |
+
|
227 |
+
if 'obj_id' in exp_dict.keys():
|
228 |
+
meta['obj_id'] = exp_dict['obj_id']
|
229 |
+
else:
|
230 |
+
meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
|
231 |
+
meta['anno_id'] = [str(anno_count), ]
|
232 |
+
anno_count += 1
|
233 |
+
meta['frames'] = vid_frames
|
234 |
+
meta['exp_id'] = exp_id
|
235 |
+
|
236 |
+
meta['length'] = vid_len
|
237 |
+
metas.append(meta)
|
238 |
+
if vid_name not in vid2metaid.keys():
|
239 |
+
vid2metaid[vid_name] = []
|
240 |
+
vid2metaid[vid_name].append(len(metas) - 1)
|
241 |
+
|
242 |
+
# process mask annotation files
|
243 |
+
with open(mask_file, 'rb') as f:
|
244 |
+
mask_dict = json.load(f)
|
245 |
+
|
246 |
+
return vid2metaid, metas, mask_dict
|
247 |
+
|
248 |
+
def create_img_to_refs_mapping(self, refs_train):
|
249 |
+
img2refs = {}
|
250 |
+
for ref in refs_train:
|
251 |
+
img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ]
|
252 |
+
return img2refs
|
253 |
+
|
254 |
+
def decode_mask(self, video_masks, image_size):
|
255 |
+
ret_masks = []
|
256 |
+
for object_masks in video_masks:
|
257 |
+
# None object
|
258 |
+
if len(object_masks) == 0:
|
259 |
+
if len(ret_masks) != 0:
|
260 |
+
_object_masks = ret_masks[0] * 0
|
261 |
+
else:
|
262 |
+
_object_masks = np.zeros(
|
263 |
+
(self.sampled_frames, image_size[0], image_size[1]), dtype=np.uint8)
|
264 |
+
else:
|
265 |
+
_object_masks = []
|
266 |
+
for i_frame in range(len(object_masks[0])):
|
267 |
+
_mask = np.zeros(image_size, dtype=np.uint8)
|
268 |
+
for i_anno in range(len(object_masks)):
|
269 |
+
if object_masks[i_anno][i_frame] is None:
|
270 |
+
continue
|
271 |
+
m = maskUtils.decode(object_masks[i_anno][i_frame])
|
272 |
+
if m.ndim == 3:
|
273 |
+
m = m.sum(axis=2).astype(np.uint8)
|
274 |
+
else:
|
275 |
+
m = m.astype(np.uint8)
|
276 |
+
_mask = _mask | m
|
277 |
+
_object_masks.append(_mask)
|
278 |
+
_object_masks = np.stack(_object_masks, axis=0)
|
279 |
+
# if self.pad_image_to_square:
|
280 |
+
# _object_masks = expand2square_mask(_object_masks)
|
281 |
+
ret_masks.append(_object_masks)
|
282 |
+
_shape = ret_masks[0].shape
|
283 |
+
for item in ret_masks:
|
284 |
+
if item.shape != _shape:
|
285 |
+
print([_ret_mask.shape for _ret_mask in ret_masks])
|
286 |
+
return None
|
287 |
+
ret_masks = np.stack(ret_masks, axis=0) # (n_obj, n_frames, h, w)
|
288 |
+
|
289 |
+
ret_masks = torch.from_numpy(ret_masks)
|
290 |
+
# ret_masks = F.interpolate(ret_masks, size=(self.image_size // self.down_ratio,
|
291 |
+
# self.image_size // self.down_ratio), mode='nearest')
|
292 |
+
ret_masks = ret_masks.flatten(0, 1)
|
293 |
+
return ret_masks
|
294 |
+
|
295 |
+
def dataset_map_fn(self, data_dict, select_k=5):
|
296 |
+
images = []
|
297 |
+
|
298 |
+
len_frames = len(data_dict[0]['frames'])
|
299 |
+
for objet_info in data_dict:
|
300 |
+
assert len_frames == len(objet_info['frames'])
|
301 |
+
|
302 |
+
# prepare images, random select k frames
|
303 |
+
if len_frames > select_k + 1:
|
304 |
+
if self.frame_contiguous_sample and random.random() < 0.5:
|
305 |
+
# do contiguous sample
|
306 |
+
selected_start_frame = np.random.choice(len_frames - select_k, 1, replace=False)
|
307 |
+
selected_frame_indexes = [selected_start_frame[0] + _i for _i in range(select_k)]
|
308 |
+
else:
|
309 |
+
selected_frame_indexes = np.random.choice(len_frames, select_k, replace=False)
|
310 |
+
else:
|
311 |
+
selected_frame_indexes = np.random.choice(len_frames, select_k, replace=True)
|
312 |
+
selected_frame_indexes.sort()
|
313 |
+
|
314 |
+
if self.use_fast:
|
315 |
+
# sample fast branch
|
316 |
+
fast_interval = len_frames / (self.n_fast_images + 1e-4)
|
317 |
+
sampled_fast_frame_idxs = [min(int(i * fast_interval), len_frames - 1) for i in range(self.n_fast_images)]
|
318 |
+
fast_video_frames = []
|
319 |
+
for selected_frame_index in sampled_fast_frame_idxs:
|
320 |
+
frame_id = data_dict[0]['frames'][selected_frame_index]
|
321 |
+
fast_video_frames.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
|
322 |
+
else:
|
323 |
+
fast_video_frames = None
|
324 |
+
sampled_fast_frame_idxs = None
|
325 |
+
|
326 |
+
for selected_frame_index in selected_frame_indexes:
|
327 |
+
frame_id = data_dict[0]['frames'][selected_frame_index]
|
328 |
+
images.append(os.path.join(data_dict[0]['video'], frame_id + '.jpg'))
|
329 |
+
|
330 |
+
# prepare text
|
331 |
+
expressions = [object_info['exp'] for object_info in data_dict]
|
332 |
+
if self.use_fast:
|
333 |
+
text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token,
|
334 |
+
n_fast_images=len(fast_video_frames),)
|
335 |
+
else:
|
336 |
+
text_dict = self.prepare_text(select_k, expressions, num_image_tokens=self.patch_token)
|
337 |
+
|
338 |
+
|
339 |
+
# prepare masks
|
340 |
+
video_masks = []
|
341 |
+
for object_info in data_dict:
|
342 |
+
anno_ids = object_info['mask_anno_id']
|
343 |
+
# print('anno_ids: ', anno_ids)
|
344 |
+
obj_masks = []
|
345 |
+
for anno_id in anno_ids:
|
346 |
+
anno_id = str(anno_id)
|
347 |
+
frames_masks = self.mask_dict[anno_id]
|
348 |
+
frames_masks_ = []
|
349 |
+
for frame_idx in selected_frame_indexes:
|
350 |
+
frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
|
351 |
+
obj_masks.append(frames_masks_)
|
352 |
+
video_masks.append(obj_masks)
|
353 |
+
|
354 |
+
if self.use_fast:
|
355 |
+
fast_video_masks = []
|
356 |
+
assert sampled_fast_frame_idxs is not None
|
357 |
+
for object_info in data_dict:
|
358 |
+
anno_ids = object_info['mask_anno_id']
|
359 |
+
obj_masks = []
|
360 |
+
for anno_id in anno_ids:
|
361 |
+
anno_id = str(anno_id)
|
362 |
+
frames_masks = self.mask_dict[anno_id]
|
363 |
+
frames_masks_ = []
|
364 |
+
for frame_idx in sampled_fast_frame_idxs:
|
365 |
+
frames_masks_.append(copy.deepcopy(frames_masks[frame_idx]))
|
366 |
+
obj_masks.append(frames_masks_)
|
367 |
+
fast_video_masks.append(obj_masks)
|
368 |
+
else:
|
369 |
+
fast_video_masks = None
|
370 |
+
|
371 |
+
ret = {'images': images, 'video_masks': video_masks, 'conversation': text_dict['conversation'],
|
372 |
+
'fast_images': fast_video_frames, 'fast_video_masks': fast_video_masks}
|
373 |
+
return ret
|
374 |
+
|
375 |
+
def prepare_text(self, n_frames, expressions, num_image_tokens=256, n_fast_images=50):
|
376 |
+
|
377 |
+
if self.use_fast and not self.fast_token_after_question:
|
378 |
+
fast_frame_token_str = f'{self.FAST_IMG_START_TOKEN}' \
|
379 |
+
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
|
380 |
+
f'{self.FAST_IMG_END_TOKEN}' + '\n'
|
381 |
+
else:
|
382 |
+
fast_frame_token_str = ''
|
383 |
+
|
384 |
+
frame_token_str = f'{self.IMG_START_TOKEN}' \
|
385 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
386 |
+
f'{self.IMG_END_TOKEN}'
|
387 |
+
if self.fast_token_after_question:
|
388 |
+
assert self.use_fast
|
389 |
+
after_question_str = f'{self.FAST_IMG_START_TOKEN}' \
|
390 |
+
f'{self.FAST_IMG_CONTEXT_TOKEN * n_fast_images * self.fast_pool_size * self.fast_pool_size}' \
|
391 |
+
f'{self.FAST_IMG_END_TOKEN}'
|
392 |
+
else:
|
393 |
+
after_question_str = ''
|
394 |
+
|
395 |
+
questions = []
|
396 |
+
answers = []
|
397 |
+
for i, exp in enumerate(expressions):
|
398 |
+
# the exp is a question
|
399 |
+
if '?' in exp:
|
400 |
+
questions.append(exp)
|
401 |
+
else:
|
402 |
+
exp = exp.replace('.', '').strip()
|
403 |
+
question_template = random.choice(SEG_QUESTIONS)
|
404 |
+
questions.append(question_template.format(class_name=exp.lower()))
|
405 |
+
|
406 |
+
answers.append(random.choice(ANSWER_LIST))
|
407 |
+
qa_list = []
|
408 |
+
for i, (question, answer) in enumerate(zip(questions, answers)):
|
409 |
+
if i == 0:
|
410 |
+
frame_tokens = frame_token_str + '\n'
|
411 |
+
# frame_tokens = '=' + ' '
|
412 |
+
frame_tokens = frame_tokens * n_frames
|
413 |
+
frame_tokens = frame_tokens.strip()
|
414 |
+
frame_tokens = fast_frame_token_str + frame_tokens
|
415 |
+
qa_list.append(
|
416 |
+
{'from': 'human', 'value': frame_tokens + question + after_question_str}
|
417 |
+
)
|
418 |
+
else:
|
419 |
+
qa_list.append(
|
420 |
+
{'from': 'human', 'value': question + after_question_str}
|
421 |
+
)
|
422 |
+
qa_list.append(
|
423 |
+
{'from': 'gpt', 'value': answer}
|
424 |
+
)
|
425 |
+
|
426 |
+
input = ''
|
427 |
+
conversation = []
|
428 |
+
for msg in qa_list:
|
429 |
+
if msg['from'] == 'human':
|
430 |
+
input += msg['value']
|
431 |
+
elif msg['from'] == 'gpt':
|
432 |
+
conversation.append({'input': input, 'output': msg['value']})
|
433 |
+
input = ''
|
434 |
+
else:
|
435 |
+
raise NotImplementedError
|
436 |
+
|
437 |
+
# add system information
|
438 |
+
conversation[0].update({'system': self._system})
|
439 |
+
return {'conversation': conversation}
|
440 |
+
|
441 |
+
def __getitem__(self, index):
|
442 |
+
index = index % self.real_len()
|
443 |
+
selected_video_objects = self.vid2metaid[self.videos[index]]
|
444 |
+
video_objects_infos = [copy.deepcopy(self.text_data[idx]) for idx in selected_video_objects]
|
445 |
+
|
446 |
+
if len(video_objects_infos) > self.select_number:
|
447 |
+
selected_indexes = np.random.choice(len(video_objects_infos), self.select_number)
|
448 |
+
video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
|
449 |
+
else:
|
450 |
+
selected_indexes = np.random.choice(len(video_objects_infos), self.select_number, replace=True)
|
451 |
+
video_objects_infos = [video_objects_infos[_idx] for _idx in selected_indexes]
|
452 |
+
|
453 |
+
data_dict = self.dataset_map_fn(video_objects_infos, select_k=self.sampled_frames)
|
454 |
+
|
455 |
+
assert 'images' in data_dict.keys()
|
456 |
+
pixel_values = []
|
457 |
+
extra_pixel_values = []
|
458 |
+
num_video_tokens = None
|
459 |
+
num_frame_tokens = None
|
460 |
+
if data_dict.get('images', None) is not None:
|
461 |
+
frames_files = data_dict['images']
|
462 |
+
frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
|
463 |
+
for frame_path in frames_files:
|
464 |
+
frame_image = Image.open(frame_path).convert('RGB')
|
465 |
+
ori_width, ori_height = frame_image.size
|
466 |
+
if self.extra_image_processor is not None:
|
467 |
+
g_image = np.array(frame_image) # for grounding
|
468 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
469 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
470 |
+
extra_pixel_values.append(g_pixel_values)
|
471 |
+
|
472 |
+
if self.preprocessor is not None:
|
473 |
+
pass
|
474 |
+
else:
|
475 |
+
frame_image = self.transformer(frame_image)
|
476 |
+
pixel_values.append(frame_image)
|
477 |
+
|
478 |
+
if self.preprocessor is not None:
|
479 |
+
if self.arch_type == 'qwen':
|
480 |
+
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
|
481 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
482 |
+
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
|
483 |
+
num_frame_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
|
484 |
+
num_frames = _data_dict['image_grid_thw'].shape[0]
|
485 |
+
num_video_tokens = num_frame_tokens * num_frames
|
486 |
+
elif self.arch_type == 'llava':
|
487 |
+
_data_dict = self.preprocessor(pixel_values, do_resize=True, size=(self.image_size, self.image_size))
|
488 |
+
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
|
489 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
490 |
+
else:
|
491 |
+
raise NotImplementedError
|
492 |
+
data_dict.update(_data_dict)
|
493 |
+
else:
|
494 |
+
pixel_values = torch.stack(pixel_values, dim=0) # (n_f, 3, h, w)
|
495 |
+
data_dict['pixel_values'] = pixel_values
|
496 |
+
if self.extra_image_processor is not None:
|
497 |
+
data_dict['g_pixel_values'] = extra_pixel_values
|
498 |
+
|
499 |
+
# process and get masks
|
500 |
+
masks = self.decode_mask(data_dict['video_masks'], image_size=(ori_height, ori_width))
|
501 |
+
if masks is None:
|
502 |
+
return self.__getitem__(random.randint(0, self.real_len()))
|
503 |
+
data_dict['masks'] = masks
|
504 |
+
else:
|
505 |
+
data_dict['pixel_values'] = torch.zeros(0, 3, self.image_size, self.image_size)
|
506 |
+
data_dict['masks'] = None
|
507 |
+
|
508 |
+
if num_video_tokens is not None:
|
509 |
+
assert self.patch_token == 1
|
510 |
+
input_str = data_dict['conversation'][0]['input']
|
511 |
+
input_str = input_str.replace(self.IMG_CONTEXT_TOKEN, self.IMG_CONTEXT_TOKEN * num_frame_tokens)
|
512 |
+
assert input_str.count(self.IMG_CONTEXT_TOKEN) == num_video_tokens
|
513 |
+
data_dict['conversation'][0]['input'] = input_str
|
514 |
+
|
515 |
+
result = self.template_map_fn(data_dict)
|
516 |
+
data_dict.update(result)
|
517 |
+
result = video_lisa_encode_fn(data_dict, tokenizer=self.tokenizer, max_length=self.max_length)
|
518 |
+
data_dict.update(result)
|
519 |
+
|
520 |
+
# for fast branch
|
521 |
+
if self.use_fast:
|
522 |
+
fast_pixel_values = []
|
523 |
+
frames_files = data_dict['fast_images']
|
524 |
+
frames_files = [os.path.join(self.image_folder, frame_file) for frame_file in frames_files]
|
525 |
+
for frame_path in frames_files:
|
526 |
+
frame_image = Image.open(frame_path).convert('RGB')
|
527 |
+
ori_width, ori_height = frame_image.size
|
528 |
+
|
529 |
+
frame_image = self.transformer(frame_image)
|
530 |
+
fast_pixel_values.append(frame_image)
|
531 |
+
|
532 |
+
fast_pixel_values = torch.stack(fast_pixel_values, dim=0) # (n_f, 3, h, w)
|
533 |
+
data_dict['fast_pixel_values'] = fast_pixel_values
|
534 |
+
|
535 |
+
# process and get masks
|
536 |
+
masks = self.decode_mask(data_dict['fast_video_masks'], image_size=(ori_height, ori_width))
|
537 |
+
|
538 |
+
if masks is None:
|
539 |
+
return self.__getitem__(random.randint(0, self.real_len()))
|
540 |
+
|
541 |
+
data_dict['fast_exists'] = masks.to(dtype=torch.int).sum(dim=(-2, -1)).ge(self.exist_thr).unsqueeze(-1)
|
542 |
+
|
543 |
+
|
544 |
+
del data_dict['fast_video_masks']
|
545 |
+
data_dict['type'] = 'video'
|
546 |
+
return data_dict
|
547 |
+
|
548 |
+
def visualization_debug(self, data_dict):
|
549 |
+
save_folder = os.path.join(self.save_folder, 'sample_{}'.format(self.cur_number))
|
550 |
+
if not os.path.exists(save_folder):
|
551 |
+
os.mkdir(save_folder)
|
552 |
+
self.cur_number += 1
|
553 |
+
|
554 |
+
# images
|
555 |
+
|
556 |
+
show_images = []
|
557 |
+
|
558 |
+
pixel_values = data_dict['pixel_values']
|
559 |
+
save_folder_image = os.path.join(save_folder, 'image')
|
560 |
+
if not os.path.exists(save_folder_image):
|
561 |
+
os.mkdir(save_folder_image)
|
562 |
+
for i_image, image_pixel_value in enumerate(pixel_values):
|
563 |
+
# print(image_pixel_value.shape)
|
564 |
+
image_pixel_value[0] = image_pixel_value[0] * 0.2686
|
565 |
+
image_pixel_value[1] = image_pixel_value[1] * 0.2613
|
566 |
+
image_pixel_value[2] = image_pixel_value[2] * 0.2757
|
567 |
+
image_pixel_value[0] = image_pixel_value[0] + 0.4814
|
568 |
+
image_pixel_value[1] = image_pixel_value[1] + 0.4578
|
569 |
+
image_pixel_value[2] = image_pixel_value[2] + 0.4082
|
570 |
+
image_pixel_value = image_pixel_value * 255
|
571 |
+
image_pixel_value = image_pixel_value.permute(1, 2, 0)
|
572 |
+
image_pixel_value = image_pixel_value.to(torch.uint8).numpy()
|
573 |
+
# print(os.path.join(save_folder_image, '{}.jpg'.format(i_image)))
|
574 |
+
# print(image_pixel_value.shape)
|
575 |
+
show_images.append(image_pixel_value)
|
576 |
+
cv2.imwrite(os.path.join(save_folder_image, '{}.jpg'.format(i_image)), image_pixel_value)
|
577 |
+
|
578 |
+
# text
|
579 |
+
input_text = self.tokenizer.decode(data_dict['input_ids'], skip_special_tokens=False)
|
580 |
+
with open(os.path.join(save_folder, 'text.json'), 'w') as f:
|
581 |
+
json.dump([input_text], f)
|
582 |
+
|
583 |
+
# masks
|
584 |
+
save_folder_mask = os.path.join(save_folder, 'mask')
|
585 |
+
if not os.path.exists(save_folder_mask):
|
586 |
+
os.mkdir(save_folder_mask)
|
587 |
+
n_frames = len(pixel_values)
|
588 |
+
masks = data_dict['masks']
|
589 |
+
_, h, w = masks.shape
|
590 |
+
masks = masks.reshape(-1, n_frames, h, w)
|
591 |
+
for i_obj, obj_masks in enumerate(masks):
|
592 |
+
save_folder_mask_obj_folder = os.path.join(save_folder_mask, 'obj_{}'.format(i_obj))
|
593 |
+
if not os.path.exists(save_folder_mask_obj_folder):
|
594 |
+
os.mkdir(save_folder_mask_obj_folder)
|
595 |
+
for i_frame, f_mask in enumerate(obj_masks):
|
596 |
+
f_mask = f_mask.numpy()
|
597 |
+
f_mask = f_mask * 255
|
598 |
+
f_mask = np.stack([f_mask * 1, f_mask * 0, f_mask * 0], axis=2)
|
599 |
+
f_mask = show_images[i_frame] * 0.3 + 0.7 * f_mask
|
600 |
+
f_mask = f_mask.astype(np.uint8)
|
601 |
+
cv2.imwrite(os.path.join(save_folder_mask_obj_folder, '{}.png'.format(i_frame)), f_mask)
|
602 |
+
return
|
projects/llava_sam2/datasets/RefCOCO_Dataset.py
ADDED
@@ -0,0 +1,338 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
from typing import Literal
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from mmengine import print_log
|
12 |
+
from mmengine.config import Config, ConfigDict
|
13 |
+
from PIL import Image
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
import numpy as np
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torchvision.transforms as T
|
18 |
+
from torchvision.transforms.functional import InterpolationMode
|
19 |
+
from pycocotools.coco import COCO
|
20 |
+
from pycocotools import mask as mask_utils
|
21 |
+
|
22 |
+
from xtuner.registry import BUILDER
|
23 |
+
from xtuner.utils import IGNORE_INDEX
|
24 |
+
from xtuner.dataset.utils import encode_fn
|
25 |
+
from xtuner.dataset.map_fns import llava_map_fn
|
26 |
+
|
27 |
+
from projects.glamm.datasets.utils.utils import expand2square
|
28 |
+
|
29 |
+
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
|
30 |
+
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
31 |
+
|
32 |
+
from third_parts.mmdet.datasets.refcoco import RefCocoDataset
|
33 |
+
|
34 |
+
from .utils import dynamic_preprocess
|
35 |
+
|
36 |
+
|
37 |
+
class ReferSegmDataset(RefCocoDataset):
|
38 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
39 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
40 |
+
IMG_START_TOKEN = '<img>'
|
41 |
+
IMG_END_TOKEN = '</img>'
|
42 |
+
|
43 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
44 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
45 |
+
|
46 |
+
def __init__(self,
|
47 |
+
data_root,
|
48 |
+
ann_file=None,
|
49 |
+
split_file=None,
|
50 |
+
special_tokens=None,
|
51 |
+
prompt_template=None,
|
52 |
+
extra_image_processor=None,
|
53 |
+
data_prefix=dict(img_path='train2014/'),
|
54 |
+
tokenizer=None,
|
55 |
+
max_length=2048,
|
56 |
+
num_classes_per_sample=3,
|
57 |
+
single_image_mode=False,
|
58 |
+
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
|
59 |
+
preprocessor=None,
|
60 |
+
**kwargs):
|
61 |
+
super().__init__(
|
62 |
+
data_root=data_root,
|
63 |
+
data_prefix=data_prefix,
|
64 |
+
pipeline=None,
|
65 |
+
ann_file=ann_file,
|
66 |
+
split_file=split_file,
|
67 |
+
**kwargs,
|
68 |
+
)
|
69 |
+
self.begin_str = f'{DEFAULT_IMAGE_TOKEN}\n'
|
70 |
+
if extra_image_processor is not None:
|
71 |
+
self.extra_image_processor = BUILDER.build(extra_image_processor)
|
72 |
+
|
73 |
+
self.arch_type = arch_type
|
74 |
+
if self.arch_type == 'qwen':
|
75 |
+
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
|
76 |
+
self.IMG_START_TOKEN = '<|vision_start|>'
|
77 |
+
self.IMG_END_TOKEN = '<|vision_end|>'
|
78 |
+
elif self.arch_type == 'llava':
|
79 |
+
self.IMG_CONTEXT_TOKEN = '<image>'
|
80 |
+
self.IMG_START_TOKEN = ''
|
81 |
+
self.IMG_END_TOKEN = ''
|
82 |
+
|
83 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
84 |
+
if special_tokens is not None:
|
85 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
86 |
+
|
87 |
+
self.image_folder = data_root
|
88 |
+
self.template = prompt_template
|
89 |
+
self.max_length = max_length
|
90 |
+
if self.arch_type == 'intern_vl':
|
91 |
+
# self._system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
|
92 |
+
self._system = ''
|
93 |
+
self.template['INSTRUCTION'] = '<|user|>\n{input}<|end|><|assistant|>\n'
|
94 |
+
elif self.arch_type == 'qwen':
|
95 |
+
self._system = ''
|
96 |
+
elif self.arch_type == 'llava':
|
97 |
+
self._system = ''
|
98 |
+
|
99 |
+
self.num_classes_per_sample = num_classes_per_sample
|
100 |
+
self.min_dynamic_patch = 1
|
101 |
+
self.max_dynamic_patch = 12
|
102 |
+
self.downsample_ratio = 0.5
|
103 |
+
if self.arch_type == 'llava':
|
104 |
+
self.downsample_ratio = 1
|
105 |
+
self.image_size = 448
|
106 |
+
if self.arch_type == 'llava':
|
107 |
+
self.image_size = 336
|
108 |
+
self.use_thumbnail = True
|
109 |
+
patch_size = 14
|
110 |
+
self.patch_token = int((self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
111 |
+
|
112 |
+
if preprocessor is None:
|
113 |
+
self.transformer = T.Compose([
|
114 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
115 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
116 |
+
T.ToTensor(),
|
117 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
118 |
+
])
|
119 |
+
self.preprocessor = None
|
120 |
+
else:
|
121 |
+
self.transformer = None
|
122 |
+
self.preprocessor = BUILDER.build(preprocessor)
|
123 |
+
self.arch_type = arch_type
|
124 |
+
self.single_image_mode = single_image_mode
|
125 |
+
self._max_refetch = 1000
|
126 |
+
|
127 |
+
print("Image RES dataset, include {} items.".format(len(self)))
|
128 |
+
|
129 |
+
@property
|
130 |
+
def modality_length(self):
|
131 |
+
import pickle
|
132 |
+
length_list = []
|
133 |
+
for idx in range(len(self)):
|
134 |
+
length_list.append(100)
|
135 |
+
return length_list
|
136 |
+
|
137 |
+
def _parse_annotations(self, ann_info):
|
138 |
+
image_path = ann_info['img_path']
|
139 |
+
image = Image.open(image_path).convert('RGB')
|
140 |
+
width, height = image.size
|
141 |
+
|
142 |
+
masks, phrases = [], []
|
143 |
+
instances, text = ann_info['instances'], ann_info['text']
|
144 |
+
# index = np.random.choice(range(len(instances)), min(
|
145 |
+
# len(instances), self.num_classes_per_sample))
|
146 |
+
index = np.random.choice(range(len(instances)), self.num_classes_per_sample, replace=True)
|
147 |
+
for idx in index:
|
148 |
+
inst = instances[idx]
|
149 |
+
phrase = text[idx].lower()
|
150 |
+
if '.' == phrase[-1]:
|
151 |
+
phrase = phrase[:-1]
|
152 |
+
phrases.append(phrase)
|
153 |
+
binary_mask = np.zeros((height, width), dtype=np.uint8)
|
154 |
+
for seg in inst["mask"]:
|
155 |
+
rles = mask_utils.frPyObjects([seg], height, width)
|
156 |
+
m = mask_utils.decode(rles)
|
157 |
+
m = m.astype(np.uint8)
|
158 |
+
binary_mask += m.squeeze()
|
159 |
+
masks.append(binary_mask)
|
160 |
+
|
161 |
+
conversation = []
|
162 |
+
for i, phrase in enumerate(phrases):
|
163 |
+
question = random.choice(SEG_QUESTIONS).format(class_name=phrase)
|
164 |
+
if i == 0:
|
165 |
+
question = self.begin_str + question
|
166 |
+
conversation.append({'from': 'human', 'value': question})
|
167 |
+
conversation.append({'from': 'gpt', 'value': random.choice(ANSWER_LIST)})
|
168 |
+
masks = torch.stack([torch.from_numpy(mask) for mask in masks], dim=0)
|
169 |
+
|
170 |
+
ann_info.update({
|
171 |
+
'masks': masks,
|
172 |
+
'conversations': conversation,
|
173 |
+
'image': image_path
|
174 |
+
})
|
175 |
+
return ann_info
|
176 |
+
|
177 |
+
def prepare_data(self, index):
|
178 |
+
data_dict = super().prepare_data(index)
|
179 |
+
data_dict = self._parse_annotations(data_dict)
|
180 |
+
if data_dict is None:
|
181 |
+
return None
|
182 |
+
|
183 |
+
out_data_dict = {}
|
184 |
+
if 'masks' in data_dict:
|
185 |
+
out_data_dict['masks'] = data_dict['masks']
|
186 |
+
|
187 |
+
if data_dict.get('image', None) is not None:
|
188 |
+
image_file = data_dict['image']
|
189 |
+
try:
|
190 |
+
image = Image.open(image_file).convert('RGB')
|
191 |
+
except Exception as e:
|
192 |
+
print(f'Error: {e}', flush=True)
|
193 |
+
print_log(f'Error: {e}', logger='current')
|
194 |
+
return None
|
195 |
+
if hasattr(self, 'extra_image_processor'):
|
196 |
+
g_image = np.array(image) # for grounding
|
197 |
+
g_image = self.extra_image_processor.apply_image(g_image)
|
198 |
+
g_pixel_values = torch.from_numpy(g_image).permute(2, 0, 1).contiguous()
|
199 |
+
out_data_dict['g_pixel_values'] = g_pixel_values
|
200 |
+
|
201 |
+
if self.single_image_mode:
|
202 |
+
images = [image]
|
203 |
+
else:
|
204 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
205 |
+
self.max_dynamic_patch,
|
206 |
+
self.image_size, self.use_thumbnail)
|
207 |
+
if self.preprocessor is not None:
|
208 |
+
if self.arch_type == 'qwen':
|
209 |
+
_data_dict = self.preprocessor(images, do_resize=True)
|
210 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
211 |
+
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
|
212 |
+
num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
|
213 |
+
elif self.arch_type == 'llava':
|
214 |
+
_data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
|
215 |
+
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
|
216 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
217 |
+
num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
|
218 |
+
else:
|
219 |
+
raise NotImplementedError
|
220 |
+
out_data_dict.update(_data_dict)
|
221 |
+
else:
|
222 |
+
pixel_values = [self.transformer(image) for image in images]
|
223 |
+
pixel_values = torch.stack(pixel_values)
|
224 |
+
out_data_dict['pixel_values'] = pixel_values
|
225 |
+
|
226 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
227 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
228 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
229 |
+
f'{self.IMG_END_TOKEN}'
|
230 |
+
token_dict = self.get_inputid_labels(data_dict['conversations'], image_token_str)
|
231 |
+
out_data_dict.update(token_dict)
|
232 |
+
else:
|
233 |
+
token_dict = self.get_inputid_labels(data_dict['conversations'], None)
|
234 |
+
out_data_dict.update(token_dict)
|
235 |
+
out_data_dict['pixel_values'] = torch.zeros(1, 3, self.image_size, self.image_size)
|
236 |
+
return out_data_dict
|
237 |
+
|
238 |
+
def get_inputid_labels(self, conversations, image_token_str) -> dict:
|
239 |
+
input = ''
|
240 |
+
out_conversation = []
|
241 |
+
while conversations and conversations[0]['from'] == 'gpt':
|
242 |
+
# Skip the first one if it is from gpt
|
243 |
+
conversations = conversations[1:]
|
244 |
+
for msg in conversations:
|
245 |
+
if msg['from'] == 'human':
|
246 |
+
if image_token_str is None and '<image>' in msg['value']:
|
247 |
+
msg['value'] = msg['value'].replace('<image>', '')
|
248 |
+
if '<image>' in msg['value']:
|
249 |
+
msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
|
250 |
+
input += msg['value'].strip()
|
251 |
+
elif msg['from'] == 'gpt':
|
252 |
+
out_conversation.append({
|
253 |
+
'input': input,
|
254 |
+
'output': msg['value'].strip()
|
255 |
+
})
|
256 |
+
input = ''
|
257 |
+
else:
|
258 |
+
raise NotImplementedError
|
259 |
+
|
260 |
+
input_ids, labels = [], []
|
261 |
+
for i, single_turn_conversation in enumerate(out_conversation):
|
262 |
+
input = single_turn_conversation.get('input', '')
|
263 |
+
if input is None:
|
264 |
+
input = ''
|
265 |
+
input_text = self.template.INSTRUCTION.format(
|
266 |
+
input=input, round=i + 1)
|
267 |
+
|
268 |
+
if i == 0:
|
269 |
+
if self._system != '' and self._system is not None:
|
270 |
+
system = self.template.SYSTEM.format(system=self._system)
|
271 |
+
input_text = system + input_text
|
272 |
+
input_encode = self.tokenizer.encode(
|
273 |
+
input_text, add_special_tokens=True)
|
274 |
+
else:
|
275 |
+
input_encode = self.tokenizer.encode(
|
276 |
+
input_text, add_special_tokens=False)
|
277 |
+
input_ids += input_encode
|
278 |
+
labels += [IGNORE_INDEX] * len(input_encode)
|
279 |
+
|
280 |
+
output_text = single_turn_conversation.get('output', '')
|
281 |
+
if self.template.get('SUFFIX', None):
|
282 |
+
output_text += self.template.SUFFIX
|
283 |
+
output_encode = self.tokenizer.encode(
|
284 |
+
output_text, add_special_tokens=False)
|
285 |
+
input_ids += output_encode
|
286 |
+
labels += copy.deepcopy(output_encode)
|
287 |
+
|
288 |
+
if len(input_ids) > self.max_length:
|
289 |
+
input_ids = input_ids[:self.max_length]
|
290 |
+
labels = labels[:self.max_length]
|
291 |
+
# print('len_ids: ', len(input_ids))
|
292 |
+
return {'input_ids': input_ids, 'labels': labels}
|
293 |
+
|
294 |
+
def __getitem__(self, index):
|
295 |
+
for _ in range(self._max_refetch + 1):
|
296 |
+
data = self.prepare_data(index)
|
297 |
+
# Broken images may cause the returned data to be None
|
298 |
+
if data is None:
|
299 |
+
index = self._rand_another()
|
300 |
+
continue
|
301 |
+
return data
|
302 |
+
|
303 |
+
|
304 |
+
if __name__ == '__main__':
|
305 |
+
from transformers import CLIPImageProcessor, AutoTokenizer
|
306 |
+
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
|
307 |
+
|
308 |
+
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
|
309 |
+
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
|
310 |
+
|
311 |
+
tokenizer = dict(
|
312 |
+
type=AutoTokenizer.from_pretrained,
|
313 |
+
pretrained_model_name_or_path=llm_name_or_path)
|
314 |
+
image_processor = dict(
|
315 |
+
type=CLIPImageProcessor.from_pretrained,
|
316 |
+
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
|
317 |
+
extra_image_processor = dict(
|
318 |
+
type=ResizeLongestSide,
|
319 |
+
target_length=1024,
|
320 |
+
)
|
321 |
+
from xtuner.utils.templates import PROMPT_TEMPLATE
|
322 |
+
|
323 |
+
prompt_template = PROMPT_TEMPLATE.vicuna
|
324 |
+
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
|
325 |
+
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
|
326 |
+
|
327 |
+
dataset = ReferSegmDataset(
|
328 |
+
tokenizer=tokenizer,
|
329 |
+
special_tokens=['[SEG]'],
|
330 |
+
extra_image_processor=extra_image_processor,
|
331 |
+
prompt_template=prompt_template,
|
332 |
+
data_root='data/coco/',
|
333 |
+
data_prefix=dict(img_path='train2014/'),
|
334 |
+
ann_file='refcoco+/instances.json',
|
335 |
+
split_file='refcoco+/refs(unc).p',
|
336 |
+
)
|
337 |
+
for i in range(1000):
|
338 |
+
dataset[i]
|
projects/llava_sam2/datasets/RefYoutubeVOS_Dataset.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ReVOS_Dataset import VideoReVOSDataset
|
2 |
+
import json
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
class VideoRefYoutubeVOSDataset(VideoReVOSDataset):
|
6 |
+
|
7 |
+
def json_file_preprocess(self, expression_file, mask_file):
|
8 |
+
# prepare expression annotation files
|
9 |
+
with open(expression_file, 'r') as f:
|
10 |
+
expression_datas = json.load(f)['videos']
|
11 |
+
|
12 |
+
metas = []
|
13 |
+
anno_count = 0 # serve as anno_id
|
14 |
+
vid2metaid = {}
|
15 |
+
for vid_name in expression_datas:
|
16 |
+
vid_express_data = expression_datas[vid_name]
|
17 |
+
|
18 |
+
vid_frames = sorted(vid_express_data['frames'])
|
19 |
+
vid_len = len(vid_frames)
|
20 |
+
|
21 |
+
exp_id_list = sorted(list(vid_express_data['expressions'].keys()))
|
22 |
+
for exp_id in exp_id_list:
|
23 |
+
exp_dict = vid_express_data['expressions'][exp_id]
|
24 |
+
meta = {}
|
25 |
+
meta['video'] = vid_name
|
26 |
+
meta['exp'] = exp_dict['exp'] # str
|
27 |
+
meta['mask_anno_id'] = [str(anno_count), ]
|
28 |
+
|
29 |
+
if 'obj_id' in exp_dict.keys():
|
30 |
+
meta['obj_id'] = exp_dict['obj_id']
|
31 |
+
else:
|
32 |
+
meta['obj_id'] = [0, ] # Ref-Youtube-VOS only has one object per expression
|
33 |
+
meta['anno_id'] = [str(anno_count), ]
|
34 |
+
anno_count += 1
|
35 |
+
meta['frames'] = vid_frames
|
36 |
+
meta['exp_id'] = exp_id
|
37 |
+
|
38 |
+
meta['length'] = vid_len
|
39 |
+
metas.append(meta)
|
40 |
+
if vid_name not in vid2metaid.keys():
|
41 |
+
vid2metaid[vid_name] = []
|
42 |
+
vid2metaid[vid_name].append(len(metas) - 1)
|
43 |
+
|
44 |
+
# process mask annotation files
|
45 |
+
with open(mask_file, 'rb') as f:
|
46 |
+
mask_dict = pickle.load(f)
|
47 |
+
return vid2metaid, metas, mask_dict
|
projects/llava_sam2/datasets/__init__.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .collect_fns import video_lisa_collate_fn
|
2 |
+
from .MeVIS_Dataset import VideoMeVISDataset
|
3 |
+
from .ReVOS_Dataset import VideoReVOSDataset
|
4 |
+
from .RefYoutubeVOS_Dataset import VideoRefYoutubeVOSDataset
|
5 |
+
from .encode_fn import video_lisa_encode_fn
|
6 |
+
from .RefCOCO_Dataset import ReferSegmDataset
|
7 |
+
from .ReSAM2_Dataset import VideoSAM2Dataset
|
8 |
+
from .vqa_dataset import LLaVADataset, InfinityMMDataset
|
9 |
+
|
10 |
+
from .GCG_Dataset import GranDfGCGDataset, FlickrGCGDataset, OpenPsgGCGDataset, RefCOCOgGCGDataset
|
11 |
+
from .Grand_Dataset import GranDDataset
|
12 |
+
|
13 |
+
from .Osprey_Dataset import OspreyDataset, OspreyDescriptionDataset, OspreyShortDescriptionDataset
|
14 |
+
|
15 |
+
from .ChatUniVi_Dataset import VideoChatUniViDataset
|
projects/llava_sam2/datasets/collect_fns.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Sequence
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.nn.utils.rnn import pad_sequence
|
6 |
+
|
7 |
+
from xtuner.parallel.sequence import (get_sequence_parallel_world_size,
|
8 |
+
pad_for_sequence_parallel)
|
9 |
+
from xtuner.utils import DEFAULT_PAD_TOKEN_INDEX, IGNORE_INDEX
|
10 |
+
|
11 |
+
|
12 |
+
def video_lisa_collate_fn(instances: Sequence[Dict],
|
13 |
+
pad_index: int = DEFAULT_PAD_TOKEN_INDEX,
|
14 |
+
return_hf_format: bool = False,
|
15 |
+
use_varlen_attn: bool = False):
|
16 |
+
seq_parallel_world_size = get_sequence_parallel_world_size()
|
17 |
+
|
18 |
+
input_ids, labels = [], []
|
19 |
+
has_image = any(inst.get('pixel_values') is not None for inst in instances)
|
20 |
+
has_pe = any(inst.get('image_grid_thw', None) is not None for inst in instances)
|
21 |
+
has_fast_image = any(inst.get('fast_pixel_values', None) is not None for inst in instances)
|
22 |
+
has_grounding_image = any(inst.get('g_pixel_values') is not None for inst in instances)
|
23 |
+
has_mask = any(inst.get('masks') is not None for inst in instances)
|
24 |
+
has_bboxes = any(inst.get('bboxes') is not None for inst in instances)
|
25 |
+
has_points = any(inst.get('points') is not None for inst in instances)
|
26 |
+
has_fast_exists = any(inst.get('fast_exists') is not None for inst in instances)
|
27 |
+
|
28 |
+
has_vp = any(inst.get('vp_overall_mask') is not None for inst in instances)
|
29 |
+
has_prompt_mask = any(inst.get('prompt_masks') is not None for inst in instances)
|
30 |
+
|
31 |
+
if use_varlen_attn:
|
32 |
+
position_ids, cumulative_len = [], []
|
33 |
+
assert len(instances) == 1, (
|
34 |
+
f'If utilizing varlen attention, the batch size should be'
|
35 |
+
f' set to 1, but got {len(instances)}')
|
36 |
+
assert not has_image, 'Currently, it is not configured to '
|
37 |
+
'accommodate the use of varlen Attention in multimodal training'
|
38 |
+
|
39 |
+
if has_image:
|
40 |
+
pixel_values = []
|
41 |
+
frames_per_batch = []
|
42 |
+
image_grid_thw = []
|
43 |
+
if has_grounding_image:
|
44 |
+
grounding_pixel_values = []
|
45 |
+
if has_mask:
|
46 |
+
object_masks = []
|
47 |
+
if has_bboxes:
|
48 |
+
object_bboxes = []
|
49 |
+
if has_points:
|
50 |
+
prompt_points = []
|
51 |
+
if has_fast_image:
|
52 |
+
fast_pixel_values = []
|
53 |
+
if has_fast_exists:
|
54 |
+
fast_exists = []
|
55 |
+
if has_vp:
|
56 |
+
vp_overall_mask = []
|
57 |
+
else:
|
58 |
+
vp_overall_mask = None
|
59 |
+
|
60 |
+
if has_prompt_mask:
|
61 |
+
prompt_masks = []
|
62 |
+
else:
|
63 |
+
prompt_masks = None
|
64 |
+
|
65 |
+
for example in instances:
|
66 |
+
input_ids.append(torch.LongTensor(example['input_ids']))
|
67 |
+
labels.append(torch.LongTensor(example['labels']))
|
68 |
+
if use_varlen_attn:
|
69 |
+
cumulative_len.append(torch.IntTensor(example['cumulative_len']))
|
70 |
+
position_ids.append(torch.LongTensor(example['position_ids']))
|
71 |
+
|
72 |
+
if has_image:
|
73 |
+
pixel_values.append(example['pixel_values'])
|
74 |
+
if has_pe:
|
75 |
+
image_grid_thw.append(example['image_grid_thw'])
|
76 |
+
if has_vp:
|
77 |
+
if 'vp_overall_mask' in example.keys() and example['vp_overall_mask'] is not None:
|
78 |
+
vp_overall_mask.append(example['vp_overall_mask'])
|
79 |
+
else:
|
80 |
+
vp_overall_mask.append(torch.Tensor([False] * len(pixel_values[-1])))
|
81 |
+
if has_fast_image:
|
82 |
+
if 'fast_pixel_values' in example.keys() and example['fast_pixel_values'] is not None:
|
83 |
+
fast_pixel_values.append(example['fast_pixel_values'])
|
84 |
+
if has_fast_exists:
|
85 |
+
if 'fast_exists' in example.keys() and example['fast_exists'] is not None:
|
86 |
+
fast_exists.append(example['fast_exists'])
|
87 |
+
if has_grounding_image and 'g_pixel_values' in example.keys():
|
88 |
+
if isinstance(example['g_pixel_values'], list):
|
89 |
+
grounding_pixel_values += example['g_pixel_values']
|
90 |
+
frames_per_batch.append(len(example['g_pixel_values']))
|
91 |
+
else:
|
92 |
+
grounding_pixel_values.append(example['g_pixel_values'])
|
93 |
+
frames_per_batch.append(1)
|
94 |
+
|
95 |
+
if has_mask:
|
96 |
+
if 'masks' in example.keys() and example['masks'] is not None:
|
97 |
+
if isinstance(example['masks'], list):
|
98 |
+
if isinstance(example['masks'][0], np.ndarray):
|
99 |
+
_masks = np.stack(example['masks'], axis=0)
|
100 |
+
_masks = torch.from_numpy(_masks)
|
101 |
+
object_masks.append(_masks)
|
102 |
+
else:
|
103 |
+
object_masks.append(torch.stack(example['masks'], dim=0))
|
104 |
+
else:
|
105 |
+
object_masks.append(example['masks'])
|
106 |
+
if has_bboxes:
|
107 |
+
if 'bboxes' in example.keys() and example['bboxes'] is not None:
|
108 |
+
object_bboxes.append(example['bboxes'])
|
109 |
+
if has_points:
|
110 |
+
if 'points' in example.keys() and example['points'] is not None:
|
111 |
+
prompt_points.append(example['points'])
|
112 |
+
|
113 |
+
if has_prompt_mask:
|
114 |
+
if 'prompt_masks' in example.keys():
|
115 |
+
prompt_masks.append(example['prompt_masks'])
|
116 |
+
|
117 |
+
ori_length = [len(ids) for ids in input_ids]
|
118 |
+
if len(instances) > 1:
|
119 |
+
input_ids = pad_sequence(
|
120 |
+
input_ids, batch_first=True, padding_value=pad_index)
|
121 |
+
labels = pad_sequence(
|
122 |
+
labels, batch_first=True, padding_value=IGNORE_INDEX)
|
123 |
+
else:
|
124 |
+
input_ids = torch.stack(input_ids)
|
125 |
+
labels = torch.stack(labels)
|
126 |
+
|
127 |
+
if use_varlen_attn:
|
128 |
+
assert input_ids.size(1) % seq_parallel_world_size == 0
|
129 |
+
attention_mask = None
|
130 |
+
position_ids = torch.stack(position_ids, dim=0)
|
131 |
+
else:
|
132 |
+
# Some tokenizers have the same eos token and pad token, so input_ids
|
133 |
+
# cannot be masked directly based on the pad token id.
|
134 |
+
attention_mask = torch.zeros_like(input_ids).bool()
|
135 |
+
for i, length in enumerate(ori_length):
|
136 |
+
attention_mask[i, :length] = True
|
137 |
+
|
138 |
+
bs, seq_len = input_ids.shape
|
139 |
+
position_ids = torch.arange(seq_len).unsqueeze(0).long().repeat(bs, 1)
|
140 |
+
|
141 |
+
if seq_parallel_world_size > 1:
|
142 |
+
input_ids = pad_for_sequence_parallel(input_ids, pad_index)
|
143 |
+
labels = pad_for_sequence_parallel(labels, IGNORE_INDEX)
|
144 |
+
position_ids = pad_for_sequence_parallel(position_ids, 0)
|
145 |
+
if attention_mask is not None:
|
146 |
+
attention_mask = pad_for_sequence_parallel(attention_mask, 0)
|
147 |
+
|
148 |
+
if use_varlen_attn:
|
149 |
+
max_seqlen = (
|
150 |
+
cumulative_len[0][1:] - # noqa: W504
|
151 |
+
cumulative_len[0][:-1]).max().item()
|
152 |
+
data_dict = {
|
153 |
+
'input_ids': input_ids,
|
154 |
+
'cumulative_len': cumulative_len,
|
155 |
+
'position_ids': position_ids,
|
156 |
+
'labels': labels,
|
157 |
+
'max_seqlen': max_seqlen
|
158 |
+
}
|
159 |
+
else:
|
160 |
+
data_dict = {
|
161 |
+
'input_ids': input_ids,
|
162 |
+
'attention_mask': attention_mask,
|
163 |
+
'position_ids': position_ids,
|
164 |
+
'labels': labels
|
165 |
+
}
|
166 |
+
|
167 |
+
if has_image:
|
168 |
+
if all(x.shape == pixel_values[0].shape for x in pixel_values):
|
169 |
+
pixel_values = torch.stack(pixel_values, dim=0)
|
170 |
+
data_dict['frames_per_batch'] = frames_per_batch
|
171 |
+
data_dict['pixel_values'] = pixel_values
|
172 |
+
if has_pe:
|
173 |
+
data_dict['image_grid_thw'] = image_grid_thw
|
174 |
+
|
175 |
+
if has_fast_image:
|
176 |
+
if all(x.shape == fast_pixel_values[0].shape for x in fast_pixel_values):
|
177 |
+
fast_pixel_values = torch.stack(fast_pixel_values, dim=0)
|
178 |
+
data_dict['fast_pixel_values'] = fast_pixel_values
|
179 |
+
|
180 |
+
if has_fast_exists:
|
181 |
+
data_dict['fast_exists'] = fast_exists
|
182 |
+
|
183 |
+
if has_vp:
|
184 |
+
data_dict['vp_overall_mask'] = torch.cat(vp_overall_mask, dim=0)
|
185 |
+
|
186 |
+
if has_prompt_mask:
|
187 |
+
data_dict['prompt_masks'] = prompt_masks
|
188 |
+
|
189 |
+
if has_grounding_image:
|
190 |
+
# if all(x.shape == grounding_pixel_values[0].shape for x in grounding_pixel_values):
|
191 |
+
# grounding_pixel_values = torch.stack(grounding_pixel_values, dim=0)
|
192 |
+
data_dict['g_pixel_values'] = grounding_pixel_values
|
193 |
+
|
194 |
+
if has_mask:
|
195 |
+
data_dict['masks'] = object_masks
|
196 |
+
|
197 |
+
if has_bboxes:
|
198 |
+
data_dict['bboxes'] = object_bboxes
|
199 |
+
|
200 |
+
if has_points:
|
201 |
+
data_dict['points'] = prompt_points
|
202 |
+
|
203 |
+
if return_hf_format:
|
204 |
+
return data_dict
|
205 |
+
else:
|
206 |
+
return {'data': data_dict, 'data_samples': None}
|
projects/llava_sam2/datasets/encode_fn.py
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from xtuner.dataset.utils import get_bos_eos_token_ids
|
3 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN, IGNORE_INDEX, IMAGE_TOKEN_INDEX
|
4 |
+
|
5 |
+
def video_lisa_encode_fn(
|
6 |
+
example,
|
7 |
+
tokenizer,
|
8 |
+
max_length,
|
9 |
+
input_ids_with_output=True,
|
10 |
+
**kwargs
|
11 |
+
):
|
12 |
+
"""We only support the following three scenarios:
|
13 |
+
|
14 |
+
1. Incremental pretraining dataset.
|
15 |
+
example['conversation'] = [
|
16 |
+
{
|
17 |
+
'input': '',
|
18 |
+
'output': '### Human: Can you write xxx'
|
19 |
+
}
|
20 |
+
]
|
21 |
+
|
22 |
+
2. Single-turn conversation dataset.
|
23 |
+
example['conversation'] = [
|
24 |
+
{
|
25 |
+
'input': 'Give three tips for staying healthy.',
|
26 |
+
'output': '1.Eat a balanced diet xxx'
|
27 |
+
}
|
28 |
+
]
|
29 |
+
|
30 |
+
3. Multi-turn conversation dataset.
|
31 |
+
example['conversation'] = [
|
32 |
+
{
|
33 |
+
'input': 'Give three tips for staying healthy.',
|
34 |
+
'output': '1.Eat a balanced diet xxx'
|
35 |
+
},
|
36 |
+
{
|
37 |
+
'input': 'Please expand on the second point.',
|
38 |
+
'output': 'Here is an expanded explanation of the xxx'
|
39 |
+
}
|
40 |
+
]
|
41 |
+
"""
|
42 |
+
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
|
43 |
+
is_multi_turn_conversation = len(example['conversation']) > 1
|
44 |
+
if is_multi_turn_conversation:
|
45 |
+
assert input_ids_with_output
|
46 |
+
|
47 |
+
input_ids, labels = [], []
|
48 |
+
next_needs_bos_token = True
|
49 |
+
for single_turn_conversation in example['conversation']:
|
50 |
+
input = single_turn_conversation['input']
|
51 |
+
input_encode = tokenizer.encode(input, add_special_tokens=False)
|
52 |
+
if next_needs_bos_token:
|
53 |
+
input_ids += bos_token_id
|
54 |
+
labels += [IGNORE_INDEX] * len(bos_token_id)
|
55 |
+
input_ids += input_encode
|
56 |
+
labels += [IGNORE_INDEX] * len(input_encode)
|
57 |
+
if input_ids_with_output:
|
58 |
+
# Add output
|
59 |
+
output_with_loss = single_turn_conversation.get(
|
60 |
+
'output_with_loss', True)
|
61 |
+
output = single_turn_conversation['output']
|
62 |
+
output_encode = tokenizer.encode(output, add_special_tokens=False)
|
63 |
+
input_ids += output_encode
|
64 |
+
if output_with_loss:
|
65 |
+
labels += copy.deepcopy(output_encode)
|
66 |
+
else:
|
67 |
+
labels += [IGNORE_INDEX] * len(output_encode)
|
68 |
+
# Add EOS_TOKEN (with loss)
|
69 |
+
if single_turn_conversation.get('need_eos_token', True):
|
70 |
+
next_needs_bos_token = True
|
71 |
+
input_ids += eos_token_id
|
72 |
+
if output_with_loss:
|
73 |
+
labels += copy.deepcopy(eos_token_id)
|
74 |
+
else:
|
75 |
+
labels += [IGNORE_INDEX] * len(eos_token_id)
|
76 |
+
else:
|
77 |
+
next_needs_bos_token = False
|
78 |
+
# Add SEP (without loss)
|
79 |
+
sep = single_turn_conversation.get('sep', '')
|
80 |
+
if sep != '':
|
81 |
+
sep_encode = tokenizer.encode(sep, add_special_tokens=False)
|
82 |
+
input_ids += sep_encode
|
83 |
+
labels += [IGNORE_INDEX] * len(sep_encode)
|
84 |
+
|
85 |
+
if len(input_ids) > max_length:
|
86 |
+
input_ids = input_ids[:max_length]
|
87 |
+
labels = labels[:max_length]
|
88 |
+
return {'input_ids': input_ids, 'labels': labels}
|
89 |
+
|
90 |
+
|
91 |
+
def video_lisa_encode_multi_conv_fn(
|
92 |
+
example,
|
93 |
+
tokenizer,
|
94 |
+
max_length,
|
95 |
+
input_ids_with_output=True
|
96 |
+
):
|
97 |
+
"""We only support the following three scenarios:
|
98 |
+
|
99 |
+
1. Incremental pretraining dataset.
|
100 |
+
example['conversation'] = [
|
101 |
+
{
|
102 |
+
'input': '',
|
103 |
+
'output': '### Human: Can you write xxx'
|
104 |
+
}
|
105 |
+
]
|
106 |
+
|
107 |
+
2. Single-turn conversation dataset.
|
108 |
+
example['conversation'] = [
|
109 |
+
{
|
110 |
+
'input': 'Give three tips for staying healthy.',
|
111 |
+
'output': '1.Eat a balanced diet xxx'
|
112 |
+
}
|
113 |
+
]
|
114 |
+
|
115 |
+
3. Multi-turn conversation dataset.
|
116 |
+
example['conversation'] = [
|
117 |
+
{
|
118 |
+
'input': 'Give three tips for staying healthy.',
|
119 |
+
'output': '1.Eat a balanced diet xxx'
|
120 |
+
},
|
121 |
+
{
|
122 |
+
'input': 'Please expand on the second point.',
|
123 |
+
'output': 'Here is an expanded explanation of the xxx'
|
124 |
+
}
|
125 |
+
]
|
126 |
+
"""
|
127 |
+
bos_token_id, eos_token_id = get_bos_eos_token_ids(tokenizer)
|
128 |
+
assert not input_ids_with_output
|
129 |
+
input_id_list = []
|
130 |
+
for conv in example['conversation']:
|
131 |
+
input_ids = []
|
132 |
+
next_needs_bos_token = True
|
133 |
+
for single_turn_conversation in conv:
|
134 |
+
input = single_turn_conversation['input']
|
135 |
+
input_encode = tokenizer.encode(input, add_special_tokens=False)
|
136 |
+
if next_needs_bos_token:
|
137 |
+
input_ids += bos_token_id
|
138 |
+
input_ids += input_encode
|
139 |
+
|
140 |
+
if len(input_ids) > max_length:
|
141 |
+
input_ids = input_ids[:max_length]
|
142 |
+
|
143 |
+
input_id_list.append(input_ids)
|
144 |
+
return {'input_ids': input_id_list}
|
projects/llava_sam2/datasets/gcg_process.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
4 |
+
|
5 |
+
GCG_QUESTIONS = [
|
6 |
+
DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
|
7 |
+
DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
|
8 |
+
DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
|
9 |
+
DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
|
10 |
+
DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
|
11 |
+
DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
|
12 |
+
]
|
13 |
+
|
14 |
+
def refcocog_parse_annotations(example):
|
15 |
+
# example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
|
16 |
+
annotations = {'labels': [], 'caption': [], 'masks': [], 'tokens_positive': [],
|
17 |
+
'file_name': example['img_file_name'], 'image': example['img_file_name']}
|
18 |
+
|
19 |
+
orig_caption = example['caption'].strip('"').strip()
|
20 |
+
annotations['caption'] = orig_caption.lower()
|
21 |
+
|
22 |
+
for detail in example['refs']:
|
23 |
+
phrase = detail['sentence']
|
24 |
+
if phrase.lower() in annotations['caption']:
|
25 |
+
annotations['labels'].append(phrase)
|
26 |
+
index = annotations['caption'].find(phrase)
|
27 |
+
end_index = index + len(phrase) if index != -1 else -1
|
28 |
+
annotations['tokens_positive'].append([index, end_index])
|
29 |
+
# still polygon or rle
|
30 |
+
annotations['masks'].append(detail["segmentation"])
|
31 |
+
|
32 |
+
# Sort tokens_positive and corresponding lists
|
33 |
+
tokens_positive = annotations['tokens_positive']
|
34 |
+
sorted_indices = sorted(range(len(tokens_positive)), key=lambda i: tokens_positive[i][0])
|
35 |
+
annotations['tokens_positive'] = [tokens_positive[i] for i in sorted_indices]
|
36 |
+
annotations['masks'] = [annotations['masks'][i] for i in sorted_indices]
|
37 |
+
annotations['labels'] = [annotations['labels'][i] for i in sorted_indices]
|
38 |
+
|
39 |
+
# Trimming overlapping intervals
|
40 |
+
for i in range(len(tokens_positive)):
|
41 |
+
for j in range(i + 1, len(tokens_positive)):
|
42 |
+
# If there is overlap
|
43 |
+
if tokens_positive[i][1] >= tokens_positive[j][0]:
|
44 |
+
# Modify the end index of phrase i to be one less than the start index of phrase j
|
45 |
+
tokens_positive[i][1] = tokens_positive[j][0] - 1
|
46 |
+
# Modify the phrases to reflect the change in indices
|
47 |
+
annotations['labels'][i] = orig_caption[tokens_positive[i][0]:tokens_positive[i][1] + 1]
|
48 |
+
break # Exit inner loop since i was modified
|
49 |
+
|
50 |
+
return annotations
|
51 |
+
|
52 |
+
def refcocog_conversation(caption, tokens_positive):
|
53 |
+
# insert <p> </p> and [seg] to caption and select a question
|
54 |
+
question = random.choice(GCG_QUESTIONS).strip()
|
55 |
+
|
56 |
+
# Prepare caption with tags
|
57 |
+
def tag_caption(caption, tokens):
|
58 |
+
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
|
59 |
+
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
|
60 |
+
return caption
|
61 |
+
|
62 |
+
detailed_answer = tag_caption(caption, tokens_positive)
|
63 |
+
|
64 |
+
conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
|
65 |
+
return conversations
|
66 |
+
|
67 |
+
def refcocog_preprocess(example):
|
68 |
+
data_labels = example['labels']
|
69 |
+
masks = example['masks']
|
70 |
+
caption = example['caption']
|
71 |
+
tokens_positive = example['tokens_positive']
|
72 |
+
|
73 |
+
# Function to sort elements based on the start index of each phrase
|
74 |
+
def sort_by_start_index(items, order):
|
75 |
+
return [items[i] for i in order]
|
76 |
+
|
77 |
+
# Sort phrases based on their appearance in the sentence
|
78 |
+
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
|
79 |
+
masks = sort_by_start_index(masks, phrase_order)
|
80 |
+
data_labels = sort_by_start_index(data_labels, phrase_order)
|
81 |
+
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
|
82 |
+
|
83 |
+
conversations = refcocog_conversation(caption, tokens_positive)
|
84 |
+
example['conversations'] = conversations
|
85 |
+
example['labels'] = data_labels
|
86 |
+
example['masks'] = masks
|
87 |
+
example['tokens_positive'] = tokens_positive
|
88 |
+
|
89 |
+
return example
|
90 |
+
|
91 |
+
def glamm_refcocog_map_fn(example):
|
92 |
+
# example {'id': str, 'refs': [{"setence", 'bbox', 'segmentation'},], 'img_file_name': str, 'caption': str}
|
93 |
+
|
94 |
+
example = refcocog_parse_annotations(example)
|
95 |
+
# example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
|
96 |
+
|
97 |
+
example = refcocog_preprocess(example)
|
98 |
+
|
99 |
+
# do llava preprocess
|
100 |
+
messages = example['conversations']
|
101 |
+
input = ''
|
102 |
+
conversation = []
|
103 |
+
while messages and messages[0]['from'] == 'gpt':
|
104 |
+
# Skip the first one if it is from gpt
|
105 |
+
messages = messages[1:]
|
106 |
+
for msg in messages:
|
107 |
+
if msg['from'] == 'human':
|
108 |
+
if DEFAULT_IMAGE_TOKEN in msg['value']:
|
109 |
+
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
|
110 |
+
'').strip()
|
111 |
+
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
|
112 |
+
msg['value'] = msg['value'].strip()
|
113 |
+
input += msg['value']
|
114 |
+
|
115 |
+
elif msg['from'] == 'gpt':
|
116 |
+
conversation.append({'input': input, 'output': msg['value']})
|
117 |
+
input = ''
|
118 |
+
else:
|
119 |
+
raise NotImplementedError
|
120 |
+
example.update({'conversation': conversation})
|
121 |
+
return example
|
122 |
+
|
123 |
+
def grandf_parse_annotations(example):
|
124 |
+
image_path = example['file_name']
|
125 |
+
annotations = {
|
126 |
+
'labels': [], 'caption': [], 'masks': [],
|
127 |
+
'tokens_positive': [], 'file_name': image_path,
|
128 |
+
'image': image_path}
|
129 |
+
annotations['caption'] = example['caption'].strip('"').strip()
|
130 |
+
|
131 |
+
for word, grounding in example["groundings"].items():
|
132 |
+
if grounding is None:
|
133 |
+
continue
|
134 |
+
annotations['labels'].append(word)
|
135 |
+
annotations['tokens_positive'].append(grounding["token_positives"])
|
136 |
+
annotations['masks'].append(grounding["rle_masks"])
|
137 |
+
|
138 |
+
return annotations
|
139 |
+
|
140 |
+
def grandf_conversation(caption, tokens_positive):
|
141 |
+
question = random.choice(GCG_QUESTIONS).strip()
|
142 |
+
|
143 |
+
# Prepare caption with tags
|
144 |
+
def tag_caption(caption, tokens):
|
145 |
+
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
|
146 |
+
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
|
147 |
+
return caption
|
148 |
+
|
149 |
+
detailed_answer = tag_caption(caption, tokens_positive)
|
150 |
+
|
151 |
+
conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
|
152 |
+
return conversations
|
153 |
+
def grandf_preprocess(example):
|
154 |
+
data_labels = example['labels']
|
155 |
+
masks = example['masks']
|
156 |
+
caption = example['caption']
|
157 |
+
tokens_positive = example['tokens_positive']
|
158 |
+
|
159 |
+
# Function to sort elements based on the start index of each phrase
|
160 |
+
def sort_by_start_index(items, order):
|
161 |
+
return [items[i] for i in order]
|
162 |
+
|
163 |
+
# Sort phrases based on their appearance in the sentence
|
164 |
+
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
|
165 |
+
masks = sort_by_start_index(masks, phrase_order)
|
166 |
+
data_labels = sort_by_start_index(data_labels, phrase_order)
|
167 |
+
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
|
168 |
+
|
169 |
+
conversations = grandf_conversation(caption, tokens_positive)
|
170 |
+
example['conversations'] = conversations
|
171 |
+
example['labels'] = data_labels
|
172 |
+
example['masks'] = masks
|
173 |
+
example['tokens_positive'] = tokens_positive
|
174 |
+
return example
|
175 |
+
|
176 |
+
def glamm_granf_map_fn(example):
|
177 |
+
# example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
|
178 |
+
# "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
|
179 |
+
example = grandf_parse_annotations(example)
|
180 |
+
# example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
|
181 |
+
|
182 |
+
example = grandf_preprocess(example)
|
183 |
+
|
184 |
+
# do llava preprocess
|
185 |
+
messages = example['conversations']
|
186 |
+
input = ''
|
187 |
+
conversation = []
|
188 |
+
while messages and messages[0]['from'] == 'gpt':
|
189 |
+
# Skip the first one if it is from gpt
|
190 |
+
messages = messages[1:]
|
191 |
+
for msg in messages:
|
192 |
+
if msg['from'] == 'human':
|
193 |
+
if DEFAULT_IMAGE_TOKEN in msg['value']:
|
194 |
+
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
|
195 |
+
'').strip()
|
196 |
+
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
|
197 |
+
msg['value'] = msg['value'].strip()
|
198 |
+
input += msg['value']
|
199 |
+
|
200 |
+
elif msg['from'] == 'gpt':
|
201 |
+
conversation.append({'input': input, 'output': msg['value']})
|
202 |
+
input = ''
|
203 |
+
else:
|
204 |
+
raise NotImplementedError
|
205 |
+
example.update({'conversation': conversation})
|
206 |
+
return example
|
207 |
+
|
208 |
+
glamm_openpsg_map_fn = glamm_granf_map_fn
|
209 |
+
|
210 |
+
def flickr_parse_annotations(example):
|
211 |
+
annotations = {'bboxes': [], 'labels': [], 'bboxes_ignore': [], 'caption': example['caption'], 'masks': [],
|
212 |
+
'tokens_positive': [], 'image': example['file_name']}
|
213 |
+
ann_info = example["ann_info"]
|
214 |
+
for ann in ann_info:
|
215 |
+
if ann.get('ignore', False):
|
216 |
+
continue
|
217 |
+
x1, y1, w, h = ann['bbox']
|
218 |
+
inter_w = max(0, min(x1 + w, example['width']) - max(x1, 0))
|
219 |
+
inter_h = max(0, min(y1 + h, example['height']) - max(y1, 0))
|
220 |
+
if inter_w * inter_h == 0 or ann['area'] <= 0 or w < 1 or h < 1:
|
221 |
+
continue
|
222 |
+
bbox = [x1, y1, x1 + w, y1 + h]
|
223 |
+
annotations['bboxes'].append(bbox)
|
224 |
+
tokens_positive = ann['tokens_positive']
|
225 |
+
gt_label = [example['caption'][span[0]:span[1]] for span in tokens_positive]
|
226 |
+
annotations['labels'].append(gt_label[0])
|
227 |
+
annotations['tokens_positive'].append(tokens_positive[0])
|
228 |
+
|
229 |
+
rle = ann['sam_mask']
|
230 |
+
annotations['masks'].append(rle)
|
231 |
+
|
232 |
+
# Convert bounding boxes to numpy arrays
|
233 |
+
annotations['bboxes'] = np.array(annotations['bboxes'], dtype=np.float32) if annotations[
|
234 |
+
'bboxes'] else np.zeros((0, 4), dtype=np.float32)
|
235 |
+
annotations['bboxes_ignore'] = np.array(annotations['bboxes_ignore'], dtype=np.float32) if annotations[
|
236 |
+
'bboxes_ignore'] else np.zeros((0, 4), dtype=np.float32)
|
237 |
+
return annotations
|
238 |
+
|
239 |
+
def flickr_preprocess(example):
|
240 |
+
data_labels = example['labels']
|
241 |
+
masks = example['masks']
|
242 |
+
caption = example['caption']
|
243 |
+
tokens_positive = example['tokens_positive']
|
244 |
+
|
245 |
+
# Function to sort elements based on the start index of each phrase
|
246 |
+
def sort_by_start_index(items, order):
|
247 |
+
return [items[i] for i in order]
|
248 |
+
|
249 |
+
# Sort phrases based on their appearance in the sentence
|
250 |
+
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
|
251 |
+
masks = sort_by_start_index(masks, phrase_order)
|
252 |
+
data_labels = sort_by_start_index(data_labels, phrase_order)
|
253 |
+
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
|
254 |
+
|
255 |
+
conversations = grandf_conversation(caption, tokens_positive)
|
256 |
+
example['conversations'] = conversations
|
257 |
+
example['labels'] = data_labels
|
258 |
+
example['masks'] = masks
|
259 |
+
example['tokens_positive'] = tokens_positive
|
260 |
+
return example
|
261 |
+
|
262 |
+
def glamm_flickr_map_fn(example):
|
263 |
+
# example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
|
264 |
+
# "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
|
265 |
+
|
266 |
+
example = flickr_parse_annotations(example)
|
267 |
+
|
268 |
+
example = flickr_preprocess(example)
|
269 |
+
|
270 |
+
# do llava preprocess
|
271 |
+
messages = example['conversations']
|
272 |
+
input = ''
|
273 |
+
conversation = []
|
274 |
+
while messages and messages[0]['from'] == 'gpt':
|
275 |
+
# Skip the first one if it is from gpt
|
276 |
+
messages = messages[1:]
|
277 |
+
for msg in messages:
|
278 |
+
if msg['from'] == 'human':
|
279 |
+
if DEFAULT_IMAGE_TOKEN in msg['value']:
|
280 |
+
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
|
281 |
+
'').strip()
|
282 |
+
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
|
283 |
+
msg['value'] = msg['value'].strip()
|
284 |
+
input += msg['value']
|
285 |
+
|
286 |
+
elif msg['from'] == 'gpt':
|
287 |
+
conversation.append({'input': input, 'output': msg['value']})
|
288 |
+
input = ''
|
289 |
+
else:
|
290 |
+
raise NotImplementedError
|
291 |
+
example.update({'conversation': conversation})
|
292 |
+
return example
|
293 |
+
|
294 |
+
|
295 |
+
|
296 |
+
|
297 |
+
|
projects/llava_sam2/datasets/grand_process.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
from xtuner.utils import DEFAULT_IMAGE_TOKEN
|
4 |
+
|
5 |
+
GCG_QUESTIONS = [
|
6 |
+
DEFAULT_IMAGE_TOKEN + 'Could you please give me a brief description of the image? Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
|
7 |
+
DEFAULT_IMAGE_TOKEN + 'Can you provide a brief description of the this image? Please output with interleaved segmentation masks for the corresponding phrases.',
|
8 |
+
DEFAULT_IMAGE_TOKEN + 'Please briefly describe the contents of the image. Please respond with interleaved segmentation masks for the corresponding parts of the answer.',
|
9 |
+
DEFAULT_IMAGE_TOKEN + 'Could you give a brief explanation of what can be found within this picture? Please output with interleaved segmentation masks for the corresponding phrases.',
|
10 |
+
DEFAULT_IMAGE_TOKEN + 'Could you give me an brief explanation of this picture? Please respond with interleaved segmentation masks for the corresponding phrases.',
|
11 |
+
DEFAULT_IMAGE_TOKEN + 'Could you provide me with a briefly analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer.',
|
12 |
+
]
|
13 |
+
|
14 |
+
def grand_parse_annotations(example):
|
15 |
+
annotations = {
|
16 |
+
'caption': [], 'masks': [],
|
17 |
+
'tokens_positive': [], 'labels': []}
|
18 |
+
annotations['caption'] = example['dense_caption']['caption'].strip('"').strip()
|
19 |
+
object_infos = example['dense_caption']['details']
|
20 |
+
|
21 |
+
all_seg_objects_dict = {}
|
22 |
+
for seg_object_dict in example["objects"]:
|
23 |
+
all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
|
24 |
+
for seg_object_dict in example["floating_objects"]:
|
25 |
+
all_seg_objects_dict[seg_object_dict['id']] = seg_object_dict
|
26 |
+
|
27 |
+
for object_info in object_infos:
|
28 |
+
ids = object_info["ids"]
|
29 |
+
if object_info["tokens_positive"] is None:
|
30 |
+
continue
|
31 |
+
annotations['labels'].append(object_info["phrase"])
|
32 |
+
annotations['tokens_positive'].append(object_info["tokens_positive"])
|
33 |
+
_masks = []
|
34 |
+
for _id in ids:
|
35 |
+
_masks.append(all_seg_objects_dict[_id]['segmentation'])
|
36 |
+
annotations['masks'].append(_masks)
|
37 |
+
return annotations
|
38 |
+
|
39 |
+
def grand_conversation(caption, tokens_positive):
|
40 |
+
question = random.choice(GCG_QUESTIONS).strip()
|
41 |
+
|
42 |
+
# Prepare caption with tags
|
43 |
+
def tag_caption(caption, tokens):
|
44 |
+
for start, end in sorted(tokens, key=lambda x: x[0], reverse=True):
|
45 |
+
caption = f"{caption[:start]}<p> {caption[start:end]} </p> [SEG]{caption[end:]}"
|
46 |
+
return caption
|
47 |
+
|
48 |
+
detailed_answer = tag_caption(caption, tokens_positive)
|
49 |
+
|
50 |
+
conversations = [{'from': 'human', 'value': question}, {'from': 'gpt', 'value': detailed_answer}]
|
51 |
+
return conversations
|
52 |
+
|
53 |
+
def grand_preprocess(example):
|
54 |
+
data_labels = example['labels']
|
55 |
+
masks = example['masks']
|
56 |
+
caption = example['caption']
|
57 |
+
tokens_positive = example['tokens_positive']
|
58 |
+
|
59 |
+
# Function to sort elements based on the start index of each phrase
|
60 |
+
def sort_by_start_index(items, order):
|
61 |
+
return [items[i] for i in order]
|
62 |
+
|
63 |
+
# Sort phrases based on their appearance in the sentence
|
64 |
+
phrase_order = sorted(range(len(tokens_positive)), key=lambda x: tokens_positive[x][0])
|
65 |
+
masks = sort_by_start_index(masks, phrase_order)
|
66 |
+
data_labels = sort_by_start_index(data_labels, phrase_order)
|
67 |
+
tokens_positive = sort_by_start_index(tokens_positive, phrase_order)
|
68 |
+
|
69 |
+
conversations = grand_conversation(caption, tokens_positive)
|
70 |
+
example['conversations'] = conversations
|
71 |
+
example['labels'] = data_labels
|
72 |
+
example['masks'] = masks
|
73 |
+
example['tokens_positive'] = tokens_positive
|
74 |
+
return example
|
75 |
+
|
76 |
+
def glamm_grand_map_fn(example):
|
77 |
+
# example {'file_name': str, "height": int, "width": int, "image_id": str, caption: "str",
|
78 |
+
# "groundings": {ground_words: {'token_positives', 'rle_masks', }}}
|
79 |
+
example = grand_parse_annotations(example)
|
80 |
+
# example 'labels': [], 'caption': str, 'masks': [], 'tokens_positive': [], 'file_name': image_file
|
81 |
+
|
82 |
+
example = grand_preprocess(example)
|
83 |
+
|
84 |
+
# do llava preprocess
|
85 |
+
messages = example['conversations']
|
86 |
+
input = ''
|
87 |
+
conversation = []
|
88 |
+
while messages and messages[0]['from'] == 'gpt':
|
89 |
+
# Skip the first one if it is from gpt
|
90 |
+
messages = messages[1:]
|
91 |
+
for msg in messages:
|
92 |
+
if msg['from'] == 'human':
|
93 |
+
if DEFAULT_IMAGE_TOKEN in msg['value']:
|
94 |
+
msg['value'] = msg['value'].replace(DEFAULT_IMAGE_TOKEN,
|
95 |
+
'').strip()
|
96 |
+
msg['value'] = DEFAULT_IMAGE_TOKEN + '\n' + msg['value']
|
97 |
+
msg['value'] = msg['value'].strip()
|
98 |
+
input += msg['value']
|
99 |
+
|
100 |
+
elif msg['from'] == 'gpt':
|
101 |
+
conversation.append({'input': input, 'output': msg['value']})
|
102 |
+
input = ''
|
103 |
+
else:
|
104 |
+
raise NotImplementedError
|
105 |
+
example.update({'conversation': conversation})
|
106 |
+
return example
|
107 |
+
|
108 |
+
|
109 |
+
|
110 |
+
|
projects/llava_sam2/datasets/utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height,
|
3 |
+
image_size):
|
4 |
+
best_ratio_diff = float('inf')
|
5 |
+
best_ratio = (1, 1)
|
6 |
+
area = width * height
|
7 |
+
for ratio in target_ratios:
|
8 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
9 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
10 |
+
if ratio_diff < best_ratio_diff:
|
11 |
+
best_ratio_diff = ratio_diff
|
12 |
+
best_ratio = ratio
|
13 |
+
elif ratio_diff == best_ratio_diff:
|
14 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
15 |
+
best_ratio = ratio
|
16 |
+
return best_ratio
|
17 |
+
|
18 |
+
def dynamic_preprocess(image,
|
19 |
+
min_num=1,
|
20 |
+
max_num=6,
|
21 |
+
image_size=448,
|
22 |
+
use_thumbnail=False):
|
23 |
+
orig_width, orig_height = image.size
|
24 |
+
aspect_ratio = orig_width / orig_height
|
25 |
+
|
26 |
+
# calculate the existing image aspect ratio
|
27 |
+
target_ratios = {(i, j)
|
28 |
+
for n in range(min_num, max_num + 1)
|
29 |
+
for i in range(1, n + 1) for j in range(1, n + 1)
|
30 |
+
if i * j <= max_num and i * j >= min_num}
|
31 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
32 |
+
|
33 |
+
# find the closest aspect ratio to the target
|
34 |
+
target_aspect_ratio = find_closest_aspect_ratio(aspect_ratio,
|
35 |
+
target_ratios, orig_width,
|
36 |
+
orig_height, image_size)
|
37 |
+
|
38 |
+
# calculate the target width and height
|
39 |
+
target_width = image_size * target_aspect_ratio[0]
|
40 |
+
target_height = image_size * target_aspect_ratio[1]
|
41 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
42 |
+
|
43 |
+
# resize the image
|
44 |
+
resized_img = image.resize((target_width, target_height))
|
45 |
+
processed_images = []
|
46 |
+
for i in range(blocks):
|
47 |
+
box = ((i % (target_width // image_size)) * image_size,
|
48 |
+
(i // (target_width // image_size)) * image_size,
|
49 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
50 |
+
((i // (target_width // image_size)) + 1) * image_size)
|
51 |
+
# split the image
|
52 |
+
split_img = resized_img.crop(box)
|
53 |
+
processed_images.append(split_img)
|
54 |
+
assert len(processed_images) == blocks
|
55 |
+
if use_thumbnail and len(processed_images) != 1:
|
56 |
+
thumbnail_img = image.resize((image_size, image_size))
|
57 |
+
processed_images.append(thumbnail_img)
|
58 |
+
return processed_images
|
projects/llava_sam2/datasets/vqa_dataset.py
ADDED
@@ -0,0 +1,509 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import glob
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
import os
|
7 |
+
from typing import Literal
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from mmengine import print_log
|
12 |
+
from mmengine.config import Config, ConfigDict
|
13 |
+
from PIL import Image
|
14 |
+
from torch.utils.data import Dataset
|
15 |
+
import numpy as np
|
16 |
+
import torch.nn.functional as F
|
17 |
+
import torchvision.transforms as T
|
18 |
+
from torchvision.transforms.functional import InterpolationMode
|
19 |
+
from pycocotools.coco import COCO
|
20 |
+
from pycocotools import mask as mask_utils
|
21 |
+
|
22 |
+
from xtuner.registry import BUILDER
|
23 |
+
from xtuner.utils import IGNORE_INDEX
|
24 |
+
from xtuner.dataset.utils import encode_fn
|
25 |
+
from xtuner.dataset.map_fns import llava_map_fn
|
26 |
+
|
27 |
+
from projects.glamm.datasets.utils.utils import expand2square
|
28 |
+
|
29 |
+
from projects.glamm.datasets.utils.utils import SEG_QUESTIONS, ANSWER_LIST
|
30 |
+
from projects.glamm.utils import DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
|
31 |
+
|
32 |
+
from .utils import dynamic_preprocess
|
33 |
+
|
34 |
+
|
35 |
+
class InfinityMMDataset(Dataset):
|
36 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
37 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
38 |
+
IMG_START_TOKEN = '<img>'
|
39 |
+
IMG_END_TOKEN = '</img>'
|
40 |
+
|
41 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
42 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
43 |
+
|
44 |
+
def __init__(self,
|
45 |
+
tokenizer,
|
46 |
+
data_path,
|
47 |
+
prompt_template,
|
48 |
+
special_tokens=None,
|
49 |
+
max_length=8192,
|
50 |
+
offline_save_path='./work_dirs/infinityMM.json',
|
51 |
+
):
|
52 |
+
self.offline_save_path = offline_save_path
|
53 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
54 |
+
if special_tokens is not None:
|
55 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
56 |
+
self._system = ''
|
57 |
+
|
58 |
+
self.template = prompt_template
|
59 |
+
self.max_length = max_length
|
60 |
+
|
61 |
+
self.min_dynamic_patch = 1
|
62 |
+
self.max_dynamic_patch = 12
|
63 |
+
self.downsample_ratio = 0.5
|
64 |
+
self.image_size = 448
|
65 |
+
self.use_thumbnail = True
|
66 |
+
patch_size = 14
|
67 |
+
self.patch_token = int(
|
68 |
+
(self.image_size // patch_size) ** 2 * (self.downsample_ratio ** 2))
|
69 |
+
|
70 |
+
self.transformer = T.Compose([
|
71 |
+
T.Lambda(lambda img: img.convert('RGB')
|
72 |
+
if img.mode != 'RGB' else img),
|
73 |
+
T.Resize((self.image_size, self.image_size),
|
74 |
+
interpolation=InterpolationMode.BICUBIC),
|
75 |
+
T.ToTensor(),
|
76 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
77 |
+
])
|
78 |
+
|
79 |
+
self.data = self._load_annotations(data_path)
|
80 |
+
self._max_refetch = 1000
|
81 |
+
|
82 |
+
def _load_annotations(self, data_path):
|
83 |
+
if os.path.exists(self.offline_save_path):
|
84 |
+
with open(self.offline_save_path, 'r') as f:
|
85 |
+
ret = json.load(f)
|
86 |
+
print(f"Load InfinityMM file list from {self.offline_save_path}, {len(ret)} items !!!")
|
87 |
+
return ret
|
88 |
+
sub_folders = []
|
89 |
+
for sub_folder in os.listdir(data_path):
|
90 |
+
if '.' not in sub_folder:
|
91 |
+
# a folder
|
92 |
+
if "LVIS_111k" in sub_folder:
|
93 |
+
# special case, have subsub folder
|
94 |
+
subsub_folders = os.listdir(os.path.join(data_path, sub_folder))
|
95 |
+
for subsub_folder in subsub_folders:
|
96 |
+
sub_folders.append(os.path.join(data_path, sub_folder, subsub_folder))
|
97 |
+
else:
|
98 |
+
sub_folders.append(os.path.join(data_path, sub_folder))
|
99 |
+
|
100 |
+
all_jsons = []
|
101 |
+
for sub_folder in sub_folders:
|
102 |
+
print(f"Processing {sub_folder} !!!")
|
103 |
+
_files = os.listdir(sub_folder)
|
104 |
+
_num = 0
|
105 |
+
for _file in _files:
|
106 |
+
if '.json' in _file:
|
107 |
+
_json_path = os.path.join(sub_folder, _file)
|
108 |
+
_num += 1
|
109 |
+
all_jsons.append(os.path.join(sub_folder, _file))
|
110 |
+
print(f"Finished {sub_folder} has {_num} items.")
|
111 |
+
|
112 |
+
with open(self.offline_save_path, 'w') as f:
|
113 |
+
json.dump(all_jsons, f)
|
114 |
+
|
115 |
+
return all_jsons
|
116 |
+
|
117 |
+
def __getitem__(self, index):
|
118 |
+
for _ in range(self._max_refetch + 1):
|
119 |
+
data = self.prepare_data(index)
|
120 |
+
# Broken images may cause the returned data to be None
|
121 |
+
if data is None:
|
122 |
+
index = self._rand_another()
|
123 |
+
continue
|
124 |
+
return data
|
125 |
+
|
126 |
+
def __len__(self):
|
127 |
+
return len(self.data)
|
128 |
+
|
129 |
+
@property
|
130 |
+
def modality_length(self):
|
131 |
+
self.group_length = []
|
132 |
+
for data_dict in self.data:
|
133 |
+
self.group_length.append(100)
|
134 |
+
return self.group_length
|
135 |
+
|
136 |
+
@property
|
137 |
+
def length(self):
|
138 |
+
group_length = np.array(self.group_length)
|
139 |
+
group_length = np.abs(group_length).tolist()
|
140 |
+
return group_length
|
141 |
+
|
142 |
+
def prepare_data(self, index):
|
143 |
+
data_path = self.data[index]
|
144 |
+
|
145 |
+
with open(data_path, 'r') as f:
|
146 |
+
data_dict = json.load(f)
|
147 |
+
if 'image' in data_dict.keys():
|
148 |
+
data_dict['image'] = data_path.replace('.json', '.jpg')
|
149 |
+
|
150 |
+
if data_dict is None:
|
151 |
+
return None
|
152 |
+
|
153 |
+
out_data_dict = {}
|
154 |
+
|
155 |
+
if data_dict.get('image', None) is not None:
|
156 |
+
image_file = data_dict['image']
|
157 |
+
try:
|
158 |
+
image = Image.open(image_file).convert('RGB')
|
159 |
+
except Exception as e:
|
160 |
+
print(f'Error: {e}', flush=True)
|
161 |
+
print_log(f'Error: {e}', logger='current')
|
162 |
+
return None
|
163 |
+
|
164 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
165 |
+
self.max_dynamic_patch,
|
166 |
+
self.image_size, self.use_thumbnail)
|
167 |
+
pixel_values = [self.transformer(image) for image in images]
|
168 |
+
pixel_values = torch.stack(pixel_values)
|
169 |
+
out_data_dict['pixel_values'] = pixel_values
|
170 |
+
|
171 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
172 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
173 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
174 |
+
f'{self.IMG_END_TOKEN}'
|
175 |
+
token_dict = self.get_inputid_labels(
|
176 |
+
data_dict['conversations'], image_token_str)
|
177 |
+
out_data_dict.update(token_dict)
|
178 |
+
else:
|
179 |
+
token_dict = self.get_inputid_labels(
|
180 |
+
data_dict['conversations'], None)
|
181 |
+
out_data_dict.update(token_dict)
|
182 |
+
out_data_dict['pixel_values'] = torch.zeros(
|
183 |
+
1, 3, self.image_size, self.image_size)
|
184 |
+
return out_data_dict
|
185 |
+
|
186 |
+
def _rand_another(self) -> int:
|
187 |
+
return np.random.randint(0, len(self.data))
|
188 |
+
|
189 |
+
def get_inputid_labels(self, conversations, image_token_str) -> dict:
|
190 |
+
input = ''
|
191 |
+
out_conversation = []
|
192 |
+
while conversations and conversations[0]['from'] == 'gpt':
|
193 |
+
# Skip the first one if it is from gpt
|
194 |
+
conversations = conversations[1:]
|
195 |
+
for i, msg in enumerate(conversations):
|
196 |
+
if msg['from'] == 'human':
|
197 |
+
|
198 |
+
# change to 1 image
|
199 |
+
if '<image>' in msg['value']:
|
200 |
+
msg['value'] = msg['value'].replace('<image>\n', '').replace('<image>', '')
|
201 |
+
if i == 0:
|
202 |
+
msg['value'] = "<image>\n" + msg['value']
|
203 |
+
|
204 |
+
if image_token_str is None and '<image>' in msg['value']:
|
205 |
+
msg['value'] = msg['value'].replace('<image>', '')
|
206 |
+
if '<image>' in msg['value']:
|
207 |
+
msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
|
208 |
+
input += msg['value'].strip()
|
209 |
+
elif msg['from'] == 'gpt':
|
210 |
+
out_conversation.append({
|
211 |
+
'input': input,
|
212 |
+
'output': msg['value'].strip()
|
213 |
+
})
|
214 |
+
input = ''
|
215 |
+
else:
|
216 |
+
raise NotImplementedError
|
217 |
+
|
218 |
+
input_ids, labels = [], []
|
219 |
+
for i, single_turn_conversation in enumerate(out_conversation):
|
220 |
+
input = single_turn_conversation.get('input', '')
|
221 |
+
if input is None:
|
222 |
+
input = ''
|
223 |
+
input_text = self.template.INSTRUCTION.format(
|
224 |
+
input=input, round=i + 1)
|
225 |
+
|
226 |
+
if i == 0:
|
227 |
+
if self._system != '' and self._system is not None:
|
228 |
+
system = self.template.SYSTEM.format(system=self._system)
|
229 |
+
input_text = system + input_text
|
230 |
+
input_encode = self.tokenizer.encode(
|
231 |
+
input_text, add_special_tokens=True)
|
232 |
+
else:
|
233 |
+
input_encode = self.tokenizer.encode(
|
234 |
+
input_text, add_special_tokens=False)
|
235 |
+
input_ids += input_encode
|
236 |
+
labels += [IGNORE_INDEX] * len(input_encode)
|
237 |
+
|
238 |
+
output_text = single_turn_conversation.get('output', '')
|
239 |
+
if self.template.get('SUFFIX', None):
|
240 |
+
output_text += self.template.SUFFIX
|
241 |
+
output_encode = self.tokenizer.encode(
|
242 |
+
output_text, add_special_tokens=False)
|
243 |
+
input_ids += output_encode
|
244 |
+
labels += copy.deepcopy(output_encode)
|
245 |
+
|
246 |
+
if len(input_ids) > self.max_length:
|
247 |
+
input_ids = input_ids[:self.max_length]
|
248 |
+
labels = labels[:self.max_length]
|
249 |
+
print_log(
|
250 |
+
f'Warning: input_ids length({len(input_ids)}) '
|
251 |
+
f'is longer than max_length, cut to {self.max_length}',
|
252 |
+
logger='current')
|
253 |
+
return {'input_ids': input_ids, 'labels': labels}
|
254 |
+
|
255 |
+
|
256 |
+
class LLaVADataset(Dataset):
|
257 |
+
os.environ['TOKENIZERS_PARALLELISM'] = 'true'
|
258 |
+
IMG_CONTEXT_TOKEN = '<IMG_CONTEXT>'
|
259 |
+
IMG_START_TOKEN = '<img>'
|
260 |
+
IMG_END_TOKEN = '</img>'
|
261 |
+
|
262 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
263 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
264 |
+
|
265 |
+
def __init__(self,
|
266 |
+
tokenizer,
|
267 |
+
data_path,
|
268 |
+
prompt_template,
|
269 |
+
special_tokens=None,
|
270 |
+
image_folder=None,
|
271 |
+
max_length=8192,
|
272 |
+
arch_type: Literal['intern_vl', 'qwen'] = 'intern_vl',
|
273 |
+
preprocessor=None,
|
274 |
+
skip_pure_text=False,
|
275 |
+
):
|
276 |
+
|
277 |
+
self.tokenizer = BUILDER.build(tokenizer)
|
278 |
+
if special_tokens is not None:
|
279 |
+
self.tokenizer.add_tokens(special_tokens, special_tokens=True)
|
280 |
+
|
281 |
+
self.image_folder = image_folder
|
282 |
+
self.template = prompt_template
|
283 |
+
self.max_length = max_length
|
284 |
+
|
285 |
+
self._system = ''
|
286 |
+
|
287 |
+
self.arch_type = arch_type
|
288 |
+
self.min_dynamic_patch = 1
|
289 |
+
self.max_dynamic_patch = 12
|
290 |
+
self.downsample_ratio = 0.5
|
291 |
+
if self.arch_type == 'llava':
|
292 |
+
self.downsample_ratio = 1
|
293 |
+
self.image_size = 448
|
294 |
+
if self.arch_type == 'llava':
|
295 |
+
self.image_size = 336
|
296 |
+
self.use_thumbnail = True
|
297 |
+
patch_size = 14
|
298 |
+
self.patch_token = int(
|
299 |
+
(self.image_size // patch_size)**2 * (self.downsample_ratio**2))
|
300 |
+
|
301 |
+
|
302 |
+
if self.arch_type == 'qwen':
|
303 |
+
self.IMG_CONTEXT_TOKEN = '<|image_pad|>'
|
304 |
+
self.IMG_START_TOKEN = '<|vision_start|>'
|
305 |
+
self.IMG_END_TOKEN = '<|vision_end|>'
|
306 |
+
elif self.arch_type == 'llava':
|
307 |
+
self.IMG_CONTEXT_TOKEN = '<image>'
|
308 |
+
self.IMG_START_TOKEN = ''
|
309 |
+
self.IMG_END_TOKEN = ''
|
310 |
+
|
311 |
+
if preprocessor is None:
|
312 |
+
self.transformer = T.Compose([
|
313 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
314 |
+
T.Resize((self.image_size, self.image_size), interpolation=InterpolationMode.BICUBIC),
|
315 |
+
T.ToTensor(),
|
316 |
+
T.Normalize(mean=self.IMAGENET_MEAN, std=self.IMAGENET_STD)
|
317 |
+
])
|
318 |
+
self.preprocessor = None
|
319 |
+
else:
|
320 |
+
self.transformer = None
|
321 |
+
self.preprocessor = BUILDER.build(preprocessor)
|
322 |
+
|
323 |
+
self.data = self._load_annotations(data_path, image_folder)
|
324 |
+
self._max_refetch = 1000
|
325 |
+
|
326 |
+
self.skip_pure_text = skip_pure_text
|
327 |
+
|
328 |
+
def _load_annotations(self, data_path, image_folder=None):
|
329 |
+
data = json.load(open(data_path))
|
330 |
+
return data
|
331 |
+
|
332 |
+
def __getitem__(self, index):
|
333 |
+
for _ in range(self._max_refetch + 1):
|
334 |
+
data = self.prepare_data(index)
|
335 |
+
# Broken images may cause the returned data to be None
|
336 |
+
if data is None:
|
337 |
+
index = self._rand_another()
|
338 |
+
continue
|
339 |
+
return data
|
340 |
+
|
341 |
+
def __len__(self):
|
342 |
+
return len(self.data)
|
343 |
+
|
344 |
+
@property
|
345 |
+
def modality_length(self):
|
346 |
+
self.group_length = []
|
347 |
+
for data_dict in self.data:
|
348 |
+
self.group_length.append(100)
|
349 |
+
return self.group_length
|
350 |
+
|
351 |
+
@property
|
352 |
+
def length(self):
|
353 |
+
group_length = np.array(self.group_length)
|
354 |
+
group_length = np.abs(group_length).tolist()
|
355 |
+
return group_length
|
356 |
+
|
357 |
+
def prepare_data(self, index):
|
358 |
+
data_dict: dict = self.data[index]
|
359 |
+
|
360 |
+
if data_dict is None:
|
361 |
+
return None
|
362 |
+
|
363 |
+
out_data_dict = {}
|
364 |
+
|
365 |
+
if self.skip_pure_text and data_dict.get('image', None) is None:
|
366 |
+
return None
|
367 |
+
|
368 |
+
if data_dict.get('image', None) is not None:
|
369 |
+
image_file = os.path.join(self.image_folder, data_dict['image'])
|
370 |
+
try:
|
371 |
+
image = Image.open(image_file).convert('RGB')
|
372 |
+
except Exception as e:
|
373 |
+
print(f'Error: {e}', flush=True)
|
374 |
+
print_log(f'Error: {e}', logger='current')
|
375 |
+
return None
|
376 |
+
if self.preprocessor is not None:
|
377 |
+
# images = dynamic_preprocess(image, self.min_dynamic_patch,
|
378 |
+
# self.max_dynamic_patch,
|
379 |
+
# self.image_size, self.use_thumbnail)
|
380 |
+
images = [image]
|
381 |
+
if self.arch_type == 'qwen':
|
382 |
+
_data_dict = self.preprocessor(images, do_resize=True)
|
383 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
384 |
+
_data_dict['image_grid_thw'] = torch.tensor(_data_dict['image_grid_thw'], dtype=torch.int)
|
385 |
+
num_image_tokens = int(_data_dict['image_grid_thw'][0].prod() * (self.downsample_ratio ** 2))
|
386 |
+
elif self.arch_type == 'llava':
|
387 |
+
_data_dict = self.preprocessor(images, do_resize=True, size=(self.image_size, self.image_size))
|
388 |
+
_data_dict['pixel_values'] = np.stack(_data_dict['pixel_values'], axis=0)
|
389 |
+
_data_dict['pixel_values'] = torch.tensor(_data_dict['pixel_values'], dtype=torch.float)
|
390 |
+
num_image_tokens = _data_dict['pixel_values'].shape[0] * self.patch_token
|
391 |
+
else:
|
392 |
+
raise NotImplementedError
|
393 |
+
out_data_dict.update(_data_dict)
|
394 |
+
else:
|
395 |
+
images = dynamic_preprocess(image, self.min_dynamic_patch,
|
396 |
+
self.max_dynamic_patch,
|
397 |
+
self.image_size, self.use_thumbnail)
|
398 |
+
pixel_values = [self.transformer(image) for image in images]
|
399 |
+
pixel_values = torch.stack(pixel_values)
|
400 |
+
out_data_dict['pixel_values'] = pixel_values
|
401 |
+
|
402 |
+
num_image_tokens = pixel_values.shape[0] * self.patch_token
|
403 |
+
image_token_str = f'{self.IMG_START_TOKEN}' \
|
404 |
+
f'{self.IMG_CONTEXT_TOKEN * num_image_tokens}' \
|
405 |
+
f'{self.IMG_END_TOKEN}'
|
406 |
+
token_dict = self.get_inputid_labels(
|
407 |
+
data_dict['conversations'], image_token_str)
|
408 |
+
out_data_dict.update(token_dict)
|
409 |
+
else:
|
410 |
+
token_dict = self.get_inputid_labels(
|
411 |
+
data_dict['conversations'], None)
|
412 |
+
out_data_dict.update(token_dict)
|
413 |
+
out_data_dict['pixel_values'] = torch.zeros(
|
414 |
+
1, 3, self.image_size, self.image_size)
|
415 |
+
return out_data_dict
|
416 |
+
|
417 |
+
def _rand_another(self) -> int:
|
418 |
+
return np.random.randint(0, len(self.data))
|
419 |
+
|
420 |
+
def get_inputid_labels(self, conversations, image_token_str) -> dict:
|
421 |
+
input = ''
|
422 |
+
out_conversation = []
|
423 |
+
while conversations and conversations[0]['from'] == 'gpt':
|
424 |
+
# Skip the first one if it is from gpt
|
425 |
+
conversations = conversations[1:]
|
426 |
+
for msg in conversations:
|
427 |
+
if msg['from'] == 'human':
|
428 |
+
if image_token_str is None and '<image>' in msg['value']:
|
429 |
+
msg['value'] = msg['value'].replace('<image>', '')
|
430 |
+
if '<image>' in msg['value']:
|
431 |
+
msg['value'] = msg['value'].replace('<image>', image_token_str).strip()
|
432 |
+
input += msg['value'].strip()
|
433 |
+
elif msg['from'] == 'gpt':
|
434 |
+
out_conversation.append({
|
435 |
+
'input': input,
|
436 |
+
'output': msg['value'].strip()
|
437 |
+
})
|
438 |
+
input = ''
|
439 |
+
else:
|
440 |
+
raise NotImplementedError
|
441 |
+
|
442 |
+
input_ids, labels = [], []
|
443 |
+
for i, single_turn_conversation in enumerate(out_conversation):
|
444 |
+
input = single_turn_conversation.get('input', '')
|
445 |
+
if input is None:
|
446 |
+
input = ''
|
447 |
+
input_text = self.template.INSTRUCTION.format(
|
448 |
+
input=input, round=i + 1)
|
449 |
+
|
450 |
+
if i == 0:
|
451 |
+
if self._system != '' and self._system is not None:
|
452 |
+
system = self.template.SYSTEM.format(system=self._system)
|
453 |
+
input_text = system + input_text
|
454 |
+
input_encode = self.tokenizer.encode(
|
455 |
+
input_text, add_special_tokens=True)
|
456 |
+
else:
|
457 |
+
input_encode = self.tokenizer.encode(
|
458 |
+
input_text, add_special_tokens=False)
|
459 |
+
input_ids += input_encode
|
460 |
+
labels += [IGNORE_INDEX] * len(input_encode)
|
461 |
+
|
462 |
+
output_text = single_turn_conversation.get('output', '')
|
463 |
+
if self.template.get('SUFFIX', None):
|
464 |
+
output_text += self.template.SUFFIX
|
465 |
+
output_encode = self.tokenizer.encode(
|
466 |
+
output_text, add_special_tokens=False)
|
467 |
+
input_ids += output_encode
|
468 |
+
labels += copy.deepcopy(output_encode)
|
469 |
+
|
470 |
+
if len(input_ids) > self.max_length:
|
471 |
+
input_ids = input_ids[:self.max_length]
|
472 |
+
labels = labels[:self.max_length]
|
473 |
+
print_log(
|
474 |
+
f'Warning: input_ids length({len(input_ids)}) '
|
475 |
+
f'is longer than max_length, cut to {self.max_length}',
|
476 |
+
logger='current')
|
477 |
+
return {'input_ids': input_ids, 'labels': labels}
|
478 |
+
|
479 |
+
|
480 |
+
if __name__ == '__main__':
|
481 |
+
from transformers import CLIPImageProcessor, AutoTokenizer
|
482 |
+
from third_parts.segment_anything.utils.transforms import ResizeLongestSide
|
483 |
+
pretrained_model = 'MBZUAI/GLaMM-GranD-Pretrained'
|
484 |
+
llm_name_or_path = 'lmsys/vicuna-7b-v1.5'
|
485 |
+
|
486 |
+
tokenizer = dict(
|
487 |
+
type=AutoTokenizer.from_pretrained,
|
488 |
+
pretrained_model_name_or_path=llm_name_or_path)
|
489 |
+
image_processor = dict(
|
490 |
+
type=CLIPImageProcessor.from_pretrained,
|
491 |
+
pretrained_model_name_or_path='openai/clip-vit-large-patch14-336')
|
492 |
+
extra_image_processor = dict(
|
493 |
+
type=ResizeLongestSide,
|
494 |
+
target_length=1024,
|
495 |
+
)
|
496 |
+
from xtuner.utils.templates import PROMPT_TEMPLATE
|
497 |
+
prompt_template = PROMPT_TEMPLATE.vicuna
|
498 |
+
from xtuner.dataset.map_fns import llava_map_fn, template_map_fn_factory, template_map_fn
|
499 |
+
from projects.glamm.datasets.collate_fns.glamm_collate_fn import glamm_collate_fn
|
500 |
+
|
501 |
+
dataset = LLaVADataset(
|
502 |
+
tokenizer=tokenizer,
|
503 |
+
data_path='data/llava_data/LLaVA-Instruct-150K/llava_instruct_150k.json',
|
504 |
+
prompt_template=prompt_template,
|
505 |
+
special_tokens=['[SEG]'],
|
506 |
+
image_folder='data/coco/train2017/',
|
507 |
+
)
|
508 |
+
for i in range(1000):
|
509 |
+
dataset[i]
|
projects/llava_sam2/deepspeed_zero2_sam2.json
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"gradient_accumulation_steps": "auto",
|
3 |
+
"train_micro_batch_size_per_gpu": "auto",
|
4 |
+
"gradient_clipping": "auto",
|
5 |
+
"zero_allow_untested_optimizer": true,
|
6 |
+
"zero_force_ds_cpu_optimizer": false,
|
7 |
+
"zero_optimization": {
|
8 |
+
"stage": 2,
|
9 |
+
"overlap_comm": true,
|
10 |
+
"allgather_bucket_size": 5368709120,
|
11 |
+
"reduce_bucket_size": 5368709120,
|
12 |
+
"reduce_scatter": true,
|
13 |
+
"sub_group_size": 1e9,
|
14 |
+
"contiguous_gradients": true,
|
15 |
+
"allgather_partitions": true
|
16 |
+
},
|
17 |
+
"fp16": {
|
18 |
+
"enabled": false,
|
19 |
+
"initial_scale_power": 16
|
20 |
+
},
|
21 |
+
"bf16": {
|
22 |
+
"enabled": true
|
23 |
+
}
|
24 |
+
}
|
projects/llava_sam2/gradio/app.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import sys
|
3 |
+
from projects.llava_sam2.gradio.app_utils import\
|
4 |
+
process_markdown, show_mask_pred, description, preprocess_video,\
|
5 |
+
show_mask_pred_video, image2video_and_save
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer,
|
9 |
+
BitsAndBytesConfig, CLIPImageProcessor,
|
10 |
+
CLIPVisionModel, GenerationConfig)
|
11 |
+
import argparse
|
12 |
+
import os
|
13 |
+
|
14 |
+
TORCH_DTYPE_MAP = dict(
|
15 |
+
fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto')
|
16 |
+
|
17 |
+
def parse_args(args):
|
18 |
+
parser = argparse.ArgumentParser(description="Sa2VA Demo")
|
19 |
+
parser.add_argument('hf_path', help='Sa2VA hf path.')
|
20 |
+
return parser.parse_args(args)
|
21 |
+
|
22 |
+
def inference(image, video, follow_up, input_str):
|
23 |
+
input_image = image
|
24 |
+
if image is not None and (video is not None and os.path.exists(video)):
|
25 |
+
return image, video, "Error: Please only input a image or a video !!!"
|
26 |
+
if image is None and (video is None or not os.path.exists(video)) and not follow_up:
|
27 |
+
return image, video, "Error: Please input a image or a video !!!"
|
28 |
+
|
29 |
+
if not follow_up:
|
30 |
+
# reset
|
31 |
+
print('Log: History responses have been removed!')
|
32 |
+
global_infos.n_turn = 0
|
33 |
+
global_infos.inputs = ''
|
34 |
+
text = input_str
|
35 |
+
|
36 |
+
image = input_image
|
37 |
+
global_infos.image_for_show = image
|
38 |
+
global_infos.image = image
|
39 |
+
video = video
|
40 |
+
global_infos.video = video
|
41 |
+
|
42 |
+
if image is not None:
|
43 |
+
global_infos.input_type = "image"
|
44 |
+
else:
|
45 |
+
global_infos.input_type = "video"
|
46 |
+
|
47 |
+
else:
|
48 |
+
text = input_str
|
49 |
+
image = global_infos.image
|
50 |
+
video = global_infos.video
|
51 |
+
|
52 |
+
input_type = global_infos.input_type
|
53 |
+
if input_type == "video":
|
54 |
+
video = preprocess_video(video, global_infos.inputs+input_str)
|
55 |
+
|
56 |
+
past_text = global_infos.inputs
|
57 |
+
|
58 |
+
if past_text == "" and "<image>" not in text:
|
59 |
+
text = "<image>" + text
|
60 |
+
if input_type == "image":
|
61 |
+
input_dict = {
|
62 |
+
'image': image,
|
63 |
+
'text': text,
|
64 |
+
'past_text': past_text,
|
65 |
+
'mask_prompts': None,
|
66 |
+
'tokenizer': tokenizer,
|
67 |
+
}
|
68 |
+
else:
|
69 |
+
input_dict = {
|
70 |
+
'video': video,
|
71 |
+
'text': text,
|
72 |
+
'past_text': past_text,
|
73 |
+
'mask_prompts': None,
|
74 |
+
'tokenizer': tokenizer,
|
75 |
+
}
|
76 |
+
|
77 |
+
return_dict = sa2va_model.predict_forward(**input_dict)
|
78 |
+
global_infos.inputs = return_dict["past_text"]
|
79 |
+
print(return_dict['past_text'])
|
80 |
+
if 'prediction_masks' in return_dict.keys() and return_dict['prediction_masks'] and len(
|
81 |
+
return_dict['prediction_masks']) != 0:
|
82 |
+
if input_type == "image":
|
83 |
+
image_mask_show, selected_colors = show_mask_pred(global_infos.image_for_show, return_dict['prediction_masks'],)
|
84 |
+
video_mask_show = global_infos.video
|
85 |
+
else:
|
86 |
+
image_mask_show = None
|
87 |
+
video_mask_show, selected_colors = show_mask_pred_video(video, return_dict['prediction_masks'],)
|
88 |
+
video_mask_show = image2video_and_save(video_mask_show, save_path="./ret_video.mp4")
|
89 |
+
else:
|
90 |
+
image_mask_show = global_infos.image_for_show
|
91 |
+
video_mask_show = global_infos.video
|
92 |
+
selected_colors = []
|
93 |
+
|
94 |
+
predict = return_dict['prediction'].strip()
|
95 |
+
global_infos.n_turn += 1
|
96 |
+
|
97 |
+
predict = process_markdown(predict, selected_colors)
|
98 |
+
return image_mask_show, video_mask_show, predict
|
99 |
+
|
100 |
+
def init_models(args):
|
101 |
+
model_path = args.hf_path
|
102 |
+
model = AutoModel.from_pretrained(
|
103 |
+
model_path,
|
104 |
+
torch_dtype=torch.bfloat16,
|
105 |
+
low_cpu_mem_usage=True,
|
106 |
+
use_flash_attn=True,
|
107 |
+
trust_remote_code=True,
|
108 |
+
).eval().cuda()
|
109 |
+
|
110 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
111 |
+
model_path,
|
112 |
+
trust_remote_code=True,
|
113 |
+
)
|
114 |
+
return model, tokenizer
|
115 |
+
|
116 |
+
class global_infos:
|
117 |
+
inputs = ''
|
118 |
+
n_turn = 0
|
119 |
+
image_width = 0
|
120 |
+
image_height = 0
|
121 |
+
|
122 |
+
image_for_show = None
|
123 |
+
image = None
|
124 |
+
video = None
|
125 |
+
|
126 |
+
input_type = "image" # "image" or "video"
|
127 |
+
|
128 |
+
if __name__ == "__main__":
|
129 |
+
# get parse args and set models
|
130 |
+
args = parse_args(sys.argv[1:])
|
131 |
+
|
132 |
+
sa2va_model, tokenizer = \
|
133 |
+
init_models(args)
|
134 |
+
|
135 |
+
demo = gr.Interface(
|
136 |
+
inference,
|
137 |
+
inputs=[
|
138 |
+
gr.Image(type="pil", label="Upload Image", height=360),
|
139 |
+
gr.Video(sources=["upload", "webcam"], label="Upload mp4 video", height=360),
|
140 |
+
gr.Checkbox(label="Follow up Question"),
|
141 |
+
gr.Textbox(lines=1, placeholder=None, label="Text Instruction"),],
|
142 |
+
outputs=[
|
143 |
+
gr.Image(type="pil", label="Output Image"),
|
144 |
+
gr.Video(label="Output Video", show_download_button=True, format='mp4'),
|
145 |
+
gr.Markdown()],
|
146 |
+
theme=gr.themes.Soft(), allow_flagging="auto", description=description,
|
147 |
+
title='Sa2VA'
|
148 |
+
)
|
149 |
+
|
150 |
+
demo.queue()
|
151 |
+
demo.launch(share=True)
|
projects/llava_sam2/gradio/app_utils.py
ADDED
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
markdown_default = """
|
6 |
+
<link href="https://fonts.googleapis.com/css2?family=Montserrat:wght@400;700&display=swap" rel="stylesheet">
|
7 |
+
<style>
|
8 |
+
.highlighted-text {
|
9 |
+
font-family: 'Montserrat', sans-serif;
|
10 |
+
font-weight: 600;
|
11 |
+
font-size: 14px;
|
12 |
+
color: rgb(255, 255, 239);
|
13 |
+
background-color: rgb(225, 231, 254);
|
14 |
+
border-radius: 7px;
|
15 |
+
padding: 5px 7px;
|
16 |
+
display: inline-block;
|
17 |
+
}
|
18 |
+
.regular-text {
|
19 |
+
font-family: 'Montserrat', sans-serif;
|
20 |
+
font-weight: 400;
|
21 |
+
font-size: 14px;
|
22 |
+
}
|
23 |
+
.highlighted-response {
|
24 |
+
font-family: 'Montserrat', sans-serif;
|
25 |
+
font-weight: 600;
|
26 |
+
font-size: 14px;
|
27 |
+
border-radius: 6px;
|
28 |
+
padding: 3px 4px;
|
29 |
+
display: inline-block;
|
30 |
+
}
|
31 |
+
</style>
|
32 |
+
<span class="highlighted-text" style='color:rgb(107, 100, 239)'>Sa2VA</span>
|
33 |
+
"""
|
34 |
+
|
35 |
+
description = """
|
36 |
+
**Usage** : <br>
|
37 |
+
 (1) For **Grounded Caption Generation** Interleaved Segmentation, input prompt like: *"Could you provide me with a detailed analysis of this photo? Please output with interleaved segmentation masks for the corresponding parts of the answer."* <br>
|
38 |
+
 (2) For **Segmentation Output**, input prompt like: *"Can you please segment xxx in the given image"* <br>
|
39 |
+
 (3) For **Image Captioning** VQA, input prompt like: *"Could you please give me a detailed description of the image?"* <br>
|
40 |
+
 (4) For **Image Conversation**, input arbitrary text instruction. <br>
|
41 |
+
"""
|
42 |
+
|
43 |
+
ONE_THIRD = 1.0/3.0
|
44 |
+
ONE_SIXTH = 1.0/6.0
|
45 |
+
TWO_THIRD = 2.0/3.0
|
46 |
+
|
47 |
+
def desaturate(rgb, factor=0.65):
|
48 |
+
"""
|
49 |
+
Desaturate an RGB color by a given factor.
|
50 |
+
|
51 |
+
:param rgb: A tuple of (r, g, b) where each value is in [0, 255].
|
52 |
+
:param factor: The factor by which to reduce the saturation.
|
53 |
+
0 means completely desaturated, 1 means original color.
|
54 |
+
:return: A tuple of desaturated (r, g, b) values in [0, 255].
|
55 |
+
"""
|
56 |
+
r, g, b = [x / 255.0 for x in rgb]
|
57 |
+
h, l, s = rgb_to_hls(r, g, b)
|
58 |
+
l = factor
|
59 |
+
new_r, new_g, new_b = hls_to_rgb(h, l, s)
|
60 |
+
return (int(new_r * 255), int(new_g * 255), int(new_b * 255))
|
61 |
+
|
62 |
+
def rgb_to_hls(r, g, b):
|
63 |
+
maxc = max(r, g, b)
|
64 |
+
minc = min(r, g, b)
|
65 |
+
sumc = (maxc+minc)
|
66 |
+
rangec = (maxc-minc)
|
67 |
+
l = sumc/2.0
|
68 |
+
if minc == maxc:
|
69 |
+
return 0.0, l, 0.0
|
70 |
+
if l <= 0.5:
|
71 |
+
s = rangec / sumc
|
72 |
+
else:
|
73 |
+
s = rangec / (2.0-sumc)
|
74 |
+
rc = (maxc-r) / rangec
|
75 |
+
gc = (maxc-g) / rangec
|
76 |
+
bc = (maxc-b) / rangec
|
77 |
+
if r == maxc:
|
78 |
+
h = bc-gc
|
79 |
+
elif g == maxc:
|
80 |
+
h = 2.0+rc-bc
|
81 |
+
else:
|
82 |
+
h = 4.0+gc-rc
|
83 |
+
h = (h/6.0) % 1.0
|
84 |
+
return h, l, s
|
85 |
+
|
86 |
+
def hls_to_rgb(h, l, s):
|
87 |
+
if s == 0.0:
|
88 |
+
return l, l, l
|
89 |
+
if l <= 0.5:
|
90 |
+
m2 = l * (1.0+s)
|
91 |
+
else:
|
92 |
+
m2 = l+s-(l*s)
|
93 |
+
m1 = 2.0*l - m2
|
94 |
+
return (_v(m1, m2, h+ONE_THIRD), _v(m1, m2, h), _v(m1, m2, h-ONE_THIRD))
|
95 |
+
|
96 |
+
def _v(m1, m2, hue):
|
97 |
+
hue = hue % 1.0
|
98 |
+
if hue < ONE_SIXTH:
|
99 |
+
return m1 + (m2-m1)*hue*6.0
|
100 |
+
if hue < 0.5:
|
101 |
+
return m2
|
102 |
+
if hue < TWO_THIRD:
|
103 |
+
return m1 + (m2-m1)*(TWO_THIRD-hue)*6.0
|
104 |
+
return m1
|
105 |
+
|
106 |
+
def process_markdown(output_str, colors):
|
107 |
+
output_str = output_str.replace("\n", "").replace(" ", " ").replace("<s>", "")\
|
108 |
+
.replace("<|im_end|>", '').replace("<|end|>", "")
|
109 |
+
output_str = output_str.split("ASSISTANT: ")[-1]
|
110 |
+
|
111 |
+
# markdown_out = output_str.replace('[SEG]', '')
|
112 |
+
markdown_out = output_str
|
113 |
+
markdown_out = markdown_out.replace(
|
114 |
+
"<p>", "<span class='highlighted-response' style='background-color:rgb[COLOR]'>"
|
115 |
+
)
|
116 |
+
markdown_out = markdown_out.replace("</p>", "</span>")
|
117 |
+
|
118 |
+
for color in colors:
|
119 |
+
markdown_out = markdown_out.replace("[COLOR]", str(desaturate(tuple(color))), 1)
|
120 |
+
|
121 |
+
markdown_out = f"""
|
122 |
+
{markdown_out}
|
123 |
+
"""
|
124 |
+
markdown_out = markdown_default + "<p><span class='regular-text'>" + markdown_out
|
125 |
+
return markdown_out
|
126 |
+
|
127 |
+
def show_mask_pred(image, masks):
|
128 |
+
masks = [mask[:1] for mask in masks]
|
129 |
+
masks = np.concatenate(masks, axis=0) # (n, h, w)
|
130 |
+
|
131 |
+
selected_colors = []
|
132 |
+
|
133 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
134 |
+
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
135 |
+
(128, 128, 255), [255, 192, 203], # Pink
|
136 |
+
[165, 42, 42], # Brown
|
137 |
+
[255, 165, 0], # Orange
|
138 |
+
[128, 0, 128], # Purple
|
139 |
+
[0, 0, 128], # Navy
|
140 |
+
[128, 0, 0], # Maroon
|
141 |
+
[128, 128, 0], # Olive
|
142 |
+
[70, 130, 180], # Steel Blue
|
143 |
+
[173, 216, 230], # Light Blue
|
144 |
+
[255, 192, 0], # Gold
|
145 |
+
[255, 165, 165], # Light Salmon
|
146 |
+
[255, 20, 147], # Deep Pink
|
147 |
+
]
|
148 |
+
|
149 |
+
_mask_image = np.zeros((masks.shape[1], masks.shape[2], 3), dtype=np.uint8)
|
150 |
+
|
151 |
+
for i, mask in enumerate(masks):
|
152 |
+
color = colors[i % len(colors)]
|
153 |
+
selected_colors.append(color)
|
154 |
+
_mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
|
155 |
+
_mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
|
156 |
+
_mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
|
157 |
+
|
158 |
+
|
159 |
+
image = np.array(image)
|
160 |
+
image = image * 0.5 + _mask_image * 0.5
|
161 |
+
image = image.astype(np.uint8)
|
162 |
+
return image, selected_colors
|
163 |
+
|
164 |
+
def show_mask_pred_video(video, masks):
|
165 |
+
ret_video = []
|
166 |
+
selected_colors = []
|
167 |
+
colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255),
|
168 |
+
(255, 255, 0), (255, 0, 255), (0, 255, 255),
|
169 |
+
(128, 128, 255), [255, 192, 203], # Pink
|
170 |
+
[165, 42, 42], # Brown
|
171 |
+
[255, 165, 0], # Orange
|
172 |
+
[128, 0, 128], # Purple
|
173 |
+
[0, 0, 128], # Navy
|
174 |
+
[128, 0, 0], # Maroon
|
175 |
+
[128, 128, 0], # Olive
|
176 |
+
[70, 130, 180], # Steel Blue
|
177 |
+
[173, 216, 230], # Light Blue
|
178 |
+
[255, 192, 0], # Gold
|
179 |
+
[255, 165, 165], # Light Salmon
|
180 |
+
[255, 20, 147], # Deep Pink
|
181 |
+
]
|
182 |
+
for i_frame in range(len(video)):
|
183 |
+
frame_masks = [mask[i_frame:i_frame+1] for mask in masks]
|
184 |
+
frame_masks = np.concatenate(frame_masks, axis=0)
|
185 |
+
_mask_image = np.zeros((frame_masks.shape[1], frame_masks.shape[2], 3), dtype=np.uint8)
|
186 |
+
|
187 |
+
for i, mask in enumerate(frame_masks):
|
188 |
+
if i_frame == 0:
|
189 |
+
color = colors[i % len(colors)]
|
190 |
+
selected_colors.append(color)
|
191 |
+
else:
|
192 |
+
color = selected_colors[i]
|
193 |
+
_mask_image[:, :, 0] = _mask_image[:, :, 0] + mask.astype(np.uint8) * color[0]
|
194 |
+
_mask_image[:, :, 1] = _mask_image[:, :, 1] + mask.astype(np.uint8) * color[1]
|
195 |
+
_mask_image[:, :, 2] = _mask_image[:, :, 2] + mask.astype(np.uint8) * color[2]
|
196 |
+
|
197 |
+
image = np.array(video[i_frame])
|
198 |
+
image = image * 0.5 + _mask_image * 0.5
|
199 |
+
image = image.astype(np.uint8)
|
200 |
+
ret_video.append(image)
|
201 |
+
return ret_video, selected_colors
|
202 |
+
|
203 |
+
def parse_visual_prompts(points):
|
204 |
+
ret = {'points': [], 'boxes': []}
|
205 |
+
for item in points:
|
206 |
+
if item[2] == 1.0:
|
207 |
+
ret['points'].append([item[0], item[1]])
|
208 |
+
elif item[2] == 2.0 or item[2] == 3.0:
|
209 |
+
ret['boxes'].append([item[0], item[1], item[3], item[4]])
|
210 |
+
else:
|
211 |
+
raise NotImplementedError
|
212 |
+
return ret
|
213 |
+
|
214 |
+
def get_video_frames(video_path):
|
215 |
+
cap = cv2.VideoCapture(video_path)
|
216 |
+
|
217 |
+
if not cap.isOpened():
|
218 |
+
print("Error: Cannot open video file.")
|
219 |
+
return
|
220 |
+
|
221 |
+
frames = []
|
222 |
+
|
223 |
+
frame_id = 0
|
224 |
+
while True:
|
225 |
+
ret, frame = cap.read()
|
226 |
+
|
227 |
+
if not ret:
|
228 |
+
break
|
229 |
+
|
230 |
+
frames.append(frame)
|
231 |
+
|
232 |
+
frame_id += 1
|
233 |
+
|
234 |
+
cap.release()
|
235 |
+
return frames
|
236 |
+
|
237 |
+
def get_frames_from_video(video_path, n_frames=5, sample_type="uniform"):
|
238 |
+
frames = get_video_frames(video_path)
|
239 |
+
if sample_type == "uniform":
|
240 |
+
stride = len(frames) / (n_frames + 1e-4)
|
241 |
+
ret = []
|
242 |
+
for i in range(n_frames):
|
243 |
+
idx = int(i * stride)
|
244 |
+
frame = frames[idx]
|
245 |
+
frame = frame[:, :, ::-1]
|
246 |
+
frame_image = Image.fromarray(frame).convert('RGB')
|
247 |
+
ret.append(frame_image)
|
248 |
+
else:
|
249 |
+
ret = []
|
250 |
+
for frame in frames[:500]:
|
251 |
+
frame = frame[:, :, ::-1]
|
252 |
+
frame_image = Image.fromarray(frame).convert('RGB')
|
253 |
+
ret.append(frame_image)
|
254 |
+
return ret
|
255 |
+
|
256 |
+
def preprocess_video(video_path, text):
|
257 |
+
if "Segment" in text or "segment" in text:
|
258 |
+
sample_type = 'begin'
|
259 |
+
else:
|
260 |
+
sample_type = 'uniform'
|
261 |
+
return get_frames_from_video(video_path, sample_type=sample_type)
|
262 |
+
|
263 |
+
def image2video_and_save(frames, save_path):
|
264 |
+
success = frames_to_video(frames, save_path)
|
265 |
+
return save_path
|
266 |
+
|
267 |
+
|
268 |
+
def frames_to_video(
|
269 |
+
frames,
|
270 |
+
output_path: str,
|
271 |
+
fps: int = 24,
|
272 |
+
) -> bool:
|
273 |
+
try:
|
274 |
+
frames = [frame[:, :, ::-1] for frame in frames]
|
275 |
+
# Use provided frame size or get from first frame
|
276 |
+
height, width = frames[0].shape[:2]
|
277 |
+
|
278 |
+
# Initialize video writer
|
279 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
280 |
+
out = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
|
281 |
+
|
282 |
+
# Process each frame
|
283 |
+
for frame in frames:
|
284 |
+
out.write(frame)
|
285 |
+
|
286 |
+
# Release video writer
|
287 |
+
out.release()
|
288 |
+
print(f"Video saved successfully to {output_path}")
|
289 |
+
return True
|
290 |
+
|
291 |
+
except Exception as e:
|
292 |
+
print(f"Error converting frames to video: {str(e)}")
|
293 |
+
return False
|
projects/llava_sam2/models/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .llava_sam2 import VideoLLaVASAMModel, VideoLLaVASAMModel_zero3
|
2 |
+
from .sam2 import SAM2
|
3 |
+
from .sam2_train import SAM2TrainRunner
|
projects/llava_sam2/models/extension/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .sam2_base import SAM2Base
|
projects/llava_sam2/models/extension/sam2_base.py
ADDED
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
|
4 |
+
from third_parts.sam2.modeling.sam2_base import SAM2Base as _SAM2Base
|
5 |
+
from third_parts.sam2.modeling.sam2_base import NO_OBJ_SCORE
|
6 |
+
|
7 |
+
|
8 |
+
class SAM2Base(_SAM2Base):
|
9 |
+
|
10 |
+
def track_step(
|
11 |
+
self,
|
12 |
+
frame_idx,
|
13 |
+
is_init_cond_frame,
|
14 |
+
current_vision_feats,
|
15 |
+
current_vision_pos_embeds,
|
16 |
+
feat_sizes,
|
17 |
+
point_inputs,
|
18 |
+
mask_inputs,
|
19 |
+
output_dict,
|
20 |
+
num_frames,
|
21 |
+
track_in_reverse=False, # tracking in reverse time order (for demo usage)
|
22 |
+
# Whether to run the memory encoder on the predicted masks. Sometimes we might want
|
23 |
+
# to skip the memory encoder with `run_mem_encoder=False`. For example,
|
24 |
+
# in demo we might call `track_step` multiple times for each user click,
|
25 |
+
# and only encode the memory when the user finalizes their clicks. And in ablation
|
26 |
+
# settings like SAM training on static images, we don't need the memory encoder.
|
27 |
+
run_mem_encoder=True,
|
28 |
+
# The previously predicted SAM mask logits (which can be fed together with new clicks in demo).
|
29 |
+
prev_sam_mask_logits=None,
|
30 |
+
## Extension: LLM prompt
|
31 |
+
language_embd=None,
|
32 |
+
):
|
33 |
+
current_out = {"point_inputs": point_inputs, "mask_inputs": mask_inputs}
|
34 |
+
# High-resolution feature maps for the SAM head, reshape (HW)BC => BCHW
|
35 |
+
if len(current_vision_feats) > 1:
|
36 |
+
high_res_features = [
|
37 |
+
x.permute(1, 2, 0).view(x.size(1), x.size(2), *s)
|
38 |
+
for x, s in zip(current_vision_feats[:-1], feat_sizes[:-1])
|
39 |
+
]
|
40 |
+
else:
|
41 |
+
high_res_features = None
|
42 |
+
if mask_inputs is not None and self.use_mask_input_as_output_without_sam:
|
43 |
+
# When use_mask_input_as_output_without_sam=True, we directly output the mask input
|
44 |
+
# (see it as a GT mask) without using a SAM prompt encoder + mask decoder.
|
45 |
+
pix_feat = current_vision_feats[-1].permute(1, 2, 0)
|
46 |
+
pix_feat = pix_feat.view(-1, self.hidden_dim, *feat_sizes[-1])
|
47 |
+
sam_outputs = self._use_mask_as_output(
|
48 |
+
pix_feat, high_res_features, mask_inputs
|
49 |
+
)
|
50 |
+
else:
|
51 |
+
# fused the visual feature with previous memory features in the memory bank
|
52 |
+
pix_feat_with_mem = self._prepare_memory_conditioned_features(
|
53 |
+
frame_idx=frame_idx,
|
54 |
+
is_init_cond_frame=is_init_cond_frame,
|
55 |
+
current_vision_feats=current_vision_feats[-1:],
|
56 |
+
current_vision_pos_embeds=current_vision_pos_embeds[-1:],
|
57 |
+
feat_sizes=feat_sizes[-1:],
|
58 |
+
output_dict=output_dict,
|
59 |
+
num_frames=num_frames,
|
60 |
+
track_in_reverse=track_in_reverse,
|
61 |
+
)
|
62 |
+
# apply SAM-style segmentation head
|
63 |
+
# here we might feed previously predicted low-res SAM mask logits into the SAM mask decoder,
|
64 |
+
# e.g. in demo where such logits come from earlier interaction instead of correction sampling
|
65 |
+
# (in this case, any `mask_inputs` shouldn't reach here as they are sent to _use_mask_as_output instead)
|
66 |
+
if prev_sam_mask_logits is not None:
|
67 |
+
assert point_inputs is not None and mask_inputs is None
|
68 |
+
mask_inputs = prev_sam_mask_logits
|
69 |
+
multimask_output = self._use_multimask(is_init_cond_frame, point_inputs)
|
70 |
+
sam_outputs = self._forward_sam_heads(
|
71 |
+
backbone_features=pix_feat_with_mem,
|
72 |
+
point_inputs=point_inputs,
|
73 |
+
mask_inputs=mask_inputs,
|
74 |
+
high_res_features=high_res_features,
|
75 |
+
multimask_output=multimask_output,
|
76 |
+
# Inject language Embed if possible
|
77 |
+
language_embd=language_embd,
|
78 |
+
)
|
79 |
+
(
|
80 |
+
_,
|
81 |
+
_,
|
82 |
+
_,
|
83 |
+
low_res_masks,
|
84 |
+
high_res_masks,
|
85 |
+
obj_ptr,
|
86 |
+
_,
|
87 |
+
) = sam_outputs
|
88 |
+
|
89 |
+
current_out["pred_masks"] = low_res_masks
|
90 |
+
current_out["pred_masks_high_res"] = high_res_masks
|
91 |
+
current_out["obj_ptr"] = obj_ptr
|
92 |
+
|
93 |
+
# Finally run the memory encoder on the predicted mask to encode
|
94 |
+
# it into a new memory feature (that can be used in future frames)
|
95 |
+
if run_mem_encoder and self.num_maskmem > 0:
|
96 |
+
high_res_masks_for_mem_enc = high_res_masks
|
97 |
+
maskmem_features, maskmem_pos_enc = self._encode_new_memory(
|
98 |
+
current_vision_feats=current_vision_feats,
|
99 |
+
feat_sizes=feat_sizes,
|
100 |
+
pred_masks_high_res=high_res_masks_for_mem_enc,
|
101 |
+
is_mask_from_pts=(point_inputs is not None),
|
102 |
+
)
|
103 |
+
current_out["maskmem_features"] = maskmem_features
|
104 |
+
current_out["maskmem_pos_enc"] = maskmem_pos_enc
|
105 |
+
else:
|
106 |
+
current_out["maskmem_features"] = None
|
107 |
+
current_out["maskmem_pos_enc"] = None
|
108 |
+
|
109 |
+
return current_out
|
110 |
+
|
111 |
+
|
112 |
+
def _forward_sam_heads(
|
113 |
+
self,
|
114 |
+
backbone_features,
|
115 |
+
point_inputs=None,
|
116 |
+
mask_inputs=None,
|
117 |
+
high_res_features=None,
|
118 |
+
multimask_output=False,
|
119 |
+
## Extension: LLM prompt
|
120 |
+
language_embd=None,
|
121 |
+
):
|
122 |
+
"""
|
123 |
+
Forward SAM prompt encoders and mask heads.
|
124 |
+
|
125 |
+
Inputs:
|
126 |
+
- backbone_features: image features of [B, C, H, W] shape
|
127 |
+
- point_inputs: a dictionary with "point_coords" and "point_labels", where
|
128 |
+
1) "point_coords" has [B, P, 2] shape and float32 dtype and contains the
|
129 |
+
absolute pixel-unit coordinate in (x, y) format of the P input points
|
130 |
+
2) "point_labels" has shape [B, P] and int32 dtype, where 1 means
|
131 |
+
positive clicks, 0 means negative clicks, and -1 means padding
|
132 |
+
- mask_inputs: a mask of [B, 1, H*16, W*16] shape, float or bool, with the
|
133 |
+
same spatial size as the image.
|
134 |
+
- high_res_features: either 1) None or 2) or a list of length 2 containing
|
135 |
+
two feature maps of [B, C, 4*H, 4*W] and [B, C, 2*H, 2*W] shapes respectively,
|
136 |
+
which will be used as high-resolution feature maps for SAM decoder.
|
137 |
+
- multimask_output: if it's True, we output 3 candidate masks and their 3
|
138 |
+
corresponding IoU estimates, and if it's False, we output only 1 mask and
|
139 |
+
its corresponding IoU estimate.
|
140 |
+
|
141 |
+
Outputs:
|
142 |
+
- low_res_multimasks: [B, M, H*4, W*4] shape (where M = 3 if
|
143 |
+
`multimask_output=True` and M = 1 if `multimask_output=False`), the SAM
|
144 |
+
output mask logits (before sigmoid) for the low-resolution masks, with 4x
|
145 |
+
the resolution (1/4 stride) of the input backbone_features.
|
146 |
+
- high_res_multimasks: [B, M, H*16, W*16] shape (where M = 3
|
147 |
+
if `multimask_output=True` and M = 1 if `multimask_output=False`),
|
148 |
+
upsampled from the low-resolution masks, with shape size as the image
|
149 |
+
(stride is 1 pixel).
|
150 |
+
- ious, [B, M] shape, where (where M = 3 if `multimask_output=True` and M = 1
|
151 |
+
if `multimask_output=False`), the estimated IoU of each output mask.
|
152 |
+
- low_res_masks: [B, 1, H*4, W*4] shape, the best mask in `low_res_multimasks`.
|
153 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
154 |
+
If `multimask_output=False`, it's the same as `low_res_multimasks`.
|
155 |
+
- high_res_masks: [B, 1, H*16, W*16] shape, the best mask in `high_res_multimasks`.
|
156 |
+
If `multimask_output=True`, it's the mask with the highest IoU estimate.
|
157 |
+
If `multimask_output=False`, it's the same as `high_res_multimasks`.
|
158 |
+
- obj_ptr: [B, C] shape, the object pointer vector for the output mask, extracted
|
159 |
+
based on the output token from the SAM mask decoder.
|
160 |
+
"""
|
161 |
+
B = backbone_features.size(0)
|
162 |
+
device = backbone_features.device
|
163 |
+
assert backbone_features.size(1) == self.sam_prompt_embed_dim
|
164 |
+
assert backbone_features.size(2) == self.sam_image_embedding_size
|
165 |
+
assert backbone_features.size(3) == self.sam_image_embedding_size
|
166 |
+
|
167 |
+
# a) Handle point prompts
|
168 |
+
if point_inputs is not None:
|
169 |
+
sam_point_coords = point_inputs["point_coords"]
|
170 |
+
sam_point_labels = point_inputs["point_labels"]
|
171 |
+
assert sam_point_coords.size(0) == B and sam_point_labels.size(0) == B
|
172 |
+
else:
|
173 |
+
# If no points are provide, pad with an empty point (with label -1)
|
174 |
+
sam_point_coords = torch.zeros(B, 1, 2, device=device)
|
175 |
+
sam_point_labels = -torch.ones(B, 1, dtype=torch.int32, device=device)
|
176 |
+
|
177 |
+
# b) Handle mask prompts
|
178 |
+
if mask_inputs is not None:
|
179 |
+
# If mask_inputs is provided, downsize it into low-res mask input if needed
|
180 |
+
# and feed it as a dense mask prompt into the SAM mask encoder
|
181 |
+
assert len(mask_inputs.shape) == 4 and mask_inputs.shape[:2] == (B, 1)
|
182 |
+
if mask_inputs.shape[-2:] != self.sam_prompt_encoder.mask_input_size:
|
183 |
+
sam_mask_prompt = F.interpolate(
|
184 |
+
mask_inputs.float(),
|
185 |
+
size=self.sam_prompt_encoder.mask_input_size,
|
186 |
+
align_corners=False,
|
187 |
+
mode="bilinear",
|
188 |
+
antialias=True, # use antialias for downsampling
|
189 |
+
)
|
190 |
+
else:
|
191 |
+
sam_mask_prompt = mask_inputs
|
192 |
+
else:
|
193 |
+
# Otherwise, simply feed None (and SAM's prompt encoder will add
|
194 |
+
# a learned `no_mask_embed` to indicate no mask input in this case).
|
195 |
+
sam_mask_prompt = None
|
196 |
+
|
197 |
+
sparse_embeddings, dense_embeddings = self.sam_prompt_encoder(
|
198 |
+
points=(sam_point_coords, sam_point_labels),
|
199 |
+
boxes=None,
|
200 |
+
masks=sam_mask_prompt,
|
201 |
+
)
|
202 |
+
|
203 |
+
## Extension: LLM prompt
|
204 |
+
if language_embd is not None:
|
205 |
+
# B N C
|
206 |
+
assert sparse_embeddings.size(0) == language_embd.size(0)
|
207 |
+
assert sparse_embeddings.size(2) == language_embd.size(2)
|
208 |
+
sparse_embeddings = torch.cat([sparse_embeddings, language_embd], dim=1)
|
209 |
+
|
210 |
+
(
|
211 |
+
low_res_multimasks,
|
212 |
+
ious,
|
213 |
+
sam_output_tokens,
|
214 |
+
object_score_logits,
|
215 |
+
) = self.sam_mask_decoder(
|
216 |
+
image_embeddings=backbone_features,
|
217 |
+
image_pe=self.sam_prompt_encoder.get_dense_pe(),
|
218 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
219 |
+
dense_prompt_embeddings=dense_embeddings,
|
220 |
+
multimask_output=multimask_output,
|
221 |
+
repeat_image=False, # the image is already batched
|
222 |
+
high_res_features=high_res_features,
|
223 |
+
)
|
224 |
+
if self.pred_obj_scores:
|
225 |
+
is_obj_appearing = object_score_logits > 0
|
226 |
+
|
227 |
+
# Mask used for spatial memories is always a *hard* choice between obj and no obj,
|
228 |
+
# consistent with the actual mask prediction
|
229 |
+
# print('Do torch.where !!!')
|
230 |
+
# low_res_multimasks = torch.where(
|
231 |
+
# is_obj_appearing[:, None, None],
|
232 |
+
# low_res_multimasks,
|
233 |
+
# NO_OBJ_SCORE,
|
234 |
+
# )
|
235 |
+
|
236 |
+
# convert masks from possibly bfloat16 (or float16) to float32
|
237 |
+
# (older PyTorch versions before 2.1 don't support `interpolate` on bf16)
|
238 |
+
low_res_multimasks = low_res_multimasks.float()
|
239 |
+
high_res_multimasks = F.interpolate(
|
240 |
+
low_res_multimasks,
|
241 |
+
size=(self.image_size, self.image_size),
|
242 |
+
mode="bilinear",
|
243 |
+
align_corners=False,
|
244 |
+
)
|
245 |
+
|
246 |
+
sam_output_token = sam_output_tokens[:, 0]
|
247 |
+
if multimask_output:
|
248 |
+
# take the best mask prediction (with the highest IoU estimation)
|
249 |
+
best_iou_inds = torch.argmax(ious, dim=-1)
|
250 |
+
batch_inds = torch.arange(B, device=device)
|
251 |
+
low_res_masks = low_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
252 |
+
high_res_masks = high_res_multimasks[batch_inds, best_iou_inds].unsqueeze(1)
|
253 |
+
if sam_output_tokens.size(1) > 1:
|
254 |
+
sam_output_token = sam_output_tokens[batch_inds, best_iou_inds]
|
255 |
+
else:
|
256 |
+
low_res_masks, high_res_masks = low_res_multimasks, high_res_multimasks
|
257 |
+
|
258 |
+
# Extract object pointer from the SAM output token (with occlusion handling)
|
259 |
+
obj_ptr = self.obj_ptr_proj(sam_output_token)
|
260 |
+
if self.pred_obj_scores:
|
261 |
+
# Allow *soft* no obj ptr, unlike for masks
|
262 |
+
if self.soft_no_obj_ptr:
|
263 |
+
# Only hard possible with gt
|
264 |
+
assert not self.teacher_force_obj_scores_for_mem
|
265 |
+
lambda_is_obj_appearing = object_score_logits.sigmoid()
|
266 |
+
else:
|
267 |
+
lambda_is_obj_appearing = is_obj_appearing.float()
|
268 |
+
|
269 |
+
if self.fixed_no_obj_ptr:
|
270 |
+
obj_ptr = lambda_is_obj_appearing * obj_ptr
|
271 |
+
obj_ptr = obj_ptr + (1 - lambda_is_obj_appearing) * self.no_obj_ptr
|
272 |
+
|
273 |
+
return (
|
274 |
+
low_res_multimasks,
|
275 |
+
high_res_multimasks,
|
276 |
+
ious,
|
277 |
+
low_res_masks,
|
278 |
+
high_res_masks,
|
279 |
+
obj_ptr,
|
280 |
+
object_score_logits,
|
281 |
+
)
|