Skip to content

tf2_make_image_classifier

The make_image_classifier Python library can be used for training various TensorFlow 2 image classification models that are available from Tensorflow Hub.

Code for the Docker images and additional Python code is available from here:

https://github.com/waikato-datamining/tensorflow/tree/master/image_classification2

Prerequisites#

Make sure you have the directory structure created as outlined in the Prerequisites.

Data#

In this example, we will use the 102 flowers dataset, which consists of 102 different categories (~ species) of flowers. More precisely, we will download the dataset with the flowers already split into categories from which we will use a subset to speed up the training process.

Download the dataset from the following URL into the data directory and extract it:

https://datasets.cms.waikato.ac.nz/ufdl/data/102flowers/102flowers-subdir.zip

Once extracted, you can delete all sub-directories apart from:

  • alpine_sea_holly
  • anthurium
  • artichoke

Rename the subdir directory to 3flowers and move it into the data folder of our directory structure outlined.

Training#

For training, we will use the following docker image:

waikatodatamining/tf_image_classification2:2.9.1_cuda11.1

If you only have a CPU machine available, then use this one instead:

waikatodatamining/tf_image_classification2:2.9.1_cpu

The training script is called make_image_classifier, for which we can invoke the help screen as follows:

docker run --rm -t waikatodatamining/tf_image_classification2:2.9.1_cuda11.1 make_image_classifier --helpfull   # GPU
docker run --rm -t waikatodatamining/tf_image_classification2:2.9.1_cpu make_image_classifier --helpfull        # CPU

It is good practice creating a separate sub-directory for each training run, with a directory name that hints at what dataset and model were used. So for our first training run, which will use mainly default parameters, we will create the following directory in the output folder:

3flowers-tf2-default

The following command will train an EfficientNet b0 model for 10 epochs:

GPU:

docker run --rm \
  -u $(id -u):$(id -g) -e USER=$USER \
  --gpus=all \
  -v `pwd`:/workspace \
  -v `pwd`/cache:/tmp/tfhub_modules \
  -t waikatodatamining/tf_image_classification2:2.9.1_cuda11.1 \
  make_image_classifier \
  --image_dir /workspace/data/3flowers \
  --image_size 224 \
  --saved_model_dir /workspace/output/3flowers-tf2-default \
  --labels_output_file /workspace/output/3flowers-tf2-default/labels.txt \
  --tflite_output_file /workspace/output/3flowers-tf2-default/model.tflite \
  --tfhub_module https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/feature_vector/2 \
  --train_epochs 10

CPU:

docker run --rm \
  -u $(id -u):$(id -g) -e USER=$USER \
  -v `pwd`:/workspace \
  -v `pwd`/cache:/tmp/tfhub_modules \
  -t waikatodatamining/tf_image_classification2:2.9.1_cpu \
  make_image_classifier \
  --image_dir /workspace/data/3flowers \
  --image_size 224 \
  --saved_model_dir /workspace/output/3flowers-tf2-default \
  --labels_output_file /workspace/output/3flowers-tf2-default/labels.txt \
  --tflite_output_file /workspace/output/3flowers-tf2-default/model.tflite \
  --tfhub_module https://tfhub.dev/google/imagenet/efficientnet_v2_imagenet21k_ft1k_b0/feature_vector/2 \
  --train_epochs 10

Predicting#

For making predictions for a single image, you can use the script label_image.

Since we will want to batch predict multiple images, will use the script predict_poll instead:

GPU:

docker run --rm \
  -u $(id -u):$(id -g) -e USER=$USER \
  -v `pwd`:/workspace \
  -t waikatodatamining/tf_image_classification2:2.9.1_cuda11.1 \
  predict_poll \
  --model /workspace/output/3flowers-tf2-default/model.tflite \
  --labels /workspace/output/3flowers-tf2-default/labels.txt \
  --input_mean 0 \
  --input_std 255 \
  --prediction_in /workspace/predictions/in \
  --prediction_out /workspace/predictions/out

CPU:

docker run --rm \
  -u $(id -u):$(id -g) -e USER=$USER \
  -v `pwd`:/workspace \
  -t waikatodatamining/tf_image_classification2:2.9.1_cpu \
  predict_poll \
  --model /workspace/output/3flowers-tf2-default/model.tflite \
  --labels /workspace/output/3flowers-tf2-default/labels.txt \
  --input_mean 0 \
  --input_std 255 \
  --prediction_in /workspace/predictions/in \
  --prediction_out /workspace/predictions/out

E.g., for the image_02048.jpg from the anthurium class, we will get a JSON file similar to this one:

{
  "anthurium": 0.8697358965873718,
  "alpine_sea_holly": 0.06686040759086609,
  "artichoke": 0.0634036660194397
}

Notes