ElasticDL Model Building

To submit an ElasticDL job, a user needs to provide a model file, such as used in this example.

This model file contains a model built with TensorFlow Keras API and other components required by ElasticDL, including dataset_fn, loss, optimizer, and eval_metrics_fn.

Model File Components


model is a Keras model built using either TensorFlow Keras functional API or model subclassing.

The following example shows a model using functional API, which has one input with shape (28, 28), and one putput with shape (10,):

inputs = tf.keras.Input(shape=(28, 28), name='image')
x = tf.keras.layers.Reshape((28, 28, 1))(inputs)
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(10)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')

Another example using model subclassing:

class MnistModel(tf.keras.Model):
    def __init__(self):
        super(MnistModel, self).__init__(name='mnist_model')
        self._reshape = tf.keras.layers.Reshape((28, 28, 1))
        self._conv1 = tf.keras.layers.Conv2D(
            32, kernel_size=(3, 3), activation='relu')
        self._conv2 = tf.keras.layers.Conv2D(
            64, kernel_size=(3, 3), activation='relu')
        self._batch_norm = tf.keras.layers.BatchNormalization()
        self._maxpooling = tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2))
        self._dropout = tf.keras.layers.Dropout(0.25)
        self._flatten = tf.keras.layers.Flatten()
        self._dense = tf.keras.layers.Dense(10)

    def call(self, inputs, training=False):
        x = self._reshape(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        x = self._batch_norm(x, training=training)
        x = self._maxpooling(x)
        if training:
            x = self._dropout(x, training=training)
        x = self._flatten(x)
        x = self._dense(x)
        return x

model = MnistModel()


dataset_fn(dataset, training)

dataset_fn is a function that takes a RecordIO dataset as input, pre-processes the data as needed, and returns the a dataset containing model_inputs and labels as a pair.


  • dataset: a RecordIO dataset generated by ElasticDL. ElasticDL creates a dataset by iterating records from RecordIO file.
  • mode: This can be any values in defined from elasticdl.python.common.constants.Mode representing different phases such as training evaluation, and prediction. For example, if mode == Mode.Prediction, we don’t need to return labels inside _parse_data().

Output: a dataset, each data is a tuple (model_inputs, labels)

model_inputs is a dictionary of tensors, which will be used as model input. labels will be used as an input argument in loss.


def dataset_fn(dataset, mode):
    def _parse_data(record):
        if mode == Mode.PREDICTION:
            feature_description = {
                "image":[28, 28], tf.float32)
            feature_description = {
                "image":[28, 28], tf.float32),
                "label":[1], tf.int64),
        r =, feature_description)
        features = {
            "image": tf.math.divide(tf.cast(r["image"], tf.float32), 255.0)
        if mode == Mode.PREDICTION:
            return features
            return features, tf.cast(r["label"], tf.int32)

    dataset =

    if mode != Mode.PREDICTION:
        dataset = dataset.shuffle(buffer_size=1024)
    return dataset


loss(output, labels)

loss is the loss function used in ElasticDL training.



def loss(output, labels):
    return tf.reduce_mean(
            logits=output, labels=labels.flatten()



optimizer is a function returns a tf.train.Optimizer.


def optimizer(lr=0.1):
    return tf.optimizers.SGD(lr)



eval_metrics_fn is a function that returns a dictionary where the key is name of the evaluation metric and the value is the evaluation metric result from the predictions and labels using TensorFlow API.


def eval_metrics_fn(predictions, labels):
    return {
        "accuracy": tf.reduce_mean(
                    tf.argmax(input=predictions, axis=1), labels.flatten()



prepare_data_for_a_single_file is to read a single file and do whatever user-defined logic to prepare the data (e.g, IO from the user’s file system, feature engineering), and return the serialized data. The function can be used to process data for training, evaluation and prediction. The only difference between prediction data with training/evaluation data is that the ‘label’ in prediction data should be empty. Users should be able to determine if the data file contains label (e.g, via the different formats of filename) and implement the logic to prepare the data accordingly.


def prepare_data_for_a_single_file(filename):
    An image classification dataset that images belonging to the same category located in the same directory.
    label = int(filename.split('/')[-2])
    image =
    numpy_image = np.array(image)
    example_dict = {
        "image": tf.train.Feature(
        "label": tf.train.Feature(
    example = tf.train.Example(
    return example.SerializeToString()

Model Building Examples

MNIST model using Keras functional API

MNIST model using Keras model subclassing

CIFAR10 model using Keras functional API

CIFAR10 model using Keras model subclassing