detectron2.config¶
Related tutorials: Yacs Configs, Extend Detectron2’s Defaults.
- class detectron2.config.CfgNode(init_dict=None, key_list=None, new_allowed=False)¶
基类:
CfgNodeThe same as fvcore.common.config.CfgNode, but different in:
Use unsafe yaml loading by default. Note that this may lead to arbitrary code execution: you must not load a config file from untrusted sources before manually inspecting the content of the file.
Support config versioning. When attempting to merge an old config, it will convert the old config automatically.
- classmethod load_yaml_with_base(filename: str, allow_unsafe: bool = False) Dict[str, Any][源代码]¶
- Just like yaml.load(open(filename)), but inherit attributes from its
_BASE_.
- merge_from_list(cfg_list: List[str]) Callable[[], None][源代码]¶
- 参数:
cfg_list (list) – list of configs to merge from.
- merge_from_other_cfg(cfg_other: CfgNode) Callable[[], None][源代码]¶
- 参数:
cfg_other (CfgNode) – configs to merge from.
- dump(*args, **kwargs)¶
- 返回:
str – a yaml string representation of the config
- detectron2.config.get_cfg() CfgNode¶
Get a copy of the default config.
- 返回:
a detectron2 CfgNode instance.
- detectron2.config.set_global_cfg(cfg: CfgNode) None¶
Let the global config point to the given cfg.
Assume that the given “cfg” has the key “KEY”, after calling set_global_cfg(cfg), the key can be accessed by:
from detectron2.config import global_cfg print(global_cfg.KEY)
By using a hacky global config, you can access these configs anywhere, without having to pass the config object or the values deep into the code. This is a hacky feature introduced for quick prototyping / research exploration.
- detectron2.config.configurable(init_func=None, *, from_config=None)¶
Decorate a function or a class’s __init__ method so that it can be called with a
CfgNodeobject using afrom_config()function that translatesCfgNodeto arguments.Examples:
# Usage 1: Decorator on __init__: class A: @configurable def __init__(self, a, b=2, c=3): pass @classmethod def from_config(cls, cfg): # 'cfg' must be the first argument # Returns kwargs to be passed to __init__ return {"a": cfg.A, "b": cfg.B} a1 = A(a=1, b=2) # regular construction a2 = A(cfg) # construct with a cfg a3 = A(cfg, b=3, c=4) # construct with extra overwrite # Usage 2: Decorator on any function. Needs an extra from_config argument: @configurable(from_config=lambda cfg: {"a: cfg.A, "b": cfg.B}) def a_func(a, b=2, c=3): pass a1 = a_func(a=1, b=2) # regular call a2 = a_func(cfg) # call with a cfg a3 = a_func(cfg, b=3, c=4) # call with extra overwrite
- 参数:
init_func (callable) – a class’s
__init__method in usage 1. The class must have afrom_configclassmethod which takes cfg as the first argument.from_config (callable) – the from_config function in usage 2. It must take cfg as its first argument.
- detectron2.config.instantiate(cfg)¶
Recursively instantiate objects defined in dictionaries by “_target_” and arguments.
- 参数:
cfg – a dict-like object with “_target_” that defines the caller, and other keys that define the arguments
- 返回:
object instantiated by cfg
- class detectron2.config.LazyCall(target)¶
基类:
objectWrap a callable so that when it’s called, the call will not be executed, but returns a dict that describes the call.
LazyCall object has to be called with only keyword arguments. Positional arguments are not yet supported.
Examples:
from detectron2.config import instantiate, LazyCall layer_cfg = LazyCall(nn.Conv2d)(in_channels=32, out_channels=32) layer_cfg.out_channels = 64 # can edit it afterwards layer = instantiate(layer_cfg)
- class detectron2.config.LazyConfig¶
基类:
objectProvide methods to save, load, and overrides an omegaconf config object which may contain definition of lazily-constructed objects.
- static apply_overrides(cfg, overrides: List[str])[源代码]¶
In-place override contents of cfg.
- 参数:
cfg – an omegaconf config object
overrides – list of strings in the format of “a=b” to override configs. See https://hydra.cc/docs/next/advanced/override_grammar/basic/ for syntax.
- 返回:
the cfg object
- static load(filename: str, keys: Union[None, str, Tuple[str, ...]] = None)[源代码]¶
Load a config file.
- 参数:
filename – absolute path or relative path w.r.t. the current working directory
keys – keys to load and return. If not given, return all keys (whose values are config objects) in a dict.
- static load_rel(filename: str, keys: Union[None, str, Tuple[str, ...]] = None)[源代码]¶
Similar to
load(), but load path relative to the caller’s source file.This has the same functionality as a relative import, except that this method accepts filename as a string, so more characters are allowed in the filename.
- static save(cfg, filename: str)[源代码]¶
Save a config object to a yaml file. Note that when the config dictionary contains complex objects (e.g. lambda), it can’t be saved to yaml. In that case we will print an error and attempt to save to a pkl file instead.
- 参数:
cfg – an omegaconf config object
filename – yaml file name to save the config file
- static to_py(cfg, prefix: str = 'cfg.')[源代码]¶
Try to convert a config object into Python-like psuedo code.
Note that perfect conversion is not always possible. So the returned results are mainly meant to be human-readable, and not meant to be executed.
- 参数:
cfg – an omegaconf config object
prefix – root name for the resulting code (default: “cfg.”)
- 返回:
str of formatted Python code
Yaml Config References¶
1
2# -----------------------------------------------------------------------------
3# Convention about Training / Test specific parameters
4# -----------------------------------------------------------------------------
5# Whenever an argument can be either used for training or for testing, the
6# corresponding name will be post-fixed by a _TRAIN for a training parameter,
7# or _TEST for a test-specific parameter.
8# For example, the number of images during training will be
9# IMAGES_PER_BATCH_TRAIN, while the number of images for testing will be
10# IMAGES_PER_BATCH_TEST
11
12# -----------------------------------------------------------------------------
13# Config definition
14# -----------------------------------------------------------------------------
15
16_C = CN()
17
18# The version number, to upgrade from old configs to new ones if any
19# changes happen. It's recommended to keep a VERSION in your config file.
20_C.VERSION = 2
21
22_C.MODEL = CN()
23_C.MODEL.LOAD_PROPOSALS = False
24_C.MODEL.MASK_ON = False
25_C.MODEL.KEYPOINT_ON = False
26_C.MODEL.DEVICE = "cuda"
27_C.MODEL.META_ARCHITECTURE = "GeneralizedRCNN"
28
29# Path (a file path, or URL like detectron2://.., https://..) to a checkpoint file
30# to be loaded to the model. You can find available models in the model zoo.
31_C.MODEL.WEIGHTS = ""
32
33# Values to be used for image normalization (BGR order, since INPUT.FORMAT defaults to BGR).
34# To train on images of different number of channels, just set different mean & std.
35# Default values are the mean pixel value from ImageNet: [103.53, 116.28, 123.675]
36_C.MODEL.PIXEL_MEAN = [103.530, 116.280, 123.675]
37# When using pre-trained models in Detectron1 or any MSRA models,
38# std has been absorbed into its conv1 weights, so the std needs to be set 1.
39# Otherwise, you can use [57.375, 57.120, 58.395] (ImageNet std)
40_C.MODEL.PIXEL_STD = [1.0, 1.0, 1.0]
41
42
43# -----------------------------------------------------------------------------
44# INPUT
45# -----------------------------------------------------------------------------
46_C.INPUT = CN()
47# By default, {MIN,MAX}_SIZE options are used in transforms.ResizeShortestEdge.
48# Please refer to ResizeShortestEdge for detailed definition.
49# Size of the smallest side of the image during training
50_C.INPUT.MIN_SIZE_TRAIN = (800,)
51# Sample size of smallest side by choice or random selection from range give by
52# INPUT.MIN_SIZE_TRAIN
53_C.INPUT.MIN_SIZE_TRAIN_SAMPLING = "choice"
54# Maximum size of the side of the image during training
55_C.INPUT.MAX_SIZE_TRAIN = 1333
56# Size of the smallest side of the image during testing. Set to zero to disable resize in testing.
57_C.INPUT.MIN_SIZE_TEST = 800
58# Maximum size of the side of the image during testing
59_C.INPUT.MAX_SIZE_TEST = 1333
60# Mode for flipping images used in data augmentation during training
61# choose one of ["horizontal, "vertical", "none"]
62_C.INPUT.RANDOM_FLIP = "horizontal"
63
64# `True` if cropping is used for data augmentation during training
65_C.INPUT.CROP = CN({"ENABLED": False})
66# Cropping type. See documentation of `detectron2.data.transforms.RandomCrop` for explanation.
67_C.INPUT.CROP.TYPE = "relative_range"
68# Size of crop in range (0, 1] if CROP.TYPE is "relative" or "relative_range" and in number of
69# pixels if CROP.TYPE is "absolute"
70_C.INPUT.CROP.SIZE = [0.9, 0.9]
71
72
73# Whether the model needs RGB, YUV, HSV etc.
74# Should be one of the modes defined here, as we use PIL to read the image:
75# https://pillow.readthedocs.io/en/stable/handbook/concepts.html#concept-modes
76# with BGR being the one exception. One can set image format to BGR, we will
77# internally use RGB for conversion and flip the channels over
78_C.INPUT.FORMAT = "BGR"
79# The ground truth mask format that the model will use.
80# Mask R-CNN supports either "polygon" or "bitmask" as ground truth.
81_C.INPUT.MASK_FORMAT = "polygon" # alternative: "bitmask"
82
83
84# -----------------------------------------------------------------------------
85# Dataset
86# -----------------------------------------------------------------------------
87_C.DATASETS = CN()
88# List of the dataset names for training. Must be registered in DatasetCatalog
89# Samples from these datasets will be merged and used as one dataset.
90_C.DATASETS.TRAIN = ()
91# List of the pre-computed proposal files for training, which must be consistent
92# with datasets listed in DATASETS.TRAIN.
93_C.DATASETS.PROPOSAL_FILES_TRAIN = ()
94# Number of top scoring precomputed proposals to keep for training
95_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TRAIN = 2000
96# List of the dataset names for testing. Must be registered in DatasetCatalog
97_C.DATASETS.TEST = ()
98# List of the pre-computed proposal files for test, which must be consistent
99# with datasets listed in DATASETS.TEST.
100_C.DATASETS.PROPOSAL_FILES_TEST = ()
101# Number of top scoring precomputed proposals to keep for test
102_C.DATASETS.PRECOMPUTED_PROPOSAL_TOPK_TEST = 1000
103
104# -----------------------------------------------------------------------------
105# DataLoader
106# -----------------------------------------------------------------------------
107_C.DATALOADER = CN()
108# Number of data loading threads
109_C.DATALOADER.NUM_WORKERS = 4
110# If True, each batch should contain only images for which the aspect ratio
111# is compatible. This groups portrait images together, and landscape images
112# are not batched with portrait images.
113_C.DATALOADER.ASPECT_RATIO_GROUPING = True
114# Options: TrainingSampler, RepeatFactorTrainingSampler
115_C.DATALOADER.SAMPLER_TRAIN = "TrainingSampler"
116# Repeat threshold for RepeatFactorTrainingSampler
117_C.DATALOADER.REPEAT_THRESHOLD = 0.0
118# if True, take square root when computing repeating factor
119_C.DATALOADER.REPEAT_SQRT = True
120# Tf True, when working on datasets that have instance annotations, the
121# training dataloader will filter out images without associated annotations
122_C.DATALOADER.FILTER_EMPTY_ANNOTATIONS = True
123
124# ---------------------------------------------------------------------------- #
125# Backbone options
126# ---------------------------------------------------------------------------- #
127_C.MODEL.BACKBONE = CN()
128
129_C.MODEL.BACKBONE.NAME = "build_resnet_backbone"
130# Freeze the first several stages so they are not trained.
131# There are 5 stages in ResNet. The first is a convolution, and the following
132# stages are each group of residual blocks.
133_C.MODEL.BACKBONE.FREEZE_AT = 2
134
135
136# ---------------------------------------------------------------------------- #
137# FPN options
138# ---------------------------------------------------------------------------- #
139_C.MODEL.FPN = CN()
140# Names of the input feature maps to be used by FPN
141# They must have contiguous power of 2 strides
142# e.g., ["res2", "res3", "res4", "res5"]
143_C.MODEL.FPN.IN_FEATURES = []
144_C.MODEL.FPN.OUT_CHANNELS = 256
145
146# Options: "" (no norm), "GN"
147_C.MODEL.FPN.NORM = ""
148
149# Types for fusing the FPN top-down and lateral features. Can be either "sum" or "avg"
150_C.MODEL.FPN.FUSE_TYPE = "sum"
151
152
153# ---------------------------------------------------------------------------- #
154# Proposal generator options
155# ---------------------------------------------------------------------------- #
156_C.MODEL.PROPOSAL_GENERATOR = CN()
157# Current proposal generators include "RPN", "RRPN" and "PrecomputedProposals"
158_C.MODEL.PROPOSAL_GENERATOR.NAME = "RPN"
159# Proposal height and width both need to be greater than MIN_SIZE
160# (a the scale used during training or inference)
161_C.MODEL.PROPOSAL_GENERATOR.MIN_SIZE = 0
162
163
164# ---------------------------------------------------------------------------- #
165# Anchor generator options
166# ---------------------------------------------------------------------------- #
167_C.MODEL.ANCHOR_GENERATOR = CN()
168# The generator can be any name in the ANCHOR_GENERATOR registry
169_C.MODEL.ANCHOR_GENERATOR.NAME = "DefaultAnchorGenerator"
170# Anchor sizes (i.e. sqrt of area) in absolute pixels w.r.t. the network input.
171# Format: list[list[float]]. SIZES[i] specifies the list of sizes to use for
172# IN_FEATURES[i]; len(SIZES) must be equal to len(IN_FEATURES) or 1.
173# When len(SIZES) == 1, SIZES[0] is used for all IN_FEATURES.
174_C.MODEL.ANCHOR_GENERATOR.SIZES = [[32, 64, 128, 256, 512]]
175# Anchor aspect ratios. For each area given in `SIZES`, anchors with different aspect
176# ratios are generated by an anchor generator.
177# Format: list[list[float]]. ASPECT_RATIOS[i] specifies the list of aspect ratios (H/W)
178# to use for IN_FEATURES[i]; len(ASPECT_RATIOS) == len(IN_FEATURES) must be true,
179# or len(ASPECT_RATIOS) == 1 is true and aspect ratio list ASPECT_RATIOS[0] is used
180# for all IN_FEATURES.
181_C.MODEL.ANCHOR_GENERATOR.ASPECT_RATIOS = [[0.5, 1.0, 2.0]]
182# Anchor angles.
183# list[list[float]], the angle in degrees, for each input feature map.
184# ANGLES[i] specifies the list of angles for IN_FEATURES[i].
185_C.MODEL.ANCHOR_GENERATOR.ANGLES = [[-90, 0, 90]]
186# Relative offset between the center of the first anchor and the top-left corner of the image
187# Value has to be in [0, 1). Recommend to use 0.5, which means half stride.
188# The value is not expected to affect model accuracy.
189_C.MODEL.ANCHOR_GENERATOR.OFFSET = 0.0
190
191# ---------------------------------------------------------------------------- #
192# RPN options
193# ---------------------------------------------------------------------------- #
194_C.MODEL.RPN = CN()
195_C.MODEL.RPN.HEAD_NAME = "StandardRPNHead" # used by RPN_HEAD_REGISTRY
196
197# Names of the input feature maps to be used by RPN
198# e.g., ["p2", "p3", "p4", "p5", "p6"] for FPN
199_C.MODEL.RPN.IN_FEATURES = ["res4"]
200# Remove RPN anchors that go outside the image by BOUNDARY_THRESH pixels
201# Set to -1 or a large value, e.g. 100000, to disable pruning anchors
202_C.MODEL.RPN.BOUNDARY_THRESH = -1
203# IOU overlap ratios [BG_IOU_THRESHOLD, FG_IOU_THRESHOLD]
204# Minimum overlap required between an anchor and ground-truth box for the
205# (anchor, gt box) pair to be a positive example (IoU >= FG_IOU_THRESHOLD
206# ==> positive RPN example: 1)
207# Maximum overlap allowed between an anchor and ground-truth box for the
208# (anchor, gt box) pair to be a negative examples (IoU < BG_IOU_THRESHOLD
209# ==> negative RPN example: 0)
210# Anchors with overlap in between (BG_IOU_THRESHOLD <= IoU < FG_IOU_THRESHOLD)
211# are ignored (-1)
212_C.MODEL.RPN.IOU_THRESHOLDS = [0.3, 0.7]
213_C.MODEL.RPN.IOU_LABELS = [0, -1, 1]
214# Number of regions per image used to train RPN
215_C.MODEL.RPN.BATCH_SIZE_PER_IMAGE = 256
216# Target fraction of foreground (positive) examples per RPN minibatch
217_C.MODEL.RPN.POSITIVE_FRACTION = 0.5
218# Options are: "smooth_l1", "giou", "diou", "ciou"
219_C.MODEL.RPN.BBOX_REG_LOSS_TYPE = "smooth_l1"
220_C.MODEL.RPN.BBOX_REG_LOSS_WEIGHT = 1.0
221# Weights on (dx, dy, dw, dh) for normalizing RPN anchor regression targets
222_C.MODEL.RPN.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
223# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
224_C.MODEL.RPN.SMOOTH_L1_BETA = 0.0
225_C.MODEL.RPN.LOSS_WEIGHT = 1.0
226# Number of top scoring RPN proposals to keep before applying NMS
227# When FPN is used, this is *per FPN level* (not total)
228_C.MODEL.RPN.PRE_NMS_TOPK_TRAIN = 12000
229_C.MODEL.RPN.PRE_NMS_TOPK_TEST = 6000
230# Number of top scoring RPN proposals to keep after applying NMS
231# When FPN is used, this limit is applied per level and then again to the union
232# of proposals from all levels
233# NOTE: When FPN is used, the meaning of this config is different from Detectron1.
234# It means per-batch topk in Detectron1, but per-image topk here.
235# See the "find_top_rpn_proposals" function for details.
236_C.MODEL.RPN.POST_NMS_TOPK_TRAIN = 2000
237_C.MODEL.RPN.POST_NMS_TOPK_TEST = 1000
238# NMS threshold used on RPN proposals
239_C.MODEL.RPN.NMS_THRESH = 0.7
240# Set this to -1 to use the same number of output channels as input channels.
241_C.MODEL.RPN.CONV_DIMS = [-1]
242
243# ---------------------------------------------------------------------------- #
244# ROI HEADS options
245# ---------------------------------------------------------------------------- #
246_C.MODEL.ROI_HEADS = CN()
247_C.MODEL.ROI_HEADS.NAME = "Res5ROIHeads"
248# Number of foreground classes
249_C.MODEL.ROI_HEADS.NUM_CLASSES = 80
250# Names of the input feature maps to be used by ROI heads
251# Currently all heads (box, mask, ...) use the same input feature map list
252# e.g., ["p2", "p3", "p4", "p5"] is commonly used for FPN
253_C.MODEL.ROI_HEADS.IN_FEATURES = ["res4"]
254# IOU overlap ratios [IOU_THRESHOLD]
255# Overlap threshold for an RoI to be considered background (if < IOU_THRESHOLD)
256# Overlap threshold for an RoI to be considered foreground (if >= IOU_THRESHOLD)
257_C.MODEL.ROI_HEADS.IOU_THRESHOLDS = [0.5]
258_C.MODEL.ROI_HEADS.IOU_LABELS = [0, 1]
259# RoI minibatch size *per image* (number of regions of interest [ROIs]) during training
260# Total number of RoIs per training minibatch =
261# ROI_HEADS.BATCH_SIZE_PER_IMAGE * SOLVER.IMS_PER_BATCH
262# E.g., a common configuration is: 512 * 16 = 8192
263_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
264# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
265_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
266
267# Only used on test mode
268
269# Minimum score threshold (assuming scores in a [0, 1] range); a value chosen to
270# balance obtaining high recall with not having too many low precision
271# detections that will slow down inference post processing steps (like NMS)
272# A default threshold of 0.0 increases AP by ~0.2-0.3 but significantly slows down
273# inference.
274_C.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.05
275# Overlap threshold used for non-maximum suppression (suppress boxes with
276# IoU >= this threshold)
277_C.MODEL.ROI_HEADS.NMS_THRESH_TEST = 0.5
278# If True, augment proposals with ground-truth boxes before sampling proposals to
279# train ROI heads.
280_C.MODEL.ROI_HEADS.PROPOSAL_APPEND_GT = True
281
282# ---------------------------------------------------------------------------- #
283# Box Head
284# ---------------------------------------------------------------------------- #
285_C.MODEL.ROI_BOX_HEAD = CN()
286# C4 don't use head name option
287# Options for non-C4 models: FastRCNNConvFCHead,
288_C.MODEL.ROI_BOX_HEAD.NAME = ""
289# Options are: "smooth_l1", "giou", "diou", "ciou"
290_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_TYPE = "smooth_l1"
291# The final scaling coefficient on the box regression loss, used to balance the magnitude of its
292# gradients with other losses in the model. See also `MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT`.
293_C.MODEL.ROI_BOX_HEAD.BBOX_REG_LOSS_WEIGHT = 1.0
294# Default weights on (dx, dy, dw, dh) for normalizing bbox regression targets
295# These are empirically chosen to approximately lead to unit variance targets
296_C.MODEL.ROI_BOX_HEAD.BBOX_REG_WEIGHTS = (10.0, 10.0, 5.0, 5.0)
297# The transition point from L1 to L2 loss. Set to 0.0 to make the loss simply L1.
298_C.MODEL.ROI_BOX_HEAD.SMOOTH_L1_BETA = 0.0
299_C.MODEL.ROI_BOX_HEAD.POOLER_RESOLUTION = 14
300_C.MODEL.ROI_BOX_HEAD.POOLER_SAMPLING_RATIO = 0
301# Type of pooling operation applied to the incoming feature map for each RoI
302_C.MODEL.ROI_BOX_HEAD.POOLER_TYPE = "ROIAlignV2"
303
304_C.MODEL.ROI_BOX_HEAD.NUM_FC = 0
305# Hidden layer dimension for FC layers in the RoI box head
306_C.MODEL.ROI_BOX_HEAD.FC_DIM = 1024
307_C.MODEL.ROI_BOX_HEAD.NUM_CONV = 0
308# Channel dimension for Conv layers in the RoI box head
309_C.MODEL.ROI_BOX_HEAD.CONV_DIM = 256
310# Normalization method for the convolution layers.
311# Options: "" (no norm), "GN", "SyncBN".
312_C.MODEL.ROI_BOX_HEAD.NORM = ""
313# Whether to use class agnostic for bbox regression
314_C.MODEL.ROI_BOX_HEAD.CLS_AGNOSTIC_BBOX_REG = False
315# If true, RoI heads use bounding boxes predicted by the box head rather than proposal boxes.
316_C.MODEL.ROI_BOX_HEAD.TRAIN_ON_PRED_BOXES = False
317
318# Federated loss can be used to improve the training of LVIS
319_C.MODEL.ROI_BOX_HEAD.USE_FED_LOSS = False
320# Sigmoid cross entrophy is used with federated loss
321_C.MODEL.ROI_BOX_HEAD.USE_SIGMOID_CE = False
322# The power value applied to image_count when calcualting frequency weight
323_C.MODEL.ROI_BOX_HEAD.FED_LOSS_FREQ_WEIGHT_POWER = 0.5
324# Number of classes to keep in total
325_C.MODEL.ROI_BOX_HEAD.FED_LOSS_NUM_CLASSES = 50
326
327# ---------------------------------------------------------------------------- #
328# Cascaded Box Head
329# ---------------------------------------------------------------------------- #
330_C.MODEL.ROI_BOX_CASCADE_HEAD = CN()
331# The number of cascade stages is implicitly defined by the length of the following two configs.
332_C.MODEL.ROI_BOX_CASCADE_HEAD.BBOX_REG_WEIGHTS = (
333 (10.0, 10.0, 5.0, 5.0),
334 (20.0, 20.0, 10.0, 10.0),
335 (30.0, 30.0, 15.0, 15.0),
336)
337_C.MODEL.ROI_BOX_CASCADE_HEAD.IOUS = (0.5, 0.6, 0.7)
338
339
340# ---------------------------------------------------------------------------- #
341# Mask Head
342# ---------------------------------------------------------------------------- #
343_C.MODEL.ROI_MASK_HEAD = CN()
344_C.MODEL.ROI_MASK_HEAD.NAME = "MaskRCNNConvUpsampleHead"
345_C.MODEL.ROI_MASK_HEAD.POOLER_RESOLUTION = 14
346_C.MODEL.ROI_MASK_HEAD.POOLER_SAMPLING_RATIO = 0
347_C.MODEL.ROI_MASK_HEAD.NUM_CONV = 0 # The number of convs in the mask head
348_C.MODEL.ROI_MASK_HEAD.CONV_DIM = 256
349# Normalization method for the convolution layers.
350# Options: "" (no norm), "GN", "SyncBN".
351_C.MODEL.ROI_MASK_HEAD.NORM = ""
352# Whether to use class agnostic for mask prediction
353_C.MODEL.ROI_MASK_HEAD.CLS_AGNOSTIC_MASK = False
354# Type of pooling operation applied to the incoming feature map for each RoI
355_C.MODEL.ROI_MASK_HEAD.POOLER_TYPE = "ROIAlignV2"
356
357
358# ---------------------------------------------------------------------------- #
359# Keypoint Head
360# ---------------------------------------------------------------------------- #
361_C.MODEL.ROI_KEYPOINT_HEAD = CN()
362_C.MODEL.ROI_KEYPOINT_HEAD.NAME = "KRCNNConvDeconvUpsampleHead"
363_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_RESOLUTION = 14
364_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_SAMPLING_RATIO = 0
365_C.MODEL.ROI_KEYPOINT_HEAD.CONV_DIMS = tuple(512 for _ in range(8))
366_C.MODEL.ROI_KEYPOINT_HEAD.NUM_KEYPOINTS = 17 # 17 is the number of keypoints in COCO.
367
368# Images with too few (or no) keypoints are excluded from training.
369_C.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE = 1
370# Normalize by the total number of visible keypoints in the minibatch if True.
371# Otherwise, normalize by the total number of keypoints that could ever exist
372# in the minibatch.
373# The keypoint softmax loss is only calculated on visible keypoints.
374# Since the number of visible keypoints can vary significantly between
375# minibatches, this has the effect of up-weighting the importance of
376# minibatches with few visible keypoints. (Imagine the extreme case of
377# only one visible keypoint versus N: in the case of N, each one
378# contributes 1/N to the gradient compared to the single keypoint
379# determining the gradient direction). Instead, we can normalize the
380# loss by the total number of keypoints, if it were the case that all
381# keypoints were visible in a full minibatch. (Returning to the example,
382# this means that the one visible keypoint contributes as much as each
383# of the N keypoints.)
384_C.MODEL.ROI_KEYPOINT_HEAD.NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS = True
385# Multi-task loss weight to use for keypoints
386# Recommended values:
387# - use 1.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is True
388# - use 4.0 if NORMALIZE_LOSS_BY_VISIBLE_KEYPOINTS is False
389_C.MODEL.ROI_KEYPOINT_HEAD.LOSS_WEIGHT = 1.0
390# Type of pooling operation applied to the incoming feature map for each RoI
391_C.MODEL.ROI_KEYPOINT_HEAD.POOLER_TYPE = "ROIAlignV2"
392
393# ---------------------------------------------------------------------------- #
394# Semantic Segmentation Head
395# ---------------------------------------------------------------------------- #
396_C.MODEL.SEM_SEG_HEAD = CN()
397_C.MODEL.SEM_SEG_HEAD.NAME = "SemSegFPNHead"
398_C.MODEL.SEM_SEG_HEAD.IN_FEATURES = ["p2", "p3", "p4", "p5"]
399# Label in the semantic segmentation ground truth that is ignored, i.e., no loss is calculated for
400# the correposnding pixel.
401_C.MODEL.SEM_SEG_HEAD.IGNORE_VALUE = 255
402# Number of classes in the semantic segmentation head
403_C.MODEL.SEM_SEG_HEAD.NUM_CLASSES = 54
404# Number of channels in the 3x3 convs inside semantic-FPN heads.
405_C.MODEL.SEM_SEG_HEAD.CONVS_DIM = 128
406# Outputs from semantic-FPN heads are up-scaled to the COMMON_STRIDE stride.
407_C.MODEL.SEM_SEG_HEAD.COMMON_STRIDE = 4
408# Normalization method for the convolution layers. Options: "" (no norm), "GN".
409_C.MODEL.SEM_SEG_HEAD.NORM = "GN"
410_C.MODEL.SEM_SEG_HEAD.LOSS_WEIGHT = 1.0
411
412_C.MODEL.PANOPTIC_FPN = CN()
413# Scaling of all losses from instance detection / segmentation head.
414_C.MODEL.PANOPTIC_FPN.INSTANCE_LOSS_WEIGHT = 1.0
415
416# options when combining instance & semantic segmentation outputs
417_C.MODEL.PANOPTIC_FPN.COMBINE = CN({"ENABLED": True}) # "COMBINE.ENABLED" is deprecated & not used
418_C.MODEL.PANOPTIC_FPN.COMBINE.OVERLAP_THRESH = 0.5
419_C.MODEL.PANOPTIC_FPN.COMBINE.STUFF_AREA_LIMIT = 4096
420_C.MODEL.PANOPTIC_FPN.COMBINE.INSTANCES_CONFIDENCE_THRESH = 0.5
421
422
423# ---------------------------------------------------------------------------- #
424# RetinaNet Head
425# ---------------------------------------------------------------------------- #
426_C.MODEL.RETINANET = CN()
427
428# This is the number of foreground classes.
429_C.MODEL.RETINANET.NUM_CLASSES = 80
430
431_C.MODEL.RETINANET.IN_FEATURES = ["p3", "p4", "p5", "p6", "p7"]
432
433# Convolutions to use in the cls and bbox tower
434# NOTE: this doesn't include the last conv for logits
435_C.MODEL.RETINANET.NUM_CONVS = 4
436
437# IoU overlap ratio [bg, fg] for labeling anchors.
438# Anchors with < bg are labeled negative (0)
439# Anchors with >= bg and < fg are ignored (-1)
440# Anchors with >= fg are labeled positive (1)
441_C.MODEL.RETINANET.IOU_THRESHOLDS = [0.4, 0.5]
442_C.MODEL.RETINANET.IOU_LABELS = [0, -1, 1]
443
444# Prior prob for rare case (i.e. foreground) at the beginning of training.
445# This is used to set the bias for the logits layer of the classifier subnet.
446# This improves training stability in the case of heavy class imbalance.
447_C.MODEL.RETINANET.PRIOR_PROB = 0.01
448
449# Inference cls score threshold, only anchors with score > INFERENCE_TH are
450# considered for inference (to improve speed)
451_C.MODEL.RETINANET.SCORE_THRESH_TEST = 0.05
452# Select topk candidates before NMS
453_C.MODEL.RETINANET.TOPK_CANDIDATES_TEST = 1000
454_C.MODEL.RETINANET.NMS_THRESH_TEST = 0.5
455
456# Weights on (dx, dy, dw, dh) for normalizing Retinanet anchor regression targets
457_C.MODEL.RETINANET.BBOX_REG_WEIGHTS = (1.0, 1.0, 1.0, 1.0)
458
459# Loss parameters
460_C.MODEL.RETINANET.FOCAL_LOSS_GAMMA = 2.0
461_C.MODEL.RETINANET.FOCAL_LOSS_ALPHA = 0.25
462_C.MODEL.RETINANET.SMOOTH_L1_LOSS_BETA = 0.1
463# Options are: "smooth_l1", "giou", "diou", "ciou"
464_C.MODEL.RETINANET.BBOX_REG_LOSS_TYPE = "smooth_l1"
465
466# One of BN, SyncBN, FrozenBN, GN
467# Only supports GN until unshared norm is implemented
468_C.MODEL.RETINANET.NORM = ""
469
470
471# ---------------------------------------------------------------------------- #
472# ResNe[X]t options (ResNets = {ResNet, ResNeXt}
473# Note that parts of a resnet may be used for both the backbone and the head
474# These options apply to both
475# ---------------------------------------------------------------------------- #
476_C.MODEL.RESNETS = CN()
477
478_C.MODEL.RESNETS.DEPTH = 50
479_C.MODEL.RESNETS.OUT_FEATURES = ["res4"] # res4 for C4 backbone, res2..5 for FPN backbone
480
481# Number of groups to use; 1 ==> ResNet; > 1 ==> ResNeXt
482_C.MODEL.RESNETS.NUM_GROUPS = 1
483
484# Options: FrozenBN, GN, "SyncBN", "BN"
485_C.MODEL.RESNETS.NORM = "FrozenBN"
486
487# Baseline width of each group.
488# Scaling this parameters will scale the width of all bottleneck layers.
489_C.MODEL.RESNETS.WIDTH_PER_GROUP = 64
490
491# Place the stride 2 conv on the 1x1 filter
492# Use True only for the original MSRA ResNet; use False for C2 and Torch models
493_C.MODEL.RESNETS.STRIDE_IN_1X1 = True
494
495# Apply dilation in stage "res5"
496_C.MODEL.RESNETS.RES5_DILATION = 1
497
498# Output width of res2. Scaling this parameters will scale the width of all 1x1 convs in ResNet
499# For R18 and R34, this needs to be set to 64
500_C.MODEL.RESNETS.RES2_OUT_CHANNELS = 256
501_C.MODEL.RESNETS.STEM_OUT_CHANNELS = 64
502
503# Apply Deformable Convolution in stages
504# Specify if apply deform_conv on Res2, Res3, Res4, Res5
505_C.MODEL.RESNETS.DEFORM_ON_PER_STAGE = [False, False, False, False]
506# Use True to use modulated deform_conv (DeformableV2, https://arxiv.org/abs/1811.11168);
507# Use False for DeformableV1.
508_C.MODEL.RESNETS.DEFORM_MODULATED = False
509# Number of groups in deformable conv.
510_C.MODEL.RESNETS.DEFORM_NUM_GROUPS = 1
511
512
513# ---------------------------------------------------------------------------- #
514# Solver
515# ---------------------------------------------------------------------------- #
516_C.SOLVER = CN()
517
518# Options: WarmupMultiStepLR, WarmupCosineLR.
519# See detectron2/solver/build.py for definition.
520_C.SOLVER.LR_SCHEDULER_NAME = "WarmupMultiStepLR"
521
522_C.SOLVER.MAX_ITER = 40000
523
524_C.SOLVER.BASE_LR = 0.001
525# The end lr, only used by WarmupCosineLR
526_C.SOLVER.BASE_LR_END = 0.0
527
528_C.SOLVER.MOMENTUM = 0.9
529
530_C.SOLVER.NESTEROV = False
531
532_C.SOLVER.WEIGHT_DECAY = 0.0001
533# The weight decay that's applied to parameters of normalization layers
534# (typically the affine transformation)
535_C.SOLVER.WEIGHT_DECAY_NORM = 0.0
536
537_C.SOLVER.GAMMA = 0.1
538# The iteration number to decrease learning rate by GAMMA.
539_C.SOLVER.STEPS = (30000,)
540# Number of decays in WarmupStepWithFixedGammaLR schedule
541_C.SOLVER.NUM_DECAYS = 3
542
543_C.SOLVER.WARMUP_FACTOR = 1.0 / 1000
544_C.SOLVER.WARMUP_ITERS = 1000
545_C.SOLVER.WARMUP_METHOD = "linear"
546# Whether to rescale the interval for the learning schedule after warmup
547_C.SOLVER.RESCALE_INTERVAL = False
548
549# Save a checkpoint after every this number of iterations
550_C.SOLVER.CHECKPOINT_PERIOD = 5000
551
552# Number of images per batch across all machines. This is also the number
553# of training images per step (i.e. per iteration). If we use 16 GPUs
554# and IMS_PER_BATCH = 32, each GPU will see 2 images per batch.
555# May be adjusted automatically if REFERENCE_WORLD_SIZE is set.
556_C.SOLVER.IMS_PER_BATCH = 16
557
558# The reference number of workers (GPUs) this config is meant to train with.
559# It takes no effect when set to 0.
560# With a non-zero value, it will be used by DefaultTrainer to compute a desired
561# per-worker batch size, and then scale the other related configs (total batch size,
562# learning rate, etc) to match the per-worker batch size.
563# See documentation of `DefaultTrainer.auto_scale_workers` for details:
564_C.SOLVER.REFERENCE_WORLD_SIZE = 0
565
566# Detectron v1 (and previous detection code) used a 2x higher LR and 0 WD for
567# biases. This is not useful (at least for recent models). You should avoid
568# changing these and they exist only to reproduce Detectron v1 training if
569# desired.
570_C.SOLVER.BIAS_LR_FACTOR = 1.0
571_C.SOLVER.WEIGHT_DECAY_BIAS = None # None means following WEIGHT_DECAY
572
573# Gradient clipping
574_C.SOLVER.CLIP_GRADIENTS = CN({"ENABLED": False})
575# Type of gradient clipping, currently 2 values are supported:
576# - "value": the absolute values of elements of each gradients are clipped
577# - "norm": the norm of the gradient for each parameter is clipped thus
578# affecting all elements in the parameter
579_C.SOLVER.CLIP_GRADIENTS.CLIP_TYPE = "value"
580# Maximum absolute value used for clipping gradients
581_C.SOLVER.CLIP_GRADIENTS.CLIP_VALUE = 1.0
582# Floating point number p for L-p norm to be used with the "norm"
583# gradient clipping type; for L-inf, please specify .inf
584_C.SOLVER.CLIP_GRADIENTS.NORM_TYPE = 2.0
585
586# Enable automatic mixed precision for training
587# Note that this does not change model's inference behavior.
588# To use AMP in inference, run inference under autocast()
589_C.SOLVER.AMP = CN({"ENABLED": False})
590
591# ---------------------------------------------------------------------------- #
592# Specific test options
593# ---------------------------------------------------------------------------- #
594_C.TEST = CN()
595# For end-to-end tests to verify the expected accuracy.
596# Each item is [task, metric, value, tolerance]
597# e.g.: [['bbox', 'AP', 38.5, 0.2]]
598_C.TEST.EXPECTED_RESULTS = []
599# The period (in terms of steps) to evaluate the model during training.
600# Set to 0 to disable.
601_C.TEST.EVAL_PERIOD = 0
602# The sigmas used to calculate keypoint OKS. See http://cocodataset.org/#keypoints-eval
603# When empty, it will use the defaults in COCO.
604# Otherwise it should be a list[float] with the same length as ROI_KEYPOINT_HEAD.NUM_KEYPOINTS.
605_C.TEST.KEYPOINT_OKS_SIGMAS = []
606# Maximum number of detections to return per image during inference (100 is
607# based on the limit established for the COCO dataset).
608_C.TEST.DETECTIONS_PER_IMAGE = 100
609
610_C.TEST.AUG = CN({"ENABLED": False})
611_C.TEST.AUG.MIN_SIZES = (400, 500, 600, 700, 800, 900, 1000, 1100, 1200)
612_C.TEST.AUG.MAX_SIZE = 4000
613_C.TEST.AUG.FLIP = True
614
615_C.TEST.PRECISE_BN = CN({"ENABLED": False})
616_C.TEST.PRECISE_BN.NUM_ITER = 200
617
618# ---------------------------------------------------------------------------- #
619# Misc options
620# ---------------------------------------------------------------------------- #
621# Directory where output files are written
622_C.OUTPUT_DIR = "./output"
623# Set seed to negative to fully randomize everything.
624# Set seed to positive to use a fixed seed. Note that a fixed seed increases
625# reproducibility but does not guarantee fully deterministic behavior.
626# Disabling all parallelism further increases reproducibility.
627_C.SEED = -1
628# Benchmark different cudnn algorithms.
629# If input images have very different sizes, this option will have large overhead
630# for about 10k iterations. It usually hurts total time, but can benefit for certain models.
631# If input images have the same or similar sizes, benchmark is often helpful.
632_C.CUDNN_BENCHMARK = False
633# The period (in terms of steps) for minibatch visualization at train time.
634# Set to 0 to disable.
635_C.VIS_PERIOD = 0
636
637# global config is for quick hack purposes.
638# You can set them in command line or config files,
639# and access it with:
640#
641# from detectron2.config import global_cfg
642# print(global_cfg.HACK)
643#
644# Do not commit any configs into it.
645_C.GLOBAL = CN()
646_C.GLOBAL.HACK = 1.0