TensorFlow 2 is coming.

If you are used to coding in TensorFlow 1.x, things are about to change. Coding in TensorFlow 2.0 is imperative, free from sessions, and includes an official Keras high level API. This tutorial explains the basics of image classification with TensorFlow 2. We'll cover:

  • Data Pipeline
  • Model Pipeline
  • Multiple-GPU
  • Callbacks

You can find the code of this tutorial at this Repo.


Machine learning solutions typically start with a data pipeline which consists of three main steps:

  • Load data from storage
  • An interface for feeding data into the training pipeline
  • Miscellaneous tasks such as preprocessing, shuffling and batching

Load Data

For image classification, it is common to read the images and labels into data arrays (numpy ndarrays). The Oth dimension of these arrays is equal to the total number of samples. Customized data usually needs a customized function. In this tutorial, we leverage Keras's load_data function to read the popular CIFAR10 dataset:

(x,y), (x_test, y_test) = keras.datasets.cifar10.load_data()

We can verify the type and shape of these data arrays:

print(type(x), type(y))
print(x.shape, y.shape) (<type 'numpy.ndarray'>, <type 'numpy.ndarray'>)
((50000, 32, 32, 3), (50000, 1))


Although it is possible to directly feed numpy ndarrays to the training loop, doing so makes it difficult to incorporate data augmentation, which is randomized on the fly. What is needed here is an interface that can handle both the data and the preprocessing steps applied to the data.

TensorFlow provides a very sophisticated Dataset API for this purpose. A TensorFlow Dataset essentially provides two things:

  • A collection of elements (nested structures of tensors)
  • A "logical plan" of transformations that act on those elements, where we can apply the necessary preprocessing jobs.

A TensorFlow dataset can be directly created from the data arrays. We can use the take(1) to fetch the first element of the dataset, which is a tuple that contains the image tensor and the label tensor:

train_dataset = tf.data.Dataset.from_tensor_slices((x,y))
for image, label in train_dataset.take(1): print(image.shape, label.shape) (TensorShape([32, 32, 3]), TensorShape([1]))

These are the first 20 images in the dataset:



Thanks to TensorFlow Dataset's ability to handle transformations, we can now add the preprocessing jobs.

Let's first add data augmentation: We pad four black pixels to the border of the image, then randomly crops 32x32 regions from the padded image, and finally perform random horizontal flips of the image. TensorFlow Dataset uses the map function to apply the augmentation to each element.

def augmentation(x, y): x = tf.image.resize_with_crop_or_pad( x, HEIGHT + 8, WIDTH + 8) x = tf.image.random_crop(x, [HEIGHT, WIDTH, NUM_CHANNELS]) x = tf.image.random_flip_left_right(x) return x, y train_dataset = train_dataset.map(augmentation)

These are the first 20 images after the augmentation:


Data augmentation is frequently used to "inflate" the training data and improve the generalization performance. Data augmentation should only be applied to the training set because the randomized nature of the data augmentation will make the inference (and thus your validation score) non-deterministic.

Next, we randomly shuffle the dataset. TensorFlow Dataset has a shuffle method, all we need to do is append it to the Dataset object:

train_dataset = train_dataset.map(augmentation).shuffle(50000)

Notice, for perfect shuffling, a buffer size should be greater than or equal to the full size of the dataset (50000 in this case). Below are the 20 images from the Dataset after shuffling. They are not the same image as the first 20 images stored in the original dataset:


It is also common practice to normalize the data, for example, by linearly scaling the image to have zero mean and unit variance. This can be achieved by mapping a customized normalize function across the dataset.

def normalize(x, y): x = tf.image.per_image_standardization(x) return x, y

Last but not least, we need to batch the data, and set drop_remainder to True in case the number of samples in the dataset is not evenly divisible by the batch_size.

train_dataset = train_dataset.map(augmentation).map(normalize).shuffle(50000).batch(128, drop_remainder=True)

Now we have a complete data pipeline. Next, we will define the model and create a training pipeline.


Define a Model

TensorFlow 2 uses Keras as its high-level API. Keras has two ways to define a model: Sequential and Functional.

