januarevan commited on
Commit
2b32bec
·
1 Parent(s): 2f19ff9
images/1.jpg ADDED
images/10.jpg ADDED
images/131.jpg ADDED
images/132.jpg ADDED
images/15.jpg ADDED
images/16.jpg ADDED
images/2.jpg ADDED
images/407.jpg ADDED
images/5.jpg ADDED
images/56.jpg ADDED
images/57.jpg ADDED
images/581.jpg ADDED
images/630.jpg ADDED
images/8.jpg ADDED
images/9.jpg ADDED
main.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Form, Depends, Request, File, UploadFile
2
+ from fastapi.encoders import jsonable_encoder
3
+ from fastapi.responses import JSONResponse
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+
6
+ import segmentation_models_pytorch as smp
7
+ import torch
8
+ import numpy as np
9
+ import cv2
10
+ import os
11
+ from torch.utils.data import Dataset, DataLoader
12
+ from PIL import Image
13
+ from io import BytesIO
14
+
15
+ DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ model = smp.PAN(encoder_name="resnext50_32x4d", in_channels=3, classes=1)
18
+ model.to(DEVICE).load_state_dict(torch.load("/model/pan_resnext50_32x4d_adam_lr001_batch16_epoch_50.ckpt", map_location=DEVICE))
19
+
20
+ app = FastAPI()
21
+
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["*"], # Replace with the list of allowed origins for production
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ image_dataset = []
31
+ for file in os.listdir("/images"):
32
+ image_dataset.append(cv2.resize(cv2.imread('images/' + file), (160, 544)))
33
+
34
+ @app.get("/")
35
+ async def root():
36
+ return {"message": "Hello World"}
37
+
38
+ class CustomDataset(Dataset):
39
+ def __init__(self, data, transform=None):
40
+ self.data = data
41
+ self.transform = transform
42
+
43
+ def __len__(self):
44
+ return len(self.data)
45
+
46
+ def __getitem__(self, idx):
47
+ sample = {
48
+ 'image': self.data[idx],
49
+ }
50
+
51
+ if self.transform:
52
+ sample = self.transform(sample)
53
+
54
+ return sample
55
+
56
+ @app.post("/segmentation")
57
+ async def segmentation(file: UploadFile = File(...)):
58
+ contents = await file.read()
59
+ image = Image.open(BytesIO(contents))
60
+ open_cv_image = np.array(image)
61
+ open_cv_image = cv2.resize(open_cv_image, (160, 544))
62
+
63
+ image_dataset.insert(0, open_cv_image)
64
+ dataset = CustomDataset(image_dataset)
65
+ dataloader = DataLoader(dataset, batch_size=1, shusffle=False, num_workers=0)
66
+
67
+ try:
68
+ for batch in dataloader:
69
+ temp_image = batch['image'].to(DEVICE)
70
+ output = model(temp_image)
71
+ output[0] = (output[0] > 0.5)
72
+ output = output[0].squeeze().cpu().numpy()
73
+
74
+ except Exception as e:
75
+ print(e)
76
+ return JSONResponse(status_code=500, content={"error": str(e)})
77
+ else:
78
+ return JSONResponse(status_code=200, content={"result": 'good'})
79
+
80
+ @app.post("/predict")
81
+ async def predict():
82
+ return None
83
+
84
+
model/pan_resnext50_32x4d_adam_lr001_batch16_epoch_50.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:29161b9a0f7f98baf457ed6cf502e4d52f05f121bc1b8c2b00d496f558ddba3e
3
+ size 190563613
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ segmentation-models-pytorch
2
+ # pandas
3
+ numpy
4
+ torch
5
+ opencv-python
6
+ fastapi==0.103.2