mmenendezg
commited on
Commit
•
c5c5181
1
Parent(s):
913eca1
Add files for the gradio app
Browse files- __init__.py +1 -0
- app.py +49 -0
- config/fluorescent_mobilevit_hps.yaml +3 -0
- data/.gitkeep +0 -0
- data/__init__.py +1 -0
- data/__pycache__/__init__.cpython-311.pyc +0 -0
- data/__pycache__/data_preprocessing.cpython-311.pyc +0 -0
- data/data_preprocessing.py +155 -0
- models/.gitkeep +0 -0
- models/__init__.py +1 -0
- models/__pycache__/__init__.cpython-311.pyc +0 -0
- models/__pycache__/hyperparameters_tuning.cpython-311.pyc +0 -0
- models/__pycache__/mobilevit.cpython-311.pyc +0 -0
- models/__pycache__/train_model.cpython-311.pyc +0 -0
- models/mobilevit.py +120 -0
- tools/.gitkeep +0 -0
- tools/__init__.py +1 -0
- tools/__pycache__/__init__.cpython-311.pyc +0 -0
- tools/__pycache__/hyperparameters_tuning.cpython-311.pyc +0 -0
- tools/__pycache__/predict.cpython-311.pyc +0 -0
- tools/__pycache__/train_model.cpython-311.pyc +0 -0
- tools/hyperparameters_tuning.py +104 -0
- tools/predict.py +47 -0
- tools/train_model.py +70 -0
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data, models, tools, visualization
|
app.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import gradio as gr
|
3 |
+
from tools.predict import single_prediction
|
4 |
+
|
5 |
+
KAGGLE_NOTEBOOK = "[![Static Badge](https://img.shields.io/badge/Open_Notebook_in_Kaggle-gray?logo=kaggle&logoColor=white&labelColor=20BEFF)](https://www.kaggle.com/code/mmenendezg/mobilevit-fluorescent-neuronal-cells/notebook)"
|
6 |
+
GITHUB_REPOSITORY = "[![Static Badge](https://img.shields.io/badge/Git_Repository-gray?logo=github&logoColor=white&labelColor=181717)](https://github.com/mmenendezg/mobilevit-fluorescent-cells)"
|
7 |
+
HF_SPACE = "[![Open in Spaces](https://huggingface.co/datasets/huggingface/badges/resolve/main/open-in-hf-spaces-md-dark.svg)](https://huggingface.co/spaces/mmenendezg/mobilevit-fluorescent-neuronal-cells)"
|
8 |
+
|
9 |
+
# Gradio interface
|
10 |
+
demo = gr.Blocks()
|
11 |
+
with demo:
|
12 |
+
gr.Markdown(
|
13 |
+
f"""
|
14 |
+
# Fluorescent Neuronal Cells Segmentation
|
15 |
+
|
16 |
+
This model extracts a segmentation mask of the neuronal cells on an image.
|
17 |
+
|
18 |
+
{KAGGLE_NOTEBOOK}
|
19 |
+
|
20 |
+
{GITHUB_REPOSITORY}
|
21 |
+
|
22 |
+
{HF_SPACE}
|
23 |
+
"""
|
24 |
+
)
|
25 |
+
with gr.Tab("Image Segmentation"):
|
26 |
+
with gr.Row():
|
27 |
+
with gr.Column():
|
28 |
+
uploaded_image = gr.Image(
|
29 |
+
label="Neuronal Cells Image",
|
30 |
+
sources=["upload", "clipboard"],
|
31 |
+
type="pil",
|
32 |
+
height=550,
|
33 |
+
)
|
34 |
+
with gr.Column():
|
35 |
+
mask_image = gr.Image(label="Segmented Neurons", height=550)
|
36 |
+
with gr.Row():
|
37 |
+
classify_btn = gr.Button("Segment the image", variant="primary")
|
38 |
+
clear_btn = gr.ClearButton(components=[uploaded_image, mask_image])
|
39 |
+
classify_btn.click(
|
40 |
+
fn=single_prediction, inputs=uploaded_image, outputs=[mask_image]
|
41 |
+
)
|
42 |
+
gr.Examples(
|
43 |
+
examples=[
|
44 |
+
os.path.join(os.path.dirname(__file__), "examples/example_1.png"),
|
45 |
+
os.path.join(os.path.dirname(__file__), "examples/example_2.png"),
|
46 |
+
],
|
47 |
+
inputs=uploaded_image,
|
48 |
+
)
|
49 |
+
demo.launch(show_error=True)
|
config/fluorescent_mobilevit_hps.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
learning_rate: 0.0005480015685663855
|
2 |
+
weight_decay: 1.544480236681167e-05
|
3 |
+
batch_size: 2
|
data/.gitkeep
ADDED
File without changes
|
data/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import data_preprocessing
|
data/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (250 Bytes). View file
|
|
data/__pycache__/data_preprocessing.cpython-311.pyc
ADDED
Binary file (8.85 kB). View file
|
|
data/data_preprocessing.py
ADDED
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
import cv2
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader
|
7 |
+
import albumentations as A
|
8 |
+
import pytorch_lightning as pl
|
9 |
+
from transformers import AutoImageProcessor
|
10 |
+
from datasets import Dataset, DatasetDict
|
11 |
+
|
12 |
+
# Checkpoint of the model used in the projec
|
13 |
+
MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small"
|
14 |
+
# Size of the image used to train the model
|
15 |
+
IMG_SIZE = [256, 256]
|
16 |
+
|
17 |
+
|
18 |
+
class FluorescentNeuronalDataModule(pl.LightningDataModule):
|
19 |
+
def __init__(self, batch_size, data_dir, dataset_size=1.0):
|
20 |
+
super().__init__()
|
21 |
+
self.data_dir = data_dir
|
22 |
+
self.batch_size = batch_size
|
23 |
+
self.image_processor = AutoImageProcessor.from_pretrained(
|
24 |
+
MODEL_CHECKPOINT, do_reduce_labels=False
|
25 |
+
)
|
26 |
+
self.image_resizer = A.Compose(
|
27 |
+
[
|
28 |
+
A.Resize(
|
29 |
+
width=IMG_SIZE[0],
|
30 |
+
height=IMG_SIZE[1],
|
31 |
+
interpolation=cv2.INTER_NEAREST,
|
32 |
+
)
|
33 |
+
]
|
34 |
+
)
|
35 |
+
self.image_augmentator = A.Compose(
|
36 |
+
[
|
37 |
+
A.HorizontalFlip(p=0.6),
|
38 |
+
A.VerticalFlip(p=0.6),
|
39 |
+
A.RandomBrightnessContrast(p=0.6),
|
40 |
+
A.RandomGamma(p=0.6),
|
41 |
+
A.HueSaturationValue(p=0.6),
|
42 |
+
]
|
43 |
+
)
|
44 |
+
|
45 |
+
# Percentage of the dataset
|
46 |
+
self.dataset_size = dataset_size
|
47 |
+
|
48 |
+
def _create_dataset(self):
|
49 |
+
images_path = os.path.join(self.data_dir, "all_images", "images")
|
50 |
+
masks_path = os.path.join(self.data_dir, "all_masks", "masks")
|
51 |
+
list_images = os.listdir(images_path)
|
52 |
+
|
53 |
+
# Determine the size of the dataset
|
54 |
+
if self.dataset_size < 1.0:
|
55 |
+
n_images = int(len(list_images) * self.dataset_size)
|
56 |
+
list_images = list_images[:n_images]
|
57 |
+
|
58 |
+
images = []
|
59 |
+
masks = []
|
60 |
+
for image_filename in list_images:
|
61 |
+
image_path = os.path.join(images_path, image_filename)
|
62 |
+
mask_path = os.path.join(masks_path, image_filename)
|
63 |
+
|
64 |
+
image = np.array(Image.open(image_path).convert("RGB"), dtype=np.uint8)
|
65 |
+
mask = np.array(Image.open(mask_path).convert("L"), dtype=np.uint8)
|
66 |
+
mask = (mask / 255).astype(np.uint8)
|
67 |
+
|
68 |
+
images.append(image)
|
69 |
+
masks.append(mask)
|
70 |
+
|
71 |
+
dataset = Dataset.from_dict({"image": images, "mask": masks})
|
72 |
+
|
73 |
+
# Split the dataset into train, val, and test sets
|
74 |
+
dataset = dataset.train_test_split(test_size=0.1)
|
75 |
+
train_val = dataset["train"]
|
76 |
+
test_ds = dataset["test"]
|
77 |
+
del dataset
|
78 |
+
|
79 |
+
train_val = train_val.train_test_split(test_size=0.2)
|
80 |
+
train_ds = train_val["train"]
|
81 |
+
valid_ds = train_val["test"]
|
82 |
+
del train_val
|
83 |
+
|
84 |
+
dataset = DatasetDict(
|
85 |
+
{"train": train_ds, "validation": valid_ds, "test": test_ds}
|
86 |
+
)
|
87 |
+
del train_ds, valid_ds, test_ds
|
88 |
+
return dataset
|
89 |
+
|
90 |
+
def _transform_train_data(self, batch):
|
91 |
+
# Preprocess the images
|
92 |
+
images, masks = [], []
|
93 |
+
for i, m in zip(batch["image"], batch["mask"]):
|
94 |
+
img = np.asarray(i, dtype=np.uint8)
|
95 |
+
mask = np.asarray(m, dtype=np.uint8)
|
96 |
+
# First resize the images and masks
|
97 |
+
resized_outputs = self.image_resizer(image=img, mask=mask)
|
98 |
+
images.append(resized_outputs["image"])
|
99 |
+
masks.append(resized_outputs["mask"])
|
100 |
+
|
101 |
+
# Then augment the images
|
102 |
+
augmented_outputs = self.image_augmentator(
|
103 |
+
image=resized_outputs["image"], mask=resized_outputs["mask"]
|
104 |
+
)
|
105 |
+
images.append(augmented_outputs["image"])
|
106 |
+
masks.append(augmented_outputs["mask"])
|
107 |
+
|
108 |
+
inputs = self.image_processor(
|
109 |
+
images=images,
|
110 |
+
return_tensors="pt",
|
111 |
+
)
|
112 |
+
inputs["labels"] = torch.tensor(masks, dtype=torch.long)
|
113 |
+
return inputs
|
114 |
+
|
115 |
+
def _transform_data(self, batch):
|
116 |
+
# Preprocess the images
|
117 |
+
images, masks = [], []
|
118 |
+
for i, m in zip(batch["image"], batch["mask"]):
|
119 |
+
img = np.asarray(i, dtype=np.uint8)
|
120 |
+
mask = np.asarray(m, dtype=np.uint8)
|
121 |
+
# Resize the images and masks
|
122 |
+
resized_outputs = self.image_resizer(image=img, mask=mask)
|
123 |
+
images.append(resized_outputs["image"])
|
124 |
+
masks.append(resized_outputs["mask"])
|
125 |
+
|
126 |
+
inputs = self.image_processor(
|
127 |
+
images=images,
|
128 |
+
return_tensors="pt",
|
129 |
+
)
|
130 |
+
inputs["labels"] = inputs["labels"] = torch.tensor(masks, dtype=torch.long)
|
131 |
+
return inputs
|
132 |
+
|
133 |
+
def setup(self, stage=None):
|
134 |
+
dataset = self._create_dataset()
|
135 |
+
train_ds = dataset["train"]
|
136 |
+
valid_ds = dataset["validation"]
|
137 |
+
test_ds = dataset["test"]
|
138 |
+
|
139 |
+
if stage is None or stage == "fit":
|
140 |
+
self.train_ds = train_ds.with_transform(self._transform_train_data)
|
141 |
+
self.valid_ds = valid_ds.with_transform(self._transform_data)
|
142 |
+
if stage is None or stage == "test" or stage == "predict":
|
143 |
+
self.test_ds = test_ds.with_transform(self._transform_data)
|
144 |
+
|
145 |
+
def train_dataloader(self):
|
146 |
+
return DataLoader(self.train_ds, batch_size=self.batch_size, shuffle=True)
|
147 |
+
|
148 |
+
def val_dataloader(self):
|
149 |
+
return DataLoader(self.valid_ds, batch_size=self.batch_size)
|
150 |
+
|
151 |
+
def test_dataloader(self):
|
152 |
+
return DataLoader(self.test_ds, batch_size=self.batch_size)
|
153 |
+
|
154 |
+
def predict_dataloader(self):
|
155 |
+
return DataLoader(self.test_ds, batch_size=self.batch_size)
|
models/.gitkeep
ADDED
File without changes
|
models/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import mobilevit
|
models/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (243 Bytes). View file
|
|
models/__pycache__/hyperparameters_tuning.cpython-311.pyc
ADDED
Binary file (5.68 kB). View file
|
|
models/__pycache__/mobilevit.cpython-311.pyc
ADDED
Binary file (7.37 kB). View file
|
|
models/__pycache__/train_model.cpython-311.pyc
ADDED
Binary file (3.21 kB). View file
|
|
models/mobilevit.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from transformers import MobileViTForSemanticSegmentation
|
6 |
+
import evaluate
|
7 |
+
|
8 |
+
MODEL_CHECKPOINT = "mmenendezg/mobilevit-fluorescent-neuronal-cells"
|
9 |
+
CLASSES = {0: "Background", 1: "Neuron"}
|
10 |
+
|
11 |
+
|
12 |
+
class MobileVIT(pl.LightningModule):
|
13 |
+
def __init__(self, learning_rate=None, weight_decay=None):
|
14 |
+
super().__init__()
|
15 |
+
self.id2label = CLASSES
|
16 |
+
self.label2id = {v: k for k, v in self.id2label.items()}
|
17 |
+
self.num_classes = len(self.id2label.keys())
|
18 |
+
self.model = MobileViTForSemanticSegmentation.from_pretrained(
|
19 |
+
MODEL_CHECKPOINT,
|
20 |
+
num_labels=self.num_classes,
|
21 |
+
id2label=self.id2label,
|
22 |
+
label2id=self.label2id,
|
23 |
+
ignore_mismatched_sizes=True,
|
24 |
+
)
|
25 |
+
self.metric = evaluate.load("mean_iou")
|
26 |
+
self.learning_rate = learning_rate
|
27 |
+
self.weight_decay = weight_decay
|
28 |
+
|
29 |
+
def forward(self, pixel_values, labels):
|
30 |
+
return self.model(pixel_values=pixel_values, labels=labels)
|
31 |
+
|
32 |
+
def common_step(self, batch, batch_idx):
|
33 |
+
pixel_values = batch["pixel_values"]
|
34 |
+
labels = batch["labels"]
|
35 |
+
|
36 |
+
outputs = self.model(pixel_values=pixel_values, labels=labels)
|
37 |
+
|
38 |
+
loss = outputs.loss
|
39 |
+
logits = outputs.logits
|
40 |
+
return loss, logits
|
41 |
+
|
42 |
+
def compute_metric(self, logits, labels):
|
43 |
+
logits_tensor = nn.functional.interpolate(
|
44 |
+
logits,
|
45 |
+
size=labels.shape[-2:],
|
46 |
+
mode="bilinear",
|
47 |
+
align_corners=False,
|
48 |
+
).argmax(dim=1)
|
49 |
+
pred_labels = logits_tensor.detach().cpu().numpy()
|
50 |
+
metrics = self.metric.compute(
|
51 |
+
predictions=pred_labels,
|
52 |
+
references=labels,
|
53 |
+
num_labels=self.num_classes,
|
54 |
+
ignore_index=255,
|
55 |
+
reduce_labels=False,
|
56 |
+
)
|
57 |
+
|
58 |
+
return metrics
|
59 |
+
|
60 |
+
def training_step(self, batch, batch_idx):
|
61 |
+
labels = batch["labels"]
|
62 |
+
|
63 |
+
# Calculate and log the loss
|
64 |
+
loss, logits = self.common_step(batch, batch_idx)
|
65 |
+
self.log("train_loss", loss)
|
66 |
+
|
67 |
+
# Calculate and log the metrics
|
68 |
+
metrics = self.compute_metric(logits, labels)
|
69 |
+
metrics = {key: np.float32(value) for key, value in metrics.items()}
|
70 |
+
|
71 |
+
self.log("train_mean_iou", metrics["mean_iou"])
|
72 |
+
self.log("train_mean_accuracy", metrics["mean_accuracy"])
|
73 |
+
self.log("train_overall_accuracy", metrics["overall_accuracy"])
|
74 |
+
|
75 |
+
return loss
|
76 |
+
|
77 |
+
def validation_step(self, batch, batch_idx):
|
78 |
+
labels = batch["labels"]
|
79 |
+
|
80 |
+
# Calculate and log the loss
|
81 |
+
loss, logits = self.common_step(batch, batch_idx)
|
82 |
+
self.log("val_loss", loss)
|
83 |
+
|
84 |
+
# Calculate and log the metrics
|
85 |
+
metrics = self.compute_metric(logits, labels)
|
86 |
+
metrics = {key: np.float32(value) for key, value in metrics.items()}
|
87 |
+
self.log("val_mean_iou", metrics["mean_iou"])
|
88 |
+
self.log("val_mean_accuracy", metrics["mean_accuracy"])
|
89 |
+
self.log("val_overall_accuracy", metrics["overall_accuracy"])
|
90 |
+
|
91 |
+
return loss
|
92 |
+
|
93 |
+
def test_step(self, batch, batch_idx):
|
94 |
+
labels = batch["labels"]
|
95 |
+
|
96 |
+
# Calculate and log the loss
|
97 |
+
loss, logits = self.common_step(batch, batch_idx)
|
98 |
+
self.log("test_loss", loss)
|
99 |
+
|
100 |
+
# Calculate and log the metrics
|
101 |
+
metrics = self.compute_metric(logits, labels)
|
102 |
+
metrics = {key: np.float32(value) for key, value in metrics.items()}
|
103 |
+
# for k, v in metrics.items():
|
104 |
+
# self.log(f"val_{k}", v.item())
|
105 |
+
self.log("test_mean_iou", metrics["mean_iou"])
|
106 |
+
self.log("test_mean_accuracy", metrics["mean_accuracy"])
|
107 |
+
self.log("test_overall_accuracy", metrics["overall_accuracy"])
|
108 |
+
|
109 |
+
return loss
|
110 |
+
|
111 |
+
def configure_optimizers(self):
|
112 |
+
param_dicts = [
|
113 |
+
{
|
114 |
+
"params": [p for n, p in self.named_parameters()],
|
115 |
+
"lr": self.learning_rate,
|
116 |
+
}
|
117 |
+
]
|
118 |
+
return torch.optim.AdamW(
|
119 |
+
param_dicts, lr=self.learning_rate, weight_decay=self.weight_decay
|
120 |
+
)
|
tools/.gitkeep
ADDED
File without changes
|
tools/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from . import hyperparameters_tuning, train_model
|
tools/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (283 Bytes). View file
|
|
tools/__pycache__/hyperparameters_tuning.cpython-311.pyc
ADDED
Binary file (5.67 kB). View file
|
|
tools/__pycache__/predict.cpython-311.pyc
ADDED
Binary file (2.43 kB). View file
|
|
tools/__pycache__/train_model.cpython-311.pyc
ADDED
Binary file (3.26 kB). View file
|
|
tools/hyperparameters_tuning.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pathlib import Path
|
3 |
+
import yaml
|
4 |
+
import torch
|
5 |
+
import optuna
|
6 |
+
import pytorch_lightning as pl
|
7 |
+
import click
|
8 |
+
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
|
9 |
+
|
10 |
+
from models.mobilevit import MobileVIT
|
11 |
+
from data.data_preprocessing import FluorescentNeuronalDataModule
|
12 |
+
|
13 |
+
|
14 |
+
MODEL_CHECKPOINT = "apple/deeplabv3-mobilevit-xx-small"
|
15 |
+
|
16 |
+
# Define the accelerator
|
17 |
+
if torch.backends.mps.is_available():
|
18 |
+
DEVICE = torch.device("mps:0")
|
19 |
+
ACCELERATOR = "mps"
|
20 |
+
elif torch.cuda.is_available():
|
21 |
+
DEVICE = torch.device("cuda")
|
22 |
+
ACCELERATOR = "gpu"
|
23 |
+
else:
|
24 |
+
DEVICE = torch.device("cpu")
|
25 |
+
ACCELERATOR = "cpu"
|
26 |
+
|
27 |
+
RAW_DATA_PATH = "./data/raw/"
|
28 |
+
DEFAULT_CONFIG_FILE = "./config/fluorescent_mobilevit_hps.yaml"
|
29 |
+
|
30 |
+
CLASSES = {0: "Background", 1: "Neuron"}
|
31 |
+
|
32 |
+
IMG_SIZE = [256, 256]
|
33 |
+
|
34 |
+
|
35 |
+
@click.command()
|
36 |
+
@click.option(
|
37 |
+
"--data_dir",
|
38 |
+
type=click.Path(exists=True, file_okay=True, path_type=Path),
|
39 |
+
default=RAW_DATA_PATH,
|
40 |
+
)
|
41 |
+
@click.option(
|
42 |
+
"--config_file",
|
43 |
+
type=click.Path(exists=True, file_okay=True, path_type=Path),
|
44 |
+
default=DEFAULT_CONFIG_FILE,
|
45 |
+
)
|
46 |
+
@click.option("--dataset_size", type=click.FLOAT, default=0.25)
|
47 |
+
@click.option("--force-tune/--no-force-tune", default=False)
|
48 |
+
def get_best_params(data_dir, config_file, dataset_size, force_tune) -> dict:
|
49 |
+
def objective(trial: optuna.Trial, dataset_size=dataset_size) -> float:
|
50 |
+
# Suggest values of the hyperparameters for the trials
|
51 |
+
learning_rate = trial.suggest_float("learning_rate", 1e-6, 1e-3, log=True)
|
52 |
+
weight_decay = trial.suggest_float("weight_decay", 1e-6, 1e-3, log=True)
|
53 |
+
batch_size = trial.suggest_int("batch_size", 2, 4, log=True)
|
54 |
+
|
55 |
+
# Define the callbacks of the model
|
56 |
+
early_stopping_cb = EarlyStopping(monitor="val_loss", patience=2)
|
57 |
+
|
58 |
+
# Create the model
|
59 |
+
model = MobileVIT(learning_rate=learning_rate, weight_decay=weight_decay)
|
60 |
+
|
61 |
+
# Instantiate the data module
|
62 |
+
data_module = FluorescentNeuronalDataModule(
|
63 |
+
batch_size=batch_size, dataset_size=dataset_size, data_dir=data_dir
|
64 |
+
)
|
65 |
+
data_module.setup()
|
66 |
+
|
67 |
+
# Train the model
|
68 |
+
trainer = pl.Trainer(
|
69 |
+
devices=1,
|
70 |
+
accelerator=ACCELERATOR,
|
71 |
+
precision="16-mixed",
|
72 |
+
max_epochs=5,
|
73 |
+
log_every_n_steps=5,
|
74 |
+
callbacks=[early_stopping_cb],
|
75 |
+
)
|
76 |
+
trainer.fit(
|
77 |
+
model,
|
78 |
+
train_dataloaders=data_module.train_dataloader(),
|
79 |
+
val_dataloaders=data_module.val_dataloader(),
|
80 |
+
)
|
81 |
+
return trainer.callback_metrics["val_loss"].item()
|
82 |
+
|
83 |
+
if os.path.exists(config_file) and force_tune:
|
84 |
+
os.remove(config_file)
|
85 |
+
pruner = optuna.pruners.MedianPruner()
|
86 |
+
study = optuna.create_study(direction="maximize", pruner=pruner)
|
87 |
+
|
88 |
+
study.optimize(objective, n_trials=25)
|
89 |
+
best_params = study.best_params
|
90 |
+
with open(config_file, "w") as file:
|
91 |
+
yaml.dump(best_params, file)
|
92 |
+
elif os.path.exists(config_file):
|
93 |
+
with open(config_file, "r") as file:
|
94 |
+
best_params = yaml.safe_load(file)
|
95 |
+
else:
|
96 |
+
pruner = optuna.pruners.MedianPruner()
|
97 |
+
study = optuna.create_study(direction="minimize", pruner=pruner)
|
98 |
+
|
99 |
+
study.optimize(objective, n_trials=25)
|
100 |
+
best_params = study.best_params
|
101 |
+
with open(config_file, "w") as file:
|
102 |
+
yaml.dump(best_params, file)
|
103 |
+
|
104 |
+
click.echo(f"The best parameters are:\n{best_params}")
|
tools/predict.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import AutoImageProcessor
|
5 |
+
|
6 |
+
from models.mobilevit import MobileVIT
|
7 |
+
|
8 |
+
# Checkpoint of the model used in the projec
|
9 |
+
MODEL_CHECKPOINT = "mmenendezg/mobilevit-fluorescent-neuronal-cells"
|
10 |
+
|
11 |
+
# Define the accelerator
|
12 |
+
if torch.backends.mps.is_available():
|
13 |
+
DEVICE = torch.device("mps:0")
|
14 |
+
ACCELERATOR = "mps"
|
15 |
+
elif torch.cuda.is_available():
|
16 |
+
DEVICE = torch.device("cuda")
|
17 |
+
ACCELERATOR = "gpu"
|
18 |
+
else:
|
19 |
+
DEVICE = torch.device("cpu")
|
20 |
+
ACCELERATOR = "cpu"
|
21 |
+
|
22 |
+
|
23 |
+
def single_prediction(image):
|
24 |
+
# Instantiate the model from the checkpoint and using the hparams file
|
25 |
+
mobilevit_model = MobileVIT()
|
26 |
+
mobilevit_model.to(DEVICE)
|
27 |
+
# Instantiate the image_processor
|
28 |
+
image_processor = AutoImageProcessor.from_pretrained(
|
29 |
+
MODEL_CHECKPOINT, do_reduce_labels=False
|
30 |
+
)
|
31 |
+
# Load the image
|
32 |
+
image = image.convert("RGB")
|
33 |
+
# Convert the image to numpy array
|
34 |
+
np_image = np.asarray(image, dtype=np.uint8)
|
35 |
+
# Preprocess the image and move the image to the GPU Device
|
36 |
+
processed_image = image_processor(images=np_image, return_tensors="pt")
|
37 |
+
processed_image.to(DEVICE)
|
38 |
+
# Make the prediction and resize the predicted mask
|
39 |
+
logits = mobilevit_model.model(pixel_values=processed_image["pixel_values"])
|
40 |
+
post_processed_image = image_processor.post_process_semantic_segmentation(
|
41 |
+
outputs=logits, target_sizes=[(np_image.shape[0], np_image.shape[1])]
|
42 |
+
)
|
43 |
+
# Process the mask
|
44 |
+
mask = post_processed_image[0].data.cpu().numpy().astype(np.uint8) * 255
|
45 |
+
mask = Image.fromarray(mask)
|
46 |
+
|
47 |
+
return mask
|
tools/train_model.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import yaml
|
2 |
+
from pathlib import Path
|
3 |
+
import click
|
4 |
+
import torch
|
5 |
+
import pytorch_lightning as pl
|
6 |
+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
|
7 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
8 |
+
|
9 |
+
from models.mobilevit import MobileVIT
|
10 |
+
from data.data_preprocessing import FluorescentNeuronalDataModule
|
11 |
+
|
12 |
+
CONFIG_FILE = "config/fluorescent_mobilevit_hps.yaml"
|
13 |
+
DATA_DIR = "data/raw/"
|
14 |
+
LOGS_DIR = "reports/logs/FluorescentMobileVIT"
|
15 |
+
MODEL_DIR = "models/FluorescentMobileVIT"
|
16 |
+
|
17 |
+
# Define the accelerator
|
18 |
+
if torch.backends.mps.is_available():
|
19 |
+
DEVICE = torch.device("mps:0")
|
20 |
+
ACCELERATOR = "mps"
|
21 |
+
elif torch.cuda.is_available():
|
22 |
+
DEVICE = torch.device("cuda")
|
23 |
+
ACCELERATOR = "gpu"
|
24 |
+
else:
|
25 |
+
DEVICE = torch.device("cpu")
|
26 |
+
ACCELERATOR = "cpu"
|
27 |
+
|
28 |
+
|
29 |
+
@click.command()
|
30 |
+
@click.option(
|
31 |
+
"--data_dir",
|
32 |
+
type=click.Path(exists=True, file_okay=True, path_type=Path),
|
33 |
+
default=DATA_DIR,
|
34 |
+
)
|
35 |
+
@click.option(
|
36 |
+
"--config_file",
|
37 |
+
type=click.Path(exists=True, file_okay=True, path_type=Path),
|
38 |
+
default=CONFIG_FILE,
|
39 |
+
)
|
40 |
+
def train_model(data_dir, config_file):
|
41 |
+
# Load the best parameters
|
42 |
+
with open(config_file, "r") as file:
|
43 |
+
best_params = yaml.safe_load(file)
|
44 |
+
# Instantiate the model
|
45 |
+
model = MobileVIT(
|
46 |
+
learning_rate=best_params["learning_rate"],
|
47 |
+
weight_decay=best_params["weight_decay"],
|
48 |
+
)
|
49 |
+
# Define the callbacks of the model
|
50 |
+
model_checkpoint_cb = ModelCheckpoint(
|
51 |
+
save_top_k=1, dirpath=MODEL_DIR, monitor="val_loss"
|
52 |
+
)
|
53 |
+
logger = TensorBoardLogger(save_dir=LOGS_DIR)
|
54 |
+
|
55 |
+
# Create the trainer with its parameters
|
56 |
+
trainer = pl.Trainer(
|
57 |
+
logger=logger,
|
58 |
+
devices=1,
|
59 |
+
accelerator=ACCELERATOR,
|
60 |
+
precision=16,
|
61 |
+
max_epochs=100,
|
62 |
+
log_every_n_steps=20,
|
63 |
+
callbacks=[model_checkpoint_cb],
|
64 |
+
)
|
65 |
+
data_module = FluorescentNeuronalDataModule(
|
66 |
+
data_dir=data_dir, batch_size=best_params["batch_size"]
|
67 |
+
)
|
68 |
+
trainer.fit(model=model, datamodule=data_module)
|
69 |
+
trainer.test(model=model, datamodule=data_module)
|
70 |
+
click.echo("\n\n==========The Training has Finished!==========")
|