from tf.keras.models import Sequential, Model
from tf.keras.layers import Input, Conv2, MaxPooling2D, Flatten, Dense # Sequential API
model = Sequential([ Conv2D(32, (3, 3), activation='relu', input_shape=(32, 32, 3)), MaxPooling2D(pool_size=(2, 2)), Flatten(), Dense(10, activation='softmax')
]) # Functional API
inputs = Input(shape=(32, 32, 3))
x = Conv2D(32, (3, 3), activation='relu')(inputs)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Flatten()(x)
x = Dense(10, activation='softmax')(x)
model = Model(inputs=inputs, outputs=x)

The two code segments above define the same model. The main difference is the Sequential API requires its first layer to be provided with input_shape; the Functional API requires its first layer to be a tf.keras.layers.Input layer, and needs to call the tf.keras.models.Model constructor at the end.

Sequential API requires less typing, but functional API is more flexible -- it allows a model to be non-sequential. For example, to have the skip connection in ResNet. This tutorial adapts TensorFlow's official Keras implementation of ResNet, which uses the functional API.

input_shape = (32, 32, 3)
img_input = Input(shape=input_shape)
model = resnet_cifar_model.resnet56(img_input, classes=10)

A Keras model needs to be compiled before being trained. The compilation of the model essentially defines three things: the loss function, the optimizer and the metrics for evaluation:

model.compile( loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.SGD(learning_rate=0.1, momentum=0.9), metrics=['accuracy'])

Notice we use sparse_categorical_crossentropy and sparse_categorical_accuracy here because each label is represented by a single integer (index of the class). One should use categorical_crossentropy and categorical_accuracy if a one-hot vector represents each label.

Train and Evaluation

Keras uses the fit API to train a model. Optionally, one can test the model on a validation dataset at every validation_freq training epoch. Notice we use the test dataset for validation only because CIFAR10 does not natively provide a validation set. Validation of the model should be conducted on a set of data split from the training set.

model.fit(train_dataset, epochs=60, validation_data=test_dataset, validation_freq=1)

Notice in this example, the fit function takes TensorFlow Dataset objects (train_dataset and test_dataset). As previously mentioned, it can also take numpy ndarrays as the input. The downside of using arrays is the lack of flexibility to apply transformations on the dataset.

model.fit(x, y, batch_size=128, epochs=5, shuffle=True, validation_data=(x_test, y_test))

To evaluate the model, call the evaluate method with the test dataset:


Save and Restore

Keras models have native support for saving/restoring model definitions and weights -- all you need to do is call the save and load_model APIs. If your model has residual layers, it also saves the moving statistics of the batch normalization layer:

model.save('model.h5') new_model = keras.models.load_model('model.h5') # Gives the same accuracy as model

However, there is one caveat: models created by sub-classing can not be saved by model.save(). This is because sub-classing defines the model's topology as Python code (rather than as a static graph of layers). That means the model's topology cannot be inspected or serialized. As a result, the following methods and attributes are not available for subclassed models:

model.inputs and model.outputs.
model.to_yaml() and model.to_json()
model.get_config() and model.save()


So far, we have shown how to use TensorFlow's Dataset API to create a data pipeline, and how to use the Keras API to define the model and conduct the training and evaluation. The next step is to make the code run with multiple GPUs.

In fact, Tensorflow 2 has made it very easy to convert your single-GPU implementation to run with multiple GPUs. All you need to do is define a distribute strategy and create the model under the strategy's scope:

mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope(): model = resnet.resnet56(classes=NUM_CLASSES) model.compile( optimizer=keras.optimizers.SGD(learning_rate=0.1, momentum=0.9), loss='sparse_categorical_crossentropy', metrics=['accuracy']) 

We use MirroredStrategy here, which supports synchronous distributed training on multiple GPUs on one machine. By default, it uses NVIDIA NCCL as the multi-gpu all-reduce implementation.

Note that you'll want to scale the batch size with the data pipeline's batch method based on the number of GPUs that you're using.

train_loader = train_loader.map(preprocess).shuffle(50000).batch(BS_PER_GPU*NUM_GPUS)
test_loader = test_loader.map(preprocess).batch(BS_PER_GPU*NUM_GPUS)


Often we need to perform custom operation during training. For example, you might want to log statistics during the training for debugging or optimization purposes; implement a learning rate schedule to improve the efficiency of training; or save visual snapshots of filter banks as they converge. In TensorFlow 2, you can use the callback feature to implement customized events during training.


TensorBoard is mainly used to log and visualize information during training. It is handy for examining the performance of the model. Tensorboard support is provided via the tensorflow.keras.callbacks.TensorBoard callback function:

