SimCLR
SimCLR class
dlf.models.simclr.SimCLR(
feature_extractor,
num_projection_dimension=128,
linear_evaluation=None,
weight_decay=1e-06,
use_batch_norm=True,
model_weights=None,
summary=False,
optimizer=None,
loss=None,
**kwargs
)
A implementation of SimCLR
Aliases
- SimCLR
- sim_clr
- simclr
Arguments
- feature_extractor: dict. A feature extractor
- num_projection_dimension: int. Number of projection dimensions for unsupervised approach
- linear_evaluation: dict. Defines linear evaulation options, if None linear evaluation is desabled. Defaults to None.
- num_classes: int. Number of classes for linear evaluation
- linear_train_at_step: int. Number at which step the linear evaluation starts
- freeze_base_network: bool. If true, the feature extractor is not trainable during linear evaluation. Defaults to True.
- weight_decay: float. Amount of weight decay to use. Defaults to 1e-6.
- use_batch_norm: bool. If true, batch norm for the projection head is used. Defaults to True.
- model_weights: str, optional. Path to the pretrained model weights. Defaults to None.
- summary: bool, optional. If true a summary of the model will be printed to stdout. Defaults to False.
- optimizer: list of dict, optional. Name of optimizer used for training.
- loss: list of dict, optional. List of loss objects to build for this model. Defaults to None.
YAML Configuration
model:
simclr:
feature_extractor:
model_name: ResNet50
input_shape:
- &width 224
- &height 224
- 3
num_projection_dimension: 128
linear_evaluation:
num_classes: 62
linear_train_at_step: 5000
freeze_base_network: False
num_classes: 62
use_batch_norm: True
linear_train_at_step: 5000
weight_decay: 1e-6
summary: True
loss:
NTXentLoss:
batch_size: &batch_size 16
temperature: 0.5
use_cosine_similarity: True
# categorical_labels is True in reader otherwise it should be SparseCategoricalCrossentropy
CategoricalCrossentropy:
from_logits: True
reduction: "none"
optimizer:
- Adam: # lr for simclr unsupervised
learning_rate: 0.0001
- SGD: # lr for linear evaluation
learning_rate: 0.001
momentum: 0.9
References