Skip to content

model

ModelWrapper class

dlf.core.model.ModelWrapper(
    model,
    preprocessing=None,
    optimizer=None,
    loss=None,
    model_weights=None,
    additional_models=[],
    is_finetune=False,
)

Base class which wraps all necessary objects and methods to make training feasible

Args

  • model: tf.keras.models.Model. The model which should be wrapped
  • preprocessing: callable, optional. A preprocessing function which will be executed. Defaults to None.
  • optimizer: tf.keras.optimizers.Optimizer, optional. An optimizer used for training. Defaults to None.
  • loss: dict[str,dict[str, Any]], optional. A loss function which should be build to use for training. Defaults to None.

init_checkpoint method

ModelWrapper.init_checkpoint(skip_optimizer=False)

Initialize a checkpoint object for saving and restoring


num_iterations method

ModelWrapper.num_iterations()

Get current number of iterations done by the optimizer

Returns

int. Number of iterations


preprocess_input_wrapper method

ModelWrapper.preprocess_input_wrapper(x)

Model specific prerprocessing function

This methods normalizes an input image in range -1 to 1 and is executedbefore it is feeds into a model

Args

  • x: tf.Tensor. Image which should be normalized

Returns

  • tf.Tensor: normalized image

restore_model method

ModelWrapper.restore_model(path, step=None)

Restores a checkpoint at a given path

Arguments

  • path: str. Path where the checkpoints are stored to restore
  • step: int, optional. Sepcifies a checkpoint at step X which should be restored. Defaults to None.

save_model method

ModelWrapper.save_model(path, step)

Methods which handles a request to save a model

Args

  • path: str. Path where the checkpoints should be stored
  • step: int. Current step

tape_step method

ModelWrapper.tape_step(tape, loss_values)

Executed during training to backpropagate loss values

Args

  • tape: Instance of GradientTape
  • loss_values: dict[str, List[tf.Tensor]]. Dictionary containing the name of a loss and the corresponding values

Raises

  • ValueError: If no optimizer is available

Returns

  • List[tf.Tensor]: List of gradients

training_step method

ModelWrapper.training_step(record)

Executet during training for each batch of data

This method is executed during training for each batch which should be feed into the network to train it.

Args

  • record: dict. Training related data which contain the key 'x_batch' and 'y_batch'

Raises

  • ValueError: If the ModelWrapper contains no loss function

Returns

tuple(dict[str,Any], tf.Tensor): A dictionary with all assigned losses and the predicted logits


validation_step method

ModelWrapper.validation_step(record)

Exectued during validation

See training_step

Args

  • record: dict. Training related data which contain the key 'x_batch' and 'y_batch'

Raises

  • ValueError: If the ModelWrapper contains no loss function

Returns

tuple(dict[str,Any], tf.Tensor): A dictionary with all assigned losses and the predicted logits