model
ModelWrapper class
dlf.core.model.ModelWrapper(
model,
preprocessing=None,
optimizer=None,
loss=None,
model_weights=None,
additional_models=[],
is_finetune=False,
)
Base class which wraps all necessary objects and methods to make training feasible
Args
- model: tf.keras.models.Model. The model which should be wrapped
- preprocessing: callable, optional. A preprocessing function which will be executed. Defaults to None.
- optimizer: tf.keras.optimizers.Optimizer, optional. An optimizer used for training. Defaults to None.
- loss: dict[str,dict[str, Any]], optional. A loss function which should be build to use for training. Defaults to None.
init_checkpoint method
ModelWrapper.init_checkpoint(skip_optimizer=False)
Initialize a checkpoint object for saving and restoring
num_iterations method
ModelWrapper.num_iterations()
Get current number of iterations done by the optimizer
Returns
int. Number of iterations
preprocess_input_wrapper method
ModelWrapper.preprocess_input_wrapper(x)
Model specific prerprocessing function
This methods normalizes an input image in range -1 to 1 and is executedbefore it is feeds into a model
Args
- x: tf.Tensor. Image which should be normalized
Returns
- tf.Tensor: normalized image
restore_model method
ModelWrapper.restore_model(path, step=None)
Restores a checkpoint at a given path
Arguments
- path: str. Path where the checkpoints are stored to restore
- step: int, optional. Sepcifies a checkpoint at step X which should be restored. Defaults to None.
save_model method
ModelWrapper.save_model(path, step)
Methods which handles a request to save a model
Args
- path: str. Path where the checkpoints should be stored
- step: int. Current step
tape_step method
ModelWrapper.tape_step(tape, loss_values)
Executed during training to backpropagate loss values
Args
- tape: Instance of GradientTape
- loss_values: dict[str, List[tf.Tensor]]. Dictionary containing the name of a loss and the corresponding values
Raises
- ValueError: If no optimizer is available
Returns
- List[tf.Tensor]: List of gradients
training_step method
ModelWrapper.training_step(record)
Executet during training for each batch of data
This method is executed during training for each batch which should be feed into the network to train it.
Args
- record: dict. Training related data which contain the key 'x_batch' and 'y_batch'
Raises
- ValueError: If the ModelWrapper contains no loss function
Returns
tuple(dict[str,Any], tf.Tensor): A dictionary with all assigned losses and the predicted logits
validation_step method
ModelWrapper.validation_step(record)
Exectued during validation
See training_step
Args
- record: dict. Training related data which contain the key 'x_batch' and 'y_batch'
Raises
- ValueError: If the ModelWrapper contains no loss function
Returns
tuple(dict[str,Any], tf.Tensor): A dictionary with all assigned losses and the predicted logits