Link Search Menu Expand Document

ElasticDL Training Checkpoints Design

This document describes the design of saving checkpoints for ElasticDL training.


Checkpoints stores the state of a training process, it includes not only variables but also certain states like global_step. In ElasticDL, we need to save all parameters of the model to checkpoints during training for the following reasons:

  • We can export the model with the best evaluation performance during training from checkpoints. Too many epochs or steps can lead to over-fitting of the training dataset. Early stop is widely used to avoid over-fitting which will stop training when the evaluation performance stops improving even goes worse. So, the final parameters may not be the best. With checkpoints, we can choose a checkpoint with the best evaluation metrics to export for serving.
  • We can restore the model parameters from a checkpoint to train or predict. Fault tolerance in ElasticDL can not resolve all unexpected errors e.g. a power outage and OS fault which is fatal. We are going to have a bad time if we haven’t finished storing any single checkpoint due to those errors. Other times, even if we don’t experience an unforeseen error, we might just want to resume a particular state of the training for a new experiment or try different things from a given state.

In the following sections, we will describe the design of how to save checkpoints and restore model parameters from a checkpoint in ElasticDL.

Design Components

This design contains two key parts: export and restore.

Export Model Parameters to a Checkpoint

When to save a checkpoint during training

In ElasticDL, we will save all model parameters to a checkpoint directory every checkpoint_steps. Though, some frameworks support to save checkpoints every N steps or every N seconds, e.g. Estimator.

Using ParameterServerStrategy in ElasticDL, Communication with PS is asynchronous and we cannot guarantee that the iteration steps of all PS instances are consistent at the same time. So, we don’t support to save checkpoints every N seconds.

Where to save a checkpoint

We need to set the checkpoint_dir in args if we want to save checkpoints. The saved checkpoint directory will be generated by checkpoint_dir and iteration steps which we call model version in ElasticDL, e.g. /{checkpoint_dir}/version_{version}/.

checkpoint_folder_name = "version_%s" % version
checkpoint_dir = os.path.join(
    args.checkpoint_dir, checkpoint_folder_name

ElasticDL is a Kubernetes-native framework and local directory in the pod will be lost if the pod exits in Kubernetes. So the checkpoint_dir should be a persistent directory which will not be lost after the pod exit. Now, we can utilize hostPath and PersistentVolumes to get a persistent directory to save checkpoints in ElasticDL.

How to save model parameters to a checkpoint directory

Using AllReduceStrategy in ElasticDL, each worker has all parameters and we can assign any worker to save all parameters to a checkpoint directory. Using ParameterServerStrategy, all model parameters are partitioned on all PS instances. Those PS instances need to save their parameter shards to the same checkpoint directory at the same step. Saving checkpoints for AllReduceStrategy can be seen as a special case that there is only a PS instance using ParameterServerStrategy. So, we can adopt the same way to save checkpoints for AllReduceStrategy and ParameterServerStrategy.

The parameters on each PS instance may contain non-embedding variables and a part of embedding vectors of each embedding table. We create a Tensor protobuf with values for each variable or with values and indices for embedding vectors. To save an embedding table, we also need to create EmbeddingTableInfo to save the meta information like initializer and dim. Then, we create a Model protobuf including those Tensors and EmbeddingTableInfos. Finally, we save the Model to a file named variables-{ps_id}-of-{ps_num}.chkpt where ps_id is the index of a PS instance and ps_num is the total number of PS instances. The checkpoint file will be located in a subdirectory version_{version} for each checkpoint version. After the iteration versions of all PS instances exceed the checkpoint version in version_{version}, the version_{version} will contain files of the entire model parameters.

How many recent checkpoints to keep

We can set the keep_checkpoint_max in ElasticDL to determine the maximum number of recent checkpoints. After saving a checkpoint, each PS instance will check the number of checkpoints it saved and will remove the checkpoint of its own shard in the folder ‘model_v{version}’ with the smallest version if the number exceeds the keep_checkpoint_max. Then, it will remove the version_{version} subdirectory if it is empty.

Restore Model Parameters from a Checkpoint

The number of PS instances may vary when we restore model parameters to train or predict. So, we need to repartition variables and embedding vectors in checkpoint files to new PS instances. For each new PS instance, it will traverse all protobuf files in the checkpoint directory and get variables and embedding vectors if hash_utils.string_to_id( or hash_utils.int_to_id(embedding_id) is equal to its index (ps_id). The pseudo-code is:

non_embedding_vars = {}
embedding_table_vectors = {}
for pb_file in os.listdir(checkpoint_dir):
    model_pb = load_model_file(pb_file)
    for param in model_pb.params:
        if param.indices
            embedding_table_vectors.setdefault(, ([], []))
            for embedding_id, vector in zip(param.indices, param.values):
                if hash_util.int_to_id(embedding_id) == ps_id:
            if hash_utils.string_to_id( == ps_id