Trash Image Classification using Vision Transformer (ViT)

This repository contains an implementation of an image classification model using a pre-trained Vision Transformer (ViT) model from Hugging Face. The model is fine-tuned to classify images into six categories: cardboard, glass, metal, paper, plastic, and trash.

Dataset

The dataset consists of images from six categories from garythung/trashnet with the following distribution:

  • Cardboard: 806 images
  • Glass: 1002 images
  • Metal: 820 images
  • Paper: 1188 images
  • Plastic: 964 images
  • Trash: 274 images

Model

We utilize the pre-trained Vision Transformer model google/vit-base-patch16-224-in21k from Hugging Face for image classification. The model is fine-tuned on the dataset to achieve optimal performance.

The trained model is accessible on Hugging Face Hub at: tribber93/my-trash-classification

Usage

To use the model for inference, follow these steps:

import torch
import requests
from PIL import Image
from transformers import AutoModelForImageClassification, AutoImageProcessor

url = 'https://cdn.grid.id/crop/0x0:0x0/700x465/photo/grid/original/127308_kaleng-bekas.jpg'
image = Image.open(requests.get(url, stream=True).raw)

model_name = "tribber93/my-trash-classification"
model = AutoModelForImageClassification.from_pretrained(model_name)
processor = AutoImageProcessor.from_pretrained(model_name)
inputs = processor(image, return_tensors="pt")

outputs = model(**inputs)
predictions = torch.argmax(outputs.logits, dim=-1)
print("Predicted class:", model.config.id2label[predictions.item()])

Results

After training, the model achieved the following performance:

Epoch Training Loss Validation Loss Accuracy
1 3.3200 0.7011 86.25%
2 1.6611 0.4298 91.49%
3 1.4353 0.3563 94.26%
Downloads last month
24
Safetensors
Model size
85.8M params
Tensor type
F32
·
Inference Examples
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social visibility and check back later, or deploy to Inference Endpoints (dedicated) instead.

Dataset used to train tribber93/my-trash-classification