import numpy as np import math def tiling_inference(session, lr, overlapping=8, patch_size=(56, 56)): """ Parameters: - session: an ONNX Runtime session object that contains the super-resolution model - lr: the low-resolution image - overlapping: the number of pixels to overlap between adjacent patches - patch_size: a tuple of (height, width) that specifies the size of each patch Returns: - a numpy array that represents the enhanced image """ _, h, w, _ = lr.shape sr = np.zeros((1, 2*h, 2*w, 3)) n_h = math.ceil(h / float(patch_size[0] - overlapping)) n_w = math.ceil(w / float(patch_size[1] - overlapping)) #every tilling input has same size of patch_size for ih in range(n_h): h_idx = ih * (patch_size[0] - overlapping) h_idx = h_idx if h_idx + patch_size[0] <= h else h - patch_size[0] for iw in range(n_w): w_idx = iw * (patch_size[1] - overlapping) w_idx = w_idx if w_idx + patch_size[1] <= w else w - patch_size[1] tiling_lr = lr[..., h_idx: h_idx+patch_size[0], w_idx: w_idx+patch_size[1], :] # import pdb; pdb.set_trace() sr_tiling = session.run(None, {session.get_inputs()[0].name: tiling_lr})[0] left, right, top, bottom = 0, patch_size[1], 0, patch_size[0] left += overlapping//2 right -= overlapping//2 top += overlapping//2 bottom -= overlapping//2 #processing edge pixels if w_idx == 0: left -= overlapping//2 if h_idx == 0: top -= overlapping//2 if h_idx+patch_size[0]>=h: bottom += overlapping//2 if w_idx+patch_size[1]>=w: right += overlapping//2 #get preditions sr[... , 2*(h_idx+top): 2*(h_idx+bottom), 2*(w_idx+left): 2*(w_idx+right), :] = sr_tiling[..., 2*top:2*bottom, 2*left:2*right, :] return sr