Update BEN2.py
Browse files
BEN2.py
CHANGED
@@ -921,6 +921,8 @@ class BEN_Base(nn.Module):
|
|
921 |
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
922 |
m.inplace = True
|
923 |
|
|
|
|
|
924 |
@torch.inference_mode()
|
925 |
@torch.autocast(device_type="cuda",dtype=torch.float16)
|
926 |
def forward(self, x):
|
@@ -1008,7 +1010,13 @@ class BEN_Base(nn.Module):
|
|
1008 |
# image = ImageOps.exif_transpose(image)
|
1009 |
if isinstance(image, Image.Image):
|
1010 |
image, h, w,original_image = rgb_loader_refiner(image)
|
1011 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
1012 |
with torch.no_grad():
|
1013 |
res = self.forward(img_tensor)
|
1014 |
|
@@ -1035,7 +1043,11 @@ class BEN_Base(nn.Module):
|
|
1035 |
foregrounds = []
|
1036 |
for batch in image:
|
1037 |
image, h, w,original_image = rgb_loader_refiner(batch)
|
1038 |
-
|
|
|
|
|
|
|
|
|
1039 |
|
1040 |
with torch.no_grad():
|
1041 |
res = self.forward(img_tensor)
|
@@ -1058,6 +1070,9 @@ class BEN_Base(nn.Module):
|
|
1058 |
|
1059 |
return foregrounds
|
1060 |
|
|
|
|
|
|
|
1061 |
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
|
1062 |
|
1063 |
"""
|
@@ -1196,6 +1211,13 @@ img_transform = transforms.Compose([
|
|
1196 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1197 |
])
|
1198 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1199 |
|
1200 |
|
1201 |
|
|
|
921 |
if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
|
922 |
m.inplace = True
|
923 |
|
924 |
+
|
925 |
+
|
926 |
@torch.inference_mode()
|
927 |
@torch.autocast(device_type="cuda",dtype=torch.float16)
|
928 |
def forward(self, x):
|
|
|
1010 |
# image = ImageOps.exif_transpose(image)
|
1011 |
if isinstance(image, Image.Image):
|
1012 |
image, h, w,original_image = rgb_loader_refiner(image)
|
1013 |
+
if torch.cuda.is_available():
|
1014 |
+
|
1015 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1016 |
+
else:
|
1017 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
1018 |
+
|
1019 |
+
|
1020 |
with torch.no_grad():
|
1021 |
res = self.forward(img_tensor)
|
1022 |
|
|
|
1043 |
foregrounds = []
|
1044 |
for batch in image:
|
1045 |
image, h, w,original_image = rgb_loader_refiner(batch)
|
1046 |
+
if torch.cuda.is_available():
|
1047 |
+
|
1048 |
+
img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
|
1049 |
+
else:
|
1050 |
+
img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
|
1051 |
|
1052 |
with torch.no_grad():
|
1053 |
res = self.forward(img_tensor)
|
|
|
1070 |
|
1071 |
return foregrounds
|
1072 |
|
1073 |
+
|
1074 |
+
|
1075 |
+
|
1076 |
def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
|
1077 |
|
1078 |
"""
|
|
|
1211 |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1212 |
])
|
1213 |
|
1214 |
+
img_transform32 = transforms.Compose([
|
1215 |
+
transforms.ToTensor(),
|
1216 |
+
transforms.ConvertImageDtype(torch.float32),
|
1217 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
1218 |
+
])
|
1219 |
+
|
1220 |
+
|
1221 |
|
1222 |
|
1223 |
|