Upload 41 files
Browse files- .gitattributes +3 -0
- .gitignore +160 -0
- LICENSE +28 -0
- README.md +77 -12
- configs/config.py +68 -0
- dataset/data_helper.py +98 -0
- dataset/data_module.py +73 -0
- evalcap/__init__.py +0 -0
- evalcap/bleu/LICENSE +23 -0
- evalcap/bleu/__init__.py +1 -0
- evalcap/bleu/bleu.py +50 -0
- evalcap/bleu/bleu_scorer.py +264 -0
- evalcap/cider/__init__.py +1 -0
- evalcap/cider/cider.py +57 -0
- evalcap/cider/cider_scorer.py +193 -0
- evalcap/meteor/__init__.py +1 -0
- evalcap/meteor/meteor-1.5.jar +3 -0
- evalcap/meteor/meteor.py +130 -0
- evalcap/meteor/test_meteor.py +10 -0
- evalcap/rouge/__init__.py +1 -0
- evalcap/rouge/rouge.py +105 -0
- evalcap/tokenizer/__init__.py +1 -0
- evalcap/tokenizer/ptbtokenizer.py +68 -0
- evalcap/tokenizer/stanford-corenlp-3.4.1.jar +3 -0
- images/align.png +3 -0
- lightning_tools/callbacks.py +30 -0
- lightning_tools/optim.py +59 -0
- models/R2GenGPT.py +379 -0
- requirements.txt +7 -4
- scripts/1-1.shallow_run_iuxray.sh +30 -0
- scripts/1-2.shallow_test_iuxray.sh +28 -0
- scripts/2-1.delta_run_iuxray.sh +30 -0
- scripts/2-2.delta_test_iuxray.sh +28 -0
- scripts/3-1.deep_run_iuxray.sh +30 -0
- scripts/3-2.deep_test_iuxray.sh +28 -0
- scripts/4-1.shallow_run.sh +39 -0
- scripts/4-2.shallow_test.sh +28 -0
- scripts/5-1.delta_run.sh +39 -0
- scripts/5-2.delta_test.sh +30 -0
- scripts/6-1.deep_run.sh +38 -0
- scripts/6-2.deep_test.sh +28 -0
- train.py +51 -0
.gitattributes
CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
evalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
|
37 |
+
evalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
|
38 |
+
images/align.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
#.idea/
|
LICENSE
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
BSD 3-Clause License
|
2 |
+
|
3 |
+
Copyright (c) 2023, zhanyuwang
|
4 |
+
|
5 |
+
Redistribution and use in source and binary forms, with or without
|
6 |
+
modification, are permitted provided that the following conditions are met:
|
7 |
+
|
8 |
+
1. Redistributions of source code must retain the above copyright notice, this
|
9 |
+
list of conditions and the following disclaimer.
|
10 |
+
|
11 |
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
12 |
+
this list of conditions and the following disclaimer in the documentation
|
13 |
+
and/or other materials provided with the distribution.
|
14 |
+
|
15 |
+
3. Neither the name of the copyright holder nor the names of its
|
16 |
+
contributors may be used to endorse or promote products derived from
|
17 |
+
this software without specific prior written permission.
|
18 |
+
|
19 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
20 |
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
21 |
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
22 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
23 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
24 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
25 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
26 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
27 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
28 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
README.md
CHANGED
@@ -1,12 +1,77 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# R2GenGPT: Radiology Report Generation with Frozen LLMs
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
![overview](https://github.com/wang-zhanyu/R2GenGPT/blob/main/images/align.png)
|
5 |
+
|
6 |
+
## Getting Started
|
7 |
+
### Installation
|
8 |
+
|
9 |
+
**1. Prepare the code and the environment**
|
10 |
+
|
11 |
+
Git clone our repository and install the requirements.
|
12 |
+
|
13 |
+
```bash
|
14 |
+
https://github.com/wang-zhanyu/R2GenGPT.git
|
15 |
+
cd R2GenGPT
|
16 |
+
pip install -r requirements.txt
|
17 |
+
```
|
18 |
+
|
19 |
+
|
20 |
+
**2. Prepare the training dataset**
|
21 |
+
|
22 |
+
IU-xray: download the dataset from [here](https://drive.google.com/file/d/1c0BXEuDy8Cmm2jfN0YYGkQxFZd2ZIoLg/view)
|
23 |
+
|
24 |
+
Mimic-cxr: you can download our preprocess annotation file from [here](https://drive.google.com/file/d/14689ztodTtrQJYs--ihB_hgsPMMNHX-H/view?usp=sharing) and download the images from [official website](https://physionet.org/content/mimic-cxr-jpg/2.0.0/)
|
25 |
+
|
26 |
+
After downloading the data, place it in the ./data folder.
|
27 |
+
|
28 |
+
### Training
|
29 |
+
|
30 |
+
For shallow alignment
|
31 |
+
|
32 |
+
```bash
|
33 |
+
bash scripts/4-1.shallow_run.sh
|
34 |
+
```
|
35 |
+
|
36 |
+
For delta alignment
|
37 |
+
|
38 |
+
```bash
|
39 |
+
bash scripts/5-1.delta_run.sh
|
40 |
+
```
|
41 |
+
|
42 |
+
For deep alignment
|
43 |
+
|
44 |
+
```bash
|
45 |
+
bash scripts/6-1.deep_run.sh
|
46 |
+
```
|
47 |
+
|
48 |
+
### Testing (For MIMIC-CXR)
|
49 |
+
You can download our pretrained Delta checkpoints for [Here](https://drive.google.com/drive/folders/1ywEITWfYIAAYy0VY1IZ24Ec_GoNmkqIY?usp=sharing)
|
50 |
+
|
51 |
+
For shallow alignment
|
52 |
+
|
53 |
+
```bash
|
54 |
+
bash scripts/4-2.shallow_test.sh
|
55 |
+
```
|
56 |
+
|
57 |
+
For delta alignment
|
58 |
+
|
59 |
+
```bash
|
60 |
+
bash scripts/5-2.delta_test.sh
|
61 |
+
```
|
62 |
+
|
63 |
+
For deep alignment
|
64 |
+
|
65 |
+
```bash
|
66 |
+
bash scripts/6-2.shallow_test.sh
|
67 |
+
```
|
68 |
+
|
69 |
+
|
70 |
+
## Acknowledgement
|
71 |
+
|
72 |
+
+ [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) Some codes of this repo are based on MiniGPT-4.
|
73 |
+
+ [Llama2](https://github.com/facebookresearch/llama) The fantastic language ability of Llama-2 with only 7B parameters is just amazing.
|
74 |
+
|
75 |
+
|
76 |
+
## License
|
77 |
+
This repository is under [BSD 3-Clause License](LICENSE.md).
|
configs/config.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
parser = argparse.ArgumentParser(description="hyper-parameter for R2GenGPT")
|
4 |
+
# ========================= Dataset Configs ==========================
|
5 |
+
parser.add_argument('--test', action='store_true', help="only run test set")
|
6 |
+
parser.add_argument('--validate', action='store_true', help="only run validation set")
|
7 |
+
parser.add_argument('--dataset', type=str, default='mimic_cxr', help="iu-xray or mimic-cxr")
|
8 |
+
parser.add_argument('--annotation', type=str, default=r'./data/mimic_cxr/annotation.json', help="annotation file of the dataset")
|
9 |
+
parser.add_argument('--base_dir', type=str, default=r'./data/mimic_cxr/images', help="base dir to help find images")
|
10 |
+
parser.add_argument('--batch_size', default=6, type=int, help="use for training duration per worker")
|
11 |
+
parser.add_argument('--val_batch_size', default=16, type=int, help="use for validation duration per worker")
|
12 |
+
parser.add_argument('--test_batch_size', default=16, type=int, help="use for testing duration per worker")
|
13 |
+
parser.add_argument('--prefetch_factor', default=4, type=int, help="use for training duration per worker")
|
14 |
+
parser.add_argument('--num_workers', default=8, type=int, help="Cpu num for dataloaders")
|
15 |
+
|
16 |
+
# ========================= Model Settings ============================
|
17 |
+
parser.add_argument('--vision_model', default='microsoft/swin-base-patch4-window7-224', type=str, help="vision model to use")
|
18 |
+
parser.add_argument('--llama_model', default='meta-llama/Llama-2-7b-chat-hf', type=str, help="LLM model to use")
|
19 |
+
parser.add_argument('--freeze_vm', default=True, type=lambda x: (str(x).lower() == 'true'), help='freeze vision model')
|
20 |
+
parser.add_argument('--llm_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for LLM model")
|
21 |
+
parser.add_argument('--llm_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
|
22 |
+
parser.add_argument('--llm_alpha', default=16, type=int, help='Scaling factor.')
|
23 |
+
parser.add_argument('--vis_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for vision model")
|
24 |
+
parser.add_argument('--vis_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
|
25 |
+
parser.add_argument('--vis_alpha', default=16, type=int, help='Scaling factor.')
|
26 |
+
parser.add_argument('--lora_dropout', default=0.1, type=float, help='lora dropout')
|
27 |
+
parser.add_argument('--global_only', default=False, type=lambda x: (str(x).lower() == 'true'), help='use global embedding only')
|
28 |
+
parser.add_argument('--low_resource', default=False, type=bool)
|
29 |
+
parser.add_argument('--end_sym', default='</s>', type=str)
|
30 |
+
|
31 |
+
# ======================== SavedModel Configs ===========================
|
32 |
+
parser.add_argument('--savedmodel_path', type=str, default='save/mimic/v1')
|
33 |
+
parser.add_argument('--ckpt_file', type=str, default=None, help='the checkpoint file to load')
|
34 |
+
parser.add_argument('--delta_file', type=str, default=None, help='the delta file to load')
|
35 |
+
parser.add_argument('--weights', type=list, default=[0.5, 0.5])
|
36 |
+
parser.add_argument('--scorer_types', type=list, default=['Bleu_4', 'CIDEr'])
|
37 |
+
|
38 |
+
# ========================= Learning Configs ==========================
|
39 |
+
parser.add_argument('--learning_rate', default=1e-4, type=float, help='initial learning rate')
|
40 |
+
parser.add_argument('--gradient_clip_val', default=None, type=int, help='gradient clip value')
|
41 |
+
|
42 |
+
# ========================= Decoding Settings ==========================
|
43 |
+
parser.add_argument('--beam_size', type=int, default=3)
|
44 |
+
parser.add_argument('--do_sample', type=bool, default=False)
|
45 |
+
parser.add_argument('--no_repeat_ngram_size', type=int, default=2)
|
46 |
+
parser.add_argument('--num_beam_groups', type=int, default=1)
|
47 |
+
parser.add_argument('--min_new_tokens', type=int, default=80)
|
48 |
+
parser.add_argument('--max_new_tokens', type=int, default=120)
|
49 |
+
parser.add_argument('--max_length', type=int, default=100)
|
50 |
+
parser.add_argument('--repetition_penalty', type=float, default=2.0)
|
51 |
+
parser.add_argument('--length_penalty', type=float, default=2.0)
|
52 |
+
parser.add_argument('--diversity_penalty', type=float, default=0)
|
53 |
+
parser.add_argument('--temperature', type=float, default=0)
|
54 |
+
|
55 |
+
# ====================== Pytorch Lightning ===========================
|
56 |
+
parser.add_argument('--devices', type=int, default=2, help='how many gpus to use')
|
57 |
+
parser.add_argument('--num_nodes', type=int, default=1, help='Number of GPU nodes for distributed training.')
|
58 |
+
parser.add_argument('--accelerator', type=str, default="gpu", choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps"], help='accelerator types')
|
59 |
+
parser.add_argument('--strategy', type=str, default="ddp", help='default ddp for multi-gpus')
|
60 |
+
parser.add_argument('--precision', type=str, default='bf16-mixed', help='16 or 32 bf16-mixed, using for original pytorch amp auto cast')
|
61 |
+
parser.add_argument('--limit_val_batches', type=float, default=1.0, help='How much of validation dataset to check (float = fraction, int = num_batches).')
|
62 |
+
parser.add_argument('--limit_test_batches', type=float, default=1.0, help='How much of test dataset to check (float = fraction, int = num_batches).')
|
63 |
+
parser.add_argument('--limit_train_batches', type=float, default=1.0, help='How much of training dataset to check (float = fraction, int = num_batches)')
|
64 |
+
parser.add_argument('--max_epochs', type=int, default=3, help='Stop training once this number of epochs is reached')
|
65 |
+
parser.add_argument('--every_n_train_steps', type=int, default=0, help='How many training steps to save a checkpoint')
|
66 |
+
parser.add_argument('--val_check_interval', type=float, default=1.0, help='How often to check the validation set')
|
67 |
+
parser.add_argument('--accumulate_grad_batches', type=int, default=1, help='Accumulates gradients over k batches before stepping the optimizer')
|
68 |
+
parser.add_argument("--num_sanity_val_steps", type=int, default=2, help='Sanity check runs n validation batches before starting the training routine')
|
dataset/data_helper.py
ADDED
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import os
|
3 |
+
import json
|
4 |
+
import re
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
import torch.utils.data as data
|
8 |
+
from transformers import BertTokenizer, AutoImageProcessor
|
9 |
+
|
10 |
+
|
11 |
+
class FieldParser:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
args
|
15 |
+
):
|
16 |
+
super().__init__()
|
17 |
+
self.args = args
|
18 |
+
self.dataset = args.dataset
|
19 |
+
self.vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model)
|
20 |
+
|
21 |
+
|
22 |
+
def _parse_image(self, img):
|
23 |
+
pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
|
24 |
+
return pixel_values[0]
|
25 |
+
|
26 |
+
# from https://github.com/cuhksz-nlp/R2Gen/blob/main/modules/tokenizers.py
|
27 |
+
def clean_report(self, report):
|
28 |
+
# clean Iu-xray reports
|
29 |
+
if self.dataset == "iu_xray":
|
30 |
+
report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
|
31 |
+
.replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
|
32 |
+
.replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
|
33 |
+
.strip().lower().split('. ')
|
34 |
+
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
|
35 |
+
replace('\\', '').replace("'", '').strip().lower())
|
36 |
+
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
|
37 |
+
report = ' . '.join(tokens) + ' .'
|
38 |
+
# clean MIMIC-CXR reports
|
39 |
+
else:
|
40 |
+
report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
|
41 |
+
.replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
|
42 |
+
.replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
|
43 |
+
.replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
|
44 |
+
.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
|
45 |
+
.replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
|
46 |
+
.replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ').replace(':', ' :') \
|
47 |
+
.strip().lower().split('. ')
|
48 |
+
sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+()\[\]{}]', '', t.replace('"', '').replace('/', '')
|
49 |
+
.replace('\\', '').replace("'", '').strip().lower())
|
50 |
+
tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
|
51 |
+
report = ' . '.join(tokens) + ' .'
|
52 |
+
# report = ' '.join(report.split()[:self.args.max_txt_len])
|
53 |
+
return report
|
54 |
+
|
55 |
+
|
56 |
+
def parse(self, features):
|
57 |
+
to_return = {'id': features['id']}
|
58 |
+
report = features.get("report", "")
|
59 |
+
report = self.clean_report(report)
|
60 |
+
to_return['input_text'] = report
|
61 |
+
# chest x-ray images
|
62 |
+
images = []
|
63 |
+
for image_path in features['image_path']:
|
64 |
+
with Image.open(os.path.join(self.args.base_dir, image_path)) as pil:
|
65 |
+
array = np.array(pil, dtype=np.uint8)
|
66 |
+
if array.shape[-1] != 3 or len(array.shape) != 3:
|
67 |
+
array = np.array(pil.convert("RGB"), dtype=np.uint8)
|
68 |
+
image = self._parse_image(array)
|
69 |
+
images.append(image)
|
70 |
+
to_return["image"] = images
|
71 |
+
return to_return
|
72 |
+
|
73 |
+
|
74 |
+
def transform_with_parse(self, inputs):
|
75 |
+
return self.parse(inputs)
|
76 |
+
|
77 |
+
|
78 |
+
class ParseDataset(data.Dataset):
|
79 |
+
def __init__(self, args, split='train'):
|
80 |
+
self.args = args
|
81 |
+
self.meta = json.load(open(args.annotation, 'r'))
|
82 |
+
self.meta = self.meta[split]
|
83 |
+
self.parser = FieldParser(args)
|
84 |
+
|
85 |
+
def __len__(self):
|
86 |
+
return len(self.meta)
|
87 |
+
|
88 |
+
def __getitem__(self, index):
|
89 |
+
return self.parser.transform_with_parse(self.meta[index])
|
90 |
+
|
91 |
+
|
92 |
+
def create_datasets(args):
|
93 |
+
train_dataset = ParseDataset(args, 'train')
|
94 |
+
dev_dataset = ParseDataset(args, 'val')
|
95 |
+
test_dataset = ParseDataset(args, 'test')
|
96 |
+
return train_dataset, dev_dataset, test_dataset
|
97 |
+
|
98 |
+
|
dataset/data_module.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning.pytorch import LightningDataModule
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from dataset.data_helper import create_datasets
|
4 |
+
|
5 |
+
|
6 |
+
|
7 |
+
class DataModule(LightningDataModule):
|
8 |
+
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
args
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
self.args = args
|
15 |
+
|
16 |
+
def prepare_data(self):
|
17 |
+
"""
|
18 |
+
Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
|
19 |
+
|
20 |
+
download
|
21 |
+
|
22 |
+
tokenize
|
23 |
+
|
24 |
+
etc…
|
25 |
+
:return:
|
26 |
+
"""
|
27 |
+
|
28 |
+
def setup(self, stage: str):
|
29 |
+
"""
|
30 |
+
There are also data operations you might want to perform on every GPU. Use setup to do things like:
|
31 |
+
|
32 |
+
count number of classes
|
33 |
+
|
34 |
+
build vocabulary
|
35 |
+
|
36 |
+
perform train/val/test splits
|
37 |
+
|
38 |
+
apply transforms (defined explicitly in your datamodule or assigned in init)
|
39 |
+
|
40 |
+
etc…
|
41 |
+
:param stage:
|
42 |
+
:return:
|
43 |
+
"""
|
44 |
+
train_dataset, dev_dataset, test_dataset = create_datasets(self.args)
|
45 |
+
self.dataset = {
|
46 |
+
"train": train_dataset, "validation": dev_dataset, "test": test_dataset
|
47 |
+
}
|
48 |
+
|
49 |
+
|
50 |
+
def train_dataloader(self):
|
51 |
+
"""
|
52 |
+
Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup.
|
53 |
+
:return:
|
54 |
+
"""
|
55 |
+
loader = DataLoader(self.dataset["train"], batch_size=self.args.batch_size, drop_last=True, pin_memory=True,
|
56 |
+
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
|
57 |
+
return loader
|
58 |
+
|
59 |
+
|
60 |
+
def val_dataloader(self):
|
61 |
+
"""
|
62 |
+
Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup.
|
63 |
+
:return:
|
64 |
+
"""
|
65 |
+
loader = DataLoader(self.dataset["validation"], batch_size=self.args.val_batch_size, drop_last=False, pin_memory=True,
|
66 |
+
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
|
67 |
+
return loader
|
68 |
+
|
69 |
+
|
70 |
+
def test_dataloader(self):
|
71 |
+
loader = DataLoader(self.dataset["test"], batch_size=self.args.test_batch_size, drop_last=False, pin_memory=False,
|
72 |
+
num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
|
73 |
+
return loader
|
evalcap/__init__.py
ADDED
File without changes
|
evalcap/bleu/LICENSE
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
|
2 |
+
|
3 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
4 |
+
of this software and associated documentation files (the "Software"), to deal
|
5 |
+
in the Software without restriction, including without limitation the rights
|
6 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
7 |
+
copies of the Software, and to permit persons to whom the Software is
|
8 |
+
furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
The above copyright notice and this permission notice shall be included in
|
11 |
+
all copies or substantial portions of the Software.
|
12 |
+
|
13 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
14 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
15 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
16 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
17 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
18 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
19 |
+
THE SOFTWARE.
|
20 |
+
|
21 |
+
python2 转 python3
|
22 |
+
python2 dict, iteritems()
|
23 |
+
python3 dict, items()
|
evalcap/bleu/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
evalcap/bleu/bleu.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : bleu.py
|
4 |
+
#
|
5 |
+
# Description : Wrapper for BLEU scorer.
|
6 |
+
#
|
7 |
+
# Creation Date : 06-01-2015
|
8 |
+
# Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
|
9 |
+
# Authors : Hao Fang <[email protected]> and Tsung-Yi Lin <[email protected]>
|
10 |
+
import os
|
11 |
+
dir_path = os.path.dirname(os.path.abspath(__file__))
|
12 |
+
import sys
|
13 |
+
sys.path.append(dir_path)
|
14 |
+
from bleu_scorer import BleuScorer
|
15 |
+
|
16 |
+
|
17 |
+
class Bleu:
|
18 |
+
def __init__(self, n=4):
|
19 |
+
# default compute Blue score up to 4
|
20 |
+
self._n = n
|
21 |
+
self._hypo_for_image = {}
|
22 |
+
self.ref_for_image = {}
|
23 |
+
|
24 |
+
def compute_score(self, gts, res, verbose=0):
|
25 |
+
|
26 |
+
assert(gts.keys() == res.keys())
|
27 |
+
imgIds = gts.keys()
|
28 |
+
|
29 |
+
bleu_scorer = BleuScorer(n=self._n)
|
30 |
+
for id in imgIds:
|
31 |
+
hypo = res[id]
|
32 |
+
ref = gts[id]
|
33 |
+
|
34 |
+
# Sanity check.
|
35 |
+
assert(type(hypo) is list)
|
36 |
+
assert(len(hypo) == 1)
|
37 |
+
assert(type(ref) is list)
|
38 |
+
assert(len(ref) >= 1)
|
39 |
+
|
40 |
+
bleu_scorer += (hypo[0], ref)
|
41 |
+
|
42 |
+
#score, scores = bleu_scorer.compute_score(option='shortest')
|
43 |
+
score, scores = bleu_scorer.compute_score(option='closest', verbose=verbose)
|
44 |
+
# score, scores = bleu_scorer.compute_score(option='average', verbose=1)
|
45 |
+
|
46 |
+
# return (bleu, bleu_info)
|
47 |
+
return score, scores
|
48 |
+
|
49 |
+
def method(self):
|
50 |
+
return "Bleu"
|
evalcap/bleu/bleu_scorer.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# bleu_scorer.py
|
4 |
+
# David Chiang <[email protected]>
|
5 |
+
|
6 |
+
# Copyright (c) 2004-2006 University of Maryland. All rights
|
7 |
+
# reserved. Do not redistribute without permission from the
|
8 |
+
# author. Not for commercial use.
|
9 |
+
|
10 |
+
# Modified by:
|
11 |
+
# Hao Fang <[email protected]>
|
12 |
+
# Tsung-Yi Lin <[email protected]>
|
13 |
+
|
14 |
+
'''Provides:
|
15 |
+
cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
|
16 |
+
cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
|
17 |
+
'''
|
18 |
+
|
19 |
+
import copy
|
20 |
+
import sys, math, re
|
21 |
+
from collections import defaultdict
|
22 |
+
|
23 |
+
def precook(s, n=4, out=False):
|
24 |
+
"""Takes a string as input and returns an object that can be given to
|
25 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
26 |
+
can take string arguments as well."""
|
27 |
+
words = s.split()
|
28 |
+
counts = defaultdict(int)
|
29 |
+
for k in range(1,n+1):
|
30 |
+
for i in range(len(words)-k+1):
|
31 |
+
ngram = tuple(words[i:i+k])
|
32 |
+
counts[ngram] += 1
|
33 |
+
return (len(words), counts)
|
34 |
+
|
35 |
+
def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
|
36 |
+
'''Takes a list of reference sentences for a single segment
|
37 |
+
and returns an object that encapsulates everything that BLEU
|
38 |
+
needs to know about them.'''
|
39 |
+
|
40 |
+
reflen = []
|
41 |
+
maxcounts = {}
|
42 |
+
for ref in refs:
|
43 |
+
rl, counts = precook(ref, n)
|
44 |
+
reflen.append(rl)
|
45 |
+
for (ngram,count) in counts.items():
|
46 |
+
maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
47 |
+
|
48 |
+
# Calculate effective reference sentence length.
|
49 |
+
if eff == "shortest":
|
50 |
+
reflen = min(reflen)
|
51 |
+
elif eff == "average":
|
52 |
+
reflen = float(sum(reflen))/len(reflen)
|
53 |
+
|
54 |
+
## lhuang: N.B.: leave reflen computaiton to the very end!!
|
55 |
+
|
56 |
+
## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
|
57 |
+
|
58 |
+
return (reflen, maxcounts)
|
59 |
+
|
60 |
+
def cook_test(test, crefs, eff=None, n=4):
|
61 |
+
'''Takes a test sentence and returns an object that
|
62 |
+
encapsulates everything that BLEU needs to know about it.'''
|
63 |
+
reflen, refmaxcounts = crefs[0], crefs[1]
|
64 |
+
|
65 |
+
testlen, counts = precook(test, n, True)
|
66 |
+
|
67 |
+
result = {}
|
68 |
+
|
69 |
+
# Calculate effective reference sentence length.
|
70 |
+
|
71 |
+
if eff == "closest":
|
72 |
+
result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
|
73 |
+
else: ## i.e., "average" or "shortest" or None
|
74 |
+
result["reflen"] = reflen
|
75 |
+
|
76 |
+
result["testlen"] = testlen
|
77 |
+
|
78 |
+
result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
|
79 |
+
|
80 |
+
result['correct'] = [0]*n
|
81 |
+
for (ngram, count) in counts.items():
|
82 |
+
result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
|
83 |
+
|
84 |
+
return result
|
85 |
+
|
86 |
+
class BleuScorer(object):
|
87 |
+
"""Bleu scorer.
|
88 |
+
"""
|
89 |
+
|
90 |
+
__slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
|
91 |
+
# special_reflen is used in oracle (proportional effective ref len for a node).
|
92 |
+
|
93 |
+
def copy(self):
|
94 |
+
''' copy the refs.'''
|
95 |
+
new = BleuScorer(n=self.n)
|
96 |
+
new.ctest = copy.copy(self.ctest)
|
97 |
+
new.crefs = copy.copy(self.crefs)
|
98 |
+
new._score = None
|
99 |
+
return new
|
100 |
+
|
101 |
+
def __init__(self, test=None, refs=None, n=4, special_reflen=None):
|
102 |
+
''' singular instance '''
|
103 |
+
|
104 |
+
self.n = n
|
105 |
+
self.crefs = []
|
106 |
+
self.ctest = []
|
107 |
+
self.cook_append(test, refs)
|
108 |
+
self.special_reflen = special_reflen
|
109 |
+
|
110 |
+
def cook_append(self, test, refs):
|
111 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
112 |
+
|
113 |
+
if refs is not None:
|
114 |
+
self.crefs.append(cook_refs(refs))
|
115 |
+
if test is not None:
|
116 |
+
cooked_test = cook_test(test, self.crefs[-1])
|
117 |
+
self.ctest.append(cooked_test) ## N.B.: -1
|
118 |
+
else:
|
119 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
120 |
+
|
121 |
+
self._score = None ## need to recompute
|
122 |
+
|
123 |
+
def ratio(self, option=None):
|
124 |
+
self.compute_score(option=option)
|
125 |
+
return self._ratio
|
126 |
+
|
127 |
+
def score_ratio(self, option=None):
|
128 |
+
'''return (bleu, len_ratio) pair'''
|
129 |
+
return (self.fscore(option=option), self.ratio(option=option))
|
130 |
+
|
131 |
+
def score_ratio_str(self, option=None):
|
132 |
+
return "%.4f (%.2f)" % self.score_ratio(option)
|
133 |
+
|
134 |
+
def reflen(self, option=None):
|
135 |
+
self.compute_score(option=option)
|
136 |
+
return self._reflen
|
137 |
+
|
138 |
+
def testlen(self, option=None):
|
139 |
+
self.compute_score(option=option)
|
140 |
+
return self._testlen
|
141 |
+
|
142 |
+
def retest(self, new_test):
|
143 |
+
if type(new_test) is str:
|
144 |
+
new_test = [new_test]
|
145 |
+
assert len(new_test) == len(self.crefs), new_test
|
146 |
+
self.ctest = []
|
147 |
+
for t, rs in zip(new_test, self.crefs):
|
148 |
+
self.ctest.append(cook_test(t, rs))
|
149 |
+
self._score = None
|
150 |
+
|
151 |
+
return self
|
152 |
+
|
153 |
+
def rescore(self, new_test):
|
154 |
+
''' replace test(s) with new test(s), and returns the new score.'''
|
155 |
+
|
156 |
+
return self.retest(new_test).compute_score()
|
157 |
+
|
158 |
+
def size(self):
|
159 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
160 |
+
return len(self.crefs)
|
161 |
+
|
162 |
+
def __iadd__(self, other):
|
163 |
+
'''add an instance (e.g., from another sentence).'''
|
164 |
+
|
165 |
+
if type(other) is tuple:
|
166 |
+
## avoid creating new BleuScorer instances
|
167 |
+
self.cook_append(other[0], other[1])
|
168 |
+
else:
|
169 |
+
assert self.compatible(other), "incompatible BLEUs."
|
170 |
+
self.ctest.extend(other.ctest)
|
171 |
+
self.crefs.extend(other.crefs)
|
172 |
+
self._score = None ## need to recompute
|
173 |
+
|
174 |
+
return self
|
175 |
+
|
176 |
+
def compatible(self, other):
|
177 |
+
return isinstance(other, BleuScorer) and self.n == other.n
|
178 |
+
|
179 |
+
def single_reflen(self, option="average"):
|
180 |
+
return self._single_reflen(self.crefs[0][0], option)
|
181 |
+
|
182 |
+
def _single_reflen(self, reflens, option=None, testlen=None):
|
183 |
+
|
184 |
+
if option == "shortest":
|
185 |
+
reflen = min(reflens)
|
186 |
+
elif option == "average":
|
187 |
+
reflen = float(sum(reflens))/len(reflens)
|
188 |
+
elif option == "closest":
|
189 |
+
reflen = min((abs(l-testlen), l) for l in reflens)[1]
|
190 |
+
else:
|
191 |
+
assert False, "unsupported reflen option %s" % option
|
192 |
+
|
193 |
+
return reflen
|
194 |
+
|
195 |
+
def recompute_score(self, option=None, verbose=0):
|
196 |
+
self._score = None
|
197 |
+
return self.compute_score(option, verbose)
|
198 |
+
|
199 |
+
def compute_score(self, option=None, verbose=0):
|
200 |
+
n = self.n
|
201 |
+
small = 1e-9
|
202 |
+
tiny = 1e-15 ## so that if guess is 0 still return 0
|
203 |
+
bleu_list = [[] for _ in range(n)]
|
204 |
+
|
205 |
+
if self._score is not None:
|
206 |
+
return self._score
|
207 |
+
|
208 |
+
if option is None:
|
209 |
+
option = "average" if len(self.crefs) == 1 else "closest"
|
210 |
+
|
211 |
+
self._testlen = 0
|
212 |
+
self._reflen = 0
|
213 |
+
totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
|
214 |
+
|
215 |
+
# for each sentence
|
216 |
+
for comps in self.ctest:
|
217 |
+
testlen = comps['testlen']
|
218 |
+
self._testlen += testlen
|
219 |
+
|
220 |
+
if self.special_reflen is None: ## need computation
|
221 |
+
reflen = self._single_reflen(comps['reflen'], option, testlen)
|
222 |
+
else:
|
223 |
+
reflen = self.special_reflen
|
224 |
+
|
225 |
+
self._reflen += reflen
|
226 |
+
|
227 |
+
for key in ['guess','correct']:
|
228 |
+
for k in range(n):
|
229 |
+
totalcomps[key][k] += comps[key][k]
|
230 |
+
|
231 |
+
# append per image bleu score
|
232 |
+
bleu = 1.
|
233 |
+
for k in range(n):
|
234 |
+
bleu *= (float(comps['correct'][k]) + tiny) \
|
235 |
+
/(float(comps['guess'][k]) + small)
|
236 |
+
bleu_list[k].append(bleu ** (1./(k+1)))
|
237 |
+
ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
|
238 |
+
if ratio < 1:
|
239 |
+
for k in range(n):
|
240 |
+
bleu_list[k][-1] *= math.exp(1 - 1/ratio)
|
241 |
+
|
242 |
+
if verbose > 1:
|
243 |
+
print(comps, reflen)
|
244 |
+
|
245 |
+
totalcomps['reflen'] = self._reflen
|
246 |
+
totalcomps['testlen'] = self._testlen
|
247 |
+
|
248 |
+
bleus = []
|
249 |
+
bleu = 1.
|
250 |
+
for k in range(n):
|
251 |
+
bleu *= float(totalcomps['correct'][k] + tiny) \
|
252 |
+
/ (totalcomps['guess'][k] + small)
|
253 |
+
bleus.append(bleu ** (1./(k+1)))
|
254 |
+
ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
|
255 |
+
if ratio < 1:
|
256 |
+
for k in range(n):
|
257 |
+
bleus[k] *= math.exp(1 - 1/ratio)
|
258 |
+
|
259 |
+
if verbose > 0:
|
260 |
+
print(totalcomps)
|
261 |
+
print("ratio:", ratio)
|
262 |
+
|
263 |
+
self._score = bleus
|
264 |
+
return self._score, bleu_list
|
evalcap/cider/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
evalcap/cider/cider.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Filename: cider.py
|
2 |
+
#
|
3 |
+
# Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
|
4 |
+
# by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
|
5 |
+
#
|
6 |
+
# Creation Date: Sun Feb 8 14:16:54 2015
|
7 |
+
#
|
8 |
+
# Authors: Ramakrishna Vedantam <[email protected]> and Tsung-Yi Lin <[email protected]>
|
9 |
+
import os
|
10 |
+
dir_path = os.path.dirname(os.path.abspath(__file__))
|
11 |
+
import sys
|
12 |
+
sys.path.append(dir_path)
|
13 |
+
from cider_scorer import CiderScorer
|
14 |
+
import pdb
|
15 |
+
|
16 |
+
class Cider:
|
17 |
+
"""
|
18 |
+
Main Class to compute the CIDEr metric
|
19 |
+
|
20 |
+
"""
|
21 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0):
|
22 |
+
# set cider to sum over 1 to 4-grams
|
23 |
+
self._n = n
|
24 |
+
# set the standard deviation parameter for gaussian penalty
|
25 |
+
self._sigma = sigma
|
26 |
+
|
27 |
+
def compute_score(self, gts, res):
|
28 |
+
"""
|
29 |
+
Main function to compute CIDEr score
|
30 |
+
:param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
|
31 |
+
ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
|
32 |
+
:return: cider (float) : computed CIDEr score for the corpus
|
33 |
+
"""
|
34 |
+
|
35 |
+
assert(gts.keys() == res.keys())
|
36 |
+
imgIds = gts.keys()
|
37 |
+
|
38 |
+
cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
|
39 |
+
|
40 |
+
for id in imgIds:
|
41 |
+
hypo = res[id]
|
42 |
+
ref = gts[id]
|
43 |
+
|
44 |
+
# Sanity check.
|
45 |
+
assert(type(hypo) is list)
|
46 |
+
assert(len(hypo) == 1)
|
47 |
+
assert(type(ref) is list)
|
48 |
+
assert(len(ref) > 0)
|
49 |
+
|
50 |
+
cider_scorer += (hypo[0], ref)
|
51 |
+
|
52 |
+
(score, scores) = cider_scorer.compute_score()
|
53 |
+
|
54 |
+
return score, scores
|
55 |
+
|
56 |
+
def method(self):
|
57 |
+
return "CIDEr"
|
evalcap/cider/cider_scorer.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# Tsung-Yi Lin <[email protected]>
|
3 |
+
# Ramakrishna Vedantam <[email protected]>
|
4 |
+
|
5 |
+
import copy
|
6 |
+
from collections import defaultdict
|
7 |
+
import numpy as np
|
8 |
+
import pdb
|
9 |
+
import math
|
10 |
+
|
11 |
+
def precook(s, n=4, out=False):
|
12 |
+
"""
|
13 |
+
Takes a string as input and returns an object that can be given to
|
14 |
+
either cook_refs or cook_test. This is optional: cook_refs and cook_test
|
15 |
+
can take string arguments as well.
|
16 |
+
:param s: string : sentence to be converted into ngrams
|
17 |
+
:param n: int : number of ngrams for which representation is calculated
|
18 |
+
:return: term frequency vector for occuring ngrams
|
19 |
+
"""
|
20 |
+
words = s.split()
|
21 |
+
counts = defaultdict(int)
|
22 |
+
for k in range(1,n+1):
|
23 |
+
for i in range(len(words)-k+1):
|
24 |
+
ngram = tuple(words[i:i+k])
|
25 |
+
counts[ngram] += 1
|
26 |
+
return counts
|
27 |
+
|
28 |
+
def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
|
29 |
+
'''Takes a list of reference sentences for a single segment
|
30 |
+
and returns an object that encapsulates everything that BLEU
|
31 |
+
needs to know about them.
|
32 |
+
:param refs: list of string : reference sentences for some image
|
33 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
34 |
+
:return: result (list of dict)
|
35 |
+
'''
|
36 |
+
return [precook(ref, n) for ref in refs]
|
37 |
+
|
38 |
+
def cook_test(test, n=4):
|
39 |
+
'''Takes a test sentence and returns an object that
|
40 |
+
encapsulates everything that BLEU needs to know about it.
|
41 |
+
:param test: list of string : hypothesis sentence for some image
|
42 |
+
:param n: int : number of ngrams for which (ngram) representation is calculated
|
43 |
+
:return: result (dict)
|
44 |
+
'''
|
45 |
+
return precook(test, n, True)
|
46 |
+
|
47 |
+
class CiderScorer(object):
|
48 |
+
"""CIDEr scorer.
|
49 |
+
"""
|
50 |
+
|
51 |
+
def copy(self):
|
52 |
+
''' copy the refs.'''
|
53 |
+
new = CiderScorer(n=self.n)
|
54 |
+
new.ctest = copy.copy(self.ctest)
|
55 |
+
new.crefs = copy.copy(self.crefs)
|
56 |
+
return new
|
57 |
+
|
58 |
+
def __init__(self, test=None, refs=None, n=4, sigma=6.0):
|
59 |
+
''' singular instance '''
|
60 |
+
self.n = n
|
61 |
+
self.sigma = sigma
|
62 |
+
self.crefs = []
|
63 |
+
self.ctest = []
|
64 |
+
self.document_frequency = defaultdict(float)
|
65 |
+
self.cook_append(test, refs)
|
66 |
+
self.ref_len = None
|
67 |
+
|
68 |
+
def cook_append(self, test, refs):
|
69 |
+
'''called by constructor and __iadd__ to avoid creating new instances.'''
|
70 |
+
|
71 |
+
if refs is not None:
|
72 |
+
self.crefs.append(cook_refs(refs))
|
73 |
+
if test is not None:
|
74 |
+
self.ctest.append(cook_test(test)) ## N.B.: -1
|
75 |
+
else:
|
76 |
+
self.ctest.append(None) # lens of crefs and ctest have to match
|
77 |
+
|
78 |
+
def size(self):
|
79 |
+
assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
|
80 |
+
return len(self.crefs)
|
81 |
+
|
82 |
+
def __iadd__(self, other):
|
83 |
+
'''add an instance (e.g., from another sentence).'''
|
84 |
+
|
85 |
+
if type(other) is tuple:
|
86 |
+
## avoid creating new CiderScorer instances
|
87 |
+
self.cook_append(other[0], other[1])
|
88 |
+
else:
|
89 |
+
self.ctest.extend(other.ctest)
|
90 |
+
self.crefs.extend(other.crefs)
|
91 |
+
|
92 |
+
return self
|
93 |
+
def compute_doc_freq(self):
|
94 |
+
'''
|
95 |
+
Compute term frequency for reference data.
|
96 |
+
This will be used to compute idf (inverse document frequency later)
|
97 |
+
The term frequency is stored in the object
|
98 |
+
:return: None
|
99 |
+
'''
|
100 |
+
for refs in self.crefs:
|
101 |
+
# refs, k ref captions of one image
|
102 |
+
for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
|
103 |
+
self.document_frequency[ngram] += 1
|
104 |
+
# maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
|
105 |
+
|
106 |
+
def compute_cider(self):
|
107 |
+
def counts2vec(cnts):
|
108 |
+
"""
|
109 |
+
Function maps counts of ngram to vector of tfidf weights.
|
110 |
+
The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
|
111 |
+
The n-th entry of array denotes length of n-grams.
|
112 |
+
:param cnts:
|
113 |
+
:return: vec (array of dict), norm (array of float), length (int)
|
114 |
+
"""
|
115 |
+
vec = [defaultdict(float) for _ in range(self.n)]
|
116 |
+
length = 0
|
117 |
+
norm = [0.0 for _ in range(self.n)]
|
118 |
+
for (ngram, term_freq) in cnts.items():
|
119 |
+
# give word count 1 if it doesn't appear in reference corpus
|
120 |
+
df = np.log(max(1.0, self.document_frequency[ngram]))
|
121 |
+
# ngram index
|
122 |
+
n = len(ngram)-1
|
123 |
+
# tf (term_freq) * idf (precomputed idf) for n-grams
|
124 |
+
vec[n][ngram] = float(term_freq)*(self.ref_len - df)
|
125 |
+
# compute norm for the vector. the norm will be used for computing similarity
|
126 |
+
norm[n] += pow(vec[n][ngram], 2)
|
127 |
+
|
128 |
+
if n == 1:
|
129 |
+
length += term_freq
|
130 |
+
norm = [np.sqrt(n) for n in norm]
|
131 |
+
return vec, norm, length
|
132 |
+
|
133 |
+
def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
|
134 |
+
'''
|
135 |
+
Compute the cosine similarity of two vectors.
|
136 |
+
:param vec_hyp: array of dictionary for vector corresponding to hypothesis
|
137 |
+
:param vec_ref: array of dictionary for vector corresponding to reference
|
138 |
+
:param norm_hyp: array of float for vector corresponding to hypothesis
|
139 |
+
:param norm_ref: array of float for vector corresponding to reference
|
140 |
+
:param length_hyp: int containing length of hypothesis
|
141 |
+
:param length_ref: int containing length of reference
|
142 |
+
:return: array of score for each n-grams cosine similarity
|
143 |
+
'''
|
144 |
+
delta = float(length_hyp - length_ref)
|
145 |
+
# measure consine similarity
|
146 |
+
val = np.array([0.0 for _ in range(self.n)])
|
147 |
+
for n in range(self.n):
|
148 |
+
# ngram
|
149 |
+
for (ngram,count) in vec_hyp[n].items():
|
150 |
+
# vrama91 : added clipping
|
151 |
+
val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
|
152 |
+
|
153 |
+
if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
|
154 |
+
val[n] /= (norm_hyp[n]*norm_ref[n])
|
155 |
+
|
156 |
+
assert(not math.isnan(val[n]))
|
157 |
+
# vrama91: added a length based gaussian penalty
|
158 |
+
val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
|
159 |
+
return val
|
160 |
+
|
161 |
+
# compute log reference length
|
162 |
+
self.ref_len = np.log(float(len(self.crefs)))
|
163 |
+
if len(self.crefs) == 1:
|
164 |
+
self.ref_len = 1
|
165 |
+
scores = []
|
166 |
+
for test, refs in zip(self.ctest, self.crefs):
|
167 |
+
# compute vector for test captions
|
168 |
+
vec, norm, length = counts2vec(test)
|
169 |
+
# compute vector for ref captions
|
170 |
+
score = np.array([0.0 for _ in range(self.n)])
|
171 |
+
for ref in refs:
|
172 |
+
vec_ref, norm_ref, length_ref = counts2vec(ref)
|
173 |
+
score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
|
174 |
+
# change by vrama91 - mean of ngram scores, instead of sum
|
175 |
+
score_avg = np.mean(score)
|
176 |
+
# divide by number of references
|
177 |
+
score_avg /= len(refs)
|
178 |
+
# multiply score by 10
|
179 |
+
score_avg *= 10.0
|
180 |
+
# append score of an image to the score list
|
181 |
+
scores.append(score_avg)
|
182 |
+
return scores
|
183 |
+
|
184 |
+
def compute_score(self, option=None, verbose=0):
|
185 |
+
# compute idf
|
186 |
+
self.compute_doc_freq()
|
187 |
+
# assert to check document frequency
|
188 |
+
assert(len(self.ctest) >= max(self.document_frequency.values()))
|
189 |
+
# compute cider score
|
190 |
+
score = self.compute_cider()
|
191 |
+
# debug
|
192 |
+
# print score
|
193 |
+
return np.mean(np.array(score)), np.array(score)
|
evalcap/meteor/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'tylin'
|
evalcap/meteor/meteor-1.5.jar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1e57b4c72c0830ebe68558f1c799a624e96cbc1b6045c9f6330e26dcff6eafc2
|
3 |
+
size 6318693
|
evalcap/meteor/meteor.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# Python wrapper for METEOR implementation, by Xinlei Chen
|
4 |
+
# Acknowledge Michael Denkowski for the generous discussion and help
|
5 |
+
from __future__ import division
|
6 |
+
|
7 |
+
import atexit
|
8 |
+
import logging
|
9 |
+
import os
|
10 |
+
import re
|
11 |
+
import subprocess
|
12 |
+
import sys
|
13 |
+
import threading
|
14 |
+
|
15 |
+
import psutil
|
16 |
+
|
17 |
+
# Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
|
18 |
+
METEOR_JAR = 'meteor-1.5.jar'
|
19 |
+
|
20 |
+
|
21 |
+
def enc(s):
|
22 |
+
return s.encode('utf-8')
|
23 |
+
|
24 |
+
|
25 |
+
def dec(s):
|
26 |
+
return s.decode('utf-8')
|
27 |
+
|
28 |
+
|
29 |
+
class Meteor:
|
30 |
+
|
31 |
+
def __init__(self):
|
32 |
+
# Used to guarantee thread safety
|
33 |
+
self.lock = threading.Lock()
|
34 |
+
|
35 |
+
mem = '1G'
|
36 |
+
mem_available_G = psutil.virtual_memory().available / 1E9
|
37 |
+
if mem_available_G < 2:
|
38 |
+
logging.warning("There is less than 2GB of available memory.\n"
|
39 |
+
"Will try with limiting Meteor to 1GB of memory but this might cause issues.\n"
|
40 |
+
"If you have problems using Meteor, "
|
41 |
+
"then you can try to lower the `mem` variable in meteor.py")
|
42 |
+
mem = '1G'
|
43 |
+
|
44 |
+
meteor_cmd = ['java', '-jar', '-Xmx{}'.format(mem), METEOR_JAR,
|
45 |
+
'-', '-', '-stdio', '-l', 'en', '-norm']
|
46 |
+
env = os.environ.copy()
|
47 |
+
env['LC_ALL'] = "C"
|
48 |
+
self.meteor_p = subprocess.Popen(meteor_cmd,
|
49 |
+
cwd=os.path.dirname(os.path.abspath(__file__)),
|
50 |
+
env=env,
|
51 |
+
stdin=subprocess.PIPE,
|
52 |
+
stdout=subprocess.PIPE,
|
53 |
+
stderr=subprocess.PIPE)
|
54 |
+
|
55 |
+
atexit.register(self.close)
|
56 |
+
|
57 |
+
def close(self):
|
58 |
+
with self.lock:
|
59 |
+
if self.meteor_p:
|
60 |
+
self.meteor_p.kill()
|
61 |
+
self.meteor_p.wait()
|
62 |
+
self.meteor_p = None
|
63 |
+
# if the user calls close() manually, remove the
|
64 |
+
# reference from atexit so the object can be garbage-collected.
|
65 |
+
if atexit is not None and atexit.unregister is not None:
|
66 |
+
atexit.unregister(self.close)
|
67 |
+
|
68 |
+
def compute_score(self, gts, res):
|
69 |
+
assert (gts.keys() == res.keys())
|
70 |
+
imgIds = gts.keys()
|
71 |
+
scores = []
|
72 |
+
|
73 |
+
eval_line = 'EVAL'
|
74 |
+
with self.lock:
|
75 |
+
for i in imgIds:
|
76 |
+
assert (len(res[i]) == 1)
|
77 |
+
stat = self._stat(res[i][0], gts[i])
|
78 |
+
eval_line += ' ||| {}'.format(stat)
|
79 |
+
|
80 |
+
self.meteor_p.stdin.write(enc('{}\n'.format(eval_line)))
|
81 |
+
self.meteor_p.stdin.flush()
|
82 |
+
for i in range(0, len(imgIds)):
|
83 |
+
v = self.meteor_p.stdout.readline()
|
84 |
+
try:
|
85 |
+
scores.append(float(dec(v.strip())))
|
86 |
+
except:
|
87 |
+
sys.stderr.write("Error handling value: {}\n".format(v))
|
88 |
+
sys.stderr.write("Decoded value: {}\n".format(dec(v.strip())))
|
89 |
+
sys.stderr.write("eval_line: {}\n".format(eval_line))
|
90 |
+
# You can try uncommenting the next code line to show stderr from the Meteor JAR.
|
91 |
+
# If the Meteor JAR is not writing to stderr, then the line will just hang.
|
92 |
+
# sys.stderr.write("Error from Meteor:\n{}".format(self.meteor_p.stderr.read()))
|
93 |
+
raise
|
94 |
+
score = float(dec(self.meteor_p.stdout.readline()).strip())
|
95 |
+
self.close()
|
96 |
+
return score, scores
|
97 |
+
|
98 |
+
def method(self):
|
99 |
+
return "METEOR"
|
100 |
+
|
101 |
+
def _stat(self, hypothesis_str, reference_list):
|
102 |
+
# SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
|
103 |
+
hypothesis_str = hypothesis_str.replace('|||', '')
|
104 |
+
score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
|
105 |
+
score_line = re.sub(r'\s+', ' ', score_line)
|
106 |
+
self.meteor_p.stdin.write(enc(score_line))
|
107 |
+
self.meteor_p.stdin.write(enc('\n'))
|
108 |
+
self.meteor_p.stdin.flush()
|
109 |
+
return dec(self.meteor_p.stdout.readline()).strip()
|
110 |
+
|
111 |
+
def _score(self, hypothesis_str, reference_list):
|
112 |
+
with self.lock:
|
113 |
+
# SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
|
114 |
+
hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
|
115 |
+
score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
|
116 |
+
self.meteor_p.stdin.write(enc('{}\n'.format(score_line)))
|
117 |
+
self.meteor_p.stdin.flush()
|
118 |
+
stats = dec(self.meteor_p.stdout.readline()).strip()
|
119 |
+
eval_line = 'EVAL ||| {}'.format(stats)
|
120 |
+
# EVAL ||| stats
|
121 |
+
self.meteor_p.stdin.write(enc('{}\n'.format(eval_line)))
|
122 |
+
self.meteor_p.stdin.flush()
|
123 |
+
score = float(dec(self.meteor_p.stdout.readline()).strip())
|
124 |
+
# bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
|
125 |
+
# thanks for Andrej for pointing this out
|
126 |
+
score = float(dec(self.meteor_p.stdout.readline()).strip())
|
127 |
+
return score
|
128 |
+
|
129 |
+
def __del__(self):
|
130 |
+
self.close()
|
evalcap/meteor/test_meteor.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import meteor
|
2 |
+
|
3 |
+
hypo = ['this is the model generated sentence1 which seems good enough']
|
4 |
+
ref = ['this is one reference sentence for sentence1',
|
5 |
+
'this is a reference sentence for sentence2 which was generated by your model']
|
6 |
+
|
7 |
+
m = meteor.Meteor()
|
8 |
+
|
9 |
+
score = m._score(hypo[0], ref)
|
10 |
+
print(score)
|
evalcap/rouge/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'vrama91'
|
evalcap/rouge/rouge.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : rouge.py
|
4 |
+
#
|
5 |
+
# Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
|
6 |
+
#
|
7 |
+
# Creation Date : 2015-01-07 06:03
|
8 |
+
# Author : Ramakrishna Vedantam <[email protected]>
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pdb
|
12 |
+
|
13 |
+
def my_lcs(string, sub):
|
14 |
+
"""
|
15 |
+
Calculates longest common subsequence for a pair of tokenized strings
|
16 |
+
:param string : list of str : tokens from a string split using whitespace
|
17 |
+
:param sub : list of str : shorter string, also split using whitespace
|
18 |
+
:returns: length (list of int): length of the longest common subsequence between the two strings
|
19 |
+
|
20 |
+
Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
|
21 |
+
"""
|
22 |
+
if(len(string)< len(sub)):
|
23 |
+
sub, string = string, sub
|
24 |
+
|
25 |
+
lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
|
26 |
+
|
27 |
+
for j in range(1,len(sub)+1):
|
28 |
+
for i in range(1,len(string)+1):
|
29 |
+
if(string[i-1] == sub[j-1]):
|
30 |
+
lengths[i][j] = lengths[i-1][j-1] + 1
|
31 |
+
else:
|
32 |
+
lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
|
33 |
+
|
34 |
+
return lengths[len(string)][len(sub)]
|
35 |
+
|
36 |
+
class Rouge():
|
37 |
+
'''
|
38 |
+
Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
|
39 |
+
|
40 |
+
'''
|
41 |
+
def __init__(self):
|
42 |
+
# vrama91: updated the value below based on discussion with Hovey
|
43 |
+
self.beta = 1.2
|
44 |
+
|
45 |
+
def calc_score(self, candidate, refs):
|
46 |
+
"""
|
47 |
+
Compute ROUGE-L score given one candidate and references for an image
|
48 |
+
:param candidate: str : candidate sentence to be evaluated
|
49 |
+
:param refs: list of str : COCO reference sentences for the particular image to be evaluated
|
50 |
+
:returns score: int (ROUGE-L score for the candidate evaluated against references)
|
51 |
+
"""
|
52 |
+
# assert(len(candidate)==1)
|
53 |
+
# assert(len(refs)>0)
|
54 |
+
prec = []
|
55 |
+
rec = []
|
56 |
+
|
57 |
+
# split into tokens
|
58 |
+
token_c = candidate[0].split(" ")
|
59 |
+
|
60 |
+
for reference in refs:
|
61 |
+
# split into tokens
|
62 |
+
token_r = reference.split(" ")
|
63 |
+
# compute the longest common subsequence
|
64 |
+
lcs = my_lcs(token_r, token_c)
|
65 |
+
prec.append(lcs/float(len(token_c)))
|
66 |
+
rec.append(lcs/float(len(token_r)))
|
67 |
+
|
68 |
+
prec_max = max(prec)
|
69 |
+
rec_max = max(rec)
|
70 |
+
|
71 |
+
if(prec_max!=0 and rec_max !=0):
|
72 |
+
score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
|
73 |
+
else:
|
74 |
+
score = 0.0
|
75 |
+
return score
|
76 |
+
|
77 |
+
def compute_score(self, gts, res):
|
78 |
+
"""
|
79 |
+
Computes Rouge-L score given a set of reference and candidate sentences for the dataset
|
80 |
+
Invoked by evaluate_captions.py
|
81 |
+
:param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
|
82 |
+
:param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
|
83 |
+
:returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
|
84 |
+
"""
|
85 |
+
assert(gts.keys() == res.keys())
|
86 |
+
imgIds = gts.keys()
|
87 |
+
|
88 |
+
score = []
|
89 |
+
for id in imgIds:
|
90 |
+
hypo = res[id]
|
91 |
+
ref = gts[id]
|
92 |
+
|
93 |
+
score.append(self.calc_score(hypo, ref))
|
94 |
+
|
95 |
+
# Sanity check.
|
96 |
+
assert(type(hypo) is list)
|
97 |
+
assert(len(hypo) == 1)
|
98 |
+
assert(type(ref) is list)
|
99 |
+
assert(len(ref) > 0)
|
100 |
+
|
101 |
+
average_score = np.mean(np.array(score))
|
102 |
+
return average_score, np.array(score)
|
103 |
+
|
104 |
+
def method(self):
|
105 |
+
return "Rouge"
|
evalcap/tokenizer/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
__author__ = 'hfang'
|
evalcap/tokenizer/ptbtokenizer.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
#
|
3 |
+
# File Name : ptbtokenizer.py
|
4 |
+
#
|
5 |
+
# Description : Do the PTB Tokenization and remove punctuations.
|
6 |
+
#
|
7 |
+
# Creation Date : 29-12-2014
|
8 |
+
# Last Modified : Thu Mar 19 09:53:35 2015
|
9 |
+
# Authors : Hao Fang <[email protected]> and Tsung-Yi Lin <[email protected]>
|
10 |
+
|
11 |
+
import os
|
12 |
+
import sys
|
13 |
+
import subprocess
|
14 |
+
import tempfile
|
15 |
+
import itertools
|
16 |
+
|
17 |
+
# path to the stanford corenlp jar
|
18 |
+
STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
|
19 |
+
|
20 |
+
# punctuations to be removed from the sentences
|
21 |
+
PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
|
22 |
+
".", "?", "!", ",", ":", "-", "--", "...", ";"]
|
23 |
+
|
24 |
+
class PTBTokenizer:
|
25 |
+
"""Python wrapper of Stanford PTBTokenizer"""
|
26 |
+
|
27 |
+
def tokenize(self, captions_for_image):
|
28 |
+
cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
|
29 |
+
'edu.stanford.nlp.process.PTBTokenizer', \
|
30 |
+
'-preserveLines', '-lowerCase']
|
31 |
+
|
32 |
+
# ======================================================
|
33 |
+
# prepare data for PTB Tokenizer
|
34 |
+
# ======================================================
|
35 |
+
final_tokenized_captions_for_image = {}
|
36 |
+
image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
|
37 |
+
sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
|
38 |
+
|
39 |
+
# ======================================================
|
40 |
+
# save sentences to temporary file
|
41 |
+
# ======================================================
|
42 |
+
path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
|
43 |
+
tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname, mode='w', encoding='utf-8')
|
44 |
+
tmp_file.write(sentences)
|
45 |
+
tmp_file.close()
|
46 |
+
|
47 |
+
# ======================================================
|
48 |
+
# tokenize sentence
|
49 |
+
# ======================================================
|
50 |
+
cmd.append(os.path.basename(tmp_file.name))
|
51 |
+
p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
|
52 |
+
stdout=subprocess.PIPE)
|
53 |
+
token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
|
54 |
+
lines = token_lines.decode().split('\n')
|
55 |
+
# remove temp file
|
56 |
+
os.remove(tmp_file.name)
|
57 |
+
|
58 |
+
# ======================================================
|
59 |
+
# create dictionary for tokenized captions
|
60 |
+
# ======================================================
|
61 |
+
for k, line in zip(image_id, lines):
|
62 |
+
if not k in final_tokenized_captions_for_image:
|
63 |
+
final_tokenized_captions_for_image[k] = []
|
64 |
+
tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
|
65 |
+
if w not in PUNCTUATIONS])
|
66 |
+
final_tokenized_captions_for_image[k].append(tokenized_caption)
|
67 |
+
|
68 |
+
return final_tokenized_captions_for_image
|
evalcap/tokenizer/stanford-corenlp-3.4.1.jar
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2fcb91bb7a111f93d71e264f4ee0e3afd19ba0dde6d21b38605088df9e940399
|
3 |
+
size 5921410
|
images/align.png
ADDED
Git LFS Details
|
lightning_tools/callbacks.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from lightning.pytorch.loggers import CSVLogger
|
3 |
+
from lightning.pytorch import loggers as pl_loggers
|
4 |
+
from lightning.pytorch.callbacks import LearningRateMonitor
|
5 |
+
from lightning.pytorch.callbacks import ModelCheckpoint
|
6 |
+
|
7 |
+
|
8 |
+
def add_callbacks(args):
|
9 |
+
log_dir = args.savedmodel_path
|
10 |
+
os.makedirs(log_dir, exist_ok=True)
|
11 |
+
|
12 |
+
# --------- Add Callbacks
|
13 |
+
checkpoint_callback = ModelCheckpoint(
|
14 |
+
dirpath=os.path.join(log_dir, "checkpoints"),
|
15 |
+
filename="{epoch}-{step}",
|
16 |
+
save_top_k=-1,
|
17 |
+
every_n_train_steps=args.every_n_train_steps,
|
18 |
+
save_last=False,
|
19 |
+
save_weights_only=False
|
20 |
+
)
|
21 |
+
|
22 |
+
lr_monitor_callback = LearningRateMonitor(logging_interval='step')
|
23 |
+
tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
|
24 |
+
csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")
|
25 |
+
|
26 |
+
to_returns = {
|
27 |
+
"callbacks": [checkpoint_callback, lr_monitor_callback],
|
28 |
+
"loggers": [csv_logger, tb_logger]
|
29 |
+
}
|
30 |
+
return to_returns
|
lightning_tools/optim.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import AdamW
|
2 |
+
import functools
|
3 |
+
from torch.optim.lr_scheduler import LambdaLR
|
4 |
+
|
5 |
+
|
6 |
+
def lr_lambda(current_step, num_warmup_steps, num_training_steps):
|
7 |
+
if current_step < num_warmup_steps:
|
8 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
9 |
+
return max(
|
10 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
|
15 |
+
"""
|
16 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
17 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
18 |
+
|
19 |
+
Args:
|
20 |
+
optimizer (:class:`~torch.optim.Optimizer`):
|
21 |
+
The optimizer for which to schedule the learning rate.
|
22 |
+
num_warmup_steps (:obj:`int`):
|
23 |
+
The number of steps for the warmup phase.
|
24 |
+
num_training_steps (:obj:`int`):
|
25 |
+
The total number of training steps.
|
26 |
+
last_epoch (:obj:`int`, `optional`, defaults to -1):
|
27 |
+
The index of the last epoch when resuming training.
|
28 |
+
|
29 |
+
Return:
|
30 |
+
:obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
31 |
+
"""
|
32 |
+
|
33 |
+
return LambdaLR(optimizer, functools.partial(lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps), last_epoch)
|
34 |
+
|
35 |
+
|
36 |
+
def config_optimizer(parameters, init_lr, warmup_steps, max_steps, name='lr'):
|
37 |
+
"""
|
38 |
+
Original Bert Optimizer do not decay for bias and layer_normal
|
39 |
+
Args:
|
40 |
+
parameters:
|
41 |
+
init_lr:
|
42 |
+
warmup_steps:
|
43 |
+
max_steps:
|
44 |
+
name:
|
45 |
+
weight_decay:
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
|
49 |
+
"""
|
50 |
+
optimizer = AdamW(
|
51 |
+
parameters, lr=init_lr, eps=1e-8, correct_bias=False
|
52 |
+
)
|
53 |
+
|
54 |
+
scheduler = get_linear_schedule_with_warmup(
|
55 |
+
optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
|
56 |
+
)
|
57 |
+
scheduler = {'scheduler': scheduler, 'name': name, 'interval': 'step', 'frequency': 1}
|
58 |
+
|
59 |
+
return optimizer, scheduler
|
models/R2GenGPT.py
ADDED
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import lightning.pytorch as pl
|
6 |
+
from transformers import LlamaForCausalLM, LlamaTokenizer
|
7 |
+
from evalcap.bleu.bleu import Bleu
|
8 |
+
from evalcap.rouge.rouge import Rouge
|
9 |
+
from evalcap.cider.cider import Cider
|
10 |
+
from evalcap.meteor.meteor import Meteor
|
11 |
+
from transformers import SwinModel
|
12 |
+
from lightning_tools.optim import config_optimizer
|
13 |
+
from peft import get_peft_model, LoraConfig, TaskType
|
14 |
+
import pdb
|
15 |
+
|
16 |
+
|
17 |
+
|
18 |
+
class R2GenGPT(pl.LightningModule):
|
19 |
+
"""
|
20 |
+
R2GenGPT model.
|
21 |
+
"""
|
22 |
+
def __init__(self, args):
|
23 |
+
super().__init__()
|
24 |
+
self.args = args
|
25 |
+
self.save_hyperparameters(args)
|
26 |
+
|
27 |
+
print(f'Loading vision encoder:{args.vision_model}')
|
28 |
+
self.visual_encoder = SwinModel.from_pretrained(args.vision_model)
|
29 |
+
if args.vis_use_lora:
|
30 |
+
peft_config_visual = LoraConfig(
|
31 |
+
r=args.vis_r,
|
32 |
+
lora_alpha=args.vis_alpha,
|
33 |
+
target_modules=["query", "value"],
|
34 |
+
lora_dropout=args.lora_dropout,
|
35 |
+
bias="none",
|
36 |
+
modules_to_save=["classifier"],
|
37 |
+
)
|
38 |
+
self.visual_encoder = get_peft_model(self.visual_encoder, peft_config_visual)
|
39 |
+
self.visual_encoder.print_trainable_parameters()
|
40 |
+
print('Loading vision encoder with LoRA -- Done')
|
41 |
+
elif args.freeze_vm:
|
42 |
+
for name, param in self.visual_encoder.named_parameters():
|
43 |
+
param.requires_grad = False
|
44 |
+
print(f'Loading Frozen vision encoder:{args.vision_model} -- Done')
|
45 |
+
else:
|
46 |
+
print(f'Loading Trainable vision encoder:{args.vision_model} -- Done')
|
47 |
+
|
48 |
+
print('Loading LLAMA')
|
49 |
+
self.llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_model, use_fast=False)
|
50 |
+
self.llama_tokenizer.pad_token_id = 0
|
51 |
+
if args.low_resource:
|
52 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
53 |
+
args.llama_model,
|
54 |
+
torch_dtype=torch.float16,
|
55 |
+
load_in_8bit=True,
|
56 |
+
device_map="auto"
|
57 |
+
)
|
58 |
+
else:
|
59 |
+
self.llama_model = LlamaForCausalLM.from_pretrained(
|
60 |
+
args.llama_model,
|
61 |
+
torch_dtype=torch.float16,
|
62 |
+
)
|
63 |
+
|
64 |
+
if args.llm_use_lora:
|
65 |
+
self.embed_tokens = self.llama_model.get_input_embeddings()
|
66 |
+
peft_config = LoraConfig(
|
67 |
+
task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.llm_r, lora_alpha=args.llm_alpha, lora_dropout=args.lora_dropout
|
68 |
+
)
|
69 |
+
self.llama_model = get_peft_model(self.llama_model, peft_config)
|
70 |
+
self.llama_model.print_trainable_parameters()
|
71 |
+
print('Loading LLAMA LoRA Done')
|
72 |
+
else:
|
73 |
+
self.embed_tokens = self.llama_model.get_input_embeddings()
|
74 |
+
for name, param in self.llama_model.named_parameters():
|
75 |
+
param.requires_grad = False
|
76 |
+
print('Loading LLAMA Done')
|
77 |
+
|
78 |
+
self.llama_proj = nn.Linear(self.visual_encoder.num_features, self.llama_model.config.hidden_size)
|
79 |
+
self.layer_norm = nn.LayerNorm(self.llama_model.config.hidden_size)
|
80 |
+
self.end_sym = args.end_sym
|
81 |
+
self.prompt = 'Generate a comprehensive and detailed diagnosis report for this chest xray image.'
|
82 |
+
self.val_step_outputs = []
|
83 |
+
self.test_step_outputs = []
|
84 |
+
self.val_score = 0.0
|
85 |
+
|
86 |
+
if args.delta_file is not None:
|
87 |
+
state_dict = torch.load(args.delta_file, map_location=torch.device(f'cuda:{torch.cuda.current_device()}'))['model']
|
88 |
+
self.load_state_dict(state_dict=state_dict, strict=False)
|
89 |
+
print(f'Load checkpoint from {args.delta_file}')
|
90 |
+
|
91 |
+
|
92 |
+
def score(self, ref, hypo):
|
93 |
+
"""
|
94 |
+
ref, dictionary of reference sentences (id, sentence)
|
95 |
+
hypo, dictionary of hypothesis sentences (id, sentence)
|
96 |
+
score, dictionary of scores
|
97 |
+
"""
|
98 |
+
scorers = [
|
99 |
+
(Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
|
100 |
+
(Rouge(), "ROUGE_L"),
|
101 |
+
(Meteor(), "METEOR"),
|
102 |
+
(Cider(), "CIDEr")
|
103 |
+
]
|
104 |
+
final_scores = {}
|
105 |
+
for scorer, method in scorers:
|
106 |
+
score, scores = scorer.compute_score(ref, hypo)
|
107 |
+
if type(score) == list:
|
108 |
+
for m, s in zip(method, score):
|
109 |
+
final_scores[m] = s
|
110 |
+
else:
|
111 |
+
final_scores[method] = score
|
112 |
+
return final_scores
|
113 |
+
|
114 |
+
|
115 |
+
def encode_img(self, images):
|
116 |
+
image_embeds = []
|
117 |
+
for image in images:
|
118 |
+
device = image.device
|
119 |
+
if self.hparams.global_only:
|
120 |
+
image_embed = self.visual_encoder(image)['pooler_output'].unsqueeze(1).to(device)
|
121 |
+
else:
|
122 |
+
image_embed = self.visual_encoder(image)['last_hidden_state'].to(device)
|
123 |
+
image_embeds.append(image_embed)
|
124 |
+
|
125 |
+
image_embeds = torch.stack(image_embeds).mean(0)
|
126 |
+
inputs_llama = self.llama_proj(image_embeds)
|
127 |
+
atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
|
128 |
+
return inputs_llama, atts_llama
|
129 |
+
|
130 |
+
|
131 |
+
def prompt_wrap(self, img_embeds, atts_img):
|
132 |
+
prompt=f'Human: <Img><ImageHere></Img> {self.prompt} \nAssistant:'
|
133 |
+
batch_size = img_embeds.shape[0]
|
134 |
+
p_before, p_after = prompt.split('<ImageHere>')
|
135 |
+
p_before_tokens = self.llama_tokenizer(
|
136 |
+
p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
137 |
+
p_after_tokens = self.llama_tokenizer(
|
138 |
+
p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
|
139 |
+
p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
|
140 |
+
p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
|
141 |
+
wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
|
142 |
+
wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
|
143 |
+
return wrapped_img_embeds, wrapped_atts_img
|
144 |
+
|
145 |
+
|
146 |
+
def forward(self, samples):
|
147 |
+
image = samples["image"]
|
148 |
+
img_embeds, atts_img = self.encode_img(image)
|
149 |
+
img_embeds = self.layer_norm(img_embeds)
|
150 |
+
|
151 |
+
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
|
152 |
+
|
153 |
+
self.llama_tokenizer.padding_side = "right"
|
154 |
+
text = [t + self.end_sym for t in samples["input_text"]]
|
155 |
+
|
156 |
+
to_regress_tokens = self.llama_tokenizer(
|
157 |
+
text,
|
158 |
+
return_tensors="pt",
|
159 |
+
padding="max_length",
|
160 |
+
truncation=True,
|
161 |
+
max_length=self.hparams.max_length,
|
162 |
+
add_special_tokens=False
|
163 |
+
).to(image[0].device)
|
164 |
+
|
165 |
+
targets = to_regress_tokens.input_ids.masked_fill(
|
166 |
+
to_regress_tokens.input_ids == 0, -100
|
167 |
+
)
|
168 |
+
|
169 |
+
empty_targets = (
|
170 |
+
torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
|
171 |
+
dtype=torch.long).to(image[0].device).fill_(-100) # plus one for bos
|
172 |
+
)
|
173 |
+
targets = torch.cat([empty_targets, targets], dim=1)
|
174 |
+
|
175 |
+
batch_size = img_embeds.shape[0]
|
176 |
+
bos = torch.ones([batch_size, 1],
|
177 |
+
dtype=to_regress_tokens.input_ids.dtype,
|
178 |
+
device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
|
179 |
+
bos_embeds = self.embed_tokens(bos)
|
180 |
+
atts_bos = atts_img[:, :1]
|
181 |
+
|
182 |
+
to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
|
183 |
+
inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
|
184 |
+
attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
|
185 |
+
|
186 |
+
outputs = self.llama_model(
|
187 |
+
inputs_embeds=inputs_embeds,
|
188 |
+
attention_mask=attention_mask,
|
189 |
+
return_dict=True,
|
190 |
+
labels=targets,
|
191 |
+
)
|
192 |
+
loss = outputs.loss
|
193 |
+
return {"loss": loss}
|
194 |
+
|
195 |
+
def training_step(self, batch, batch_idx):
|
196 |
+
result = self(batch)
|
197 |
+
self.log_dict(result, prog_bar=True)
|
198 |
+
return result
|
199 |
+
|
200 |
+
def save_checkpoint(self, eval_res):
|
201 |
+
current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
|
202 |
+
param_grad_dic = {
|
203 |
+
k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
|
204 |
+
}
|
205 |
+
state_dict = self.state_dict()
|
206 |
+
for k in list(state_dict.keys()):
|
207 |
+
if k not in param_grad_dic.keys():
|
208 |
+
del state_dict[k]
|
209 |
+
save_obj = {
|
210 |
+
"model": state_dict,
|
211 |
+
"config": self.hparams,
|
212 |
+
"epoch": current_epoch,
|
213 |
+
"step":global_step
|
214 |
+
}
|
215 |
+
os.makedirs(os.path.join(self.hparams.savedmodel_path, 'checkpoints'), exist_ok=True)
|
216 |
+
save_to = os.path.join(
|
217 |
+
self.hparams.savedmodel_path, 'checkpoints',
|
218 |
+
"checkpoint_epoch{}_step{}_bleu{:3f}_cider{:3f}.pth".format(current_epoch, global_step, eval_res['Bleu_4'], eval_res['CIDEr']),
|
219 |
+
)
|
220 |
+
self.print("Saving checkpoint at step {} to {}.".format(global_step, save_to))
|
221 |
+
torch.save(save_obj, save_to)
|
222 |
+
|
223 |
+
def validation_step(self, samples, batch_idx):
|
224 |
+
self.llama_tokenizer.padding_side = "right"
|
225 |
+
to_regress_tokens = self.llama_tokenizer(
|
226 |
+
samples['input_text'],
|
227 |
+
return_tensors="pt",
|
228 |
+
padding="max_length",
|
229 |
+
truncation=True,
|
230 |
+
max_length=self.hparams.max_length,
|
231 |
+
add_special_tokens=False
|
232 |
+
)
|
233 |
+
|
234 |
+
image = samples["image"]
|
235 |
+
img_embeds, atts_img = self.encode_img(image)
|
236 |
+
img_embeds = self.layer_norm(img_embeds)
|
237 |
+
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
|
238 |
+
|
239 |
+
batch_size = img_embeds.shape[0]
|
240 |
+
bos = torch.ones([batch_size, 1],
|
241 |
+
dtype=atts_img.dtype,
|
242 |
+
device=atts_img.device) * self.llama_tokenizer.bos_token_id
|
243 |
+
bos_embeds = self.embed_tokens(bos)
|
244 |
+
atts_bos = atts_img[:, :1]
|
245 |
+
|
246 |
+
inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
|
247 |
+
attention_mask = torch.cat([atts_bos, atts_img], dim=1)
|
248 |
+
|
249 |
+
outputs = self.llama_model.generate(
|
250 |
+
inputs_embeds=inputs_embeds,
|
251 |
+
num_beams=self.hparams.beam_size,
|
252 |
+
do_sample=self.hparams.do_sample,
|
253 |
+
min_new_tokens=self.hparams.min_new_tokens,
|
254 |
+
max_new_tokens=self.hparams.max_new_tokens,
|
255 |
+
repetition_penalty=self.hparams.repetition_penalty,
|
256 |
+
length_penalty=self.hparams.length_penalty,
|
257 |
+
temperature=self.hparams.temperature,
|
258 |
+
)
|
259 |
+
hypo = [self.decode(i) for i in outputs]
|
260 |
+
ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
|
261 |
+
self.val_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
|
262 |
+
return hypo, ref
|
263 |
+
|
264 |
+
def decode(self, output_token):
|
265 |
+
if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
|
266 |
+
output_token = output_token[1:]
|
267 |
+
if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
|
268 |
+
output_token = output_token[1:]
|
269 |
+
output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
|
270 |
+
output_text = output_text.split('</s>')[0].strip()
|
271 |
+
output_text = output_text.replace('<unk>', '')
|
272 |
+
return output_text
|
273 |
+
|
274 |
+
def on_validation_epoch_end(self):
|
275 |
+
ref, hypo, ids = [], [], []
|
276 |
+
for i in self.val_step_outputs:
|
277 |
+
ref.extend(i['ref'])
|
278 |
+
hypo.extend(i['hypo'])
|
279 |
+
ids.extend(i['id'])
|
280 |
+
|
281 |
+
ref = {k:[v] for k, v in zip(ids, ref)}
|
282 |
+
hypo = {k:[v] for k, v in zip(ids, hypo)}
|
283 |
+
eval_res = self.score(ref=ref,hypo=hypo)
|
284 |
+
self.log_dict(eval_res, sync_dist=True, logger=True)
|
285 |
+
|
286 |
+
result_folder = os.path.join(self.hparams.savedmodel_path, 'result')
|
287 |
+
os.makedirs(result_folder, exist_ok=True)
|
288 |
+
current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
|
289 |
+
json.dump(hypo, open(os.path.join(result_folder, f"result_{current_epoch}_{global_step}" + '.json'), 'w'))
|
290 |
+
json.dump(ref, open(os.path.join(result_folder, 'refs.json'), 'w'))
|
291 |
+
self.print(eval_res)
|
292 |
+
|
293 |
+
val_score = 0
|
294 |
+
for score_type, weight in zip(self.hparams.scorer_types, self.hparams.weights):
|
295 |
+
val_score += eval_res[score_type] * weight
|
296 |
+
|
297 |
+
if self.trainer.local_rank == 0:
|
298 |
+
if val_score > self.val_score:
|
299 |
+
self.save_checkpoint(eval_res)
|
300 |
+
self.val_score = val_score
|
301 |
+
self.val_step_outputs.clear()
|
302 |
+
|
303 |
+
|
304 |
+
def test_step(self, samples, batch_idx):
|
305 |
+
self.llama_tokenizer.padding_side = "right"
|
306 |
+
to_regress_tokens = self.llama_tokenizer(
|
307 |
+
samples['input_text'],
|
308 |
+
return_tensors="pt",
|
309 |
+
padding="max_length",
|
310 |
+
truncation=True,
|
311 |
+
max_length=self.hparams.max_length,
|
312 |
+
add_special_tokens=False
|
313 |
+
)
|
314 |
+
|
315 |
+
image = samples["image"]
|
316 |
+
img_embeds, atts_img = self.encode_img(image)
|
317 |
+
img_embeds = self.layer_norm(img_embeds)
|
318 |
+
img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
|
319 |
+
|
320 |
+
batch_size = img_embeds.shape[0]
|
321 |
+
bos = torch.ones([batch_size, 1],
|
322 |
+
dtype=atts_img.dtype,
|
323 |
+
device=atts_img.device) * self.llama_tokenizer.bos_token_id
|
324 |
+
bos_embeds = self.embed_tokens(bos)
|
325 |
+
atts_bos = atts_img[:, :1]
|
326 |
+
|
327 |
+
inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
|
328 |
+
attention_mask = torch.cat([atts_bos, atts_img], dim=1)
|
329 |
+
|
330 |
+
outputs = self.llama_model.generate(
|
331 |
+
inputs_embeds=inputs_embeds,
|
332 |
+
num_beams=self.hparams.beam_size,
|
333 |
+
do_sample=self.hparams.do_sample,
|
334 |
+
min_new_tokens=self.hparams.min_new_tokens,
|
335 |
+
max_new_tokens=self.hparams.max_new_tokens,
|
336 |
+
repetition_penalty=self.hparams.repetition_penalty,
|
337 |
+
length_penalty=self.hparams.length_penalty,
|
338 |
+
temperature=self.hparams.temperature,
|
339 |
+
)
|
340 |
+
hypo = [self.decode(i) for i in outputs]
|
341 |
+
ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
|
342 |
+
self.test_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
|
343 |
+
return hypo, ref
|
344 |
+
|
345 |
+
|
346 |
+
def on_test_epoch_end(self):
|
347 |
+
"""
|
348 |
+
This function is called at the end of the test epoch.
|
349 |
+
It is recommended to test on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampler is used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.
|
350 |
+
"""
|
351 |
+
ref, hypo, ids = [], [], []
|
352 |
+
for i in self.test_step_outputs:
|
353 |
+
ref.extend(i['ref'])
|
354 |
+
hypo.extend(i['hypo'])
|
355 |
+
ids.extend(i['id'])
|
356 |
+
|
357 |
+
ref = {k:[v] for k, v in zip(ids, ref)}
|
358 |
+
hypo = {k:[v] for k, v in zip(ids, hypo)}
|
359 |
+
eval_res = self.score(ref=ref,hypo=hypo)
|
360 |
+
|
361 |
+
result_folder = os.path.join(self.hparams.savedmodel_path, 'result')
|
362 |
+
os.makedirs(result_folder, exist_ok=True)
|
363 |
+
json.dump(hypo, open(os.path.join(result_folder, f"test_result.json"), 'w'))
|
364 |
+
json.dump(ref, open(os.path.join(result_folder, 'test_refs.json'), 'w'))
|
365 |
+
self.print(f"Test result of {self.hparams.delta_file}: {eval_res}")
|
366 |
+
|
367 |
+
def configure_optimizers(self):
|
368 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
|
369 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=self.hparams.max_epochs, eta_min=1e-6)
|
370 |
+
return {"optimizer": optimizer, "lr_scheduler": scheduler}
|
371 |
+
|
372 |
+
def get_progress_bar_dict(self):
|
373 |
+
# don't show the version number
|
374 |
+
items = super().get_progress_bar_dict()
|
375 |
+
items.pop("v_num", None)
|
376 |
+
return items
|
377 |
+
|
378 |
+
def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
|
379 |
+
optimizer.zero_grad()
|
requirements.txt
CHANGED
@@ -1,5 +1,8 @@
|
|
1 |
-
streamlit
|
2 |
torch
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
torch
|
2 |
+
peft
|
3 |
+
tensorboardX
|
4 |
+
transformers==4.30.2
|
5 |
+
lightning==2.0.5
|
6 |
+
Pillow
|
7 |
+
numpy
|
8 |
+
gradio
|
scripts/1-1.shallow_run_iuxray.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="iu_xray"
|
4 |
+
annotation="data/iu_xray/annotation.json"
|
5 |
+
base_dir="./data/iu_xray/images"
|
6 |
+
|
7 |
+
version="v1_shallow"
|
8 |
+
savepath="./save/$dataset/$version"
|
9 |
+
|
10 |
+
python -u train.py \
|
11 |
+
--dataset ${dataset} \
|
12 |
+
--annotation ${annotation} \
|
13 |
+
--base_dir ${base_dir} \
|
14 |
+
--batch_size 8 \
|
15 |
+
--val_batch_size 12 \
|
16 |
+
--freeze_vm True \
|
17 |
+
--vis_use_lora False \
|
18 |
+
--savedmodel_path ${savepath} \
|
19 |
+
--max_length 60 \
|
20 |
+
--min_new_tokens 40 \
|
21 |
+
--max_new_tokens 100 \
|
22 |
+
--repetition_penalty 2.0 \
|
23 |
+
--length_penalty 2.0 \
|
24 |
+
--num_workers 8 \
|
25 |
+
--devices 2 \
|
26 |
+
--max_epochs 15 \
|
27 |
+
--limit_val_batches 1.0 \
|
28 |
+
--val_check_interval 1.0 \
|
29 |
+
--num_sanity_val_steps 0 \
|
30 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/1-2.shallow_test_iuxray.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="iu_xray"
|
4 |
+
annotation="data/iu_xray/annotation.json"
|
5 |
+
base_dir="./data/iu_xray/images"
|
6 |
+
delta_file="/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/R2GenGPT/save/iu_xray/v1_shallow/checkpoints/checkpoint_epoch11_step1548_bleu0.155866_cider0.450477.pth"
|
7 |
+
|
8 |
+
version="v1_shallow"
|
9 |
+
savepath="./save/$dataset/$version"
|
10 |
+
|
11 |
+
python -u train.py \
|
12 |
+
--test \
|
13 |
+
--dataset ${dataset} \
|
14 |
+
--annotation ${annotation} \
|
15 |
+
--base_dir ${base_dir} \
|
16 |
+
--delta_file ${delta_file} \
|
17 |
+
--test_batch_size 16 \
|
18 |
+
--freeze_vm True \
|
19 |
+
--vis_use_lora False \
|
20 |
+
--savedmodel_path ${savepath} \
|
21 |
+
--max_length 60 \
|
22 |
+
--min_new_tokens 40 \
|
23 |
+
--max_new_tokens 100 \
|
24 |
+
--repetition_penalty 2.0 \
|
25 |
+
--length_penalty 2.0 \
|
26 |
+
--num_workers 8 \
|
27 |
+
--devices 1 \
|
28 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/2-1.delta_run_iuxray.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="iu_xray"
|
4 |
+
annotation="data/iu_xray/annotation.json"
|
5 |
+
base_dir="./data/iu_xray/images"
|
6 |
+
|
7 |
+
version="v1_delta"
|
8 |
+
savepath="./save/$dataset/$version"
|
9 |
+
|
10 |
+
python -u train.py \
|
11 |
+
--dataset ${dataset} \
|
12 |
+
--annotation ${annotation} \
|
13 |
+
--base_dir ${base_dir} \
|
14 |
+
--batch_size 8 \
|
15 |
+
--val_batch_size 12 \
|
16 |
+
--freeze_vm True \
|
17 |
+
--vis_use_lora True \
|
18 |
+
--savedmodel_path ${savepath} \
|
19 |
+
--max_length 60 \
|
20 |
+
--min_new_tokens 40 \
|
21 |
+
--max_new_tokens 100 \
|
22 |
+
--repetition_penalty 2.0 \
|
23 |
+
--length_penalty 2.0 \
|
24 |
+
--num_workers 8 \
|
25 |
+
--devices 2 \
|
26 |
+
--max_epochs 15 \
|
27 |
+
--limit_val_batches 1.0 \
|
28 |
+
--val_check_interval 1.0 \
|
29 |
+
--num_sanity_val_steps 2 \
|
30 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/2-2.delta_test_iuxray.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="iu_xray"
|
4 |
+
annotation="data/iu_xray/annotation.json"
|
5 |
+
base_dir="./data/iu_xray/images"
|
6 |
+
delta_file="/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/R2GenGPT/save/iu_xray/v1_delta/checkpoints/checkpoint_epoch13_step1806_bleu0.161532_cider0.530213.pth"
|
7 |
+
|
8 |
+
version="v1_delta"
|
9 |
+
savepath="./save/$dataset/$version"
|
10 |
+
|
11 |
+
python -u train.py \
|
12 |
+
--test \
|
13 |
+
--dataset ${dataset} \
|
14 |
+
--annotation ${annotation} \
|
15 |
+
--base_dir ${base_dir} \
|
16 |
+
--delta_file ${delta_file} \
|
17 |
+
--test_batch_size 16 \
|
18 |
+
--freeze_vm True \
|
19 |
+
--vis_use_lora True \
|
20 |
+
--savedmodel_path ${savepath} \
|
21 |
+
--max_length 60 \
|
22 |
+
--min_new_tokens 40 \
|
23 |
+
--max_new_tokens 100 \
|
24 |
+
--repetition_penalty 2.0 \
|
25 |
+
--length_penalty 2.0 \
|
26 |
+
--num_workers 8 \
|
27 |
+
--devices 1 \
|
28 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/3-1.deep_run_iuxray.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="iu_xray"
|
4 |
+
annotation="data/iu_xray/annotation.json"
|
5 |
+
base_dir="./data/iu_xray/images"
|
6 |
+
|
7 |
+
version="v1_deep"
|
8 |
+
savepath="./save/$dataset/$version"
|
9 |
+
|
10 |
+
python -u train.py \
|
11 |
+
--dataset ${dataset} \
|
12 |
+
--annotation ${annotation} \
|
13 |
+
--base_dir ${base_dir} \
|
14 |
+
--batch_size 8 \
|
15 |
+
--val_batch_size 12 \
|
16 |
+
--freeze_vm False \
|
17 |
+
--vis_use_lora False \
|
18 |
+
--savedmodel_path ${savepath} \
|
19 |
+
--max_length 60 \
|
20 |
+
--min_new_tokens 40 \
|
21 |
+
--max_new_tokens 100 \
|
22 |
+
--repetition_penalty 2.0 \
|
23 |
+
--length_penalty 2.0 \
|
24 |
+
--num_workers 8 \
|
25 |
+
--devices 2 \
|
26 |
+
--max_epochs 15 \
|
27 |
+
--limit_val_batches 1.0 \
|
28 |
+
--val_check_interval 1.0 \
|
29 |
+
--num_sanity_val_steps 2 \
|
30 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/3-2.deep_test_iuxray.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="iu_xray"
|
4 |
+
annotation="data/iu_xray/annotation.json"
|
5 |
+
base_dir="./data/iu_xray/images"
|
6 |
+
delta_file="/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/R2GenGPT/save/iu_xray/v1_deep/checkpoints/checkpoint_epoch12_step1677_bleu0.185560_cider0.678231.pth"
|
7 |
+
|
8 |
+
version="v1_deep"
|
9 |
+
savepath="./save/$dataset/$version"
|
10 |
+
|
11 |
+
python -u train.py \
|
12 |
+
--test \
|
13 |
+
--dataset ${dataset} \
|
14 |
+
--annotation ${annotation} \
|
15 |
+
--base_dir ${base_dir} \
|
16 |
+
--delta_file ${delta_file} \
|
17 |
+
--test_batch_size 16 \
|
18 |
+
--freeze_vm False \
|
19 |
+
--vis_use_lora False \
|
20 |
+
--savedmodel_path ${savepath} \
|
21 |
+
--max_length 60 \
|
22 |
+
--min_new_tokens 40 \
|
23 |
+
--max_new_tokens 100 \
|
24 |
+
--repetition_penalty 2.0 \
|
25 |
+
--length_penalty 2.0 \
|
26 |
+
--num_workers 8 \
|
27 |
+
--devices 1 \
|
28 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/4-1.shallow_run.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="mimic_cxr"
|
4 |
+
annotation="./data/mimic_cxr/annnotation.json"
|
5 |
+
base_dir="./data/mimic_cxr/images"
|
6 |
+
|
7 |
+
version="v1_shallow"
|
8 |
+
savepath="./save/$dataset/$version"
|
9 |
+
|
10 |
+
if [ ! -d "$savepath" ]; then
|
11 |
+
mkdir -p "$savepath"
|
12 |
+
echo "Folder '$savepath' created."
|
13 |
+
else
|
14 |
+
echo "Folder '$savepath' already exists."
|
15 |
+
fi
|
16 |
+
|
17 |
+
python -u train.py \
|
18 |
+
--dataset ${dataset} \
|
19 |
+
--annotation ${annotation} \
|
20 |
+
--base_dir ${base_dir} \
|
21 |
+
--batch_size 8 \
|
22 |
+
--val_batch_size 12 \
|
23 |
+
--freeze_vm True \
|
24 |
+
--vis_use_lora False \
|
25 |
+
--savedmodel_path ${savepath} \
|
26 |
+
--learning_rate 1e-4 \
|
27 |
+
--gradient_clip_val 1 \
|
28 |
+
--max_length 100 \
|
29 |
+
--min_new_tokens 80 \
|
30 |
+
--max_new_tokens 120 \
|
31 |
+
--repetition_penalty 2.0 \
|
32 |
+
--length_penalty 2.0 \
|
33 |
+
--num_workers 8 \
|
34 |
+
--devices 4 \
|
35 |
+
--max_epochs 5 \
|
36 |
+
--limit_val_batches 0.5 \
|
37 |
+
--val_check_interval 0.5 \
|
38 |
+
--num_sanity_val_steps 2 \
|
39 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/4-2.shallow_test.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="mimic_cxr"
|
4 |
+
annotation="data/mimic_cxr/my_mimic_anno.json"
|
5 |
+
base_dir="./data/mimic_cxr/images"
|
6 |
+
delta_file="path/to/pretrained/delta_file"
|
7 |
+
|
8 |
+
version="v1_shallow"
|
9 |
+
savepath="./save/$dataset/$version"
|
10 |
+
|
11 |
+
python -u train.py \
|
12 |
+
--test \
|
13 |
+
--dataset ${dataset} \
|
14 |
+
--annotation ${annotation} \
|
15 |
+
--base_dir ${base_dir} \
|
16 |
+
--delta_file ${delta_file} \
|
17 |
+
--test_batch_size 16 \
|
18 |
+
--freeze_vm True \
|
19 |
+
--vis_use_lora False \
|
20 |
+
--savedmodel_path ${savepath} \
|
21 |
+
--max_length 100 \
|
22 |
+
--min_new_tokens 80 \
|
23 |
+
--max_new_tokens 120 \
|
24 |
+
--repetition_penalty 2.0 \
|
25 |
+
--length_penalty 2.0 \
|
26 |
+
--num_workers 12 \
|
27 |
+
--devices 1 \
|
28 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/5-1.delta_run.sh
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="mimic_cxr"
|
4 |
+
annotation="data/mimic_cxr/my_mimic_anno.json"
|
5 |
+
base_dir="./data/mimic_cxr/images"
|
6 |
+
|
7 |
+
version="v1_delta"
|
8 |
+
savepath="./save/$dataset/$version"
|
9 |
+
|
10 |
+
if [ ! -d "$savepath" ]; then
|
11 |
+
mkdir -p "$savepath"
|
12 |
+
echo "Folder '$savepath' created."
|
13 |
+
else
|
14 |
+
echo "Folder '$savepath' already exists."
|
15 |
+
fi
|
16 |
+
|
17 |
+
python -u train.py \
|
18 |
+
--dataset ${dataset} \
|
19 |
+
--annotation ${annotation} \
|
20 |
+
--base_dir ${base_dir} \
|
21 |
+
--batch_size 8 \
|
22 |
+
--val_batch_size 16 \
|
23 |
+
--freeze_vm True \
|
24 |
+
--vis_use_lora True \
|
25 |
+
--vis_r 16 \
|
26 |
+
--vis_alpha 16 \
|
27 |
+
--savedmodel_path ${savepath} \
|
28 |
+
--max_length 100 \
|
29 |
+
--min_new_tokens 80 \
|
30 |
+
--max_new_tokens 120 \
|
31 |
+
--repetition_penalty 2.0 \
|
32 |
+
--length_penalty 2.0 \
|
33 |
+
--num_workers 16 \
|
34 |
+
--devices 4 \
|
35 |
+
--max_epochs 5 \
|
36 |
+
--limit_val_batches 0.5 \
|
37 |
+
--val_check_interval 0.5 \
|
38 |
+
--num_sanity_val_steps 2 \
|
39 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/5-2.delta_test.sh
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="mimic_cxr"
|
4 |
+
annotation="data/mimic_cxr/my_mimic_anno.json"
|
5 |
+
base_dir="./data/mimic_cxr/images"
|
6 |
+
delta_file="path/to/pretrained/delta_file"
|
7 |
+
|
8 |
+
version="v1_delta"
|
9 |
+
savepath="./save/$dataset/$version"
|
10 |
+
|
11 |
+
python -u train.py \
|
12 |
+
--test \
|
13 |
+
--dataset ${dataset} \
|
14 |
+
--annotation ${annotation} \
|
15 |
+
--base_dir ${base_dir} \
|
16 |
+
--delta_file ${delta_file} \
|
17 |
+
--max_length 100 \
|
18 |
+
--min_new_tokens 80 \
|
19 |
+
--max_new_tokens 120 \
|
20 |
+
--repetition_penalty 2.0 \
|
21 |
+
--length_penalty 2.0 \
|
22 |
+
--test_batch_size 16 \
|
23 |
+
--freeze_vm True \
|
24 |
+
--vis_use_lora True \
|
25 |
+
--vis_r 16 \
|
26 |
+
--vis_alpha 16 \
|
27 |
+
--savedmodel_path ${savepath} \
|
28 |
+
--num_workers 12 \
|
29 |
+
--devices 1 \
|
30 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/6-1.deep_run.sh
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="mimic_cxr"
|
4 |
+
annotation="data/mimic_cxr/my_mimic_anno.json"
|
5 |
+
base_dir="./data/mimic_cxr/images"
|
6 |
+
|
7 |
+
version="v1_deep"
|
8 |
+
savepath="./save/$dataset/$version"
|
9 |
+
|
10 |
+
if [ ! -d "$savepath" ]; then
|
11 |
+
mkdir -p "$savepath"
|
12 |
+
echo "Folder '$savepath' created."
|
13 |
+
else
|
14 |
+
echo "Folder '$savepath' already exists."
|
15 |
+
fi
|
16 |
+
|
17 |
+
python -u train.py \
|
18 |
+
--dataset ${dataset} \
|
19 |
+
--annotation ${annotation} \
|
20 |
+
--base_dir ${base_dir} \
|
21 |
+
--batch_size 6 \
|
22 |
+
--val_batch_size 12 \
|
23 |
+
--freeze_vm False \
|
24 |
+
--vis_use_lora False \
|
25 |
+
--llm_use_lora False \
|
26 |
+
--savedmodel_path ${savepath} \
|
27 |
+
--max_length 100 \
|
28 |
+
--min_new_tokens 80 \
|
29 |
+
--max_new_tokens 120 \
|
30 |
+
--repetition_penalty 2.0 \
|
31 |
+
--length_penalty 2.0 \
|
32 |
+
--num_workers 12 \
|
33 |
+
--devices 4 \
|
34 |
+
--max_epochs 5 \
|
35 |
+
--limit_val_batches 0.5 \
|
36 |
+
--val_check_interval 0.5 \
|
37 |
+
--num_sanity_val_steps 2 \
|
38 |
+
2>&1 |tee -a ${savepath}/log.txt
|
scripts/6-2.deep_test.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
dataset="mimic_cxr"
|
4 |
+
annotation="data/mimic_cxr/my_mimic_anno.json"
|
5 |
+
base_dir="./data/mimic_cxr/images"
|
6 |
+
delta_file="path/to/pretrained/delta_file"
|
7 |
+
|
8 |
+
version="v1_deep"
|
9 |
+
savepath="./save/$dataset/$version"
|
10 |
+
|
11 |
+
python -u train.py \
|
12 |
+
--test \
|
13 |
+
--dataset ${dataset} \
|
14 |
+
--annotation ${annotation} \
|
15 |
+
--base_dir ${base_dir} \
|
16 |
+
--delta_file ${delta_file} \
|
17 |
+
--test_batch_size 16 \
|
18 |
+
--max_length 100 \
|
19 |
+
--min_new_tokens 80 \
|
20 |
+
--max_new_tokens 120 \
|
21 |
+
--repetition_penalty 2.0 \
|
22 |
+
--length_penalty 2.0 \
|
23 |
+
--freeze_vm False \
|
24 |
+
--vis_use_lora False \
|
25 |
+
--savedmodel_path ${savepath} \
|
26 |
+
--num_workers 12 \
|
27 |
+
--devices 1 \
|
28 |
+
2>&1 |tee -a ${savepath}/log.txt
|
train.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pprint import pprint
|
3 |
+
from configs.config import parser
|
4 |
+
from dataset.data_module import DataModule
|
5 |
+
from lightning_tools.callbacks import add_callbacks
|
6 |
+
from models.R2GenGPT import R2GenGPT
|
7 |
+
from lightning.pytorch import seed_everything
|
8 |
+
import lightning.pytorch as pl
|
9 |
+
|
10 |
+
|
11 |
+
def train(args):
|
12 |
+
dm = DataModule(args)
|
13 |
+
callbacks = add_callbacks(args)
|
14 |
+
|
15 |
+
trainer = pl.Trainer(
|
16 |
+
devices=args.devices,
|
17 |
+
num_nodes=args.num_nodes,
|
18 |
+
strategy=args.strategy,
|
19 |
+
accelerator=args.accelerator,
|
20 |
+
precision=args.precision,
|
21 |
+
val_check_interval = args.val_check_interval,
|
22 |
+
limit_val_batches = args.limit_val_batches,
|
23 |
+
max_epochs = args.max_epochs,
|
24 |
+
num_sanity_val_steps = args.num_sanity_val_steps,
|
25 |
+
accumulate_grad_batches=args.accumulate_grad_batches,
|
26 |
+
callbacks=callbacks["callbacks"],
|
27 |
+
logger=callbacks["loggers"]
|
28 |
+
)
|
29 |
+
|
30 |
+
if args.ckpt_file is not None:
|
31 |
+
model = R2GenGPT.load_from_checkpoint(args.ckpt_file, strict=False)
|
32 |
+
else:
|
33 |
+
model = R2GenGPT(args)
|
34 |
+
|
35 |
+
if args.test:
|
36 |
+
trainer.test(model, datamodule=dm)
|
37 |
+
elif args.validate:
|
38 |
+
trainer.validate(model, datamodule=dm)
|
39 |
+
else:
|
40 |
+
trainer.fit(model, datamodule=dm)
|
41 |
+
|
42 |
+
def main():
|
43 |
+
args = parser.parse_args()
|
44 |
+
os.makedirs(args.savedmodel_path, exist_ok=True)
|
45 |
+
pprint(vars(args))
|
46 |
+
seed_everything(42, workers=True)
|
47 |
+
train(args)
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == '__main__':
|
51 |
+
main()
|