shaktibiplab commited on
Commit
06eb4c6
·
verified ·
1 Parent(s): a07f3ca

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -31
app.py CHANGED
@@ -1,44 +1,30 @@
1
- import numpy as np
2
- from PIL import Image
3
- from tkinter import Tk, filedialog
4
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
 
5
 
6
- # Load the model and feature extractor from Hugging Face
7
  MODEL_NAME = "shaktibiplab/Animal-Classification"
8
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
9
  extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
10
 
11
- # Function to load and preprocess the image
12
- def load_and_preprocess_image(image_path):
13
- image = Image.open(image_path).convert("RGB")
14
- return image
15
-
16
- # Function to classify the image
17
- def classify_image(image_path):
18
- image = load_and_preprocess_image(image_path)
19
  inputs = extractor(images=image, return_tensors="pt")
20
  outputs = model(**inputs)
21
  logits = outputs.logits
22
  predicted_class_idx = logits.argmax(-1).item()
23
  return model.config.id2label[predicted_class_idx]
24
 
25
- # Main program with file upload dialog
 
 
 
 
 
 
 
 
 
26
  if __name__ == "__main__":
27
- root = Tk()
28
- root.withdraw() # Hide the main tkinter window
29
- print("Please select an image file.")
30
-
31
- # Open a file dialog to select the image
32
- image_path = filedialog.askopenfilename(
33
- title="Select an Image",
34
- filetypes=[("Image Files", "*.jpg *.jpeg *.png *.bmp *.tiff")]
35
- )
36
-
37
- if image_path:
38
- try:
39
- predicted_class = classify_image(image_path)
40
- print(f"Predicted Class: {predicted_class}")
41
- except Exception as e:
42
- print(f"Error: {e}")
43
- else:
44
- print("No file selected.")
 
1
+ import gradio as gr
 
 
2
  from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
+ from PIL import Image
4
 
5
+ # Load the model and feature extractor
6
  MODEL_NAME = "shaktibiplab/Animal-Classification"
7
  model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
8
  extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
9
 
10
+ # Function to classify uploaded image
11
+ def classify_image(image):
12
+ image = Image.open(image).convert("RGB")
 
 
 
 
 
13
  inputs = extractor(images=image, return_tensors="pt")
14
  outputs = model(**inputs)
15
  logits = outputs.logits
16
  predicted_class_idx = logits.argmax(-1).item()
17
  return model.config.id2label[predicted_class_idx]
18
 
19
+ # Create Gradio Interface
20
+ interface = gr.Interface(
21
+ fn=classify_image,
22
+ inputs=gr.Image(type="file"),
23
+ outputs="text",
24
+ title="Animal Classifier",
25
+ description="Upload an image of an animal, and the model will classify it into one of the trained categories."
26
+ )
27
+
28
+ # Launch the application
29
  if __name__ == "__main__":
30
+ interface.launch()