MaxwellMeyer commited on
Commit
0004261
·
verified ·
1 Parent(s): d8358da

Update BEN2.py

Browse files
Files changed (1) hide show
  1. BEN2.py +24 -2
BEN2.py CHANGED
@@ -921,6 +921,8 @@ class BEN_Base(nn.Module):
921
  if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
922
  m.inplace = True
923
 
 
 
924
  @torch.inference_mode()
925
  @torch.autocast(device_type="cuda",dtype=torch.float16)
926
  def forward(self, x):
@@ -1008,7 +1010,13 @@ class BEN_Base(nn.Module):
1008
  # image = ImageOps.exif_transpose(image)
1009
  if isinstance(image, Image.Image):
1010
  image, h, w,original_image = rgb_loader_refiner(image)
1011
- img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
 
 
 
 
 
 
1012
  with torch.no_grad():
1013
  res = self.forward(img_tensor)
1014
 
@@ -1035,7 +1043,11 @@ class BEN_Base(nn.Module):
1035
  foregrounds = []
1036
  for batch in image:
1037
  image, h, w,original_image = rgb_loader_refiner(batch)
1038
- img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
 
 
 
 
1039
 
1040
  with torch.no_grad():
1041
  res = self.forward(img_tensor)
@@ -1058,6 +1070,9 @@ class BEN_Base(nn.Module):
1058
 
1059
  return foregrounds
1060
 
 
 
 
1061
  def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
1062
 
1063
  """
@@ -1196,6 +1211,13 @@ img_transform = transforms.Compose([
1196
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1197
  ])
1198
 
 
 
 
 
 
 
 
1199
 
1200
 
1201
 
 
921
  if isinstance(m, nn.GELU) or isinstance(m, nn.Dropout):
922
  m.inplace = True
923
 
924
+
925
+
926
  @torch.inference_mode()
927
  @torch.autocast(device_type="cuda",dtype=torch.float16)
928
  def forward(self, x):
 
1010
  # image = ImageOps.exif_transpose(image)
1011
  if isinstance(image, Image.Image):
1012
  image, h, w,original_image = rgb_loader_refiner(image)
1013
+ if torch.cuda.is_available():
1014
+
1015
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1016
+ else:
1017
+ img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1018
+
1019
+
1020
  with torch.no_grad():
1021
  res = self.forward(img_tensor)
1022
 
 
1043
  foregrounds = []
1044
  for batch in image:
1045
  image, h, w,original_image = rgb_loader_refiner(batch)
1046
+ if torch.cuda.is_available():
1047
+
1048
+ img_tensor = img_transform(image).unsqueeze(0).to(next(self.parameters()).device)
1049
+ else:
1050
+ img_tensor = img_transform32(image).unsqueeze(0).to(next(self.parameters()).device)
1051
 
1052
  with torch.no_grad():
1053
  res = self.forward(img_tensor)
 
1070
 
1071
  return foregrounds
1072
 
1073
+
1074
+
1075
+
1076
  def segment_video(self, video_path, output_path="./", fps=0, refine_foreground=False, batch=1, print_frames_processed=True, webm = False, rgb_value= (0, 255, 0)):
1077
 
1078
  """
 
1211
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1212
  ])
1213
 
1214
+ img_transform32 = transforms.Compose([
1215
+ transforms.ToTensor(),
1216
+ transforms.ConvertImageDtype(torch.float32),
1217
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
1218
+ ])
1219
+
1220
+
1221
 
1222
 
1223