from tensorflow.keras.callbacks import TensorBoard log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S") tensorboard_callback = TensorBoard( log_dir=log_dir, update_freq='batch', histogram_freq=1) model.fit(..., callbacks=[tensorboard_callback])

In the above example, we first create a TensorBoard callback that record data for each training step (via update_freq=batch), then attach this callback to the fit function. TensorFlow will generate tfevents files, which can be visualized with TensorBoard. For example, this is the visualization of classification accuracy during the training (blue is the training accuracy, red is the validation accuracy):


Learning Rate Schedule

Often, we would like to have fine control of learning rate as the training progresses. A custom learning rate schedule can be implemented as callback functions. Here, we create a customized schedule function that decreases the learning rate using a step function (at 30th epoch and 45th epoch). This schedule is converted to a keras.callbacks.LearningRateScheduler and attached to the fit function.

from tensorflow.keras.callbacks import LearningRateScheduler BASE_LEARNING_RATE = 0.1
LR_SCHEDULE = [(0.1, 30), (0.01, 45)] def schedule(epoch): initial_learning_rate = BASE_LEARNING_RATE * BS_PER_GPU / 128 learning_rate = initial_learning_rate for mult, start_epoch in LR_SCHEDULE: if epoch >= start_epoch: learning_rate = initial_learning_rate * mult else: break tf.summary.scalar('learning rate', data=learning_rate, step=epoch) return learning_rate lr_schedule_callback = LearningRateScheduler(schedule) model.fit(..., callbacks=[..., lr_schedule_callback])

These are the statistics of the customized learning rate during a 60-epochs training:


This tutorial explains the basic of TensorFlow 2.0 with image classification as an example. We covered:

  • Data pipeline with TensorFlow 2's dataset API
  • Train, evaluation, save and restore models with Keras (TensorFlow 2's official high-level API)
  • Multiple-GPU with distributed strategy
  • Customized training with callbacks

Below is the full code of this tutorial. You can also reproduce our tutorials on TensorFlow 2.0 using this Tensorflow 2.0 Tutorial repo.

import datetime import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.callbacks import TensorBoard, LearningRateScheduler import resnet NUM_GPUS = 2
BS_PER_GPU = 128
WIDTH = 32
LR_SCHEDULE = [(0.1, 30), (0.01, 45)] def preprocess(x, y): x = tf.image.per_image_standardization(x) return x, y def augmentation(x, y): x = tf.image.resize_with_crop_or_pad( x, HEIGHT + 8, WIDTH + 8) x = tf.image.random_crop(x, [HEIGHT, WIDTH, NUM_CHANNELS]) x = tf.image.random_flip_left_right(x) return x, y def schedule(epoch): initial_learning_rate = BASE_LEARNING_RATE * BS_PER_GPU / 128 learning_rate = initial_learning_rate for mult, start_epoch in LR_SCHEDULE: if epoch >= start_epoch: learning_rate = initial_learning_rate * mult else: break tf.summary.scalar('learning rate', data=learning_rate, step=epoch) return learning_rate (x,y), (x_test, y_test) = keras.datasets.cifar10.load_data() train_dataset = tf.data.Dataset.from_tensor_slices((x,y))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) tf.random.set_seed(22)
train_dataset = train_dataset.map(augmentation).map(preprocess).shuffle(NUM_TRAIN_SAMPLES).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True)
test_dataset = test_dataset.map(preprocess).batch(BS_PER_GPU * NUM_GPUS, drop_remainder=True) input_shape = (32, 32, 3)
img_input = tf.keras.layers.Input(shape=input_shape)
opt = keras.optimizers.SGD(learning_rate=0.1, momentum=0.9) if NUM_GPUS == 1: model = resnet.resnet56(img_input=img_input, classes=NUM_CLASSES) model.compile( optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy'])
else: mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = resnet.resnet56(img_input=img_input, classes=NUM_CLASSES) model.compile( optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['sparse_categorical_accuracy']) log_dir="logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(log_dir + "/metrics")
tensorboard_callback = TensorBoard( log_dir=log_dir, update_freq='batch', histogram_freq=1) lr_schedule_callback = LearningRateScheduler(schedule) model.fit(train_dataset, epochs=NUM_EPOCHS, validation_data=test_dataset, validation_freq=1, callbacks=[tensorboard_callback, lr_schedule_callback])
model.evaluate(test_dataset) model.save('model.h5') new_model = keras.models.load_model('model.h5') new_model.evaluate(test_dataset)