Link Search Menu Expand Document

ElasticDL Model Contribution

To submit an ElasticDL job, a user needs to provide a model file, such as used in this example.

This model file contains a model built with TensorFlow Keras API and other components required by ElasticDL, including dataset_fn, loss, optimizer, and eval_metrics_fn.

Model File Components


model is a Keras model built using either TensorFlow Keras functional API or model subclassing.

The following example shows a model using functional API, which has one input with shape (28, 28), and one output with shape (10,):

inputs = tf.keras.Input(shape=(28, 28), name='image')
x = tf.keras.layers.Reshape((28, 28, 1))(inputs)
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(10)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')

Another example using model subclassing:

class MnistModel(tf.keras.Model):
    def __init__(self):
        super(MnistModel, self).__init__(name='mnist_model')
        self._reshape = tf.keras.layers.Reshape((28, 28, 1))
        self._conv1 = tf.keras.layers.Conv2D(
            32, kernel_size=(3, 3), activation='relu')
        self._conv2 = tf.keras.layers.Conv2D(
            64, kernel_size=(3, 3), activation='relu')
        self._batch_norm = tf.keras.layers.BatchNormalization()
        self._maxpooling = tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2))
        self._dropout = tf.keras.layers.Dropout(0.25)
        self._flatten = tf.keras.layers.Flatten()
        self._dense = tf.keras.layers.Dense(10)

    def call(self, inputs, training=False):
        x = self._reshape(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        x = self._batch_norm(x, training=training)
        x = self._maxpooling(x)
        if training:
            x = self._dropout(x, training=training)
        x = self._flatten(x)
        x = self._dense(x)
        return x

model = MnistModel()


dataset_fn(dataset, mode)

dataset_fn is a function that takes a RecordIO dataset as input, pre-processes the data as needed, and returns the a dataset containing model_inputs and labels as a pair.


  • dataset: a RecordIO dataset generated by ElasticDL. ElasticDL creates a dataset by iterating records from RecordIO file.
  • mode: This can be any values in defined from elasticdl.python.common.constants.Mode representing different phases such as training evaluation, and prediction. For example, if mode == Mode.Prediction, we don’t need to return labels inside _parse_data().

Output: a dataset, each data is a tuple (model_inputs, labels)

model_inputs is a dictionary of tensors, which will be used as model input. labels will be used as an input argument in loss.


def dataset_fn(dataset, mode):
    def _parse_data(record):
        if mode == Mode.PREDICTION:
            feature_description = {
                "image":[28, 28], tf.float32)
            feature_description = {
                "image":[28, 28], tf.float32),
                "label":[1], tf.int64),
        r =, feature_description)
        features = {
            "image": tf.math.divide(tf.cast(r["image"], tf.float32), 255.0)
        if mode == Mode.PREDICTION:
            return features
            return features, tf.cast(r["label"], tf.int32)

    dataset =

    if mode != Mode.PREDICTION:
        dataset = dataset.shuffle(buffer_size=1024)
    return dataset


loss(labels, predictions)

loss is the loss function used in ElasticDL training.



def loss(labels, predictions):
    return tf.reduce_mean(
            logits=predictions, labels=labels.flatten()



optimizer is a function returns a tf.train.Optimizer.


def optimizer(lr=0.1):
    return tf.optimizers.SGD(lr)



eval_metrics_fn is a function that returns a dictionary where the key is name of the evaluation metric and the value is the evaluation metric result from the predictions and labels using TensorFlow API.


def eval_metrics_fn():
    return {
        "accuracy": lambda labels, predictions: tf.equal(
            tf.argmax(predictions, 1, output_type=tf.int32),
            tf.cast(tf.reshape(labels, [-1]), tf.int32),

Model Building Examples