ElasticDL Training Checkpoints Design
This document describes the design of saving checkpoints for ElasticDL training.
Motivation
Checkpoints stores the state of a training process, it includes not only
variables but also certain states like global_step
. In ElasticDL, we need to
save all parameters of the model to checkpoints during training for the
following reasons:
- We can export the model with the best evaluation performance during training from checkpoints. Too many epochs or steps can lead to over-fitting of the training dataset. Early stop is widely used to avoid over-fitting which will stop training when the evaluation performance stops improving even goes worse. So, the final parameters may not be the best. With checkpoints, we can choose a checkpoint with the best evaluation metrics to export for serving.
- We can restore the model parameters from a checkpoint to train or predict. Fault tolerance in ElasticDL can not resolve all unexpected errors e.g. a power outage and OS fault which is fatal. We are going to have a bad time if we haven’t finished storing any single checkpoint due to those errors. Other times, even if we don’t experience an unforeseen error, we might just want to resume a particular state of the training for a new experiment or try different things from a given state.
In the following sections, we will describe the design of how to save checkpoints and restore model parameters from a checkpoint in ElasticDL.
Design Components
This design contains two key parts: export and restore.
Export Model Parameters to a Checkpoint
When to save a checkpoint during training
In ElasticDL, we will save all model parameters to a checkpoint directory every
checkpoint_steps
. Though, some frameworks support to save checkpoints every N
steps or every N seconds, e.g.
Estimator.
Using ParameterServerStrategy in ElasticDL, Communication with PS is asynchronous and we cannot guarantee that the iteration steps of all PS instances are consistent at the same time. So, we don’t support to save checkpoints every N seconds.
Where to save a checkpoint
We need to set the checkpoint_dir
in args if we want to save checkpoints. The
saved checkpoint directory will be generated by checkpoint_dir
and iteration
steps which we call model version in ElasticDL, e.g.
/{checkpoint_dir}/version_{version}/
.
checkpoint_folder_name = "version_%s" % version
checkpoint_dir = os.path.join(
args.checkpoint_dir, checkpoint_folder_name
)
ElasticDL is a Kubernetes-native framework and local directory in the pod will
be lost if the pod exits in Kubernetes. So the checkpoint_dir
should be a
persistent directory which will not be lost after the pod exit. Now, we can
utilize
hostPath and
PersistentVolumes
to get a persistent directory to save checkpoints in ElasticDL.
How to save model parameters to a checkpoint directory
Using AllReduceStrategy in ElasticDL, each worker has all parameters and we can assign any worker to save all parameters to a checkpoint directory. Using ParameterServerStrategy, all model parameters are partitioned on all PS instances. Those PS instances need to save their parameter shards to the same checkpoint directory at the same step. Saving checkpoints for AllReduceStrategy can be seen as a special case that there is only a PS instance using ParameterServerStrategy. So, we can adopt the same way to save checkpoints for AllReduceStrategy and ParameterServerStrategy.
The parameters on each PS instance may contain non-embedding variables and a
part of embedding vectors of each embedding table. We create a Tensor
protobuf with values for each variable or with values and indices for embedding
vectors. To save an embedding table, we also need to create
EmbeddingTableInfo
to save the meta information like initializer
and dim
.
Then, we create a Model
protobuf including those Tensor
s and
EmbeddingTableInfo
s. Finally, we save the Model
to a file named
variables-{ps_id}-of-{ps_num}.chkpt
where ps_id
is the index of a PS
instance and ps_num
is the total number of PS instances. The checkpoint file
will be located in a subdirectory version_{version}
for each checkpoint
version. After the iteration versions of all PS instances exceed the checkpoint
version in version_{version}
, the version_{version}
will contain files of
the entire model parameters.
How many recent checkpoints to keep
We can set the keep_checkpoint_max
in ElasticDL to determine the maximum
number of recent checkpoints. After saving a checkpoint, each PS instance will
check the number of checkpoints it saved and will remove the checkpoint of its
own shard in the folder ‘model_v{version}’ with the smallest version if the
number exceeds the keep_checkpoint_max
. Then, it will remove the
version_{version}
subdirectory if it is empty.
Restore Model Parameters from a Checkpoint
The number of PS instances may vary when we restore model parameters to train
or predict. So, we need to repartition variables and embedding vectors in
checkpoint files to new PS instances. For each new PS instance, it will
traverse all protobuf files in the checkpoint directory and get variables and
embedding vectors if hash_utils.string_to_id(var.name)
or
hash_utils.int_to_id(embedding_id)
is equal to its index (ps_id
). The
pseudo-code is:
non_embedding_vars = {}
embedding_table_vectors = {}
for pb_file in os.listdir(checkpoint_dir):
model_pb = load_model_file(pb_file)
for param in model_pb.params:
if param.indices
embedding_table_vectors.setdefault(param.name, ([], []))
for embedding_id, vector in zip(param.indices, param.values):
if hash_util.int_to_id(embedding_id) == ps_id:
embedding_table_vectors[param.name].append(embedding_id)
embedding_table_vectors[param.name].append(vector)
else:
if hash_utils.string_to_id(param.name) == ps_id
non_embedding_vars[param.name]=param.values