﻿from tensorflow.keras import callbacks


class ParallelModelCheckpoint(callbacks.ModelCheckpoint):
    def __init__(self, model, filepath, monitor='val_loss', verbose=0,
                 save_best_only=False, save_weights_only=False,
                 mode='auto', save_freq='epoch'):
        self.single_model = model
        super(ParallelModelCheckpoint, self).__init__(filepath, monitor, verbose, save_best_only, save_weights_only,
                                                      mode, save_freq)

    def set_model(self, model):
        super(ParallelModelCheckpoint, self).set_model(self.single_model)
