Link Search Menu Expand Document

ElasticDL Model Contribution

To submit an ElasticDL job, a user needs to provide a model file, such as mnist_functional_api.py used in this example.

This model file contains a model built with TensorFlow Keras API and other components required by ElasticDL, including dataset_fn, loss, optimizer, and eval_metrics_fn.

Model File Components

model

model is a Keras model built using either TensorFlow Keras functional API or model subclassing.

The following example shows a model using functional API, which has one input with shape (28, 28), and one output with shape (10,):

inputs = tf.keras.Input(shape=(28, 28), name='image')
x = tf.keras.layers.Reshape((28, 28, 1))(inputs)
x = tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation='relu')(x)
x = tf.keras.layers.BatchNormalization()(x)
x = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(x)
x = tf.keras.layers.Dropout(0.25)(x)
x = tf.keras.layers.Flatten()(x)
outputs = tf.keras.layers.Dense(10)(x)

model = tf.keras.Model(inputs=inputs, outputs=outputs, name='mnist_model')

Another example using model subclassing:

class MnistModel(tf.keras.Model):
    def __init__(self):
        super(MnistModel, self).__init__(name='mnist_model')
        self._reshape = tf.keras.layers.Reshape((28, 28, 1))
        self._conv1 = tf.keras.layers.Conv2D(
            32, kernel_size=(3, 3), activation='relu')
        self._conv2 = tf.keras.layers.Conv2D(
            64, kernel_size=(3, 3), activation='relu')
        self._batch_norm = tf.keras.layers.BatchNormalization()
        self._maxpooling = tf.keras.layers.MaxPooling2D(
            pool_size=(2, 2))
        self._dropout = tf.keras.layers.Dropout(0.25)
        self._flatten = tf.keras.layers.Flatten()
        self._dense = tf.keras.layers.Dense(10)

    def call(self, inputs, training=False):
        x = self._reshape(inputs)
        x = self._conv1(x)
        x = self._conv2(x)
        x = self._batch_norm(x, training=training)
        x = self._maxpooling(x)
        if training:
            x = self._dropout(x, training=training)
        x = self._flatten(x)
        x = self._dense(x)
        return x

model = MnistModel()

dataset_fn

dataset_fn(dataset, mode)

dataset_fn is a function that takes a RecordIO dataset as input, pre-processes the data as needed, and returns the a dataset containing model_inputs and labels as a pair.

Argument:

  • dataset: a RecordIO dataset generated by ElasticDL. ElasticDL creates a dataset by iterating records from RecordIO file.
  • mode: This can be any values in defined from elasticdl.python.common.constants.Mode representing different phases such as training evaluation, and prediction. For example, if mode == Mode.Prediction, we don’t need to return labels inside _parse_data().

Output: a dataset, each data is a tuple (model_inputs, labels)

model_inputs is a dictionary of tensors, which will be used as model input. labels will be used as an input argument in loss.

Example:

def dataset_fn(dataset, mode):
    def _parse_data(record):
        if mode == Mode.PREDICTION:
            feature_description = {
                "image": tf.io.FixedLenFeature([28, 28], tf.float32)
            }
        else:
            feature_description = {
                "image": tf.io.FixedLenFeature([28, 28], tf.float32),
                "label": tf.io.FixedLenFeature([1], tf.int64),
            }
        r = tf.io.parse_single_example(record, feature_description)
        features = {
            "image": tf.math.divide(tf.cast(r["image"], tf.float32), 255.0)
        }
        if mode == Mode.PREDICTION:
            return features
        else:
            return features, tf.cast(r["label"], tf.int32)

    dataset = dataset.map(_parse_data)

    if mode != Mode.PREDICTION:
        dataset = dataset.shuffle(buffer_size=1024)
    return dataset

loss

loss(labels, predictions)

loss is the loss function used in ElasticDL training.

Arguments:

Example:

def loss(labels, predictions):
    return tf.reduce_mean(
        input_tensor=tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits=predictions, labels=labels.flatten()
        )
    )

optimizer

optimizer()

optimizer is a function returns a tf.train.Optimizer.

Example:

def optimizer(lr=0.1):
    return tf.optimizers.SGD(lr)

eval_metrics_fn

eval_metrics_fn()

eval_metrics_fn is a function that returns a dictionary where the key is name of the evaluation metric and the value is the evaluation metric result from the predictions and labels using TensorFlow API.

Example:

def eval_metrics_fn():
    return {
        "accuracy": lambda labels, predictions: tf.equal(
            tf.argmax(predictions, 1, output_type=tf.int32),
            tf.cast(tf.reshape(labels, [-1]), tf.int32),
        )
    }

Model Building Examples