import streamlit as st import torch import torch.nn as nn import timm import torchvision.transforms as transforms import pytorch_lightning as pl from PIL import Image import numpy as np from torch import nn import smp # The accompanying inference app PATHS = ['1.tiff', '2.tiff'] NUM_CLASSES = len(CLASSES) IDS_TO_CLASSES_DICT = dict(zip(list(range(NUM_CLASSES)), CLASSES)) MODEL_NAME = "se_resne" MODEL_PATH = "model.ckpt" DEVICE = "cuda:0" if torch.cuda.is_available() else "cpu" TRANSFORM = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) BACKBONE = "" IN_CHANNELS = "" CLASSES = "" # TODO: path to weights? WEIGHTS = "" class VesuviusModel(nn.Module): def __init__(self, weight=None): super().__init__() self.cfg = cfg self.encoder = smp.Unet( encoder_name=BACKBONE, encoder_weights=WEIGHTS, in_channels=IN_CHANNELS, classes=CLASSES, activation=None, ) def forward(self, image): output = self.encoder(image) output = output.squeeze(-1) return output def load_weights_into_model(model_name: str, model_path: str) -> nn.Module: model = VesuviusModel(model_name) state_dict = torch.load(model_path, map_location=DEVICE)["state_dict"] model.load_state_dict(state_dict) return model model = load_weights_into_model(MODEL_NAME, MODEL_PATH) model.to(DEVICE) model.eval() img_path = st.selectbox('Select an image to segment', PATHS) st.write('You have selected:', img_path) img = Image.open(img_path) st.image(img, caption='Selected image to segment') np_img = np.array(img) input_batch = TRANSFORM(np_img[:, :, :3]).unsqueeze(0).to(DEVICE) with st.spinner("Segmenting the image in progress..."): with torch.no_grad(): # TODO: Finish... prediction = model(input_batch).cpu() print(prediction)