diff --git a/test/test_models_detection_negative_samples.py b/test/test_models_detection_negative_samples.py index d50a8ac7cf7..7d2953f7e64 100644 --- a/test/test_models_detection_negative_samples.py +++ b/test/test_models_detection_negative_samples.py @@ -143,6 +143,17 @@ def test_forward_negative_sample_retinanet(self): assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) + def test_forward_negative_sample_fcos(self): + model = torchvision.models.detection.fcos_resnet50_fpn( + num_classes=2, min_size=100, max_size=100, pretrained_backbone=False + ) + + images, targets = self._make_empty_sample() + loss_dict = model(images, targets) + + assert_equal(loss_dict["bbox_regression"], torch.tensor(0.0)) + assert_equal(loss_dict["bbox_ctrness"], torch.tensor(0.0)) + def test_forward_negative_sample_ssd(self): model = torchvision.models.detection.ssd300_vgg16(num_classes=2, pretrained_backbone=False) diff --git a/torchvision/models/detection/fcos.py b/torchvision/models/detection/fcos.py index 71a6306e7e1..91baf1d0b29 100644 --- a/torchvision/models/detection/fcos.py +++ b/torchvision/models/detection/fcos.py @@ -59,9 +59,13 @@ def compute_loss( all_gt_classes_targets = [] all_gt_boxes_targets = [] for targets_per_image, matched_idxs_per_image in zip(targets, matched_idxs): - gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] + if len(targets_per_image["labels"]) == 0: + gt_classes_targets = targets_per_image["labels"].new_zeros((len(matched_idxs_per_image),)) + gt_boxes_targets = targets_per_image["boxes"].new_zeros((len(matched_idxs_per_image), 4)) + else: + gt_classes_targets = targets_per_image["labels"][matched_idxs_per_image.clip(min=0)] + gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] gt_classes_targets[matched_idxs_per_image < 0] = -1 # backgroud - gt_boxes_targets = targets_per_image["boxes"][matched_idxs_per_image.clip(min=0)] all_gt_classes_targets.append(gt_classes_targets) all_gt_boxes_targets.append(gt_boxes_targets) @@ -95,13 +99,14 @@ def compute_loss( ] bbox_reg_targets = torch.stack(bbox_reg_targets, dim=0) if len(bbox_reg_targets) == 0: - bbox_reg_targets.new_zeros(len(bbox_reg_targets)) - left_right = bbox_reg_targets[:, :, [0, 2]] - top_bottom = bbox_reg_targets[:, :, [1, 3]] - gt_ctrness_targets = torch.sqrt( - (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) - * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) - ) + gt_ctrness_targets = bbox_reg_targets.new_zeros(bbox_reg_targets.size()[:-1]) + else: + left_right = bbox_reg_targets[:, :, [0, 2]] + top_bottom = bbox_reg_targets[:, :, [1, 3]] + gt_ctrness_targets = torch.sqrt( + (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) + * (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]) + ) pred_centerness = bbox_ctrness.squeeze(dim=2) loss_bbox_ctrness = nn.functional.binary_cross_entropy_with_logits( pred_centerness[foregroud_mask], gt_ctrness_targets[foregroud_mask], reduction="sum"