Spaces:
Sleeping
Sleeping
import streamlit as st | |
import torch | |
import torch.nn as nn | |
import timm | |
import torchvision.transforms as transforms | |
import pytorch_lightning as pl | |
from PIL import Image | |
import numpy as np | |
from torch import nn | |
import smp | |
# The accompanying inference app | |
PATHS = ['1.tiff', '2.tiff'] | |
NUM_CLASSES = len(CLASSES) | |
IDS_TO_CLASSES_DICT = dict(zip(list(range(NUM_CLASSES)), CLASSES)) | |
MODEL_NAME = "se_resne" | |
MODEL_PATH = "model.ckpt" | |
DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" | |
TRANSFORM = transforms.Compose([transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) | |
BACKBONE = "" | |
IN_CHANNELS = "" | |
CLASSES = "" | |
# TODO: path to weights? | |
WEIGHTS = "" | |
class VesuviusModel(nn.Module): | |
def __init__(self, weight=None): | |
super().__init__() | |
self.cfg = cfg | |
self.encoder = smp.Unet( | |
encoder_name=BACKBONE, | |
encoder_weights=WEIGHTS, | |
in_channels=IN_CHANNELS, | |
classes=CLASSES, | |
activation=None, | |
) | |
def forward(self, image): | |
output = self.encoder(image) | |
output = output.squeeze(-1) | |
return output | |
def load_weights_into_model(model_name: str, model_path: str) -> nn.Module: | |
model = VesuviusModel(model_name) | |
state_dict = torch.load(model_path, map_location=DEVICE)["state_dict"] | |
model.load_state_dict(state_dict) | |
return model | |
model = load_weights_into_model(MODEL_NAME, MODEL_PATH) | |
model.to(DEVICE) | |
model.eval() | |
img_path = st.selectbox('Select an image to segment', PATHS) | |
st.write('You have selected:', img_path) | |
img = Image.open(img_path) | |
st.image(img, caption='Selected image to segment') | |
np_img = np.array(img) | |
input_batch = TRANSFORM(np_img[:, :, :3]).unsqueeze(0).to(DEVICE) | |
with st.spinner("Segmenting the image in progress..."): | |
with torch.no_grad(): | |
# TODO: Finish... | |
prediction = model(input_batch).cpu() | |
print(prediction) | |