Unverified Commit 681a0aae authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

[Enhancement]: fix docstr of anchor head (#4883)

parent 9e3fd33c
......@@ -503,10 +503,12 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
"""Transform network output for a batch into bbox predictions.
Args:
cls_scores (list[Tensor]): Box scores for each scale level
Has shape (N, num_anchors * num_classes, H, W)
bbox_preds (list[Tensor]): Box energies / deltas for each scale
level with shape (N, num_anchors * 4, H, W)
cls_scores (list[Tensor]): Box scores for each level in the
feature pyramid, has shape
(N, num_anchors * num_classes, H, W).
bbox_preds (list[Tensor]): Box energies / deltas for each
level in the feature pyramid, has shape
(N, num_anchors * 4, H, W).
img_metas (list[dict]): Meta information of each image, e.g.,
image size, scaling factor, etc.
cfg (mmcv.Config | None): Test / postprocessing configuration,
......@@ -558,8 +560,8 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
mlvl_anchors = self.anchor_generator.grid_anchors(
featmap_sizes, device=device)
cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)]
mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
if torch.onnx.is_in_onnx_export():
assert len(
......@@ -577,19 +579,19 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
if with_nms:
# some heads don't support with_nms argument
result_list = self._get_bboxes(cls_score_list, bbox_pred_list,
result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds,
mlvl_anchors, img_shapes,
scale_factors, cfg, rescale)
else:
result_list = self._get_bboxes(cls_score_list, bbox_pred_list,
result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds,
mlvl_anchors, img_shapes,
scale_factors, cfg, rescale,
with_nms)
return result_list
def _get_bboxes(self,
cls_score_list,
bbox_pred_list,
mlvl_cls_scores,
mlvl_bbox_preds,
mlvl_anchors,
img_shapes,
scale_factors,
......@@ -599,14 +601,17 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
"""Transform outputs for a batch item into bbox predictions.
Args:
cls_score_list (list[Tensor]): Box scores for a single scale level
Has shape (N, num_anchors * num_classes, H, W).
bbox_pred_list (list[Tensor]): Box energies / deltas for a single
scale level with shape (N, num_anchors * 4, H, W).
mlvl_anchors (list[Tensor]): Box reference for a single scale level
with shape (num_total_anchors, 4).
img_shapes (list[tuple[int]]): Shape of the batch input image,
list[(height, width, 3)].
mlvl_cls_scores (list[Tensor]): Each element in the list is
the scores of bboxes of single level in the feature pyramid,
has shape (N, num_anchors * num_classes, H, W).
mlvl_bbox_preds (list[Tensor]): Each element in the list is the
bboxes predictions of single level in the feature pyramid,
has shape (N, num_anchors * 4, H, W).
mlvl_anchors (list[Tensor]): Each element in the list is
the anchors of single level in feature pyramid, has shape
(num_anchors, 4).
img_shapes (list[tuple[int]]): Each tuple in the list represent
the shape(height, width, 3) of single image in the batch.
scale_factors (list[ndarray]): Scale factor of the batch
image arange as list[(w_scale, h_scale, w_scale, h_scale)].
cfg (mmcv.Config): Test / postprocessing configuration,
......@@ -625,18 +630,20 @@ class AnchorHead(BaseDenseHead, BBoxTestMixin):
box.
"""
cfg = self.test_cfg if cfg is None else cfg
assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
batch_size = cls_score_list[0].shape[0]
assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(
mlvl_anchors)
batch_size = mlvl_cls_scores[0].shape[0]
# convert to tensor to keep tracing
nms_pre_tensor = torch.tensor(
cfg.get('nms_pre', -1),
device=cls_score_list[0].device,
device=mlvl_cls_scores[0].device,
dtype=torch.long)
mlvl_bboxes = []
mlvl_scores = []
for cls_score, bbox_pred, anchors in zip(cls_score_list,
bbox_pred_list, mlvl_anchors):
for cls_score, bbox_pred, anchors in zip(mlvl_cls_scores,
mlvl_bbox_preds,
mlvl_anchors):
assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
cls_score = cls_score.permute(0, 2, 3,
1).reshape(batch_size, -1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment