Skip to main content

Animal pose estimation using Vision Transformers and HRNet(VHR)

Project description

Welcome to our module!

Aja-pose

AJA-pose helps you train, validate and test your animal pose estimation model. Check out how we have done it in Google Colab. We have evaluated our model (PCK@0.05) and the mean accuracy for the 23 keypoints for our model is 93.073%
We recommend using at least Nvidia V100 GPU for faster inferencing but T4 GPUs will still work.


Open In Colab

Getting Started

git clone https://github.com/Antony-gitau/AJA-pose.git
cd AJA-pose
pip install -e .

Getting our model and the dataset. We have made our model public and can be downloaded here

import urllib.request

# # Get the dataset
url = "https://storage.googleapis.com/figures-gp/animal-kingdom/dataset.zip"
destination = "dataset.zip"

urllib.request.urlretrieve(url, destination)

# Unzip the file
!unzip dataset.zip

# The model file
url = "https://storage.googleapis.com/figures-gp/animal-kingdom/all_animals_no_pretrain_106.pth"
destination = "all_animals_no_pretrain_60.pth"

urllib.request.urlretrieve(url, destination)

Test our model
We require the path to the test images directory and the test.json file in MPII format

from aja_pose import Model

# path to the images directory and annotation in mpii json format
images_directory = '' # Path to the images directory
mpii_json = '' # Path to the test.json file
model_file = 'all_animals_no_pretrain_60.pth' # Path to the model file 

# Initialize the class
model = Model()
# Test the model on Protocol 1
model.test(images_directory, protocol='P1', model=model_file)
# Test the model on Protocol 2
model.test(images_directory, protocol='P2', model=model_file)
# Test the model on birds class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='bird')
# Test the model on reptiles class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='reptile')
# Test the model on mammals class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='mammal')
# Test the model on fish class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='fish')
# Test the model on amphibian class Protocol 3
model.test(images_directory, protocol='P3', model=model_file, animal_class='amphibian')

You can also start to train your model or pretrain on top of ours

# train a VHR model
train_json = '' # labels for the train set (train.json)
valid_json = '' # Labels for the validation set (test.json)
model_file = '' # A pytorch model file to pretrain on.
model.train(images_directory, train_json, valid_json, pretrained=model_file)

# Train a model on a particular class e.g (Ampibian)
model.train(images_directory, protocol='P3', animal_class='amphibian', model=model_file)

Results

A sanity check on our model.

image3 image6
Ground Truth image
Predictions image

Performance

The performance of our model on the different animal classes is as shown below.

Animal Class Samples Head Shoulder Elbow Wrist Hip Knee Ankle Mouth Tail Mean
Birds 1705 95.756 93.637 89.774 88.179 98.975 97.582 94.326 98.447 95.112 95.164
Reptiles 1209 91.538 85.291 84.662 85.587 90.457 88.097 85.239 96.723 83.925 89.553
Mammals 1496 90.641 89.269 88.509 89.927 90.263 88.655 89.535 93.622 82.161 90.038
Fish 918 96.468 96.249 98.643 96.058 98.403 96.743 95.775 97.564 98.256 96.467
Amphibian 1279 98.128 94.342 97.948 98.508 95.491 94.957 94.319 98.702 99.568 95.493

The model performance on Protocol 1 and Protocol 2 is as shown below.

Protocol Samples Head Shoulder Elbow Wrist Hip Knee Ankle Mouth Tail Mean
P1 6620 94.230 91.054 90.806 90.920 94.414 93.233 92.094 96.867 92.346 93.073
P2 2883 88.683 75.815 80.223 81.136 85.568 83.840 82.028 94.799 72.506 83.711

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

aja_pose-0.0.1.tar.gz (24.0 MB view hashes)

Uploaded Source

Built Distribution

aja_pose-0.0.1-py3-none-any.whl (25.5 MB view hashes)

Uploaded Python 3

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page