eyupipler commited on
Commit
26f5e14
1 Parent(s): a948a43

Added Vbai-1.2 Dementia

Browse files
Main Models/Vbai-1.2 Dementia/README.md ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vbai-1.2 Dementia (11178564 parametre)
2
+
3
+ ## "Vbai-1.2 Dementia" modeli, bir önceki modele göre daha fazla veriyle eğitilmiş olup üzerinde ince ayar yapılmış versiyonudur.
4
+
5
+ ## -----------------------------------------------------------------------------------
6
+
7
+ # Vbai-1.2 Dementia (11178564 parameters)
8
+
9
+ ## The "Vbai-1.2 Dementia" model is a fine-tuned version of the previous model, trained with more data.
10
+
11
+ [![Vbai-1.2](https://img.youtube.com/vi/qUkId3S9W94/0.jpg)](https://youtu.be/wDfsFwusGQU)
12
+
13
+
14
+ # Kullanım / Usage
15
+
16
+ ```python
17
+ import torch
18
+ import torch.nn as nn
19
+ from torchvision import transforms, models
20
+ from PIL import Image
21
+ import matplotlib.pyplot as plt
22
+ import os
23
+ from torchsummary import summary
24
+
25
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
26
+
27
+ model = models.resnet18(pretrained=False)
28
+ num_ftrs = model.fc.in_features
29
+ model.fc = nn.Linear(num_ftrs, 4)
30
+ model.load_state_dict(torch.load('Vbai-1.2 Dementia/path'))
31
+ model = model.to(device)
32
+ model.eval()
33
+ summary(model, (3, 224, 224))
34
+
35
+ transform = transforms.Compose([
36
+ transforms.Resize((224, 224)),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ ])
40
+
41
+ class_names = ['No Dementia', 'Mild Dementia', 'Avarage Dementia', 'Very Mild Dementia']
42
+
43
+ def predict(image_path, model, transform):
44
+ image = Image.open(image_path).convert('RGB')
45
+ image = transform(image).unsqueeze(0).to(device)
46
+ model.eval()
47
+ with torch.no_grad():
48
+ outputs = model(image)
49
+ probs = torch.nn.functional.softmax(outputs, dim=1)
50
+ _, preds = torch.max(outputs, 1)
51
+ return preds.item(), probs[0][preds.item()].item()
52
+
53
+ def show_image_with_prediction(image_path, prediction, confidence, class_names):
54
+ image = Image.open(image_path)
55
+ plt.imshow(image)
56
+ plt.title(f"Prediction: {class_names[prediction]} (%{confidence * 100:.2f})")
57
+ plt.axis('off')
58
+ plt.show()
59
+
60
+ test_image_path = 'image-path'
61
+ prediction, confidence = predict(test_image_path, model, transform)
62
+ print(f'Prediction: {class_names[prediction]} (%{confidence * 100})')
63
+
64
+ show_image_with_prediction(test_image_path, prediction, confidence, class_names)
65
+ ```
66
+
67
+ # Uygulama / As App
68
+
69
+ ```python
70
+ import sys
71
+ import torch
72
+ import torch.nn as nn
73
+ from torchvision import transforms, models
74
+ from PIL import Image
75
+ import matplotlib.pyplot as plt
76
+ from PyQt5.QtWidgets import QApplication, QWidget, QPushButton, QLabel, QFileDialog, QVBoxLayout, QMessageBox
77
+ from PyQt5.QtGui import QPixmap, QIcon
78
+ from PyQt5.QtCore import Qt
79
+
80
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
81
+
82
+ transform = transforms.Compose([
83
+ transforms.Resize((224, 224)),
84
+ transforms.ToTensor(),
85
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
86
+ ])
87
+
88
+ class_names = ['No Dementia', 'Mild Dementia', 'Avarage Dementia', 'Very Mild Dementia']
89
+
90
+
91
+ class DementiaApp(QWidget):
92
+ def __init__(self):
93
+ super().__init__()
94
+ self.initUI()
95
+ self.model = None
96
+ self.image_path = None
97
+
98
+ def initUI(self):
99
+ self.setWindowTitle('Prediction App by Neurazum')
100
+ self.setWindowIcon(QIcon('C:/Users/eyupi/PycharmProjects/Neurazum/NeurAI/Assets/neurazumicon.ico'))
101
+ self.setGeometry(2500, 300, 400, 200)
102
+
103
+ self.loadModelButton = QPushButton('Upload Model', self)
104
+ self.loadModelButton.clicked.connect(self.loadModel)
105
+
106
+ self.loadImageButton = QPushButton('Upload Image', self)
107
+ self.loadImageButton.clicked.connect(self.loadImage)
108
+
109
+ self.predictButton = QPushButton('Make a Prediction', self)
110
+ self.predictButton.clicked.connect(self.predict)
111
+ self.predictButton.setEnabled(False)
112
+
113
+ self.resultLabel = QLabel('', self)
114
+ self.resultLabel.setAlignment(Qt.AlignCenter)
115
+
116
+ self.imageLabel = QLabel('', self)
117
+ self.imageLabel.setAlignment(Qt.AlignCenter)
118
+
119
+ layout = QVBoxLayout()
120
+ layout.addWidget(self.loadModelButton)
121
+ layout.addWidget(self.loadImageButton)
122
+ layout.addWidget(self.imageLabel)
123
+ layout.addWidget(self.predictButton)
124
+ layout.addWidget(self.resultLabel)
125
+
126
+ self.setLayout(layout)
127
+
128
+ def loadModel(self):
129
+ options = QFileDialog.Options()
130
+ fileName, _ = QFileDialog.getOpenFileName(self, "Choose Model Path", "",
131
+ "PyTorch Model Files (*.pt);;All Files (*)", options=options)
132
+ if fileName:
133
+ self.model = models.resnet18(pretrained=False)
134
+ num_ftrs = self.model.fc.in_features
135
+ self.model.fc = nn.Linear(num_ftrs, 4)
136
+ self.model.load_state_dict(torch.load(fileName, map_location=device))
137
+ self.model = self.model.to(device)
138
+ self.model.eval()
139
+ self.predictButton.setEnabled(True)
140
+ QMessageBox.information(self, "Model Uploaded", "Model successfully uploaded!")
141
+
142
+ def loadImage(self):
143
+ options = QFileDialog.Options()
144
+ fileName, _ = QFileDialog.getOpenFileName(self, "Choose Image File", "",
145
+ "Image Files (*.jpg *.jpeg *.png);;All Files (*)", options=options)
146
+ if fileName:
147
+ self.image_path = fileName
148
+ pixmap = QPixmap(self.image_path)
149
+ self.imageLabel.setPixmap(pixmap.scaled(224, 224, Qt.KeepAspectRatio))
150
+
151
+ def predict(self):
152
+ if self.model and self.image_path:
153
+ prediction, confidence = self.predictImage(self.image_path, self.model, transform)
154
+ self.resultLabel.setText(f'Prediction: {class_names[prediction]} (%{confidence * 100:.2f})')
155
+ else:
156
+ QMessageBox.warning(self, "Missing Information", "Model and picture must be uploaded.")
157
+
158
+ def predictImage(self, image_path, model, transform):
159
+ image = Image.open(image_path).convert('RGB')
160
+ image = transform(image).unsqueeze(0).to(device)
161
+ model.eval()
162
+ with torch.no_grad():
163
+ outputs = model(image)
164
+ probs = torch.nn.functional.softmax(outputs, dim=1)
165
+ _, preds = torch.max(outputs, 1)
166
+ return preds.item(), probs[0][preds.item()].item()
167
+
168
+
169
+ if __name__ == '__main__':
170
+ app = QApplication(sys.argv)
171
+ ex = DementiaApp()
172
+ ex.show()
173
+ sys.exit(app.exec_())
174
+ ```
175
+
176
+ # Python Sürümü / Python Version
177
+
178
+ ### 3.9 <=> 3.13
179
+
180
+ # Modüller / Modules
181
+
182
+ ```bash
183
+ matplotlib==3.8.0
184
+ Pillow==10.0.1
185
+ torch==2.3.1
186
+ torchsummary==1.5.1
187
+ torchvision==0.18.1
188
+ ```
Main Models/Vbai-1.2 Dementia/Vbai-1.2 Dementia.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e97a1a006ab50b98a641b44a9a3d36eaafc97648bd07dad01ad3b1b8b8c08980
3
+ size 44792618