DEEP LEARNING BEST PRACTICES: CHECKPOINTING YOUR DEEP LEARNING MODEL TRAINING
This article covers one of many best practices in Deep Learning, which is creating checkpoints while training your deep learning model. We will look at what needs to be saved while creating checkpoints, why checkpoints are needed (especially on NUS HPC systems), methods to create them, how to create checkpoints in various deep learning frameworks (Keras, Tensorflow, Pytorch) and their benefits.
What Needs to be Saved
Neural Networks consist of layers and many other parameters that define its architecture, training settings and Hyperparameters, and for a trained neural network – trained layers (weights and biases) which are called Parameters. Hyperparameters are user defined while Parameters are learned (obtained through training the neural network with data).
Hyperparameters
Hyperparameters define how much you want to train the neural network (training epochs), batch size, learning rate, optimisation function, configuration of the neural network (layers, hidden units in layers, activation function, etc). During the training phase of a deep learning project, you might want to tune these hyperparameters to obtain the optimal performance of the neural network model.
Parameters
The training phase of a deep learning project is highly repetitive and monotonous, but it is what produces the most important thing of the project the trained model parameters (weights and biases). These parameters exist in memory (RAM/GPU memory) and not on non-volatile storage. These parameters will not be saved unless you explicitly program it in your deep learning code.
Why Checkpoints Are Needed
Checkpointing is the practice or term used to describe saving a snapshot of your model parameters (weights) after every epoch of training. It is like saving levels in a game you are playing, when you can resume your game by loading the save file at any time. You can load the saved weights of the model and resume your training or even run an inference later.
The training phase for complex models are usually long (hours to days to weeks). On NUS HPC systems, the GPU queue for deep learning has a default walltime limit of 24 hours and max limit of 48 hours for job execution. Deep learning training jobs for complex models and large datasets might take a longer time to execute than the queue walltime limits.
Therefore, to not lose your training progress, it is advisable to implement checkpointing of your model’s parameters (weights) at every epoch or at every epoch but only if it is the best weights at that point in time.
Having the most up-to-date or best weights saved on non-volatile memory is good practice as it allows you to keep a copy of your progress at a given epoch in case you want to tune your hyperparameters at any given epoch. It also allows you to resume training from any epoch that has a checkpoint. If the job or process terminates prematurely, you can resume training by loading the weights from the last saved checkpoint or from any other checkpoints.
Methods of Creating Checkpoints
Saving weights every epoch can mean costly storage space if your model is highly complex and has a lot of learnable parameters (e.g.: VGG16). To avoid taking up so much storage space for checkpointing, you can implement (for other libraries/frameworks besides Keras) saving the best-only weights at each epoch.
The model is evaluated after each epoch and the weights with the highest accuracy lowest loss at that point in time will be saved. If the weights of the model at a given epoch does not produce the best accuracy or loss (defined by the user) the weights will not be saved, but training will still continue from that state.
Using this method, only the best weights are saved, keeping storage space usage to a minimal.
How to Create Checkpoints (Code Snippets)
Keras
Keras provides built-in checkpointing so there is no need to implement it like the other frameworks/libraries. It is provided as a callback called ModelCheckpoint that is passed in to the fit method for training.
Save Checkpoint
If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity (e.g. validation loss).
You can set the frequency of creating checkpoints by defining the period parameter.
checkpointer = ModelCheckpoint(filepath='/tmp/weights.hdf5', , monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False, mode='auto', period=1) ) model.fit(x_train, y_train, batch_size=128, epochs=20, verbose=0, validation_data=(X_test, Y_test), callbacks=[checkpointer])
• filepath: string, path to save the model file
• monitor: quantity to monitor
• verbose: verbosity mode, 0 or 1
• save_best_only: if save_best_only=True, the latest best model according to the quantity monitored will not be overwritten
• mode: one of {auto, min, max}. If save_best_only=True, the decision to overwrite the current save file is made based on either the maximization or the minimization of the monitored quantity. For val_acc, this should be max, for val_loss this should be min, etc. In auto mode, the direction is automatically inferred from the name of the monitored quantity
• save_weights_only: if True, then only the model’s weights will be saved (model.save_weights(filepath)), else the full model is saved (model.save(filepath)
• period: Interval (number of epochs) between checkpoints
Loading Saved Weights/Checkpoint
Loading saved weights is very straightforward as well. Just a simple load weights() call to a defined model and it will load the layer weights to the model.
# Assuming loaded_model is your model definition before compilation # Load Weights loaded_model.load_weights("final_weights.h5") # Compile Model opt = SGD(lr=lr) loaded_model.compile(loss=keras.losses.categorical_crossentropy, optimizer=opt, metrics=['accuracy'])
Keras makes it quick and easy to define checkpointing for your model. It can help you save the best only weights according to monitored quantity which can be a higher validation accuracy or a lower validation loss for example. You can set the epoch interval between checkpoints and define whether to save weights only or save weights and the full model definition. It is just one line of code to set it up and a parameter definition in the fit function to get it to work while training.
Tensorflow
In Tensorflow, you can use Keras through Tensorflow (tf.keras) and it will be the same as checkpointing in Keras or if you use the Tensorflow Low Level API or Eager Execution (without tf.keras for model definition) or custom Estimators you would have to implement it in the training loop.
With tf.Session (Low Level API) Training Method
# Initialize tf.Saver instances to save weights during training last_saver = tf.train.Saver() # will keep last 5 epochs best_saver = tf.train.Saver(max_to_keep=1) # only keep 1 best checkpoint (best on eval) begin_at_epoch = 0 with tf.Session() as sess: # Initialize model variables sess.run(train_model_spec['variable_init_op']) # Reload weights from directory if specified if restore_from is not None: logging.info("Restoring parameters from {}".format(restore_from)) if os.path.isdir(restore_from): restore_from = tf.train.latest_checkpoint(restore_from) begin_at_epoch = int(restore_from.split('-')[-1]) last_saver.restore(sess, restore_from) best_eval_acc = 0.0 for epoch in range(begin_at_epoch, begin_at_epoch + params.num_epochs): # Run one epoch logging.info("Epoch {}/{}".format(epoch + 1, begin_at_epoch + params.num_epochs)) # Compute number of batches in one epoch (one full pass over the training set) num_steps = (params.train_size + params.batch_size - 1) // params.batch_size train_sess(sess, train_model_spec, num_steps, train_writer, params) # Save weights last_save_path = os.path.join(model_dir, 'last_weights', 'after-epoch') last_saver.save(sess, last_save_path, global_step=epoch + 1) # Evaluate for one epoch on validation set num_steps = (params.eval_size + params.batch_size - 1) // params.batch_size metrics = evaluate_sess(sess, eval_model_spec, num_steps, eval_writer) # If best_eval, best_save_path eval_acc = metrics['accuracy'] if eval_acc >= best_eval_acc: # Store new best accuracy best_eval_acc = eval_acc # Save weights best_save_path = os.path.join(model_dir, 'best_weights', 'after-epoch') best_save_path = best_saver.save(sess, best_save_path, global_step=epoch + 1) logging.info("- Found new best accuracy, saving in {}".format(best_save_path)) # Save best eval metrics in a json file in the model directory best_json_path = os.path.join(model_dir, "metrics_eval_best_weights.json") save_dict_to_json(metrics, best_json_path) # Save latest eval metrics in a json file in the model directory last_json_path = os.path.join(model_dir, "metrics_eval_last_weights.json") save_dict_to_json(metrics, last_json_path)
From CS230 Code Example
With Tensorflow Estimators
If you are using models from Official Tensorflow Model Implementations GitHub repository (https://github.com/tensorflow/models/), the models have built in checkpointing.
Estimators automatically write the following to disk:
• checkpoints, which are versions of the model created during training
• event files, which contain information that TensorBoard uses to create visualisations
To specify the top-level directory in which the Estimator stores its information, assign a value to the optional model_dir argument of any Estimator’s constructor or if you are running the premade training script directly, it can be defined using –model_dir or –train_dir depending on the script.
For coding it directly:
classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10, 10], n_classes=3, model_dir='models/iris')
DNN Classifier can be a constructor for any classifier.
TF Estimators Checkpoint Frequency
By default, the Estimator saves checkpoints in the model_dir according to the following schedule:
• Writes a checkpoint every 10 minutes (600 seconds)
• Writes a checkpoint when the train method starts (first iteration) and completes (final iteration)
• Retains only the 5 most recent checkpoints in the directory
You can change the default schedule by taking the following steps:
- Create a tf.estimator.RunConfig object that defines the desired schedule
- When instantiating the Estimator, pass that RunConfig object to the Estimator’s config argument
The below code snippet changes the frequency to every 20 minutes and only 10 most recent checkpoints are saved.
my_checkpointing_config = tf.estimator.RunConfig( save_checkpoints_secs = 20*60, # Save checkpoints every 20 minutes. keep_checkpoint_max = 10, # Retain the 10 most recent checkpoints. ) classifier = tf.estimator.DNNClassifier( feature_columns=my_feature_columns, hidden_units=[10, 10], n_classes=3, model_dir='models/iris', config=my_checkpointing_config)
Resuming from a TF Estimator Checkpoint
The first time you call an Estimator’s train method, TensorFlow saves a checkpoint to the model_dir. Each subsequent call to the Estimator’s train, evaluate, or predict method causes the following:
- The Estimator builds the model’s graph by running the model_fn().
- The Estimator initialises the weights of the new model from the data stored in the most recent checkpoint.
Once checkpoints exist, TensorFlow rebuilds the model each time you call train(), evaluate(), or predict().
Pytorch
Implementing checkpointing in PyTorch is similar to in Tensorflow Low Level API you would have to save weights in the training loop or train-evaluate loop for saving best only weights based on validation. Below are functions that help in saving and loading checkpoints.
Save Checkpoint
def save_checkpoint(state, is_best, checkpoint): """Saves model and training parameters at checkpoint + 'last.pth.tar'. If is_best==True, also saves checkpoint + 'best.pth.tar' Args: state: (dict) contains model's state_dict, may contain other keys such as epoch, optimizer state_dict is_best: (bool) True if it is the best model seen till now checkpoint: (string) folder where parameters are to be saved """ filepath = os.path.join(checkpoint, 'last.pth.tar') if not os.path.exists(checkpoint): print("Checkpoint Directory does not exist! Making directory {}".format(checkpoint)) os.mkdir(checkpoint) else: print("Checkpoint Directory exists! ") torch.save(state, filepath) if is_best: shutil.copyfile(filepath, os.path.join(checkpoint, 'best.pth.tar'))
Load Checkpoint
def load_checkpoint(checkpoint, model, optimizer=None): """Loads model parameters (state_dict) from file_path. If optimizer is provided, loads state_dict of optimizer assuming it is present in checkpoint. Args: checkpoint: (string) filename which needs to be loaded model: (torch.nn.Module) model for which the parameters are loaded optimizer: (torch.optim) optional: resume optimizer from checkpoint """ if not os.path.exists(checkpoint): raise("File doesn't exist {}".format(checkpoint)) checkpoint = torch.load(checkpoint) model.load_state_dict(checkpoint['state_dict']) if optimizer: optimizer.load_state_dict(checkpoint['optim_dict']) return checkpoint
In train-evaluate loop, you will have to implement a few things :
- Whether to resume from a checkpoint
- Train – Evaluate every iteration and tracking best accuracy or validation
- Save weights if the training iterations producebest accuracy or validation or just save every iteration
A small snippet for determining whether to restore from checkpoint
if restore_file is not None: restore_path = os.path.join(args.model_dir, args.restore_file + '.pth.tar') logging.info("Restoring parameters from {}".format(restore_path)) utils.load_checkpoint(restore_path, model, optimizer)
In the train-evaluate loop, you train one full pass over the training set, followed by evaluation of validation data. From evaluation you get the validation accuracy, keep track of the best validation accuracy. If validation accuracy is best, save checkpoint.
best_val_acc = 0.0 # track best validation accuracy (or loss) outside of loop for epoch in range(params.num_epochs): # Run one epoch logging.info("Epoch {}/{}".format(epoch + 1, params.num_epochs)) # compute number of batches in one epoch (one full pass over the training set) train(model, optimizer, loss_fn, train_dataloader, metrics, params) # Evaluate for one epoch on validation set val_metrics = evaluate(model, loss_fn, val_dataloader, metrics, params) # Get validation accuracy and track best validation accuracy val_acc = val_metrics['accuracy'] is_best = val_acc>=best_val_acc # Save weights if validation accuracy is best utils.save_checkpoint({'epoch': epoch + 1, 'state_dict': model.state_dict(), 'optim_dict' : optimizer.state_dict()}, is_best=is_best, checkpoint=model_dir) # If best_eval, best_save_path if is_best: logging.info("- Found new best accuracy") best_val_acc = val_acc # Save best val metrics in a json file in the model directory best_json_path = os.path.join(model_dir, "metrics_val_best_weights.json") utils.save_dict_to_json(val_metrics, best_json_path)
Benefits of Checkpointing
Like how people usually save the progress of the document they are working on or the game they are current playing, checkpointing saves you from disappointment, anger, sadness and other negative emotions from having to lose valuable training progress data. Especially when training a complex model on a large dataset might take up to weeks, and losing progress halfway through might mean losing a few days’ worth of work.
You can resume training from any point in time in the training phase as long as a checkpoint exists for that point in time. You can use a lower learning rate, freeze some layers or even change your optimiser when resuming training from a checkpoint.
In short, use or implement checkpointing for training your deep learning models to prevent disappointment.
Resources
[1] Parameters vs Hyperparameters https://www.coursera.org/lecture/neural-networks-deep-learning/parameters-vs-hyperparameters-TBvb5
[2] Keras ModelCheckpoint https://keras.io/callbacks/#example-model-checkpoints
[3] Tensorflow/Keras ModelCheckpoint https://www.tensorflow.org/tutorials/keras/save_and_restore_models
[4] Tensorflow Estimators Checkpoint https://www.tensorflow.org/guide/checkpoints
[5] Tensorflow Low Level API Saved Model https://www.tensorflow.org/guide/saved_model
[6] Tensorflow Eager Execution Checkpoint https://www.tensorflow.org/api_docs/python/tf/contrib/eager/save_network_checkpoint
[7] Saving and Loading Models (Coding TensorFlow) https://www.youtube.com/watch?v=HxtBIwfy0kM
[8] Pytorch Checkpoint https://pytorch.org/docs/stable/checkpoint.html
[9] Pytorch Save https://pytorch.org/docs/stable/torch.html?highlight=save#torch.save