tousin23 commited on
Commit
6551065
·
verified ·
1 Parent(s): fabd620

Upload 41 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ evalcap/meteor/meteor-1.5.jar filter=lfs diff=lfs merge=lfs -text
37
+ evalcap/tokenizer/stanford-corenlp-3.4.1.jar filter=lfs diff=lfs merge=lfs -text
38
+ images/align.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ #.idea/
LICENSE ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2023, zhanyuwang
4
+
5
+ Redistribution and use in source and binary forms, with or without
6
+ modification, are permitted provided that the following conditions are met:
7
+
8
+ 1. Redistributions of source code must retain the above copyright notice, this
9
+ list of conditions and the following disclaimer.
10
+
11
+ 2. Redistributions in binary form must reproduce the above copyright notice,
12
+ this list of conditions and the following disclaimer in the documentation
13
+ and/or other materials provided with the distribution.
14
+
15
+ 3. Neither the name of the copyright holder nor the names of its
16
+ contributors may be used to endorse or promote products derived from
17
+ this software without specific prior written permission.
18
+
19
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
20
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
21
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
23
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
24
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
25
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
26
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
27
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
28
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
README.md CHANGED
@@ -1,12 +1,77 @@
1
- ---
2
- title: X RayDemo
3
- emoji: 🏃
4
- colorFrom: green
5
- colorTo: green
6
- sdk: streamlit
7
- sdk_version: 1.35.0
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # R2GenGPT: Radiology Report Generation with Frozen LLMs
2
+
3
+ ## Introduction
4
+ ![overview](https://github.com/wang-zhanyu/R2GenGPT/blob/main/images/align.png)
5
+
6
+ ## Getting Started
7
+ ### Installation
8
+
9
+ **1. Prepare the code and the environment**
10
+
11
+ Git clone our repository and install the requirements.
12
+
13
+ ```bash
14
+ https://github.com/wang-zhanyu/R2GenGPT.git
15
+ cd R2GenGPT
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+
20
+ **2. Prepare the training dataset**
21
+
22
+ IU-xray: download the dataset from [here](https://drive.google.com/file/d/1c0BXEuDy8Cmm2jfN0YYGkQxFZd2ZIoLg/view)
23
+
24
+ Mimic-cxr: you can download our preprocess annotation file from [here](https://drive.google.com/file/d/14689ztodTtrQJYs--ihB_hgsPMMNHX-H/view?usp=sharing) and download the images from [official website](https://physionet.org/content/mimic-cxr-jpg/2.0.0/)
25
+
26
+ After downloading the data, place it in the ./data folder.
27
+
28
+ ### Training
29
+
30
+ For shallow alignment
31
+
32
+ ```bash
33
+ bash scripts/4-1.shallow_run.sh
34
+ ```
35
+
36
+ For delta alignment
37
+
38
+ ```bash
39
+ bash scripts/5-1.delta_run.sh
40
+ ```
41
+
42
+ For deep alignment
43
+
44
+ ```bash
45
+ bash scripts/6-1.deep_run.sh
46
+ ```
47
+
48
+ ### Testing (For MIMIC-CXR)
49
+ You can download our pretrained Delta checkpoints for [Here](https://drive.google.com/drive/folders/1ywEITWfYIAAYy0VY1IZ24Ec_GoNmkqIY?usp=sharing)
50
+
51
+ For shallow alignment
52
+
53
+ ```bash
54
+ bash scripts/4-2.shallow_test.sh
55
+ ```
56
+
57
+ For delta alignment
58
+
59
+ ```bash
60
+ bash scripts/5-2.delta_test.sh
61
+ ```
62
+
63
+ For deep alignment
64
+
65
+ ```bash
66
+ bash scripts/6-2.shallow_test.sh
67
+ ```
68
+
69
+
70
+ ## Acknowledgement
71
+
72
+ + [MiniGPT-4](https://github.com/Vision-CAIR/MiniGPT-4) Some codes of this repo are based on MiniGPT-4.
73
+ + [Llama2](https://github.com/facebookresearch/llama) The fantastic language ability of Llama-2 with only 7B parameters is just amazing.
74
+
75
+
76
+ ## License
77
+ This repository is under [BSD 3-Clause License](LICENSE.md).
configs/config.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ parser = argparse.ArgumentParser(description="hyper-parameter for R2GenGPT")
4
+ # ========================= Dataset Configs ==========================
5
+ parser.add_argument('--test', action='store_true', help="only run test set")
6
+ parser.add_argument('--validate', action='store_true', help="only run validation set")
7
+ parser.add_argument('--dataset', type=str, default='mimic_cxr', help="iu-xray or mimic-cxr")
8
+ parser.add_argument('--annotation', type=str, default=r'./data/mimic_cxr/annotation.json', help="annotation file of the dataset")
9
+ parser.add_argument('--base_dir', type=str, default=r'./data/mimic_cxr/images', help="base dir to help find images")
10
+ parser.add_argument('--batch_size', default=6, type=int, help="use for training duration per worker")
11
+ parser.add_argument('--val_batch_size', default=16, type=int, help="use for validation duration per worker")
12
+ parser.add_argument('--test_batch_size', default=16, type=int, help="use for testing duration per worker")
13
+ parser.add_argument('--prefetch_factor', default=4, type=int, help="use for training duration per worker")
14
+ parser.add_argument('--num_workers', default=8, type=int, help="Cpu num for dataloaders")
15
+
16
+ # ========================= Model Settings ============================
17
+ parser.add_argument('--vision_model', default='microsoft/swin-base-patch4-window7-224', type=str, help="vision model to use")
18
+ parser.add_argument('--llama_model', default='meta-llama/Llama-2-7b-chat-hf', type=str, help="LLM model to use")
19
+ parser.add_argument('--freeze_vm', default=True, type=lambda x: (str(x).lower() == 'true'), help='freeze vision model')
20
+ parser.add_argument('--llm_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for LLM model")
21
+ parser.add_argument('--llm_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
22
+ parser.add_argument('--llm_alpha', default=16, type=int, help='Scaling factor.')
23
+ parser.add_argument('--vis_use_lora', default=False, type=lambda x: (str(x).lower() == 'true'), help="whether use lora for vision model")
24
+ parser.add_argument('--vis_r', default=16, type=int, help='The dimension used by the LoRA update matrices')
25
+ parser.add_argument('--vis_alpha', default=16, type=int, help='Scaling factor.')
26
+ parser.add_argument('--lora_dropout', default=0.1, type=float, help='lora dropout')
27
+ parser.add_argument('--global_only', default=False, type=lambda x: (str(x).lower() == 'true'), help='use global embedding only')
28
+ parser.add_argument('--low_resource', default=False, type=bool)
29
+ parser.add_argument('--end_sym', default='</s>', type=str)
30
+
31
+ # ======================== SavedModel Configs ===========================
32
+ parser.add_argument('--savedmodel_path', type=str, default='save/mimic/v1')
33
+ parser.add_argument('--ckpt_file', type=str, default=None, help='the checkpoint file to load')
34
+ parser.add_argument('--delta_file', type=str, default=None, help='the delta file to load')
35
+ parser.add_argument('--weights', type=list, default=[0.5, 0.5])
36
+ parser.add_argument('--scorer_types', type=list, default=['Bleu_4', 'CIDEr'])
37
+
38
+ # ========================= Learning Configs ==========================
39
+ parser.add_argument('--learning_rate', default=1e-4, type=float, help='initial learning rate')
40
+ parser.add_argument('--gradient_clip_val', default=None, type=int, help='gradient clip value')
41
+
42
+ # ========================= Decoding Settings ==========================
43
+ parser.add_argument('--beam_size', type=int, default=3)
44
+ parser.add_argument('--do_sample', type=bool, default=False)
45
+ parser.add_argument('--no_repeat_ngram_size', type=int, default=2)
46
+ parser.add_argument('--num_beam_groups', type=int, default=1)
47
+ parser.add_argument('--min_new_tokens', type=int, default=80)
48
+ parser.add_argument('--max_new_tokens', type=int, default=120)
49
+ parser.add_argument('--max_length', type=int, default=100)
50
+ parser.add_argument('--repetition_penalty', type=float, default=2.0)
51
+ parser.add_argument('--length_penalty', type=float, default=2.0)
52
+ parser.add_argument('--diversity_penalty', type=float, default=0)
53
+ parser.add_argument('--temperature', type=float, default=0)
54
+
55
+ # ====================== Pytorch Lightning ===========================
56
+ parser.add_argument('--devices', type=int, default=2, help='how many gpus to use')
57
+ parser.add_argument('--num_nodes', type=int, default=1, help='Number of GPU nodes for distributed training.')
58
+ parser.add_argument('--accelerator', type=str, default="gpu", choices=["cpu", "gpu", "tpu", "ipu", "hpu", "mps"], help='accelerator types')
59
+ parser.add_argument('--strategy', type=str, default="ddp", help='default ddp for multi-gpus')
60
+ parser.add_argument('--precision', type=str, default='bf16-mixed', help='16 or 32 bf16-mixed, using for original pytorch amp auto cast')
61
+ parser.add_argument('--limit_val_batches', type=float, default=1.0, help='How much of validation dataset to check (float = fraction, int = num_batches).')
62
+ parser.add_argument('--limit_test_batches', type=float, default=1.0, help='How much of test dataset to check (float = fraction, int = num_batches).')
63
+ parser.add_argument('--limit_train_batches', type=float, default=1.0, help='How much of training dataset to check (float = fraction, int = num_batches)')
64
+ parser.add_argument('--max_epochs', type=int, default=3, help='Stop training once this number of epochs is reached')
65
+ parser.add_argument('--every_n_train_steps', type=int, default=0, help='How many training steps to save a checkpoint')
66
+ parser.add_argument('--val_check_interval', type=float, default=1.0, help='How often to check the validation set')
67
+ parser.add_argument('--accumulate_grad_batches', type=int, default=1, help='Accumulates gradients over k batches before stepping the optimizer')
68
+ parser.add_argument("--num_sanity_val_steps", type=int, default=2, help='Sanity check runs n validation batches before starting the training routine')
dataset/data_helper.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ import json
4
+ import re
5
+ import numpy as np
6
+ from PIL import Image
7
+ import torch.utils.data as data
8
+ from transformers import BertTokenizer, AutoImageProcessor
9
+
10
+
11
+ class FieldParser:
12
+ def __init__(
13
+ self,
14
+ args
15
+ ):
16
+ super().__init__()
17
+ self.args = args
18
+ self.dataset = args.dataset
19
+ self.vit_feature_extractor = AutoImageProcessor.from_pretrained(args.vision_model)
20
+
21
+
22
+ def _parse_image(self, img):
23
+ pixel_values = self.vit_feature_extractor(img, return_tensors="pt").pixel_values
24
+ return pixel_values[0]
25
+
26
+ # from https://github.com/cuhksz-nlp/R2Gen/blob/main/modules/tokenizers.py
27
+ def clean_report(self, report):
28
+ # clean Iu-xray reports
29
+ if self.dataset == "iu_xray":
30
+ report_cleaner = lambda t: t.replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '') \
31
+ .replace('. 2. ', '. ').replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ') \
32
+ .replace(' 2. ', '. ').replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ') \
33
+ .strip().lower().split('. ')
34
+ sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+():-\[\]{}]', '', t.replace('"', '').replace('/', '').
35
+ replace('\\', '').replace("'", '').strip().lower())
36
+ tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
37
+ report = ' . '.join(tokens) + ' .'
38
+ # clean MIMIC-CXR reports
39
+ else:
40
+ report_cleaner = lambda t: t.replace('\n', ' ').replace('__', '_').replace('__', '_').replace('__', '_') \
41
+ .replace('__', '_').replace('__', '_').replace('__', '_').replace('__', '_').replace(' ', ' ') \
42
+ .replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ').replace(' ', ' ') \
43
+ .replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.').replace('..', '.') \
44
+ .replace('..', '.').replace('..', '.').replace('..', '.').replace('1. ', '').replace('. 2. ', '. ') \
45
+ .replace('. 3. ', '. ').replace('. 4. ', '. ').replace('. 5. ', '. ').replace(' 2. ', '. ') \
46
+ .replace(' 3. ', '. ').replace(' 4. ', '. ').replace(' 5. ', '. ').replace(':', ' :') \
47
+ .strip().lower().split('. ')
48
+ sent_cleaner = lambda t: re.sub('[.,?;*!%^&_+()\[\]{}]', '', t.replace('"', '').replace('/', '')
49
+ .replace('\\', '').replace("'", '').strip().lower())
50
+ tokens = [sent_cleaner(sent) for sent in report_cleaner(report) if sent_cleaner(sent) != []]
51
+ report = ' . '.join(tokens) + ' .'
52
+ # report = ' '.join(report.split()[:self.args.max_txt_len])
53
+ return report
54
+
55
+
56
+ def parse(self, features):
57
+ to_return = {'id': features['id']}
58
+ report = features.get("report", "")
59
+ report = self.clean_report(report)
60
+ to_return['input_text'] = report
61
+ # chest x-ray images
62
+ images = []
63
+ for image_path in features['image_path']:
64
+ with Image.open(os.path.join(self.args.base_dir, image_path)) as pil:
65
+ array = np.array(pil, dtype=np.uint8)
66
+ if array.shape[-1] != 3 or len(array.shape) != 3:
67
+ array = np.array(pil.convert("RGB"), dtype=np.uint8)
68
+ image = self._parse_image(array)
69
+ images.append(image)
70
+ to_return["image"] = images
71
+ return to_return
72
+
73
+
74
+ def transform_with_parse(self, inputs):
75
+ return self.parse(inputs)
76
+
77
+
78
+ class ParseDataset(data.Dataset):
79
+ def __init__(self, args, split='train'):
80
+ self.args = args
81
+ self.meta = json.load(open(args.annotation, 'r'))
82
+ self.meta = self.meta[split]
83
+ self.parser = FieldParser(args)
84
+
85
+ def __len__(self):
86
+ return len(self.meta)
87
+
88
+ def __getitem__(self, index):
89
+ return self.parser.transform_with_parse(self.meta[index])
90
+
91
+
92
+ def create_datasets(args):
93
+ train_dataset = ParseDataset(args, 'train')
94
+ dev_dataset = ParseDataset(args, 'val')
95
+ test_dataset = ParseDataset(args, 'test')
96
+ return train_dataset, dev_dataset, test_dataset
97
+
98
+
dataset/data_module.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning.pytorch import LightningDataModule
2
+ from torch.utils.data import DataLoader
3
+ from dataset.data_helper import create_datasets
4
+
5
+
6
+
7
+ class DataModule(LightningDataModule):
8
+
9
+ def __init__(
10
+ self,
11
+ args
12
+ ):
13
+ super().__init__()
14
+ self.args = args
15
+
16
+ def prepare_data(self):
17
+ """
18
+ Use this method to do things that might write to disk or that need to be done only from a single process in distributed settings.
19
+
20
+ download
21
+
22
+ tokenize
23
+
24
+ etc…
25
+ :return:
26
+ """
27
+
28
+ def setup(self, stage: str):
29
+ """
30
+ There are also data operations you might want to perform on every GPU. Use setup to do things like:
31
+
32
+ count number of classes
33
+
34
+ build vocabulary
35
+
36
+ perform train/val/test splits
37
+
38
+ apply transforms (defined explicitly in your datamodule or assigned in init)
39
+
40
+ etc…
41
+ :param stage:
42
+ :return:
43
+ """
44
+ train_dataset, dev_dataset, test_dataset = create_datasets(self.args)
45
+ self.dataset = {
46
+ "train": train_dataset, "validation": dev_dataset, "test": test_dataset
47
+ }
48
+
49
+
50
+ def train_dataloader(self):
51
+ """
52
+ Use this method to generate the train dataloader. Usually you just wrap the dataset you defined in setup.
53
+ :return:
54
+ """
55
+ loader = DataLoader(self.dataset["train"], batch_size=self.args.batch_size, drop_last=True, pin_memory=True,
56
+ num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
57
+ return loader
58
+
59
+
60
+ def val_dataloader(self):
61
+ """
62
+ Use this method to generate the val dataloader. Usually you just wrap the dataset you defined in setup.
63
+ :return:
64
+ """
65
+ loader = DataLoader(self.dataset["validation"], batch_size=self.args.val_batch_size, drop_last=False, pin_memory=True,
66
+ num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
67
+ return loader
68
+
69
+
70
+ def test_dataloader(self):
71
+ loader = DataLoader(self.dataset["test"], batch_size=self.args.test_batch_size, drop_last=False, pin_memory=False,
72
+ num_workers=self.args.num_workers, prefetch_factor=self.args.prefetch_factor)
73
+ return loader
evalcap/__init__.py ADDED
File without changes
evalcap/bleu/LICENSE ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2015 Xinlei Chen, Hao Fang, Tsung-Yi Lin, and Ramakrishna Vedantam
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining a copy
4
+ of this software and associated documentation files (the "Software"), to deal
5
+ in the Software without restriction, including without limitation the rights
6
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
7
+ copies of the Software, and to permit persons to whom the Software is
8
+ furnished to do so, subject to the following conditions:
9
+
10
+ The above copyright notice and this permission notice shall be included in
11
+ all copies or substantial portions of the Software.
12
+
13
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
14
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
15
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
16
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
17
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
18
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
19
+ THE SOFTWARE.
20
+
21
+ python2 转 python3
22
+ python2 dict, iteritems()
23
+ python3 dict, items()
evalcap/bleu/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
evalcap/bleu/bleu.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : bleu.py
4
+ #
5
+ # Description : Wrapper for BLEU scorer.
6
+ #
7
+ # Creation Date : 06-01-2015
8
+ # Last Modified : Thu 19 Mar 2015 09:13:28 PM PDT
9
+ # Authors : Hao Fang <[email protected]> and Tsung-Yi Lin <[email protected]>
10
+ import os
11
+ dir_path = os.path.dirname(os.path.abspath(__file__))
12
+ import sys
13
+ sys.path.append(dir_path)
14
+ from bleu_scorer import BleuScorer
15
+
16
+
17
+ class Bleu:
18
+ def __init__(self, n=4):
19
+ # default compute Blue score up to 4
20
+ self._n = n
21
+ self._hypo_for_image = {}
22
+ self.ref_for_image = {}
23
+
24
+ def compute_score(self, gts, res, verbose=0):
25
+
26
+ assert(gts.keys() == res.keys())
27
+ imgIds = gts.keys()
28
+
29
+ bleu_scorer = BleuScorer(n=self._n)
30
+ for id in imgIds:
31
+ hypo = res[id]
32
+ ref = gts[id]
33
+
34
+ # Sanity check.
35
+ assert(type(hypo) is list)
36
+ assert(len(hypo) == 1)
37
+ assert(type(ref) is list)
38
+ assert(len(ref) >= 1)
39
+
40
+ bleu_scorer += (hypo[0], ref)
41
+
42
+ #score, scores = bleu_scorer.compute_score(option='shortest')
43
+ score, scores = bleu_scorer.compute_score(option='closest', verbose=verbose)
44
+ # score, scores = bleu_scorer.compute_score(option='average', verbose=1)
45
+
46
+ # return (bleu, bleu_info)
47
+ return score, scores
48
+
49
+ def method(self):
50
+ return "Bleu"
evalcap/bleu/bleu_scorer.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # bleu_scorer.py
4
+ # David Chiang <[email protected]>
5
+
6
+ # Copyright (c) 2004-2006 University of Maryland. All rights
7
+ # reserved. Do not redistribute without permission from the
8
+ # author. Not for commercial use.
9
+
10
+ # Modified by:
11
+ # Hao Fang <[email protected]>
12
+ # Tsung-Yi Lin <[email protected]>
13
+
14
+ '''Provides:
15
+ cook_refs(refs, n=4): Transform a list of reference sentences as strings into a form usable by cook_test().
16
+ cook_test(test, refs, n=4): Transform a test sentence as a string (together with the cooked reference sentences) into a form usable by score_cooked().
17
+ '''
18
+
19
+ import copy
20
+ import sys, math, re
21
+ from collections import defaultdict
22
+
23
+ def precook(s, n=4, out=False):
24
+ """Takes a string as input and returns an object that can be given to
25
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
26
+ can take string arguments as well."""
27
+ words = s.split()
28
+ counts = defaultdict(int)
29
+ for k in range(1,n+1):
30
+ for i in range(len(words)-k+1):
31
+ ngram = tuple(words[i:i+k])
32
+ counts[ngram] += 1
33
+ return (len(words), counts)
34
+
35
+ def cook_refs(refs, eff=None, n=4): ## lhuang: oracle will call with "average"
36
+ '''Takes a list of reference sentences for a single segment
37
+ and returns an object that encapsulates everything that BLEU
38
+ needs to know about them.'''
39
+
40
+ reflen = []
41
+ maxcounts = {}
42
+ for ref in refs:
43
+ rl, counts = precook(ref, n)
44
+ reflen.append(rl)
45
+ for (ngram,count) in counts.items():
46
+ maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
47
+
48
+ # Calculate effective reference sentence length.
49
+ if eff == "shortest":
50
+ reflen = min(reflen)
51
+ elif eff == "average":
52
+ reflen = float(sum(reflen))/len(reflen)
53
+
54
+ ## lhuang: N.B.: leave reflen computaiton to the very end!!
55
+
56
+ ## lhuang: N.B.: in case of "closest", keep a list of reflens!! (bad design)
57
+
58
+ return (reflen, maxcounts)
59
+
60
+ def cook_test(test, crefs, eff=None, n=4):
61
+ '''Takes a test sentence and returns an object that
62
+ encapsulates everything that BLEU needs to know about it.'''
63
+ reflen, refmaxcounts = crefs[0], crefs[1]
64
+
65
+ testlen, counts = precook(test, n, True)
66
+
67
+ result = {}
68
+
69
+ # Calculate effective reference sentence length.
70
+
71
+ if eff == "closest":
72
+ result["reflen"] = min((abs(l-testlen), l) for l in reflen)[1]
73
+ else: ## i.e., "average" or "shortest" or None
74
+ result["reflen"] = reflen
75
+
76
+ result["testlen"] = testlen
77
+
78
+ result["guess"] = [max(0,testlen-k+1) for k in range(1,n+1)]
79
+
80
+ result['correct'] = [0]*n
81
+ for (ngram, count) in counts.items():
82
+ result["correct"][len(ngram)-1] += min(refmaxcounts.get(ngram,0), count)
83
+
84
+ return result
85
+
86
+ class BleuScorer(object):
87
+ """Bleu scorer.
88
+ """
89
+
90
+ __slots__ = "n", "crefs", "ctest", "_score", "_ratio", "_testlen", "_reflen", "special_reflen"
91
+ # special_reflen is used in oracle (proportional effective ref len for a node).
92
+
93
+ def copy(self):
94
+ ''' copy the refs.'''
95
+ new = BleuScorer(n=self.n)
96
+ new.ctest = copy.copy(self.ctest)
97
+ new.crefs = copy.copy(self.crefs)
98
+ new._score = None
99
+ return new
100
+
101
+ def __init__(self, test=None, refs=None, n=4, special_reflen=None):
102
+ ''' singular instance '''
103
+
104
+ self.n = n
105
+ self.crefs = []
106
+ self.ctest = []
107
+ self.cook_append(test, refs)
108
+ self.special_reflen = special_reflen
109
+
110
+ def cook_append(self, test, refs):
111
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
112
+
113
+ if refs is not None:
114
+ self.crefs.append(cook_refs(refs))
115
+ if test is not None:
116
+ cooked_test = cook_test(test, self.crefs[-1])
117
+ self.ctest.append(cooked_test) ## N.B.: -1
118
+ else:
119
+ self.ctest.append(None) # lens of crefs and ctest have to match
120
+
121
+ self._score = None ## need to recompute
122
+
123
+ def ratio(self, option=None):
124
+ self.compute_score(option=option)
125
+ return self._ratio
126
+
127
+ def score_ratio(self, option=None):
128
+ '''return (bleu, len_ratio) pair'''
129
+ return (self.fscore(option=option), self.ratio(option=option))
130
+
131
+ def score_ratio_str(self, option=None):
132
+ return "%.4f (%.2f)" % self.score_ratio(option)
133
+
134
+ def reflen(self, option=None):
135
+ self.compute_score(option=option)
136
+ return self._reflen
137
+
138
+ def testlen(self, option=None):
139
+ self.compute_score(option=option)
140
+ return self._testlen
141
+
142
+ def retest(self, new_test):
143
+ if type(new_test) is str:
144
+ new_test = [new_test]
145
+ assert len(new_test) == len(self.crefs), new_test
146
+ self.ctest = []
147
+ for t, rs in zip(new_test, self.crefs):
148
+ self.ctest.append(cook_test(t, rs))
149
+ self._score = None
150
+
151
+ return self
152
+
153
+ def rescore(self, new_test):
154
+ ''' replace test(s) with new test(s), and returns the new score.'''
155
+
156
+ return self.retest(new_test).compute_score()
157
+
158
+ def size(self):
159
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
160
+ return len(self.crefs)
161
+
162
+ def __iadd__(self, other):
163
+ '''add an instance (e.g., from another sentence).'''
164
+
165
+ if type(other) is tuple:
166
+ ## avoid creating new BleuScorer instances
167
+ self.cook_append(other[0], other[1])
168
+ else:
169
+ assert self.compatible(other), "incompatible BLEUs."
170
+ self.ctest.extend(other.ctest)
171
+ self.crefs.extend(other.crefs)
172
+ self._score = None ## need to recompute
173
+
174
+ return self
175
+
176
+ def compatible(self, other):
177
+ return isinstance(other, BleuScorer) and self.n == other.n
178
+
179
+ def single_reflen(self, option="average"):
180
+ return self._single_reflen(self.crefs[0][0], option)
181
+
182
+ def _single_reflen(self, reflens, option=None, testlen=None):
183
+
184
+ if option == "shortest":
185
+ reflen = min(reflens)
186
+ elif option == "average":
187
+ reflen = float(sum(reflens))/len(reflens)
188
+ elif option == "closest":
189
+ reflen = min((abs(l-testlen), l) for l in reflens)[1]
190
+ else:
191
+ assert False, "unsupported reflen option %s" % option
192
+
193
+ return reflen
194
+
195
+ def recompute_score(self, option=None, verbose=0):
196
+ self._score = None
197
+ return self.compute_score(option, verbose)
198
+
199
+ def compute_score(self, option=None, verbose=0):
200
+ n = self.n
201
+ small = 1e-9
202
+ tiny = 1e-15 ## so that if guess is 0 still return 0
203
+ bleu_list = [[] for _ in range(n)]
204
+
205
+ if self._score is not None:
206
+ return self._score
207
+
208
+ if option is None:
209
+ option = "average" if len(self.crefs) == 1 else "closest"
210
+
211
+ self._testlen = 0
212
+ self._reflen = 0
213
+ totalcomps = {'testlen':0, 'reflen':0, 'guess':[0]*n, 'correct':[0]*n}
214
+
215
+ # for each sentence
216
+ for comps in self.ctest:
217
+ testlen = comps['testlen']
218
+ self._testlen += testlen
219
+
220
+ if self.special_reflen is None: ## need computation
221
+ reflen = self._single_reflen(comps['reflen'], option, testlen)
222
+ else:
223
+ reflen = self.special_reflen
224
+
225
+ self._reflen += reflen
226
+
227
+ for key in ['guess','correct']:
228
+ for k in range(n):
229
+ totalcomps[key][k] += comps[key][k]
230
+
231
+ # append per image bleu score
232
+ bleu = 1.
233
+ for k in range(n):
234
+ bleu *= (float(comps['correct'][k]) + tiny) \
235
+ /(float(comps['guess'][k]) + small)
236
+ bleu_list[k].append(bleu ** (1./(k+1)))
237
+ ratio = (testlen + tiny) / (reflen + small) ## N.B.: avoid zero division
238
+ if ratio < 1:
239
+ for k in range(n):
240
+ bleu_list[k][-1] *= math.exp(1 - 1/ratio)
241
+
242
+ if verbose > 1:
243
+ print(comps, reflen)
244
+
245
+ totalcomps['reflen'] = self._reflen
246
+ totalcomps['testlen'] = self._testlen
247
+
248
+ bleus = []
249
+ bleu = 1.
250
+ for k in range(n):
251
+ bleu *= float(totalcomps['correct'][k] + tiny) \
252
+ / (totalcomps['guess'][k] + small)
253
+ bleus.append(bleu ** (1./(k+1)))
254
+ ratio = (self._testlen + tiny) / (self._reflen + small) ## N.B.: avoid zero division
255
+ if ratio < 1:
256
+ for k in range(n):
257
+ bleus[k] *= math.exp(1 - 1/ratio)
258
+
259
+ if verbose > 0:
260
+ print(totalcomps)
261
+ print("ratio:", ratio)
262
+
263
+ self._score = bleus
264
+ return self._score, bleu_list
evalcap/cider/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
evalcap/cider/cider.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Filename: cider.py
2
+ #
3
+ # Description: Describes the class to compute the CIDEr (Consensus-Based Image Description Evaluation) Metric
4
+ # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
5
+ #
6
+ # Creation Date: Sun Feb 8 14:16:54 2015
7
+ #
8
+ # Authors: Ramakrishna Vedantam <[email protected]> and Tsung-Yi Lin <[email protected]>
9
+ import os
10
+ dir_path = os.path.dirname(os.path.abspath(__file__))
11
+ import sys
12
+ sys.path.append(dir_path)
13
+ from cider_scorer import CiderScorer
14
+ import pdb
15
+
16
+ class Cider:
17
+ """
18
+ Main Class to compute the CIDEr metric
19
+
20
+ """
21
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
22
+ # set cider to sum over 1 to 4-grams
23
+ self._n = n
24
+ # set the standard deviation parameter for gaussian penalty
25
+ self._sigma = sigma
26
+
27
+ def compute_score(self, gts, res):
28
+ """
29
+ Main function to compute CIDEr score
30
+ :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
31
+ ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
32
+ :return: cider (float) : computed CIDEr score for the corpus
33
+ """
34
+
35
+ assert(gts.keys() == res.keys())
36
+ imgIds = gts.keys()
37
+
38
+ cider_scorer = CiderScorer(n=self._n, sigma=self._sigma)
39
+
40
+ for id in imgIds:
41
+ hypo = res[id]
42
+ ref = gts[id]
43
+
44
+ # Sanity check.
45
+ assert(type(hypo) is list)
46
+ assert(len(hypo) == 1)
47
+ assert(type(ref) is list)
48
+ assert(len(ref) > 0)
49
+
50
+ cider_scorer += (hypo[0], ref)
51
+
52
+ (score, scores) = cider_scorer.compute_score()
53
+
54
+ return score, scores
55
+
56
+ def method(self):
57
+ return "CIDEr"
evalcap/cider/cider_scorer.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # Tsung-Yi Lin <[email protected]>
3
+ # Ramakrishna Vedantam <[email protected]>
4
+
5
+ import copy
6
+ from collections import defaultdict
7
+ import numpy as np
8
+ import pdb
9
+ import math
10
+
11
+ def precook(s, n=4, out=False):
12
+ """
13
+ Takes a string as input and returns an object that can be given to
14
+ either cook_refs or cook_test. This is optional: cook_refs and cook_test
15
+ can take string arguments as well.
16
+ :param s: string : sentence to be converted into ngrams
17
+ :param n: int : number of ngrams for which representation is calculated
18
+ :return: term frequency vector for occuring ngrams
19
+ """
20
+ words = s.split()
21
+ counts = defaultdict(int)
22
+ for k in range(1,n+1):
23
+ for i in range(len(words)-k+1):
24
+ ngram = tuple(words[i:i+k])
25
+ counts[ngram] += 1
26
+ return counts
27
+
28
+ def cook_refs(refs, n=4): ## lhuang: oracle will call with "average"
29
+ '''Takes a list of reference sentences for a single segment
30
+ and returns an object that encapsulates everything that BLEU
31
+ needs to know about them.
32
+ :param refs: list of string : reference sentences for some image
33
+ :param n: int : number of ngrams for which (ngram) representation is calculated
34
+ :return: result (list of dict)
35
+ '''
36
+ return [precook(ref, n) for ref in refs]
37
+
38
+ def cook_test(test, n=4):
39
+ '''Takes a test sentence and returns an object that
40
+ encapsulates everything that BLEU needs to know about it.
41
+ :param test: list of string : hypothesis sentence for some image
42
+ :param n: int : number of ngrams for which (ngram) representation is calculated
43
+ :return: result (dict)
44
+ '''
45
+ return precook(test, n, True)
46
+
47
+ class CiderScorer(object):
48
+ """CIDEr scorer.
49
+ """
50
+
51
+ def copy(self):
52
+ ''' copy the refs.'''
53
+ new = CiderScorer(n=self.n)
54
+ new.ctest = copy.copy(self.ctest)
55
+ new.crefs = copy.copy(self.crefs)
56
+ return new
57
+
58
+ def __init__(self, test=None, refs=None, n=4, sigma=6.0):
59
+ ''' singular instance '''
60
+ self.n = n
61
+ self.sigma = sigma
62
+ self.crefs = []
63
+ self.ctest = []
64
+ self.document_frequency = defaultdict(float)
65
+ self.cook_append(test, refs)
66
+ self.ref_len = None
67
+
68
+ def cook_append(self, test, refs):
69
+ '''called by constructor and __iadd__ to avoid creating new instances.'''
70
+
71
+ if refs is not None:
72
+ self.crefs.append(cook_refs(refs))
73
+ if test is not None:
74
+ self.ctest.append(cook_test(test)) ## N.B.: -1
75
+ else:
76
+ self.ctest.append(None) # lens of crefs and ctest have to match
77
+
78
+ def size(self):
79
+ assert len(self.crefs) == len(self.ctest), "refs/test mismatch! %d<>%d" % (len(self.crefs), len(self.ctest))
80
+ return len(self.crefs)
81
+
82
+ def __iadd__(self, other):
83
+ '''add an instance (e.g., from another sentence).'''
84
+
85
+ if type(other) is tuple:
86
+ ## avoid creating new CiderScorer instances
87
+ self.cook_append(other[0], other[1])
88
+ else:
89
+ self.ctest.extend(other.ctest)
90
+ self.crefs.extend(other.crefs)
91
+
92
+ return self
93
+ def compute_doc_freq(self):
94
+ '''
95
+ Compute term frequency for reference data.
96
+ This will be used to compute idf (inverse document frequency later)
97
+ The term frequency is stored in the object
98
+ :return: None
99
+ '''
100
+ for refs in self.crefs:
101
+ # refs, k ref captions of one image
102
+ for ngram in set([ngram for ref in refs for (ngram,count) in ref.items()]):
103
+ self.document_frequency[ngram] += 1
104
+ # maxcounts[ngram] = max(maxcounts.get(ngram,0), count)
105
+
106
+ def compute_cider(self):
107
+ def counts2vec(cnts):
108
+ """
109
+ Function maps counts of ngram to vector of tfidf weights.
110
+ The function returns vec, an array of dictionary that store mapping of n-gram and tf-idf weights.
111
+ The n-th entry of array denotes length of n-grams.
112
+ :param cnts:
113
+ :return: vec (array of dict), norm (array of float), length (int)
114
+ """
115
+ vec = [defaultdict(float) for _ in range(self.n)]
116
+ length = 0
117
+ norm = [0.0 for _ in range(self.n)]
118
+ for (ngram, term_freq) in cnts.items():
119
+ # give word count 1 if it doesn't appear in reference corpus
120
+ df = np.log(max(1.0, self.document_frequency[ngram]))
121
+ # ngram index
122
+ n = len(ngram)-1
123
+ # tf (term_freq) * idf (precomputed idf) for n-grams
124
+ vec[n][ngram] = float(term_freq)*(self.ref_len - df)
125
+ # compute norm for the vector. the norm will be used for computing similarity
126
+ norm[n] += pow(vec[n][ngram], 2)
127
+
128
+ if n == 1:
129
+ length += term_freq
130
+ norm = [np.sqrt(n) for n in norm]
131
+ return vec, norm, length
132
+
133
+ def sim(vec_hyp, vec_ref, norm_hyp, norm_ref, length_hyp, length_ref):
134
+ '''
135
+ Compute the cosine similarity of two vectors.
136
+ :param vec_hyp: array of dictionary for vector corresponding to hypothesis
137
+ :param vec_ref: array of dictionary for vector corresponding to reference
138
+ :param norm_hyp: array of float for vector corresponding to hypothesis
139
+ :param norm_ref: array of float for vector corresponding to reference
140
+ :param length_hyp: int containing length of hypothesis
141
+ :param length_ref: int containing length of reference
142
+ :return: array of score for each n-grams cosine similarity
143
+ '''
144
+ delta = float(length_hyp - length_ref)
145
+ # measure consine similarity
146
+ val = np.array([0.0 for _ in range(self.n)])
147
+ for n in range(self.n):
148
+ # ngram
149
+ for (ngram,count) in vec_hyp[n].items():
150
+ # vrama91 : added clipping
151
+ val[n] += min(vec_hyp[n][ngram], vec_ref[n][ngram]) * vec_ref[n][ngram]
152
+
153
+ if (norm_hyp[n] != 0) and (norm_ref[n] != 0):
154
+ val[n] /= (norm_hyp[n]*norm_ref[n])
155
+
156
+ assert(not math.isnan(val[n]))
157
+ # vrama91: added a length based gaussian penalty
158
+ val[n] *= np.e**(-(delta**2)/(2*self.sigma**2))
159
+ return val
160
+
161
+ # compute log reference length
162
+ self.ref_len = np.log(float(len(self.crefs)))
163
+ if len(self.crefs) == 1:
164
+ self.ref_len = 1
165
+ scores = []
166
+ for test, refs in zip(self.ctest, self.crefs):
167
+ # compute vector for test captions
168
+ vec, norm, length = counts2vec(test)
169
+ # compute vector for ref captions
170
+ score = np.array([0.0 for _ in range(self.n)])
171
+ for ref in refs:
172
+ vec_ref, norm_ref, length_ref = counts2vec(ref)
173
+ score += sim(vec, vec_ref, norm, norm_ref, length, length_ref)
174
+ # change by vrama91 - mean of ngram scores, instead of sum
175
+ score_avg = np.mean(score)
176
+ # divide by number of references
177
+ score_avg /= len(refs)
178
+ # multiply score by 10
179
+ score_avg *= 10.0
180
+ # append score of an image to the score list
181
+ scores.append(score_avg)
182
+ return scores
183
+
184
+ def compute_score(self, option=None, verbose=0):
185
+ # compute idf
186
+ self.compute_doc_freq()
187
+ # assert to check document frequency
188
+ assert(len(self.ctest) >= max(self.document_frequency.values()))
189
+ # compute cider score
190
+ score = self.compute_cider()
191
+ # debug
192
+ # print score
193
+ return np.mean(np.array(score)), np.array(score)
evalcap/meteor/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'tylin'
evalcap/meteor/meteor-1.5.jar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1e57b4c72c0830ebe68558f1c799a624e96cbc1b6045c9f6330e26dcff6eafc2
3
+ size 6318693
evalcap/meteor/meteor.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # Python wrapper for METEOR implementation, by Xinlei Chen
4
+ # Acknowledge Michael Denkowski for the generous discussion and help
5
+ from __future__ import division
6
+
7
+ import atexit
8
+ import logging
9
+ import os
10
+ import re
11
+ import subprocess
12
+ import sys
13
+ import threading
14
+
15
+ import psutil
16
+
17
+ # Assumes meteor-1.5.jar is in the same directory as meteor.py. Change as needed.
18
+ METEOR_JAR = 'meteor-1.5.jar'
19
+
20
+
21
+ def enc(s):
22
+ return s.encode('utf-8')
23
+
24
+
25
+ def dec(s):
26
+ return s.decode('utf-8')
27
+
28
+
29
+ class Meteor:
30
+
31
+ def __init__(self):
32
+ # Used to guarantee thread safety
33
+ self.lock = threading.Lock()
34
+
35
+ mem = '1G'
36
+ mem_available_G = psutil.virtual_memory().available / 1E9
37
+ if mem_available_G < 2:
38
+ logging.warning("There is less than 2GB of available memory.\n"
39
+ "Will try with limiting Meteor to 1GB of memory but this might cause issues.\n"
40
+ "If you have problems using Meteor, "
41
+ "then you can try to lower the `mem` variable in meteor.py")
42
+ mem = '1G'
43
+
44
+ meteor_cmd = ['java', '-jar', '-Xmx{}'.format(mem), METEOR_JAR,
45
+ '-', '-', '-stdio', '-l', 'en', '-norm']
46
+ env = os.environ.copy()
47
+ env['LC_ALL'] = "C"
48
+ self.meteor_p = subprocess.Popen(meteor_cmd,
49
+ cwd=os.path.dirname(os.path.abspath(__file__)),
50
+ env=env,
51
+ stdin=subprocess.PIPE,
52
+ stdout=subprocess.PIPE,
53
+ stderr=subprocess.PIPE)
54
+
55
+ atexit.register(self.close)
56
+
57
+ def close(self):
58
+ with self.lock:
59
+ if self.meteor_p:
60
+ self.meteor_p.kill()
61
+ self.meteor_p.wait()
62
+ self.meteor_p = None
63
+ # if the user calls close() manually, remove the
64
+ # reference from atexit so the object can be garbage-collected.
65
+ if atexit is not None and atexit.unregister is not None:
66
+ atexit.unregister(self.close)
67
+
68
+ def compute_score(self, gts, res):
69
+ assert (gts.keys() == res.keys())
70
+ imgIds = gts.keys()
71
+ scores = []
72
+
73
+ eval_line = 'EVAL'
74
+ with self.lock:
75
+ for i in imgIds:
76
+ assert (len(res[i]) == 1)
77
+ stat = self._stat(res[i][0], gts[i])
78
+ eval_line += ' ||| {}'.format(stat)
79
+
80
+ self.meteor_p.stdin.write(enc('{}\n'.format(eval_line)))
81
+ self.meteor_p.stdin.flush()
82
+ for i in range(0, len(imgIds)):
83
+ v = self.meteor_p.stdout.readline()
84
+ try:
85
+ scores.append(float(dec(v.strip())))
86
+ except:
87
+ sys.stderr.write("Error handling value: {}\n".format(v))
88
+ sys.stderr.write("Decoded value: {}\n".format(dec(v.strip())))
89
+ sys.stderr.write("eval_line: {}\n".format(eval_line))
90
+ # You can try uncommenting the next code line to show stderr from the Meteor JAR.
91
+ # If the Meteor JAR is not writing to stderr, then the line will just hang.
92
+ # sys.stderr.write("Error from Meteor:\n{}".format(self.meteor_p.stderr.read()))
93
+ raise
94
+ score = float(dec(self.meteor_p.stdout.readline()).strip())
95
+ self.close()
96
+ return score, scores
97
+
98
+ def method(self):
99
+ return "METEOR"
100
+
101
+ def _stat(self, hypothesis_str, reference_list):
102
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
103
+ hypothesis_str = hypothesis_str.replace('|||', '')
104
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
105
+ score_line = re.sub(r'\s+', ' ', score_line)
106
+ self.meteor_p.stdin.write(enc(score_line))
107
+ self.meteor_p.stdin.write(enc('\n'))
108
+ self.meteor_p.stdin.flush()
109
+ return dec(self.meteor_p.stdout.readline()).strip()
110
+
111
+ def _score(self, hypothesis_str, reference_list):
112
+ with self.lock:
113
+ # SCORE ||| reference 1 words ||| reference n words ||| hypothesis words
114
+ hypothesis_str = hypothesis_str.replace('|||', '').replace(' ', ' ')
115
+ score_line = ' ||| '.join(('SCORE', ' ||| '.join(reference_list), hypothesis_str))
116
+ self.meteor_p.stdin.write(enc('{}\n'.format(score_line)))
117
+ self.meteor_p.stdin.flush()
118
+ stats = dec(self.meteor_p.stdout.readline()).strip()
119
+ eval_line = 'EVAL ||| {}'.format(stats)
120
+ # EVAL ||| stats
121
+ self.meteor_p.stdin.write(enc('{}\n'.format(eval_line)))
122
+ self.meteor_p.stdin.flush()
123
+ score = float(dec(self.meteor_p.stdout.readline()).strip())
124
+ # bug fix: there are two values returned by the jar file, one average, and one all, so do it twice
125
+ # thanks for Andrej for pointing this out
126
+ score = float(dec(self.meteor_p.stdout.readline()).strip())
127
+ return score
128
+
129
+ def __del__(self):
130
+ self.close()
evalcap/meteor/test_meteor.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import meteor
2
+
3
+ hypo = ['this is the model generated sentence1 which seems good enough']
4
+ ref = ['this is one reference sentence for sentence1',
5
+ 'this is a reference sentence for sentence2 which was generated by your model']
6
+
7
+ m = meteor.Meteor()
8
+
9
+ score = m._score(hypo[0], ref)
10
+ print(score)
evalcap/rouge/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'vrama91'
evalcap/rouge/rouge.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : rouge.py
4
+ #
5
+ # Description : Computes ROUGE-L metric as described by Lin and Hovey (2004)
6
+ #
7
+ # Creation Date : 2015-01-07 06:03
8
+ # Author : Ramakrishna Vedantam <[email protected]>
9
+
10
+ import numpy as np
11
+ import pdb
12
+
13
+ def my_lcs(string, sub):
14
+ """
15
+ Calculates longest common subsequence for a pair of tokenized strings
16
+ :param string : list of str : tokens from a string split using whitespace
17
+ :param sub : list of str : shorter string, also split using whitespace
18
+ :returns: length (list of int): length of the longest common subsequence between the two strings
19
+
20
+ Note: my_lcs only gives length of the longest common subsequence, not the actual LCS
21
+ """
22
+ if(len(string)< len(sub)):
23
+ sub, string = string, sub
24
+
25
+ lengths = [[0 for i in range(0,len(sub)+1)] for j in range(0,len(string)+1)]
26
+
27
+ for j in range(1,len(sub)+1):
28
+ for i in range(1,len(string)+1):
29
+ if(string[i-1] == sub[j-1]):
30
+ lengths[i][j] = lengths[i-1][j-1] + 1
31
+ else:
32
+ lengths[i][j] = max(lengths[i-1][j] , lengths[i][j-1])
33
+
34
+ return lengths[len(string)][len(sub)]
35
+
36
+ class Rouge():
37
+ '''
38
+ Class for computing ROUGE-L score for a set of candidate sentences for the MS COCO test set
39
+
40
+ '''
41
+ def __init__(self):
42
+ # vrama91: updated the value below based on discussion with Hovey
43
+ self.beta = 1.2
44
+
45
+ def calc_score(self, candidate, refs):
46
+ """
47
+ Compute ROUGE-L score given one candidate and references for an image
48
+ :param candidate: str : candidate sentence to be evaluated
49
+ :param refs: list of str : COCO reference sentences for the particular image to be evaluated
50
+ :returns score: int (ROUGE-L score for the candidate evaluated against references)
51
+ """
52
+ # assert(len(candidate)==1)
53
+ # assert(len(refs)>0)
54
+ prec = []
55
+ rec = []
56
+
57
+ # split into tokens
58
+ token_c = candidate[0].split(" ")
59
+
60
+ for reference in refs:
61
+ # split into tokens
62
+ token_r = reference.split(" ")
63
+ # compute the longest common subsequence
64
+ lcs = my_lcs(token_r, token_c)
65
+ prec.append(lcs/float(len(token_c)))
66
+ rec.append(lcs/float(len(token_r)))
67
+
68
+ prec_max = max(prec)
69
+ rec_max = max(rec)
70
+
71
+ if(prec_max!=0 and rec_max !=0):
72
+ score = ((1 + self.beta**2)*prec_max*rec_max)/float(rec_max + self.beta**2*prec_max)
73
+ else:
74
+ score = 0.0
75
+ return score
76
+
77
+ def compute_score(self, gts, res):
78
+ """
79
+ Computes Rouge-L score given a set of reference and candidate sentences for the dataset
80
+ Invoked by evaluate_captions.py
81
+ :param hypo_for_image: dict : candidate / test sentences with "image name" key and "tokenized sentences" as values
82
+ :param ref_for_image: dict : reference MS-COCO sentences with "image name" key and "tokenized sentences" as values
83
+ :returns: average_score: float (mean ROUGE-L score computed by averaging scores for all the images)
84
+ """
85
+ assert(gts.keys() == res.keys())
86
+ imgIds = gts.keys()
87
+
88
+ score = []
89
+ for id in imgIds:
90
+ hypo = res[id]
91
+ ref = gts[id]
92
+
93
+ score.append(self.calc_score(hypo, ref))
94
+
95
+ # Sanity check.
96
+ assert(type(hypo) is list)
97
+ assert(len(hypo) == 1)
98
+ assert(type(ref) is list)
99
+ assert(len(ref) > 0)
100
+
101
+ average_score = np.mean(np.array(score))
102
+ return average_score, np.array(score)
103
+
104
+ def method(self):
105
+ return "Rouge"
evalcap/tokenizer/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __author__ = 'hfang'
evalcap/tokenizer/ptbtokenizer.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ #
3
+ # File Name : ptbtokenizer.py
4
+ #
5
+ # Description : Do the PTB Tokenization and remove punctuations.
6
+ #
7
+ # Creation Date : 29-12-2014
8
+ # Last Modified : Thu Mar 19 09:53:35 2015
9
+ # Authors : Hao Fang <[email protected]> and Tsung-Yi Lin <[email protected]>
10
+
11
+ import os
12
+ import sys
13
+ import subprocess
14
+ import tempfile
15
+ import itertools
16
+
17
+ # path to the stanford corenlp jar
18
+ STANFORD_CORENLP_3_4_1_JAR = 'stanford-corenlp-3.4.1.jar'
19
+
20
+ # punctuations to be removed from the sentences
21
+ PUNCTUATIONS = ["''", "'", "``", "`", "-LRB-", "-RRB-", "-LCB-", "-RCB-", \
22
+ ".", "?", "!", ",", ":", "-", "--", "...", ";"]
23
+
24
+ class PTBTokenizer:
25
+ """Python wrapper of Stanford PTBTokenizer"""
26
+
27
+ def tokenize(self, captions_for_image):
28
+ cmd = ['java', '-cp', STANFORD_CORENLP_3_4_1_JAR, \
29
+ 'edu.stanford.nlp.process.PTBTokenizer', \
30
+ '-preserveLines', '-lowerCase']
31
+
32
+ # ======================================================
33
+ # prepare data for PTB Tokenizer
34
+ # ======================================================
35
+ final_tokenized_captions_for_image = {}
36
+ image_id = [k for k, v in captions_for_image.items() for _ in range(len(v))]
37
+ sentences = '\n'.join([c['caption'].replace('\n', ' ') for k, v in captions_for_image.items() for c in v])
38
+
39
+ # ======================================================
40
+ # save sentences to temporary file
41
+ # ======================================================
42
+ path_to_jar_dirname=os.path.dirname(os.path.abspath(__file__))
43
+ tmp_file = tempfile.NamedTemporaryFile(delete=False, dir=path_to_jar_dirname, mode='w', encoding='utf-8')
44
+ tmp_file.write(sentences)
45
+ tmp_file.close()
46
+
47
+ # ======================================================
48
+ # tokenize sentence
49
+ # ======================================================
50
+ cmd.append(os.path.basename(tmp_file.name))
51
+ p_tokenizer = subprocess.Popen(cmd, cwd=path_to_jar_dirname, \
52
+ stdout=subprocess.PIPE)
53
+ token_lines = p_tokenizer.communicate(input=sentences.rstrip())[0]
54
+ lines = token_lines.decode().split('\n')
55
+ # remove temp file
56
+ os.remove(tmp_file.name)
57
+
58
+ # ======================================================
59
+ # create dictionary for tokenized captions
60
+ # ======================================================
61
+ for k, line in zip(image_id, lines):
62
+ if not k in final_tokenized_captions_for_image:
63
+ final_tokenized_captions_for_image[k] = []
64
+ tokenized_caption = ' '.join([w for w in line.rstrip().split(' ') \
65
+ if w not in PUNCTUATIONS])
66
+ final_tokenized_captions_for_image[k].append(tokenized_caption)
67
+
68
+ return final_tokenized_captions_for_image
evalcap/tokenizer/stanford-corenlp-3.4.1.jar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2fcb91bb7a111f93d71e264f4ee0e3afd19ba0dde6d21b38605088df9e940399
3
+ size 5921410
images/align.png ADDED

