shaktibiplab commited on
Commit
538d369
·
verified ·
1 Parent(s): 2e42840

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +62 -27
app.py CHANGED
@@ -1,31 +1,66 @@
1
  import gradio as gr
2
- from transformers import AutoModelForImageClassification, AutoFeatureExtractor
3
- from PIL import Image
 
 
4
 
5
- # Load the model and feature extractor
6
- MODEL_NAME = "shaktibiplab/Animal-Classification"
7
- model = AutoModelForImageClassification.from_pretrained(MODEL_NAME)
8
- extractor = AutoFeatureExtractor.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
 
9
 
10
- # Function to classify uploaded image
11
  def classify_image(image):
12
- image = Image.open(image).convert("RGB")
13
- inputs = extractor(images=image, return_tensors="pt")
14
- outputs = model(**inputs)
15
- logits = outputs.logits
16
- predicted_class_idx = logits.argmax(-1).item()
17
- return model.config.id2label[predicted_class_idx]
18
-
19
- # Create Gradio Interface
20
- interface = gr.Interface(
21
- fn=classify_image,
22
- inputs=gr.Image(type="file"),
23
- outputs="text",
24
-
25
- title="Animal Classifier",
26
- description="Upload an image of an animal, and the model will classify it into one of the trained categories."
27
- )
28
-
29
- # Launch the application
30
- if __name__ == "__main__":
31
- interface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow.keras.models import load_model
5
+ from tensorflow.keras.preprocessing.image import load_img, img_to_array
6
 
7
+ # Load the trained model
8
+ MODEL_PATH = "best_model.weights.h5"
9
+ model = load_model(MODEL_PATH)
10
+
11
+ # Define the class names
12
+ class_names = [
13
+ "Bear", "Bird", "Cat", "Cow", "Deer",
14
+ "Dog", "Dolphin", "Elephant", "Giraffe",
15
+ "Horse", "Kangaroo", "Lion", "Panda",
16
+ "Tiger", "Zebra"
17
+ ]
18
 
 
19
  def classify_image(image):
20
+ img = image.resize((256, 256))
21
+ img_array = img_to_array(img) / 255.0
22
+ img_array = np.expand_dims(img_array, axis=0)
23
+ predictions = model.predict(img_array)
24
+ predicted_class = class_names[np.argmax(predictions)]
25
+ return f"Predicted Class: {predicted_class}"
26
+
27
+ def instruction():
28
+ return (
29
+ "**Important Note:**\n\n"
30
+ "This model is specifically trained to classify images into the following **15 animal categories**:\n\n"
31
+ "- Bear\n"
32
+ "- Bird\n"
33
+ "- Cat\n"
34
+ "- Cow\n"
35
+ "- Deer\n"
36
+ "- Dog\n"
37
+ "- Dolphin\n"
38
+ "- Elephant\n"
39
+ "- Giraffe\n"
40
+ "- Horse\n"
41
+ "- Kangaroo\n"
42
+ "- Lion\n"
43
+ "- Panda\n"
44
+ "- Tiger\n"
45
+ "- Zebra\n\n"
46
+ "**Usage Limitation:**\n\n"
47
+ "- The model will only recognize images containing these animals.\n"
48
+ "- Uploading an image of an animal not listed above or a non-animal image may result in inaccurate or undefined predictions.\n\n"
49
+ "Ensure the uploaded image is clear, contains a single animal, and resembles the categories listed for the best results."
50
+ )
51
+
52
+ # Gradio Interface
53
+ with gr.Blocks() as app:
54
+ gr.Markdown("# Animal Classifier")
55
+ gr.Markdown(instruction())
56
+
57
+ with gr.Row():
58
+ with gr.Column():
59
+ image_input = gr.Image(label="Upload an Image", type="pil")
60
+ predict_button = gr.Button("Classify Image")
61
+ with gr.Column():
62
+ result_output = gr.Textbox(label="Prediction Result", lines=3)
63
+
64
+ predict_button.click(classify_image, inputs=image_input, outputs=result_output)
65
+
66
+ app.launch()