Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
import torch | |
from models.model import GLPDepth | |
from PIL import Image | |
from torchvision import transforms | |
import matplotlib.pyplot as plt | |
import numpy as np | |
DEVICE='cpu' | |
def load_mde_model(path): | |
model = GLPDepth(max_depth=700.0, is_train=False).to(DEVICE) | |
model_weight = torch.load(path, map_location=torch.device('cpu')) | |
model_weight = model_weight['model_state_dict'] | |
if 'module' in next(iter(model_weight.items()))[0]: | |
model_weight = OrderedDict((k[7:], v) for k, v in model_weight.items()) | |
model.load_state_dict(model_weight) | |
model.eval() | |
return model | |
model = load_mde_model('best_model.ckpt') | |
preprocess = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
transforms.ToTensor() | |
]) | |
input_img = Image.open('demo_imgs/fake.jpg') | |
torch_img = preprocess(input_img).to(DEVICE).unsqueeze(0) | |
with torch.no_grad(): | |
output_patch = model(torch_img) | |
output_patch = output_patch['pred_d'].squeeze().cpu().detach().numpy() | |
print(output_patch.shape) | |
plt.imshow(output_patch, cmap='jet', vmin=0, vmax=np.max(output_patch)) | |
plt.colorbar() | |
plt.savefig('test.png') |