Git LFS Details

  • SHA256: 01abc0814362789759e8bdb363f06891343e0d98f2dd2d18b311ec2de5a51ba2
  • Pointer size: 132 Bytes
  • Size of remote file: 1.51 MB
lightning_tools/callbacks.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from lightning.pytorch.loggers import CSVLogger
3
+ from lightning.pytorch import loggers as pl_loggers
4
+ from lightning.pytorch.callbacks import LearningRateMonitor
5
+ from lightning.pytorch.callbacks import ModelCheckpoint
6
+
7
+
8
+ def add_callbacks(args):
9
+ log_dir = args.savedmodel_path
10
+ os.makedirs(log_dir, exist_ok=True)
11
+
12
+ # --------- Add Callbacks
13
+ checkpoint_callback = ModelCheckpoint(
14
+ dirpath=os.path.join(log_dir, "checkpoints"),
15
+ filename="{epoch}-{step}",
16
+ save_top_k=-1,
17
+ every_n_train_steps=args.every_n_train_steps,
18
+ save_last=False,
19
+ save_weights_only=False
20
+ )
21
+
22
+ lr_monitor_callback = LearningRateMonitor(logging_interval='step')
23
+ tb_logger = pl_loggers.TensorBoardLogger(save_dir=os.path.join(log_dir, "logs"), name="tensorboard")
24
+ csv_logger = CSVLogger(save_dir=os.path.join(log_dir, "logs"), name="csvlog")
25
+
26
+ to_returns = {
27
+ "callbacks": [checkpoint_callback, lr_monitor_callback],
28
+ "loggers": [csv_logger, tb_logger]
29
+ }
30
+ return to_returns
lightning_tools/optim.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AdamW
2
+ import functools
3
+ from torch.optim.lr_scheduler import LambdaLR
4
+
5
+
6
+ def lr_lambda(current_step, num_warmup_steps, num_training_steps):
7
+ if current_step < num_warmup_steps:
8
+ return float(current_step) / float(max(1, num_warmup_steps))
9
+ return max(
10
+ 0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
11
+ )
12
+
13
+
14
+ def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
15
+ """
16
+ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
17
+ a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
18
+
19
+ Args:
20
+ optimizer (:class:`~torch.optim.Optimizer`):
21
+ The optimizer for which to schedule the learning rate.
22
+ num_warmup_steps (:obj:`int`):
23
+ The number of steps for the warmup phase.
24
+ num_training_steps (:obj:`int`):
25
+ The total number of training steps.
26
+ last_epoch (:obj:`int`, `optional`, defaults to -1):
27
+ The index of the last epoch when resuming training.
28
+
29
+ Return:
30
+ :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
31
+ """
32
+
33
+ return LambdaLR(optimizer, functools.partial(lr_lambda, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps), last_epoch)
34
+
35
+
36
+ def config_optimizer(parameters, init_lr, warmup_steps, max_steps, name='lr'):
37
+ """
38
+ Original Bert Optimizer do not decay for bias and layer_normal
39
+ Args:
40
+ parameters:
41
+ init_lr:
42
+ warmup_steps:
43
+ max_steps:
44
+ name:
45
+ weight_decay:
46
+
47
+ Returns:
48
+
49
+ """
50
+ optimizer = AdamW(
51
+ parameters, lr=init_lr, eps=1e-8, correct_bias=False
52
+ )
53
+
54
+ scheduler = get_linear_schedule_with_warmup(
55
+ optimizer, num_warmup_steps=warmup_steps, num_training_steps=max_steps,
56
+ )
57
+ scheduler = {'scheduler': scheduler, 'name': name, 'interval': 'step', 'frequency': 1}
58
+
59
+ return optimizer, scheduler
models/R2GenGPT.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import torch.nn as nn
5
+ import lightning.pytorch as pl
6
+ from transformers import LlamaForCausalLM, LlamaTokenizer
7
+ from evalcap.bleu.bleu import Bleu
8
+ from evalcap.rouge.rouge import Rouge
9
+ from evalcap.cider.cider import Cider
10
+ from evalcap.meteor.meteor import Meteor
11
+ from transformers import SwinModel
12
+ from lightning_tools.optim import config_optimizer
13
+ from peft import get_peft_model, LoraConfig, TaskType
14
+ import pdb
15
+
16
+
17
+
18
+ class R2GenGPT(pl.LightningModule):
19
+ """
20
+ R2GenGPT model.
21
+ """
22
+ def __init__(self, args):
23
+ super().__init__()
24
+ self.args = args
25
+ self.save_hyperparameters(args)
26
+
27
+ print(f'Loading vision encoder:{args.vision_model}')
28
+ self.visual_encoder = SwinModel.from_pretrained(args.vision_model)
29
+ if args.vis_use_lora:
30
+ peft_config_visual = LoraConfig(
31
+ r=args.vis_r,
32
+ lora_alpha=args.vis_alpha,
33
+ target_modules=["query", "value"],
34
+ lora_dropout=args.lora_dropout,
35
+ bias="none",
36
+ modules_to_save=["classifier"],
37
+ )
38
+ self.visual_encoder = get_peft_model(self.visual_encoder, peft_config_visual)
39
+ self.visual_encoder.print_trainable_parameters()
40
+ print('Loading vision encoder with LoRA -- Done')
41
+ elif args.freeze_vm:
42
+ for name, param in self.visual_encoder.named_parameters():
43
+ param.requires_grad = False
44
+ print(f'Loading Frozen vision encoder:{args.vision_model} -- Done')
45
+ else:
46
+ print(f'Loading Trainable vision encoder:{args.vision_model} -- Done')
47
+
48
+ print('Loading LLAMA')
49
+ self.llama_tokenizer = LlamaTokenizer.from_pretrained(args.llama_model, use_fast=False)
50
+ self.llama_tokenizer.pad_token_id = 0
51
+ if args.low_resource:
52
+ self.llama_model = LlamaForCausalLM.from_pretrained(
53
+ args.llama_model,
54
+ torch_dtype=torch.float16,
55
+ load_in_8bit=True,
56
+ device_map="auto"
57
+ )
58
+ else:
59
+ self.llama_model = LlamaForCausalLM.from_pretrained(
60
+ args.llama_model,
61
+ torch_dtype=torch.float16,
62
+ )
63
+
64
+ if args.llm_use_lora:
65
+ self.embed_tokens = self.llama_model.get_input_embeddings()
66
+ peft_config = LoraConfig(
67
+ task_type=TaskType.CAUSAL_LM, inference_mode=False, r=args.llm_r, lora_alpha=args.llm_alpha, lora_dropout=args.lora_dropout
68
+ )
69
+ self.llama_model = get_peft_model(self.llama_model, peft_config)
70
+ self.llama_model.print_trainable_parameters()
71
+ print('Loading LLAMA LoRA Done')
72
+ else:
73
+ self.embed_tokens = self.llama_model.get_input_embeddings()
74
+ for name, param in self.llama_model.named_parameters():
75
+ param.requires_grad = False
76
+ print('Loading LLAMA Done')
77
+
78
+ self.llama_proj = nn.Linear(self.visual_encoder.num_features, self.llama_model.config.hidden_size)
79
+ self.layer_norm = nn.LayerNorm(self.llama_model.config.hidden_size)
80
+ self.end_sym = args.end_sym
81
+ self.prompt = 'Generate a comprehensive and detailed diagnosis report for this chest xray image.'
82
+ self.val_step_outputs = []
83
+ self.test_step_outputs = []
84
+ self.val_score = 0.0
85
+
86
+ if args.delta_file is not None:
87
+ state_dict = torch.load(args.delta_file, map_location=torch.device(f'cuda:{torch.cuda.current_device()}'))['model']
88
+ self.load_state_dict(state_dict=state_dict, strict=False)
89
+ print(f'Load checkpoint from {args.delta_file}')
90
+
91
+
92
+ def score(self, ref, hypo):
93
+ """
94
+ ref, dictionary of reference sentences (id, sentence)
95
+ hypo, dictionary of hypothesis sentences (id, sentence)
96
+ score, dictionary of scores
97
+ """
98
+ scorers = [
99
+ (Bleu(4), ["Bleu_1", "Bleu_2", "Bleu_3", "Bleu_4"]),
100
+ (Rouge(), "ROUGE_L"),
101
+ (Meteor(), "METEOR"),
102
+ (Cider(), "CIDEr")
103
+ ]
104
+ final_scores = {}
105
+ for scorer, method in scorers:
106
+ score, scores = scorer.compute_score(ref, hypo)
107
+ if type(score) == list:
108
+ for m, s in zip(method, score):
109
+ final_scores[m] = s
110
+ else:
111
+ final_scores[method] = score
112
+ return final_scores
113
+
114
+
115
+ def encode_img(self, images):
116
+ image_embeds = []
117
+ for image in images:
118
+ device = image.device
119
+ if self.hparams.global_only:
120
+ image_embed = self.visual_encoder(image)['pooler_output'].unsqueeze(1).to(device)
121
+ else:
122
+ image_embed = self.visual_encoder(image)['last_hidden_state'].to(device)
123
+ image_embeds.append(image_embed)
124
+
125
+ image_embeds = torch.stack(image_embeds).mean(0)
126
+ inputs_llama = self.llama_proj(image_embeds)
127
+ atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device)
128
+ return inputs_llama, atts_llama
129
+
130
+
131
+ def prompt_wrap(self, img_embeds, atts_img):
132
+ prompt=f'Human: <Img><ImageHere></Img> {self.prompt} \nAssistant:'
133
+ batch_size = img_embeds.shape[0]
134
+ p_before, p_after = prompt.split('<ImageHere>')
135
+ p_before_tokens = self.llama_tokenizer(
136
+ p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
137
+ p_after_tokens = self.llama_tokenizer(
138
+ p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
139
+ p_before_embeds = self.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1)
140
+ p_after_embeds = self.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1)
141
+ wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1)
142
+ wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1])
143
+ return wrapped_img_embeds, wrapped_atts_img
144
+
145
+
146
+ def forward(self, samples):
147
+ image = samples["image"]
148
+ img_embeds, atts_img = self.encode_img(image)
149
+ img_embeds = self.layer_norm(img_embeds)
150
+
151
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
152
+
153
+ self.llama_tokenizer.padding_side = "right"
154
+ text = [t + self.end_sym for t in samples["input_text"]]
155
+
156
+ to_regress_tokens = self.llama_tokenizer(
157
+ text,
158
+ return_tensors="pt",
159
+ padding="max_length",
160
+ truncation=True,
161
+ max_length=self.hparams.max_length,
162
+ add_special_tokens=False
163
+ ).to(image[0].device)
164
+
165
+ targets = to_regress_tokens.input_ids.masked_fill(
166
+ to_regress_tokens.input_ids == 0, -100
167
+ )
168
+
169
+ empty_targets = (
170
+ torch.ones([atts_img.shape[0], atts_img.shape[1]+1],
171
+ dtype=torch.long).to(image[0].device).fill_(-100) # plus one for bos
172
+ )
173
+ targets = torch.cat([empty_targets, targets], dim=1)
174
+
175
+ batch_size = img_embeds.shape[0]
176
+ bos = torch.ones([batch_size, 1],
177
+ dtype=to_regress_tokens.input_ids.dtype,
178
+ device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id
179
+ bos_embeds = self.embed_tokens(bos)
180
+ atts_bos = atts_img[:, :1]
181
+
182
+ to_regress_embeds = self.embed_tokens(to_regress_tokens.input_ids)
183
+ inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1)
184
+ attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1)
185
+
186
+ outputs = self.llama_model(
187
+ inputs_embeds=inputs_embeds,
188
+ attention_mask=attention_mask,
189
+ return_dict=True,
190
+ labels=targets,
191
+ )
192
+ loss = outputs.loss
193
+ return {"loss": loss}
194
+
195
+ def training_step(self, batch, batch_idx):
196
+ result = self(batch)
197
+ self.log_dict(result, prog_bar=True)
198
+ return result
199
+
200
+ def save_checkpoint(self, eval_res):
201
+ current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
202
+ param_grad_dic = {
203
+ k: v.requires_grad for (k, v) in self.named_parameters() if v.requires_grad
204
+ }
205
+ state_dict = self.state_dict()
206
+ for k in list(state_dict.keys()):
207
+ if k not in param_grad_dic.keys():
208
+ del state_dict[k]
209
+ save_obj = {
210
+ "model": state_dict,
211
+ "config": self.hparams,
212
+ "epoch": current_epoch,
213
+ "step":global_step
214
+ }
215
+ os.makedirs(os.path.join(self.hparams.savedmodel_path, 'checkpoints'), exist_ok=True)
216
+ save_to = os.path.join(
217
+ self.hparams.savedmodel_path, 'checkpoints',
218
+ "checkpoint_epoch{}_step{}_bleu{:3f}_cider{:3f}.pth".format(current_epoch, global_step, eval_res['Bleu_4'], eval_res['CIDEr']),
219
+ )
220
+ self.print("Saving checkpoint at step {} to {}.".format(global_step, save_to))
221
+ torch.save(save_obj, save_to)
222
+
223
+ def validation_step(self, samples, batch_idx):
224
+ self.llama_tokenizer.padding_side = "right"
225
+ to_regress_tokens = self.llama_tokenizer(
226
+ samples['input_text'],
227
+ return_tensors="pt",
228
+ padding="max_length",
229
+ truncation=True,
230
+ max_length=self.hparams.max_length,
231
+ add_special_tokens=False
232
+ )
233
+
234
+ image = samples["image"]
235
+ img_embeds, atts_img = self.encode_img(image)
236
+ img_embeds = self.layer_norm(img_embeds)
237
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
238
+
239
+ batch_size = img_embeds.shape[0]
240
+ bos = torch.ones([batch_size, 1],
241
+ dtype=atts_img.dtype,
242
+ device=atts_img.device) * self.llama_tokenizer.bos_token_id
243
+ bos_embeds = self.embed_tokens(bos)
244
+ atts_bos = atts_img[:, :1]
245
+
246
+ inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
247
+ attention_mask = torch.cat([atts_bos, atts_img], dim=1)
248
+
249
+ outputs = self.llama_model.generate(
250
+ inputs_embeds=inputs_embeds,
251
+ num_beams=self.hparams.beam_size,
252
+ do_sample=self.hparams.do_sample,
253
+ min_new_tokens=self.hparams.min_new_tokens,
254
+ max_new_tokens=self.hparams.max_new_tokens,
255
+ repetition_penalty=self.hparams.repetition_penalty,
256
+ length_penalty=self.hparams.length_penalty,
257
+ temperature=self.hparams.temperature,
258
+ )
259
+ hypo = [self.decode(i) for i in outputs]
260
+ ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
261
+ self.val_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
262
+ return hypo, ref
263
+
264
+ def decode(self, output_token):
265
+ if output_token[0] == 0: # the model might output a unknow token <unk> at the beginning. remove it
266
+ output_token = output_token[1:]
267
+ if output_token[0] == 1: # some users find that there is a start token <s> at the beginning. remove it
268
+ output_token = output_token[1:]
269
+ output_text = self.llama_tokenizer.decode(output_token, add_special_tokens=False)
270
+ output_text = output_text.split('</s>')[0].strip()
271
+ output_text = output_text.replace('<unk>', '')
272
+ return output_text
273
+
274
+ def on_validation_epoch_end(self):
275
+ ref, hypo, ids = [], [], []
276
+ for i in self.val_step_outputs:
277
+ ref.extend(i['ref'])
278
+ hypo.extend(i['hypo'])
279
+ ids.extend(i['id'])
280
+
281
+ ref = {k:[v] for k, v in zip(ids, ref)}
282
+ hypo = {k:[v] for k, v in zip(ids, hypo)}
283
+ eval_res = self.score(ref=ref,hypo=hypo)
284
+ self.log_dict(eval_res, sync_dist=True, logger=True)
285
+
286
+ result_folder = os.path.join(self.hparams.savedmodel_path, 'result')
287
+ os.makedirs(result_folder, exist_ok=True)
288
+ current_epoch, global_step = self.trainer.current_epoch, self.trainer.global_step
289
+ json.dump(hypo, open(os.path.join(result_folder, f"result_{current_epoch}_{global_step}" + '.json'), 'w'))
290
+ json.dump(ref, open(os.path.join(result_folder, 'refs.json'), 'w'))
291
+ self.print(eval_res)
292
+
293
+ val_score = 0
294
+ for score_type, weight in zip(self.hparams.scorer_types, self.hparams.weights):
295
+ val_score += eval_res[score_type] * weight
296
+
297
+ if self.trainer.local_rank == 0:
298
+ if val_score > self.val_score:
299
+ self.save_checkpoint(eval_res)
300
+ self.val_score = val_score
301
+ self.val_step_outputs.clear()
302
+
303
+
304
+ def test_step(self, samples, batch_idx):
305
+ self.llama_tokenizer.padding_side = "right"
306
+ to_regress_tokens = self.llama_tokenizer(
307
+ samples['input_text'],
308
+ return_tensors="pt",
309
+ padding="max_length",
310
+ truncation=True,
311
+ max_length=self.hparams.max_length,
312
+ add_special_tokens=False
313
+ )
314
+
315
+ image = samples["image"]
316
+ img_embeds, atts_img = self.encode_img(image)
317
+ img_embeds = self.layer_norm(img_embeds)
318
+ img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img)
319
+
320
+ batch_size = img_embeds.shape[0]
321
+ bos = torch.ones([batch_size, 1],
322
+ dtype=atts_img.dtype,
323
+ device=atts_img.device) * self.llama_tokenizer.bos_token_id
324
+ bos_embeds = self.embed_tokens(bos)
325
+ atts_bos = atts_img[:, :1]
326
+
327
+ inputs_embeds = torch.cat([bos_embeds, img_embeds], dim=1)
328
+ attention_mask = torch.cat([atts_bos, atts_img], dim=1)
329
+
330
+ outputs = self.llama_model.generate(
331
+ inputs_embeds=inputs_embeds,
332
+ num_beams=self.hparams.beam_size,
333
+ do_sample=self.hparams.do_sample,
334
+ min_new_tokens=self.hparams.min_new_tokens,
335
+ max_new_tokens=self.hparams.max_new_tokens,
336
+ repetition_penalty=self.hparams.repetition_penalty,
337
+ length_penalty=self.hparams.length_penalty,
338
+ temperature=self.hparams.temperature,
339
+ )
340
+ hypo = [self.decode(i) for i in outputs]
341
+ ref = [self.decode(i) for i in to_regress_tokens['input_ids']]
342
+ self.test_step_outputs.append({"hypo": hypo, "ref": ref, "id": samples["id"]})
343
+ return hypo, ref
344
+
345
+
346
+ def on_test_epoch_end(self):
347
+ """
348
+ This function is called at the end of the test epoch.
349
+ It is recommended to test on single device to ensure each sample/batch gets evaluated exactly once. This is helpful to make sure benchmarking for research papers is done the right way. Otherwise, in a multi-device setting, samples could occur duplicated when DistributedSampler is used, for eg. with strategy="ddp". It replicates some samples on some devices to make sure all devices have same batch size in case of uneven inputs.
350
+ """
351
+ ref, hypo, ids = [], [], []
352
+ for i in self.test_step_outputs:
353
+ ref.extend(i['ref'])
354
+ hypo.extend(i['hypo'])
355
+ ids.extend(i['id'])
356
+
357
+ ref = {k:[v] for k, v in zip(ids, ref)}
358
+ hypo = {k:[v] for k, v in zip(ids, hypo)}
359
+ eval_res = self.score(ref=ref,hypo=hypo)
360
+
361
+ result_folder = os.path.join(self.hparams.savedmodel_path, 'result')
362
+ os.makedirs(result_folder, exist_ok=True)
363
+ json.dump(hypo, open(os.path.join(result_folder, f"test_result.json"), 'w'))
364
+ json.dump(ref, open(os.path.join(result_folder, 'test_refs.json'), 'w'))
365
+ self.print(f"Test result of {self.hparams.delta_file}: {eval_res}")
366
+
367
+ def configure_optimizers(self):
368
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.learning_rate)
369
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=self.hparams.max_epochs, eta_min=1e-6)
370
+ return {"optimizer": optimizer, "lr_scheduler": scheduler}
371
+
372
+ def get_progress_bar_dict(self):
373
+ # don't show the version number
374
+ items = super().get_progress_bar_dict()
375
+ items.pop("v_num", None)
376
+ return items
377
+
378
+ def optimizer_zero_grad(self, epoch, batch_idx, optimizer):
379
+ optimizer.zero_grad()
requirements.txt CHANGED
@@ -1,5 +1,8 @@
1
- streamlit
2
  torch
