detectron2.solver¶
- detectron2.solver.build_lr_scheduler(cfg: CfgNode, optimizer: Optimizer) _LRScheduler[源代码]¶
Build a LR scheduler from config.
- detectron2.solver.build_optimizer(cfg: CfgNode, model: Module) Optimizer[源代码]¶
Build an optimizer from config.
- detectron2.solver.get_default_optimizer_params(model: Module, base_lr: Optional[float] = None, weight_decay: Optional[float] = None, weight_decay_norm: Optional[float] = None, bias_lr_factor: Optional[float] = 1.0, weight_decay_bias: Optional[float] = None, lr_factor_func: Optional[Callable] = None, overrides: Optional[Dict[str, Dict[str, float]]] = None) List[Dict[str, Any]][源代码]¶
Get default param list for optimizer, with support for a few types of overrides. If no overrides needed, this is equivalent to model.parameters().
- 参数:
base_lr – lr for every group by default. Can be omitted to use the one in optimizer.
weight_decay – weight decay for every group by default. Can be omitted to use the one in optimizer.
weight_decay_norm – override weight decay for params in normalization layers
bias_lr_factor – multiplier of lr for bias parameters.
weight_decay_bias – override weight decay for bias parameters.
lr_factor_func – function to calculate lr decay rate by mapping the parameter names to corresponding lr decay rate. Note that setting this option requires also setting
base_lr.overrides – if not None, provides values for optimizer hyperparameters (LR, weight decay) for module parameters with a given name; e.g.
{"embedding": {"lr": 0.01, "weight_decay": 0.1}}will set the LR and weight decay values for all module parameters named embedding.
For common detection models,
weight_decay_normis the only option needed to be set.bias_lr_factor,weight_decay_biasare legacy settings from Detectron1 that are not found useful.Example:
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0), lr=0.01, weight_decay=1e-4, momentum=0.9)
- class detectron2.solver.LRMultiplier(optimizer: Optimizer, multiplier: ParamScheduler, max_iter: int, last_iter: int = -1)[源代码]¶
基类:
_LRSchedulerA LRScheduler which uses fvcore
ParamSchedulerto multiply the learning rate of each param in the optimizer. Every step, the learning rate of each parameter becomes its initial value multiplied by the output of the givenParamScheduler.The absolute learning rate value of each parameter can be different. This scheduler can be used as long as the relative scale among them do not change during training.
Examples:
LRMultiplier( opt, WarmupParamScheduler( MultiStepParamScheduler( [1, 0.1, 0.01], milestones=[60000, 80000], num_updates=90000, ), 0.001, 100 / 90000 ), max_iter=90000 )
- __init__(optimizer: Optimizer, multiplier: ParamScheduler, max_iter: int, last_iter: int = -1)[源代码]¶
- 参数:
optimizer – See
torch.optim.lr_scheduler.LRScheduler.last_iteris the same aslast_epoch.last_iter – See
torch.optim.lr_scheduler.LRScheduler.last_iteris the same aslast_epoch.multiplier – a fvcore ParamScheduler that defines the multiplier on every LR of the optimizer
max_iter – the total number of training iterations
- detectron2.solver.LRScheduler¶
_LRScheduler的别名
- class detectron2.solver.WarmupParamScheduler(scheduler: ParamScheduler, warmup_factor: float, warmup_length: float, warmup_method: str = 'linear', rescale_interval: bool = False)[源代码]¶
-
Add an initial warmup stage to another scheduler.
- __init__(scheduler: ParamScheduler, warmup_factor: float, warmup_length: float, warmup_method: str = 'linear', rescale_interval: bool = False)[源代码]¶
- 参数:
scheduler – warmup will be added at the beginning of this scheduler
warmup_factor – the factor w.r.t the initial value of
scheduler, e.g. 0.001warmup_length – the relative length (in [0, 1]) of warmup steps w.r.t the entire training, e.g. 0.01
warmup_method – one of “linear” or “constant”
rescale_interval – whether we will rescale the interval of the scheduler after warmup