Model Overview
The SegFormer model was proposed in SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers by Enze Xie, Wenhai Wang, Zhiding Yu, Anima Anandkumar, Jose M. Alvarez, Ping Luo. The model consists of a hierarchical Transformer encoder and a lightweight all-MLP decode head to achieve great results on image segmentation benchmarks such as ADE20K and Cityscapes.
` Weights are released under the MIT License. Keras model code is released under the Apache 2 License.
Links
Installation
Keras and KerasCV can be installed with:
pip install -U -q keras-cv
pip install -U -q keras>=3
Jax, TensorFlow, and Torch come preinstalled in Kaggle Notebooks. For instructions on installing them in another environment see the Keras Getting Started page.
Presets
The following model checkpoints are provided by the Keras team. Full code examples for each are available below.
Preset name | Parameters | Description |
---|---|---|
segformer_b0_imagenet | 3.72M | SegFormer model with a pretrained MiTB0 backbone. |
segformer_b0 | 3.72M | SegFormer model with MiTB0 backbone. |
segformer_b1 | 13.68M | SegFormer model with MiTB1 backbone. |
segformer_b2 | 24.73M | SegFormer model with MiTB2 backbone. |
segformer_b3 | 44.60M | SegFormer model with MiTB3 backbone. |
segformer_b4 | 61.37M | SegFormer model with MiTB4 backbone. |
segformer_b5 | 81.97M | SegFormer model with MiTB5 backbone. |
Example code
import keras_cv
images = np.ones(shape=(1, 224, 224, 3))
labels = np.zeros(shape=(1, 224, 224, 1))
model = keras_cv.models.SegFormer.from_preset(
"segformer_b0", num_classes=2
)
# Evaluate model
model(images)
Example Usage
import keras_cv
import keras
import numpy as np
Using the class with a backbone
:
import tensorflow as tf
import keras_cv
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("segformer_b3_cityscapes_1024")
model = keras_cv.models.segmentation.SegFormer(
num_classes=1, backbone=backbone,
)
# Evaluate model
model(images)
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
Example Usage with Hugging Face URI
import keras_cv
import keras
import numpy as np
Using the class with a backbone
:
import tensorflow as tf
import keras_cv
images = np.ones(shape=(1, 96, 96, 3))
labels = np.zeros(shape=(1, 96, 96, 1))
backbone = keras_cv.models.MiTBackbone.from_preset("hf://keras/segformer_b3_cityscapes_1024")
model = keras_cv.models.segmentation.SegFormer(
num_classes=1, backbone=backbone,
)
# Evaluate model
model(images)
# Train model
model.compile(
optimizer="adam",
loss=keras.losses.BinaryCrossentropy(from_logits=False),
metrics=["accuracy"],
)
model.fit(images, labels, epochs=3)
- Downloads last month
- 3