3
- torchvision
4
- requests
5
- Pillow
 
 
 
 
 
 
1
  torch
2
+ peft
3
+ tensorboardX
4
+ transformers==4.30.2
5
+ lightning==2.0.5
6
+ Pillow
7
+ numpy
8
+ gradio
scripts/1-1.shallow_run_iuxray.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="iu_xray"
4
+ annotation="data/iu_xray/annotation.json"
5
+ base_dir="./data/iu_xray/images"
6
+
7
+ version="v1_shallow"
8
+ savepath="./save/$dataset/$version"
9
+
10
+ python -u train.py \
11
+ --dataset ${dataset} \
12
+ --annotation ${annotation} \
13
+ --base_dir ${base_dir} \
14
+ --batch_size 8 \
15
+ --val_batch_size 12 \
16
+ --freeze_vm True \
17
+ --vis_use_lora False \
18
+ --savedmodel_path ${savepath} \
19
+ --max_length 60 \
20
+ --min_new_tokens 40 \
21
+ --max_new_tokens 100 \
22
+ --repetition_penalty 2.0 \
23
+ --length_penalty 2.0 \
24
+ --num_workers 8 \
25
+ --devices 2 \
26
+ --max_epochs 15 \
27
+ --limit_val_batches 1.0 \
28
+ --val_check_interval 1.0 \
29
+ --num_sanity_val_steps 0 \
30
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/1-2.shallow_test_iuxray.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="iu_xray"
4
+ annotation="data/iu_xray/annotation.json"
5
+ base_dir="./data/iu_xray/images"
6
+ delta_file="/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/R2GenGPT/save/iu_xray/v1_shallow/checkpoints/checkpoint_epoch11_step1548_bleu0.155866_cider0.450477.pth"
7
+
8
+ version="v1_shallow"
9
+ savepath="./save/$dataset/$version"
10
+
11
+ python -u train.py \
12
+ --test \
13
+ --dataset ${dataset} \
14
+ --annotation ${annotation} \
15
+ --base_dir ${base_dir} \
16
+ --delta_file ${delta_file} \
17
+ --test_batch_size 16 \
18
+ --freeze_vm True \
19
+ --vis_use_lora False \
20
+ --savedmodel_path ${savepath} \
21
+ --max_length 60 \
22
+ --min_new_tokens 40 \
23
+ --max_new_tokens 100 \
24
+ --repetition_penalty 2.0 \
25
+ --length_penalty 2.0 \
26
+ --num_workers 8 \
27
+ --devices 1 \
28
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/2-1.delta_run_iuxray.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="iu_xray"
4
+ annotation="data/iu_xray/annotation.json"
5
+ base_dir="./data/iu_xray/images"
6
+
7
+ version="v1_delta"
8
+ savepath="./save/$dataset/$version"
9
+
10
+ python -u train.py \
11
+ --dataset ${dataset} \
12
+ --annotation ${annotation} \
13
+ --base_dir ${base_dir} \
14
+ --batch_size 8 \
15
+ --val_batch_size 12 \
16
+ --freeze_vm True \
17
+ --vis_use_lora True \
18
+ --savedmodel_path ${savepath} \
19
+ --max_length 60 \
20
+ --min_new_tokens 40 \
21
+ --max_new_tokens 100 \
22
+ --repetition_penalty 2.0 \
23
+ --length_penalty 2.0 \
24
+ --num_workers 8 \
25
+ --devices 2 \
26
+ --max_epochs 15 \
27
+ --limit_val_batches 1.0 \
28
+ --val_check_interval 1.0 \
29
+ --num_sanity_val_steps 2 \
30
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/2-2.delta_test_iuxray.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="iu_xray"
4
+ annotation="data/iu_xray/annotation.json"
5
+ base_dir="./data/iu_xray/images"
6
+ delta_file="/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/R2GenGPT/save/iu_xray/v1_delta/checkpoints/checkpoint_epoch13_step1806_bleu0.161532_cider0.530213.pth"
7
+
8
+ version="v1_delta"
9
+ savepath="./save/$dataset/$version"
10
+
11
+ python -u train.py \
12
+ --test \
13
+ --dataset ${dataset} \
14
+ --annotation ${annotation} \
15
+ --base_dir ${base_dir} \
16
+ --delta_file ${delta_file} \
17
+ --test_batch_size 16 \
18
+ --freeze_vm True \
19
+ --vis_use_lora True \
20
+ --savedmodel_path ${savepath} \
21
+ --max_length 60 \
22
+ --min_new_tokens 40 \
23
+ --max_new_tokens 100 \
24
+ --repetition_penalty 2.0 \
25
+ --length_penalty 2.0 \
26
+ --num_workers 8 \
27
+ --devices 1 \
28
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/3-1.deep_run_iuxray.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="iu_xray"
4
+ annotation="data/iu_xray/annotation.json"
5
+ base_dir="./data/iu_xray/images"
6
+
7
+ version="v1_deep"
8
+ savepath="./save/$dataset/$version"
9
+
10
+ python -u train.py \
11
+ --dataset ${dataset} \
12
+ --annotation ${annotation} \
13
+ --base_dir ${base_dir} \
14
+ --batch_size 8 \
15
+ --val_batch_size 12 \
16
+ --freeze_vm False \
17
+ --vis_use_lora False \
18
+ --savedmodel_path ${savepath} \
19
+ --max_length 60 \
20
+ --min_new_tokens 40 \
21
+ --max_new_tokens 100 \
22
+ --repetition_penalty 2.0 \
23
+ --length_penalty 2.0 \
24
+ --num_workers 8 \
25
+ --devices 2 \
26
+ --max_epochs 15 \
27
+ --limit_val_batches 1.0 \
28
+ --val_check_interval 1.0 \
29
+ --num_sanity_val_steps 2 \
30
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/3-2.deep_test_iuxray.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="iu_xray"
4
+ annotation="data/iu_xray/annotation.json"
5
+ base_dir="./data/iu_xray/images"
6
+ delta_file="/apdcephfs/share_733425/vinnylywang/zhanyuwang/Code/R2GenGPT/save/iu_xray/v1_deep/checkpoints/checkpoint_epoch12_step1677_bleu0.185560_cider0.678231.pth"
7
+
8
+ version="v1_deep"
9
+ savepath="./save/$dataset/$version"
10
+
11
+ python -u train.py \
12
+ --test \
13
+ --dataset ${dataset} \
14
+ --annotation ${annotation} \
15
+ --base_dir ${base_dir} \
16
+ --delta_file ${delta_file} \
17
+ --test_batch_size 16 \
18
+ --freeze_vm False \
19
+ --vis_use_lora False \
20
+ --savedmodel_path ${savepath} \
21
+ --max_length 60 \
22
+ --min_new_tokens 40 \
23
+ --max_new_tokens 100 \
24
+ --repetition_penalty 2.0 \
25
+ --length_penalty 2.0 \
26
+ --num_workers 8 \
27
+ --devices 1 \
28
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/4-1.shallow_run.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="mimic_cxr"
4
+ annotation="./data/mimic_cxr/annnotation.json"
5
+ base_dir="./data/mimic_cxr/images"
6
+
7
+ version="v1_shallow"
8
+ savepath="./save/$dataset/$version"
9
+
10
+ if [ ! -d "$savepath" ]; then
11
+ mkdir -p "$savepath"
12
+ echo "Folder '$savepath' created."
13
+ else
14
+ echo "Folder '$savepath' already exists."
15
+ fi
16
+
17
+ python -u train.py \
18
+ --dataset ${dataset} \
19
+ --annotation ${annotation} \
20
+ --base_dir ${base_dir} \
21
+ --batch_size 8 \
22
+ --val_batch_size 12 \
23
+ --freeze_vm True \
24
+ --vis_use_lora False \
25
+ --savedmodel_path ${savepath} \
26
+ --learning_rate 1e-4 \
27
+ --gradient_clip_val 1 \
28
+ --max_length 100 \
29
+ --min_new_tokens 80 \
30
+ --max_new_tokens 120 \
31
+ --repetition_penalty 2.0 \
32
+ --length_penalty 2.0 \
33
+ --num_workers 8 \
34
+ --devices 4 \
35
+ --max_epochs 5 \
36
+ --limit_val_batches 0.5 \
37
+ --val_check_interval 0.5 \
38
+ --num_sanity_val_steps 2 \
39
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/4-2.shallow_test.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="mimic_cxr"
4
+ annotation="data/mimic_cxr/my_mimic_anno.json"
5
+ base_dir="./data/mimic_cxr/images"
6
+ delta_file="path/to/pretrained/delta_file"
7
+
8
+ version="v1_shallow"
9
+ savepath="./save/$dataset/$version"
10
+
11
+ python -u train.py \
12
+ --test \
13
+ --dataset ${dataset} \
14
+ --annotation ${annotation} \
15
+ --base_dir ${base_dir} \
16
+ --delta_file ${delta_file} \
17
+ --test_batch_size 16 \
18
+ --freeze_vm True \
19
+ --vis_use_lora False \
20
+ --savedmodel_path ${savepath} \
21
+ --max_length 100 \
22
+ --min_new_tokens 80 \
23
+ --max_new_tokens 120 \
24
+ --repetition_penalty 2.0 \
25
+ --length_penalty 2.0 \
26
+ --num_workers 12 \
27
+ --devices 1 \
28
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/5-1.delta_run.sh ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="mimic_cxr"
4
+ annotation="data/mimic_cxr/my_mimic_anno.json"
5
+ base_dir="./data/mimic_cxr/images"
6
+
7
+ version="v1_delta"
8
+ savepath="./save/$dataset/$version"
9
+
10
+ if [ ! -d "$savepath" ]; then
11
+ mkdir -p "$savepath"
12
+ echo "Folder '$savepath' created."
13
+ else
14
+ echo "Folder '$savepath' already exists."
15
+ fi
16
+
17
+ python -u train.py \
18
+ --dataset ${dataset} \
19
+ --annotation ${annotation} \
20
+ --base_dir ${base_dir} \
21
+ --batch_size 8 \
22
+ --val_batch_size 16 \
23
+ --freeze_vm True \
24
+ --vis_use_lora True \
25
+ --vis_r 16 \
26
+ --vis_alpha 16 \
27
+ --savedmodel_path ${savepath} \
28
+ --max_length 100 \
29
+ --min_new_tokens 80 \
30
+ --max_new_tokens 120 \
31
+ --repetition_penalty 2.0 \
32
+ --length_penalty 2.0 \
33
+ --num_workers 16 \
34
+ --devices 4 \
35
+ --max_epochs 5 \
36
+ --limit_val_batches 0.5 \
37
+ --val_check_interval 0.5 \
38
+ --num_sanity_val_steps 2 \
39
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/5-2.delta_test.sh ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="mimic_cxr"
4
+ annotation="data/mimic_cxr/my_mimic_anno.json"
5
+ base_dir="./data/mimic_cxr/images"
6
+ delta_file="path/to/pretrained/delta_file"
7
+
8
+ version="v1_delta"
9
+ savepath="./save/$dataset/$version"
10
+
11
+ python -u train.py \
12
+ --test \
13
+ --dataset ${dataset} \
14
+ --annotation ${annotation} \
15
+ --base_dir ${base_dir} \
16
+ --delta_file ${delta_file} \
17
+ --max_length 100 \
18
+ --min_new_tokens 80 \
19
+ --max_new_tokens 120 \
20
+ --repetition_penalty 2.0 \
21
+ --length_penalty 2.0 \
22
+ --test_batch_size 16 \
23
+ --freeze_vm True \
24
+ --vis_use_lora True \
25
+ --vis_r 16 \
26
+ --vis_alpha 16 \
27
+ --savedmodel_path ${savepath} \
28
+ --num_workers 12 \
29
+ --devices 1 \
30
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/6-1.deep_run.sh ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="mimic_cxr"
4
+ annotation="data/mimic_cxr/my_mimic_anno.json"
5
+ base_dir="./data/mimic_cxr/images"
6
+
7
+ version="v1_deep"
8
+ savepath="./save/$dataset/$version"
9
+
10
+ if [ ! -d "$savepath" ]; then
11
+ mkdir -p "$savepath"
12
+ echo "Folder '$savepath' created."
13
+ else
14
+ echo "Folder '$savepath' already exists."
15
+ fi
16
+
17
+ python -u train.py \
18
+ --dataset ${dataset} \
19
+ --annotation ${annotation} \
20
+ --base_dir ${base_dir} \
21
+ --batch_size 6 \
22
+ --val_batch_size 12 \
23
+ --freeze_vm False \
24
+ --vis_use_lora False \
25
+ --llm_use_lora False \
26
+ --savedmodel_path ${savepath} \
27
+ --max_length 100 \
28
+ --min_new_tokens 80 \
29
+ --max_new_tokens 120 \
30
+ --repetition_penalty 2.0 \
31
+ --length_penalty 2.0 \
32
+ --num_workers 12 \
33
+ --devices 4 \
34
+ --max_epochs 5 \
35
+ --limit_val_batches 0.5 \
36
+ --val_check_interval 0.5 \
37
+ --num_sanity_val_steps 2 \
38
+ 2>&1 |tee -a ${savepath}/log.txt
scripts/6-2.deep_test.sh ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ dataset="mimic_cxr"
4
+ annotation="data/mimic_cxr/my_mimic_anno.json"
5
+ base_dir="./data/mimic_cxr/images"
6
+ delta_file="path/to/pretrained/delta_file"
7
+
8
+ version="v1_deep"
9
+ savepath="./save/$dataset/$version"
10
+
11
+ python -u train.py \
12
+ --test \
13
+ --dataset ${dataset} \
14
+ --annotation ${annotation} \
15
+ --base_dir ${base_dir} \
16
+ --delta_file ${delta_file} \
17
+ --test_batch_size 16 \
18
+ --max_length 100 \
19
+ --min_new_tokens 80 \
20
+ --max_new_tokens 120 \
21
+ --repetition_penalty 2.0 \
22
+ --length_penalty 2.0 \
23
+ --freeze_vm False \
24
+ --vis_use_lora False \
25
+ --savedmodel_path ${savepath} \
26
+ --num_workers 12 \
27
+ --devices 1 \
28
+ 2>&1 |tee -a ${savepath}/log.txt
train.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pprint import pprint
3
+ from configs.config import parser
4
+ from dataset.data_module import DataModule
5
+ from lightning_tools.callbacks import add_callbacks
6
+ from models.R2GenGPT import R2GenGPT
7
+ from lightning.pytorch import seed_everything
8
+ import lightning.pytorch as pl
9
+
10
+
11
+ def train(args):
12
+ dm = DataModule(args)
13
+ callbacks = add_callbacks(args)
14
+
15
+ trainer = pl.Trainer(
16
+ devices=args.devices,
17
+ num_nodes=args.num_nodes,
18
+ strategy=args.strategy,
19
+ accelerator=args.accelerator,
20
+ precision=args.precision,
21
+ val_check_interval = args.val_check_interval,
22
+ limit_val_batches = args.limit_val_batches,
23
+ max_epochs = args.max_epochs,
24
+ num_sanity_val_steps = args.num_sanity_val_steps,
25
+ accumulate_grad_batches=args.accumulate_grad_batches,
26
+ callbacks=callbacks["callbacks"],
27
+ logger=callbacks["loggers"]
28
+ )
29
+
30
+ if args.ckpt_file is not None:
31
+ model = R2GenGPT.load_from_checkpoint(args.ckpt_file, strict=False)
32
+ else:
33
+ model = R2GenGPT(args)
34
+
35
+ if args.test:
36
+ trainer.test(model, datamodule=dm)
37
+ elif args.validate:
38
+ trainer.validate(model, datamodule=dm)
39
+ else:
40
+ trainer.fit(model, datamodule=dm)
41
+
42
+ def main():
43
+ args = parser.parse_args()
44
+ os.makedirs(args.savedmodel_path, exist_ok=True)
45
+ pprint(vars(args))
46
+ seed_everything(42, workers=True)
47
+ train(args)
48
+
49
+
50
+ if __name__ == '__main__':
51
+ main()