From 0d728c4b849dfa5857c16d050f236f14360d3735 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 10 May 2022 01:05:00 +0530 Subject: [PATCH 01/11] Try to converge implementations --- torchvision/ops/boxes.py | 28 +++++----------------------- 1 file changed, 5 insertions(+), 23 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 3b994879ecf..060cbd776a5 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -296,7 +296,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values for every element in boxes1 and boxes2 """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): + if not torch.jit.is_scripting() and not torch.jit.is_tracing() and called_itself: _log_api_usage_once(generalized_box_iou) inter, union = _box_inter_union(boxes1, boxes2) @@ -327,25 +327,8 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) - - inter, union = _box_inter_union(boxes1, boxes2) - iou = inter / union - - lti = torch.min(boxes1[:, None, :2], boxes2[:, None, :2]) - rbi = torch.max(boxes1[:, None, 2:], boxes2[:, None, 2:]) - - whi = (rbi - lti).clamp(min=0) # [N,M,2] - diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps - - # centers of boxes - x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 - y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 - x_g = (boxes2[:, 0] + boxes2[:, 2]) / 2 - y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 - # The distance between boxes' centers squared. - centers_distance_squared = (x_p - x_g) ** 2 + (y_p - y_g) ** 2 + diou = distance_box_iou(boxes1, boxes2, eps) + iou = box_iou(boxes1, boxes2) w_pred = boxes1[:, 2] - boxes1[:, 0] h_pred = boxes1[:, 3] - boxes1[:, 1] @@ -356,7 +339,7 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso v = (4 / (torch.pi ** 2)) * torch.pow((torch.atan(w_gt / h_gt) - torch.atan(w_pred / h_pred)), 2) with torch.no_grad(): alpha = v / (1 - iou + v + eps) - return iou - (centers_distance_squared / diagonal_distance_squared) - alpha * v + return diou - alpha * v def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tensor: @@ -381,8 +364,7 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) - inter, union = _box_inter_union(boxes1, boxes2) - iou = inter / union + iou = box_iou(boxes1, boxes2) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) From 475f656ef31486411acb6245232f2092405d7b14 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 10 May 2022 01:12:30 +0530 Subject: [PATCH 02/11] Uplift upcast --- torchvision/ops/_utils.py | 8 ++++++++ torchvision/ops/boxes.py | 9 +-------- torchvision/ops/ciou_loss.py | 2 +- torchvision/ops/giou_loss.py | 11 ++--------- 4 files changed, 12 insertions(+), 18 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 8a02490ab13..41c05226163 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -67,3 +67,11 @@ def split_normalization_params( else: other_params.extend(p for p in module.parameters() if p.requires_grad) return norm_params, other_params + + +def _upcast(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.is_floating_point(): + return t if t.dtype in (torch.float32, torch.float64) else t.float() + else: + return t if t.dtype in (torch.int32, torch.int64) else t.int() diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index 060cbd776a5..dd8f8a2f793 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -7,6 +7,7 @@ from ..utils import _log_api_usage_once from ._box_convert import _box_cxcywh_to_xyxy, _box_xyxy_to_cxcywh, _box_xywh_to_xyxy, _box_xyxy_to_xywh +from ._utils import _upcast def nms(boxes: Tensor, scores: Tensor, iou_threshold: float) -> Tensor: @@ -215,14 +216,6 @@ def box_convert(boxes: Tensor, in_fmt: str, out_fmt: str) -> Tensor: return boxes -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.is_floating_point(): - return t if t.dtype in (torch.float32, torch.float64) else t.float() - else: - return t if t.dtype in (torch.int32, torch.int64) else t.int() - - def box_area(boxes: Tensor) -> Tensor: """ Computes the area of a set of bounding boxes, which are specified by their diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index d53e2d6af2a..45f68560917 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from .giou_loss import _upcast +from ._utils import _upcast def complete_box_iou_loss( diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index 4d6f946f5e8..a0587b71b85 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,18 +1,11 @@ import torch -from torch import Tensor from ..utils import _log_api_usage_once - - -def _upcast(t: Tensor) -> Tensor: - # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type - if t.dtype not in (torch.float32, torch.float64): - return t.float() - return t +from ._utils import _upcast def generalized_box_iou_loss( - boxes1: torch.Tensor, + boxes1: Tensor, boxes2: torch.Tensor, reduction: str = "none", eps: float = 1e-7, From e28511dafc62808adef93d8c703888dca4c7e2da Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Tue, 10 May 2022 01:16:26 +0530 Subject: [PATCH 03/11] Fix bugs --- torchvision/ops/boxes.py | 2 +- torchvision/ops/diou_loss.py | 5 +++-- torchvision/ops/giou_loss.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index dd8f8a2f793..f262e5aef8c 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -289,7 +289,7 @@ def generalized_box_iou(boxes1: Tensor, boxes2: Tensor) -> Tensor: Tensor[N, M]: the NxM matrix containing the pairwise generalized IoU values for every element in boxes1 and boxes2 """ - if not torch.jit.is_scripting() and not torch.jit.is_tracing() and called_itself: + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou) inter, union = _box_inter_union(boxes1, boxes2) diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py index ea7ead19344..d8e285790ef 100644 --- a/torchvision/ops/diou_loss.py +++ b/torchvision/ops/diou_loss.py @@ -11,6 +11,9 @@ def distance_box_iou_loss( eps: float = 1e-7, ) -> torch.Tensor: """ + Original Implementation from + https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + Gradient-friendly IoU loss with an additional penalty that is non-zero when the distance between boxes' centers isn't zero. Indeed, for two exactly overlapping boxes, the distance IoU is the same as the IoU loss. @@ -37,8 +40,6 @@ def distance_box_iou_loss( https://arxiv.org/abs/1911.08287 """ - # Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(distance_box_iou_loss) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index a0587b71b85..efeb79cd1a7 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -5,7 +5,7 @@ def generalized_box_iou_loss( - boxes1: Tensor, + boxes1: torch.Tensor, boxes2: torch.Tensor, reduction: str = "none", eps: float = 1e-7, From 77f8f7ae2c9008744a7662f4b5335e0319f3bf87 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 11 May 2022 13:27:46 +0530 Subject: [PATCH 04/11] Refactor losses --- torchvision/ops/__init__.py | 2 +- torchvision/ops/_utils.py | 22 ++++++++++++++++ torchvision/ops/ciou_loss.py | 49 +++++++++++------------------------- torchvision/ops/diou_loss.py | 15 +++-------- torchvision/ops/giou_loss.py | 16 ++++-------- 5 files changed, 45 insertions(+), 59 deletions(-) diff --git a/torchvision/ops/__init__.py b/torchvision/ops/__init__.py index cd711578a6c..d3f27ef1657 100644 --- a/torchvision/ops/__init__.py +++ b/torchvision/ops/__init__.py @@ -5,13 +5,13 @@ remove_small_boxes, clip_boxes_to_image, box_area, + box_convert, box_iou, generalized_box_iou, distance_box_iou, complete_box_iou, masks_to_boxes, ) -from .boxes import box_convert from .ciou_loss import complete_box_iou_loss from .deform_conv import deform_conv2d, DeformConv2d from .diou_loss import distance_box_iou_loss diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 41c05226163..592906471e1 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -75,3 +75,25 @@ def _upcast(t: Tensor) -> Tensor: return t if t.dtype in (torch.float32, torch.float64) else t.float() else: return t if t.dtype in (torch.int32, torch.int64) else t.int() + + +def _loss_inter_union( + boxes1: torch.Tensor, + boxes2: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + + x1, y1, x2, y2 = boxes1.unbind(dim=-1) + x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) + + # Intersection keypoints + xkis1 = torch.max(x1, x1g) + ykis1 = torch.max(y1, y1g) + xkis2 = torch.min(x2, x2g) + ykis2 = torch.min(y2, y2g) + + intsctk = torch.zeros_like(x1) + mask = (ykis2 > ykis1) & (xkis2 > xkis1) + intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) + unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + + return intsctk, unionk diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index 45f68560917..d62218982a1 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -1,7 +1,8 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast +from ._utils import _loss_inter_union +from .diou_loss import distance_box_iou_loss def complete_box_iou_loss( @@ -12,6 +13,9 @@ def complete_box_iou_loss( ) -> torch.Tensor: """ + # Original Implementation from + https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap overlap area, This loss function considers important geometrical factors such as overlap area, normalized central point distance and aspect ratio. @@ -30,50 +34,25 @@ def complete_box_iou_loss( ``'sum'``: The output will be summed. Default: ``'none'`` eps : (float): small number to prevent division by zero. Default: 1e-7 - Reference: + Returns: + Tensor: Loss tensor with the reduction option applied. - Complete Intersection over Union Loss (Zhaohui Zheng et. al) - https://arxiv.org/abs/1911.08287 + Reference: + Zhaohui Zheng et. al: Complete Intersection over Union Loss: + https://arxiv.org/abs/1911.08287 """ - # Original Implementation : https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py - if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + diou_loss = distance_box_iou_loss(boxes1, boxes2, reduction="none", eps=eps) + intsct, union = _loss_inter_union(boxes1, boxes2) + iou = intsct / (union + eps) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) - - intsct = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps - iou = intsct / union - - # smallest enclosing box - xc1 = torch.min(x1, x1g) - yc1 = torch.min(y1, y1g) - xc2 = torch.max(x2, x2g) - yc2 = torch.max(y2, y2g) - diag_len = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps - - # centers of boxes - x_p = (x2 + x1) / 2 - y_p = (y2 + y1) / 2 - x_g = (x1g + x2g) / 2 - y_g = (y1g + y2g) / 2 - distance = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) - # width and height of boxes w_pred = x2 - x1 h_pred = y2 - y1 @@ -83,7 +62,7 @@ def complete_box_iou_loss( with torch.no_grad(): alpha = v / (1 - iou + v + eps) - loss = 1 - iou + (distance / diag_len) + alpha * v + loss = diou_loss + alpha * v if reduction == "mean": loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() elif reduction == "sum": diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py index d8e285790ef..81793a1c6ec 100644 --- a/torchvision/ops/diou_loss.py +++ b/torchvision/ops/diou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from .boxes import _upcast +from ._utils import _upcast, _loss_inter_union def distance_box_iou_loss( @@ -49,17 +49,8 @@ def distance_box_iou_loss( x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) - - intsct = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsct[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - union = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsct + eps - iou = intsct / union + intsct, union = _loss_inter_union(boxes1, boxes2) + iou = intsct / (union + eps) # smallest enclosing box xc1 = torch.min(x1, x1g) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index efeb79cd1a7..c24baaf729d 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast +from ._utils import _upcast, _loss_inter_union def generalized_box_iou_loss( @@ -31,6 +31,9 @@ def generalized_box_iou_loss( ``'sum'``: The output will be summed. Default: ``'none'`` eps (float): small number to prevent division by zero. Default: 1e-7 + Returns: + Tensor: Loss tensor with the reduction option applied. + Reference: Hamid Rezatofighi et. al: Generalized Intersection over Union: A Metric and A Loss for Bounding Box Regression: @@ -44,16 +47,7 @@ def generalized_box_iou_loss( x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - # Intersection keypoints - xkis1 = torch.max(x1, x1g) - ykis1 = torch.max(y1, y1g) - xkis2 = torch.min(x2, x2g) - ykis2 = torch.min(y2, y2g) - - intsctk = torch.zeros_like(x1) - mask = (ykis2 > ykis1) & (xkis2 > xkis1) - intsctk[mask] = (xkis2[mask] - xkis1[mask]) * (ykis2[mask] - ykis1[mask]) - unionk = (x2 - x1) * (y2 - y1) + (x2g - x1g) * (y2g - y1g) - intsctk + intsctk, unionk = _loss_inter_union(boxes1, boxes2) iouk = intsctk / (unionk + eps) # smallest enclosing box From 4d558917514b91c552ea6d141e91802a3930e4ef Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Wed, 11 May 2022 15:25:24 +0530 Subject: [PATCH 05/11] Refactor losses --- torchvision/ops/boxes.py | 17 ++++++++++------- torchvision/ops/ciou_loss.py | 16 ++++++++-------- torchvision/ops/diou_loss.py | 34 ++++++++++++++++++++++------------ torchvision/ops/giou_loss.py | 6 +++--- 4 files changed, 43 insertions(+), 30 deletions(-) diff --git a/torchvision/ops/boxes.py b/torchvision/ops/boxes.py index f262e5aef8c..72c95442b78 100644 --- a/torchvision/ops/boxes.py +++ b/torchvision/ops/boxes.py @@ -320,8 +320,10 @@ def complete_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou) - diou = distance_box_iou(boxes1, boxes2, eps) - iou = box_iou(boxes1, boxes2) + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) + + diou, iou = _box_diou_iou(boxes1, boxes2, eps) w_pred = boxes1[:, 2] - boxes1[:, 0] h_pred = boxes1[:, 3] - boxes1[:, 1] @@ -356,15 +358,17 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) + diou, _ = _box_diou_iou(boxes1, boxes2) + return diou - iou = box_iou(boxes1, boxes2) +def _box_diou_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tuple[Tensor, Tensor]: + + iou = box_iou(boxes1, boxes2) lti = torch.min(boxes1[:, None, :2], boxes2[:, :2]) rbi = torch.max(boxes1[:, None, 2:], boxes2[:, 2:]) - whi = _upcast(rbi - lti).clamp(min=0) # [N,M,2] diagonal_distance_squared = (whi[:, :, 0] ** 2) + (whi[:, :, 1] ** 2) + eps - # centers of boxes x_p = (boxes1[:, 0] + boxes1[:, 2]) / 2 y_p = (boxes1[:, 1] + boxes1[:, 3]) / 2 @@ -372,10 +376,9 @@ def distance_box_iou(boxes1: Tensor, boxes2: Tensor, eps: float = 1e-7) -> Tenso y_g = (boxes2[:, 1] + boxes2[:, 3]) / 2 # The distance between boxes' centers squared. centers_distance_squared = (_upcast(x_p - x_g) ** 2) + (_upcast(y_p - y_g) ** 2) - # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. - return iou - (centers_distance_squared / diagonal_distance_squared) + return iou - (centers_distance_squared / diagonal_distance_squared), iou def masks_to_boxes(masks: torch.Tensor) -> torch.Tensor: diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index d62218982a1..74f9755335a 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -1,8 +1,8 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _loss_inter_union -from .diou_loss import distance_box_iou_loss +from ._utils import _upcast +from .diou_loss import _diou_iou_loss def complete_box_iou_loss( @@ -13,9 +13,6 @@ def complete_box_iou_loss( ) -> torch.Tensor: """ - # Original Implementation from - https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py - Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap overlap area, This loss function considers important geometrical factors such as overlap area, normalized central point distance and aspect ratio. @@ -43,12 +40,15 @@ def complete_box_iou_loss( """ + # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou_loss) - diou_loss = distance_box_iou_loss(boxes1, boxes2, reduction="none", eps=eps) - intsct, union = _loss_inter_union(boxes1, boxes2) - iou = intsct / (union + eps) + boxes1 = _upcast(boxes1) + boxes2 = _upcast(boxes2) + + diou_loss, iou = _diou_iou_loss(boxes1, boxes2) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py index 81793a1c6ec..91ebe1ea6e0 100644 --- a/torchvision/ops/diou_loss.py +++ b/torchvision/ops/diou_loss.py @@ -1,19 +1,22 @@ +from typing import Tuple + import torch from ..utils import _log_api_usage_once from ._utils import _upcast, _loss_inter_union +# Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + + def distance_box_iou_loss( boxes1: torch.Tensor, boxes2: torch.Tensor, reduction: str = "none", eps: float = 1e-7, ) -> torch.Tensor: - """ - Original Implementation from - https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py + """ Gradient-friendly IoU loss with an additional penalty that is non-zero when the distance between boxes' centers isn't zero. Indeed, for two exactly overlapping boxes, the distance IoU is the same as the IoU loss. @@ -46,12 +49,25 @@ def distance_box_iou_loss( boxes1 = _upcast(boxes1) boxes2 = _upcast(boxes2) + loss, _ = _diou_iou_loss(boxes1, boxes2, eps) + + if reduction == "mean": + loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() + elif reduction == "sum": + loss = loss.sum() + return loss + + +def _diou_iou_loss( + boxes1: torch.Tensor, + boxes2: torch.Tensor, + eps: float = 1e-7, +) -> Tuple[torch.Tensor, torch.Tensor]: + x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) - intsct, union = _loss_inter_union(boxes1, boxes2) iou = intsct / (union + eps) - # smallest enclosing box xc1 = torch.min(x1, x1g) yc1 = torch.min(y1, y1g) @@ -59,7 +75,6 @@ def distance_box_iou_loss( yc2 = torch.max(y2, y2g) # The diagonal distance of the smallest enclosing box squared diagonal_distance_squared = ((xc2 - xc1) ** 2) + ((yc2 - yc1) ** 2) + eps - # centers of boxes x_p = (x2 + x1) / 2 y_p = (y2 + y1) / 2 @@ -67,12 +82,7 @@ def distance_box_iou_loss( y_g = (y1g + y2g) / 2 # The distance between boxes' centers squared. centers_distance_squared = ((x_p - x_g) ** 2) + ((y_p - y_g) ** 2) - # The distance IoU is the IoU penalized by a normalized # distance between boxes' centers squared. loss = 1 - iou + (centers_distance_squared / diagonal_distance_squared) - if reduction == "mean": - loss = loss.mean() if loss.numel() > 0 else 0.0 * loss.sum() - elif reduction == "sum": - loss = loss.sum() - return loss + return loss, iou diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index c24baaf729d..e9144f662bf 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -3,6 +3,8 @@ from ..utils import _log_api_usage_once from ._utils import _upcast, _loss_inter_union +# Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py + def generalized_box_iou_loss( boxes1: torch.Tensor, @@ -10,10 +12,8 @@ def generalized_box_iou_loss( reduction: str = "none", eps: float = 1e-7, ) -> torch.Tensor: - """ - Original implementation from - https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py + """ Gradient-friendly IoU loss with an additional penalty that is non-zero when the boxes do not overlap and scales with the size of their smallest enclosing box. This loss is symmetric, so the boxes1 and boxes2 arguments are interchangeable. From 8fd0e30c62c513c9e0a8ff9df51e2b1a3bf10790 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Thu, 12 May 2022 21:22:07 +0530 Subject: [PATCH 06/11] take the losses out --- test/test_losses.py | 166 ++++++++++++++++++++++++++++++++++++++++++++ test/test_ops.py | 160 +----------------------------------------- 2 files changed, 167 insertions(+), 159 deletions(-) create mode 100644 test/test_losses.py diff --git a/test/test_losses.py b/test/test_losses.py new file mode 100644 index 00000000000..f532879d4d5 --- /dev/null +++ b/test/test_losses.py @@ -0,0 +1,166 @@ +import pytest +import torch +from common_utils import cpu_and_gpu +from torchvision import ops + + +class TestGeneralizedBoxIouLoss: + # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_giou_loss(self, dtype, device) -> None: + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) + + def assert_giou_loss(box1, box2, expected_loss, reduction="none"): + tol = 1e-3 if dtype is torch.half else 1e-5 + computed_loss = ops.generalized_box_iou_loss(box1, box2, reduction=reduction) + expected_loss = torch.tensor(expected_loss, device=device) + torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) + + # Identical boxes should have loss of 0 + assert_giou_loss(box1, box1, 0.0) + + # quarter size box inside other box = IoU of 0.25 + assert_giou_loss(box1, box2, 0.75) + + # Two side by side boxes, area=union + # IoU=0 and GIoU=0 (loss 1.0) + assert_giou_loss(box2, box3, 1.0) + + # Two diagonally adjacent boxes, area=2*union + # IoU=0 and GIoU=-0.5 (loss 1.5) + assert_giou_loss(box2, box4, 1.5) + + # Test batched loss and reductions + assert_giou_loss(box1s, box2s, 2.5, reduction="sum") + assert_giou_loss(box1s, box2s, 1.25, reduction="mean") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device) -> None: + box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() + + loss = ops.generalized_box_iou_loss(box1, box2, reduction="mean") + loss.backward() + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + + loss = ops.generalized_box_iou_loss(box1, box2, reduction="none") + assert loss.numel() == 0, "giou_loss for two empty box should be empty" + + +class TestCIOULoss: + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + def test_ciou_loss(self, dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) + + def assert_ciou_loss(box1, box2, expected_output, reduction="none"): + + output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) + # TODO: When passing the dtype, the torch.half test doesn't pass... + expected_output = torch.tensor(expected_output, device=device) + tol = 1e-5 if dtype != torch.half else 1e-3 + torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) + + assert_ciou_loss(box1, box1, 0.0) + + assert_ciou_loss(box1, box2, 0.8125) + + assert_ciou_loss(box1, box3, 1.1923) + + assert_ciou_loss(box1, box4, 1.2500) + + assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean") + assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum") + + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device) -> None: + box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() + + loss = ops.complete_box_iou_loss(box1, box2, reduction="mean") + loss.backward() + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + + loss = ops.complete_box_iou_loss(box1, box2, reduction="none") + assert loss.numel() == 0, "ciou_loss for two empty box should be empty" + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) +def test_distance_iou_loss(dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack( + [box2, box2], + dim=0, + ) + box2s = torch.stack( + [box3, box4], + dim=0, + ) + + def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): + output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) + # TODO: When passing the dtype, the torch.half fails as usual. + expected_output = torch.tensor(expected_output, device=device) + tol = 1e-5 if dtype != torch.half else 1e-3 + torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) + + assert_distance_iou_loss(box1, box1, 0.0) + + assert_distance_iou_loss(box1, box2, 0.8125) + + assert_distance_iou_loss(box1, box3, 1.1923) + + assert_distance_iou_loss(box1, box4, 1.2500) + + assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") + assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") + + +@pytest.mark.parametrize("device", cpu_and_gpu()) +@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) +def test_empty_distance_iou_inputs(dtype, device) -> None: + box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + + loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") + loss.backward() + + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + + loss = ops.distance_box_iou_loss(box1, box2, reduction="none") + assert loss.numel() == 0, "diou_loss for two empty box should be empty" + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_ops.py b/test/test_ops.py index 96cfb630e8d..818018baba0 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1021,7 +1021,7 @@ def test_convert_boxes_to_roi_format(self, box_sequence): assert_equal(ref_tensor, ops._utils.convert_boxes_to_roi_format(box_sequence)) -class TestBox: +class TestBoxConvert: def test_bbox_same(self): box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float @@ -1295,60 +1295,6 @@ def test_distance_iou_jit(self): self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_distance_iou_loss(dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack( - [box2, box2], - dim=0, - ) - box2s = torch.stack( - [box3, box4], - dim=0, - ) - - def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): - output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half fails as usual. - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) - - assert_distance_iou_loss(box1, box1, 0.0) - - assert_distance_iou_loss(box1, box2, 0.8125) - - assert_distance_iou_loss(box1, box3, 1.1923) - - assert_distance_iou_loss(box1, box4, 1.2500) - - assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_empty_distance_iou_inputs(dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - - loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.distance_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "diou_loss for two empty box should be empty" - - class TestCompleteBoxIou(BoxTestBase): def _target_fn(self) -> Tuple[bool, Callable]: return (True, ops.complete_box_iou) @@ -1697,109 +1643,5 @@ def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) -class TestGeneralizedBoxIouLoss: - # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_giou_loss(self, dtype, device) -> None: - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_giou_loss(box1, box2, expected_loss, reduction="none"): - tol = 1e-3 if dtype is torch.half else 1e-5 - computed_loss = ops.generalized_box_iou_loss(box1, box2, reduction=reduction) - expected_loss = torch.tensor(expected_loss, device=device) - torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) - - # Identical boxes should have loss of 0 - assert_giou_loss(box1, box1, 0.0) - - # quarter size box inside other box = IoU of 0.25 - assert_giou_loss(box1, box2, 0.75) - - # Two side by side boxes, area=union - # IoU=0 and GIoU=0 (loss 1.0) - assert_giou_loss(box2, box3, 1.0) - - # Two diagonally adjacent boxes, area=2*union - # IoU=0 and GIoU=-0.5 (loss 1.5) - assert_giou_loss(box2, box4, 1.5) - - # Test batched loss and reductions - assert_giou_loss(box1s, box2s, 2.5, reduction="sum") - assert_giou_loss(box1s, box2s, 1.25, reduction="mean") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "giou_loss for two empty box should be empty" - - -class TestCIOULoss: - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - def test_ciou_loss(self, dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_ciou_loss(box1, box2, expected_output, reduction="none"): - - output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half test doesn't pass... - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) - - assert_ciou_loss(box1, box1, 0.0) - - assert_ciou_loss(box1, box2, 0.8125) - - assert_ciou_loss(box1, box3, 1.1923) - - assert_ciou_loss(box1, box4, 1.2500) - - assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum") - - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.complete_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.complete_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "ciou_loss for two empty box should be empty" - - if __name__ == "__main__": pytest.main([__file__]) From 6aea76e9ea80f3bd07f3eb56711a1ad9d7c21dea Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 13 May 2022 14:20:39 +0530 Subject: [PATCH 07/11] Replace with other util --- torchvision/ops/_utils.py | 7 +++++++ torchvision/ops/ciou_loss.py | 6 +++--- torchvision/ops/diou_loss.py | 6 +++--- torchvision/ops/giou_loss.py | 6 +++--- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/torchvision/ops/_utils.py b/torchvision/ops/_utils.py index 592906471e1..a6ca557a98b 100644 --- a/torchvision/ops/_utils.py +++ b/torchvision/ops/_utils.py @@ -77,6 +77,13 @@ def _upcast(t: Tensor) -> Tensor: return t if t.dtype in (torch.int32, torch.int64) else t.int() +def _upcast_non_float(t: Tensor) -> Tensor: + # Protects from numerical overflows in multiplications by upcasting to the equivalent higher type + if t.dtype not in (torch.float32, torch.float64): + return t.float() + return t + + def _loss_inter_union( boxes1: torch.Tensor, boxes2: torch.Tensor, diff --git a/torchvision/ops/ciou_loss.py b/torchvision/ops/ciou_loss.py index 74f9755335a..1f271fb0a1d 100644 --- a/torchvision/ops/ciou_loss.py +++ b/torchvision/ops/ciou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast +from ._utils import _upcast_non_float from .diou_loss import _diou_iou_loss @@ -45,8 +45,8 @@ def complete_box_iou_loss( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(complete_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + boxes1 = _upcast_non_float(boxes1) + boxes2 = _upcast_non_float(boxes2) diou_loss, iou = _diou_iou_loss(boxes1, boxes2) diff --git a/torchvision/ops/diou_loss.py b/torchvision/ops/diou_loss.py index 91ebe1ea6e0..4b38d58a28f 100644 --- a/torchvision/ops/diou_loss.py +++ b/torchvision/ops/diou_loss.py @@ -3,7 +3,7 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast, _loss_inter_union +from ._utils import _loss_inter_union, _upcast_non_float # Original Implementation from https://github.com/facebookresearch/detectron2/blob/main/detectron2/layers/losses.py @@ -46,8 +46,8 @@ def distance_box_iou_loss( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(distance_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + boxes1 = _upcast_non_float(boxes1) + boxes2 = _upcast_non_float(boxes2) loss, _ = _diou_iou_loss(boxes1, boxes2, eps) diff --git a/torchvision/ops/giou_loss.py b/torchvision/ops/giou_loss.py index e9144f662bf..efb9cd1f992 100644 --- a/torchvision/ops/giou_loss.py +++ b/torchvision/ops/giou_loss.py @@ -1,7 +1,7 @@ import torch from ..utils import _log_api_usage_once -from ._utils import _upcast, _loss_inter_union +from ._utils import _upcast_non_float, _loss_inter_union # Original implementation from https://github.com/facebookresearch/fvcore/blob/bfff2ef/fvcore/nn/giou_loss.py @@ -42,8 +42,8 @@ def generalized_box_iou_loss( if not torch.jit.is_scripting() and not torch.jit.is_tracing(): _log_api_usage_once(generalized_box_iou_loss) - boxes1 = _upcast(boxes1) - boxes2 = _upcast(boxes2) + boxes1 = _upcast_non_float(boxes1) + boxes2 = _upcast_non_float(boxes2) x1, y1, x2, y2 = boxes1.unbind(dim=-1) x1g, y1g, x2g, y2g = boxes2.unbind(dim=-1) From 5fdd7a81990181026b77425f81cb8f2b9cbcd3be Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Fri, 13 May 2022 17:31:17 +0530 Subject: [PATCH 08/11] Simplify loss tests --- test/test_losses.py | 293 +++++++++++++++++++++++++++----------------- test/test_ops.py | 119 ------------------ 2 files changed, 178 insertions(+), 234 deletions(-) diff --git a/test/test_losses.py b/test/test_losses.py index f532879d4d5..c2d7f9452ef 100644 --- a/test/test_losses.py +++ b/test/test_losses.py @@ -1,165 +1,228 @@ import pytest import torch +import torch.nn.functional as F from common_utils import cpu_and_gpu from torchvision import ops +def get_boxes(dtype, device): + box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) + box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) + box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) + box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + + box1s = torch.stack([box2, box2], dim=0) + box2s = torch.stack([box3, box4], dim=0) + + return box1, box2, box3, box4, box1s, box2s + + +def assert_iou_loss(iou_fn, box1, box2, expected_loss, dtype, device, reduction="none"): + tol = 1e-3 if dtype is torch.half else 1e-5 + computed_loss = iou_fn(box1, box2, reduction=reduction) + expected_loss = torch.tensor(expected_loss, device=device) + torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) + + +def assert_empty_loss(iou_fn, dtype, device): + box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + loss = iou_fn(box1, box2, reduction="mean") + loss.backward() + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) + assert box1.grad is not None, "box1.grad should not be None after backward is called" + assert box2.grad is not None, "box2.grad should not be None after backward is called" + loss = iou_fn(box1, box2, reduction="none") + assert loss.numel() == 0, f"{str(iou_fn)} for two empty box should be empty" + + class TestGeneralizedBoxIouLoss: # We refer to original test: https://github.com/facebookresearch/fvcore/blob/main/tests/test_giou_loss.py @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_giou_loss(self, dtype, device) -> None: - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) + def test_giou_loss(self, dtype, device): - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_giou_loss(box1, box2, expected_loss, reduction="none"): - tol = 1e-3 if dtype is torch.half else 1e-5 - computed_loss = ops.generalized_box_iou_loss(box1, box2, reduction=reduction) - expected_loss = torch.tensor(expected_loss, device=device) - torch.testing.assert_close(computed_loss, expected_loss, rtol=tol, atol=tol) + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) # Identical boxes should have loss of 0 - assert_giou_loss(box1, box1, 0.0) + assert_iou_loss(ops.generalized_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) # quarter size box inside other box = IoU of 0.25 - assert_giou_loss(box1, box2, 0.75) + assert_iou_loss(ops.generalized_box_iou_loss, box1, box2, 0.75, dtype=dtype, device=device) # Two side by side boxes, area=union # IoU=0 and GIoU=0 (loss 1.0) - assert_giou_loss(box2, box3, 1.0) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box3, 1.0, dtype=dtype, device=device) # Two diagonally adjacent boxes, area=2*union # IoU=0 and GIoU=-0.5 (loss 1.5) - assert_giou_loss(box2, box4, 1.5) + assert_iou_loss(ops.generalized_box_iou_loss, box2, box4, 1.5, dtype=dtype, device=device) # Test batched loss and reductions - assert_giou_loss(box1s, box2s, 2.5, reduction="sum") - assert_giou_loss(box1s, box2s, 1.25, reduction="mean") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 2.5, dtype=dtype, device=device, reduction="sum") + assert_iou_loss(ops.generalized_box_iou_loss, box1s, box2s, 1.25, dtype=dtype, device=device, reduction="mean") @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="mean") - loss.backward() - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.generalized_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "giou_loss for two empty box should be empty" + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.generalized_box_iou_loss, dtype, device) class TestCIOULoss: @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) @pytest.mark.parametrize("device", cpu_and_gpu()) def test_ciou_loss(self, dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack([box2, box2], dim=0) - box2s = torch.stack([box3, box4], dim=0) - - def assert_ciou_loss(box1, box2, expected_output, reduction="none"): - - output = ops.complete_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half test doesn't pass... - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - assert_ciou_loss(box1, box1, 0.0) + assert_iou_loss(ops.complete_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.complete_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") - assert_ciou_loss(box1, box2, 0.8125) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_empty_inputs(self, dtype, device): + assert_empty_loss(ops.complete_box_iou_loss, dtype, device) - assert_ciou_loss(box1, box3, 1.1923) - assert_ciou_loss(box1, box4, 1.2500) +class TestDIouLoss: + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + def test_distance_iou_loss(self, dtype, device): + box1, box2, box3, box4, box1s, box2s = get_boxes(dtype, device) - assert_ciou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_ciou_loss(box1s, box2s, 2.4500, reduction="sum") + assert_iou_loss(ops.distance_box_iou_loss, box1, box1, 0.0, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box2, 0.8125, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box3, 1.1923, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1, box4, 1.2500, dtype=dtype, device=device) + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 1.2250, dtype=dtype, device=device, reduction="mean") + assert_iou_loss(ops.distance_box_iou_loss, box1s, box2s, 2.4500, dtype=dtype, device=device, reduction="sum") @pytest.mark.parametrize("device", cpu_and_gpu()) @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - def test_empty_inputs(self, dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype).requires_grad_() - - loss = ops.complete_box_iou_loss(box1, box2, reduction="mean") - loss.backward() + def test_empty_distance_iou_inputs(self, dtype, device): + assert_empty_loss(ops.distance_box_iou_loss, dtype, device) + + +class TestFocalLoss: + def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): + def logit(p): + return torch.log(p / (1 - p)) + + def generate_tensor_with_range_type(shape, range_type, **kwargs): + if range_type != "random_binary": + low, high = { + "small": (0.0, 0.2), + "big": (0.8, 1.0), + "zeros": (0.0, 0.0), + "ones": (1.0, 1.0), + "random": (0.0, 1.0), + }[range_type] + return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) + else: + return torch.randint(0, 2, shape, **kwargs) + + # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) + inputs = [] + targets = [] + for input_range_type, target_range_type in [ + ("small", "zeros"), + ("small", "ones"), + ("small", "random_binary"), + ("big", "zeros"), + ("big", "ones"), + ("big", "random_binary"), + ("random", "zeros"), + ("random", "ones"), + ("random", "random_binary"), + ]: + inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) + targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) + + return torch.cat(inputs), torch.cat(targets) + + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [0, 1]) + def test_correct_ratio(self, alpha, gamma, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # For testing the ratio with manual calculation, we require the reduction to be "none" + reduction = "none" + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) + + assert torch.all( + focal_loss <= ce_loss + ), "focal loss must be less or equal to cross entropy loss with same input" + + loss_ratio = (focal_loss / ce_loss).squeeze() + prob = torch.sigmoid(inputs) + p_t = prob * targets + (1 - prob) * (1 - targets) + correct_ratio = (1.0 - p_t) ** gamma + if alpha >= 0: + alpha_t = alpha * targets + (1 - alpha) * (1 - targets) + correct_ratio = correct_ratio * alpha_t tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" - - loss = ops.complete_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "ciou_loss for two empty box should be empty" - - -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_distance_iou_loss(dtype, device): - box1 = torch.tensor([-1, -1, 1, 1], dtype=dtype, device=device) - box2 = torch.tensor([0, 0, 1, 1], dtype=dtype, device=device) - box3 = torch.tensor([0, 1, 1, 2], dtype=dtype, device=device) - box4 = torch.tensor([1, 1, 2, 2], dtype=dtype, device=device) - - box1s = torch.stack( - [box2, box2], - dim=0, - ) - box2s = torch.stack( - [box3, box4], - dim=0, - ) - - def assert_distance_iou_loss(box1, box2, expected_output, reduction="none"): - output = ops.distance_box_iou_loss(box1, box2, reduction=reduction) - # TODO: When passing the dtype, the torch.half fails as usual. - expected_output = torch.tensor(expected_output, device=device) - tol = 1e-5 if dtype != torch.half else 1e-3 - torch.testing.assert_close(output, expected_output, rtol=tol, atol=tol) - - assert_distance_iou_loss(box1, box1, 0.0) - - assert_distance_iou_loss(box1, box2, 0.8125) - - assert_distance_iou_loss(box1, box3, 1.1923) - - assert_distance_iou_loss(box1, box4, 1.2500) - - assert_distance_iou_loss(box1s, box2s, 1.2250, reduction="mean") - assert_distance_iou_loss(box1s, box2s, 2.4500, reduction="sum") + torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) + @pytest.mark.parametrize("reduction", ["mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [2, 3]) + def test_equal_ce_loss(self, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + # focal loss should be equal ce_loss if alpha=-1 and gamma=0 + alpha = -1 + gamma = 0 + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + inputs_fl = inputs.clone().requires_grad_() + targets_fl = targets.clone() + inputs_ce = inputs.clone().requires_grad_() + targets_ce = targets.clone() + focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) + ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) -@pytest.mark.parametrize("device", cpu_and_gpu()) -@pytest.mark.parametrize("dtype", [torch.float32, torch.half]) -def test_empty_distance_iou_inputs(dtype, device) -> None: - box1 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() - box2 = torch.randn([0, 4], dtype=dtype, device=device).requires_grad_() + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) - loss = ops.distance_box_iou_loss(box1, box2, reduction="mean") - loss.backward() + focal_loss.backward() + ce_loss.backward() + torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(loss, torch.tensor(0.0, device=device), rtol=tol, atol=tol) - assert box1.grad is not None, "box1.grad should not be None after backward is called" - assert box2.grad is not None, "box2.grad should not be None after backward is called" + @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) + @pytest.mark.parametrize("gamma", [0, 2]) + @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) + @pytest.mark.parametrize("device", cpu_and_gpu()) + @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) + @pytest.mark.parametrize("seed", [4, 5]) + def test_jit(self, alpha, gamma, reduction, device, dtype, seed): + if device == "cpu" and dtype is torch.half: + pytest.skip("Currently torch.half is not fully supported on cpu") + script_fn = torch.jit.script(ops.sigmoid_focal_loss) + torch.random.manual_seed(seed) + inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) + focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + if device == "cpu": + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) + else: + with torch.jit.fuser("fuser2"): + # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 + # We may remove this condition once the bug is resolved + scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - loss = ops.distance_box_iou_loss(box1, box2, reduction="none") - assert loss.numel() == 0, "diou_loss for two empty box should be empty" + tol = 1e-3 if dtype is torch.half else 1e-5 + torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) if __name__ == "__main__": diff --git a/test/test_ops.py b/test/test_ops.py index 818018baba0..66b8246ab63 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -9,7 +9,6 @@ import pytest import torch import torch.fx -import torch.nn.functional as F from common_utils import assert_equal, cpu_and_gpu, needs_cuda from PIL import Image from torch import nn, Tensor @@ -1525,123 +1524,5 @@ def test_is_leaf_node(self, dim, p, block_size, inplace): assert len(graph_node_names[0]) == 1 + op_obj.n_inputs -class TestFocalLoss: - def _generate_diverse_input_target_pair(self, shape=(5, 2), **kwargs): - def logit(p: Tensor) -> Tensor: - return torch.log(p / (1 - p)) - - def generate_tensor_with_range_type(shape, range_type, **kwargs): - if range_type != "random_binary": - low, high = { - "small": (0.0, 0.2), - "big": (0.8, 1.0), - "zeros": (0.0, 0.0), - "ones": (1.0, 1.0), - "random": (0.0, 1.0), - }[range_type] - return torch.testing.make_tensor(shape, low=low, high=high, **kwargs) - else: - return torch.randint(0, 2, shape, **kwargs) - - # This function will return inputs and targets with shape: (shape[0]*9, shape[1]) - inputs = [] - targets = [] - for input_range_type, target_range_type in [ - ("small", "zeros"), - ("small", "ones"), - ("small", "random_binary"), - ("big", "zeros"), - ("big", "ones"), - ("big", "random_binary"), - ("random", "zeros"), - ("random", "ones"), - ("random", "random_binary"), - ]: - inputs.append(logit(generate_tensor_with_range_type(shape, input_range_type, **kwargs))) - targets.append(generate_tensor_with_range_type(shape, target_range_type, **kwargs)) - - return torch.cat(inputs), torch.cat(targets) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [0, 1]) - def test_correct_ratio(self, alpha, gamma, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # For testing the ratio with manual calculation, we require the reduction to be "none" - reduction = "none" - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction=reduction) - - assert torch.all( - focal_loss <= ce_loss - ), "focal loss must be less or equal to cross entropy loss with same input" - - loss_ratio = (focal_loss / ce_loss).squeeze() - prob = torch.sigmoid(inputs) - p_t = prob * targets + (1 - prob) * (1 - targets) - correct_ratio = (1.0 - p_t) ** gamma - if alpha >= 0: - alpha_t = alpha * targets + (1 - alpha) * (1 - targets) - correct_ratio = correct_ratio * alpha_t - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(correct_ratio, loss_ratio, rtol=tol, atol=tol) - - @pytest.mark.parametrize("reduction", ["mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [2, 3]) - def test_equal_ce_loss(self, reduction, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - # focal loss should be equal ce_loss if alpha=-1 and gamma=0 - alpha = -1 - gamma = 0 - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - inputs_fl = inputs.clone().requires_grad_() - targets_fl = targets.clone() - inputs_ce = inputs.clone().requires_grad_() - targets_ce = targets.clone() - focal_loss = ops.sigmoid_focal_loss(inputs_fl, targets_fl, gamma=gamma, alpha=alpha, reduction=reduction) - ce_loss = F.binary_cross_entropy_with_logits(inputs_ce, targets_ce, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, ce_loss, rtol=tol, atol=tol) - - focal_loss.backward() - ce_loss.backward() - torch.testing.assert_close(inputs_fl.grad, inputs_ce.grad, rtol=tol, atol=tol) - - @pytest.mark.parametrize("alpha", [-1.0, 0.0, 0.58, 1.0]) - @pytest.mark.parametrize("gamma", [0, 2]) - @pytest.mark.parametrize("reduction", ["none", "mean", "sum"]) - @pytest.mark.parametrize("device", cpu_and_gpu()) - @pytest.mark.parametrize("dtype", [torch.float32, torch.half]) - @pytest.mark.parametrize("seed", [4, 5]) - def test_jit(self, alpha, gamma, reduction, device, dtype, seed) -> None: - if device == "cpu" and dtype is torch.half: - pytest.skip("Currently torch.half is not fully supported on cpu") - script_fn = torch.jit.script(ops.sigmoid_focal_loss) - torch.random.manual_seed(seed) - inputs, targets = self._generate_diverse_input_target_pair(dtype=dtype, device=device) - focal_loss = ops.sigmoid_focal_loss(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - if device == "cpu": - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - else: - with torch.jit.fuser("fuser2"): - # Use fuser2 to prevent a bug on fuser: https://github.com/pytorch/pytorch/issues/75476 - # We may remove this condition once the bug is resolved - scripted_focal_loss = script_fn(inputs, targets, gamma=gamma, alpha=alpha, reduction=reduction) - - tol = 1e-3 if dtype is torch.half else 1e-5 - torch.testing.assert_close(focal_loss, scripted_focal_loss, rtol=tol, atol=tol) - - if __name__ == "__main__": pytest.main([__file__]) From 4175be39797636f7ebf6905cd990a0d2e5c0535b Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 16 May 2022 17:55:06 +0530 Subject: [PATCH 09/11] Rewrite to simplify? --- test/test_ops.py | 226 +++++++++++++++++++---------------------------- 1 file changed, 93 insertions(+), 133 deletions(-) diff --git a/test/test_ops.py b/test/test_ops.py index 66b8246ab63..78d3cf8a489 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1050,7 +1050,7 @@ def test_bbox_xyxy_xywh(self): assert_equal(box_xyxy, box_tensor) def test_bbox_xyxy_cxcywh(self): - # Simple test convert boxes to xywh and back. Make sure they are same. + # Simple test convert boxes to cxcywh and back. Make sure they are same. # box_tensor is in x1 y1 x2 y2 format. box_tensor = torch.tensor( [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 30, 35], [23, 35, 93, 95]], dtype=torch.float @@ -1072,7 +1072,6 @@ def test_bbox_xywh_cxcywh(self): [[0, 0, 100, 100], [0, 0, 0, 0], [10, 15, 20, 20], [23, 35, 70, 60]], dtype=torch.float ) - # This is wrong exp_cxcywh = torch.tensor( [[50, 50, 100, 100], [0, 0, 0, 0], [20, 25, 20, 20], [58, 65, 70, 60]], dtype=torch.float ) @@ -1112,20 +1111,55 @@ def test_bbox_convert_jit(self): torch.testing.assert_close(scripted_cxcywh, box_cxcywh, rtol=0.0, atol=TOLERANCE) -class BoxTestBase(ABC): - @abstractmethod - def _target_fn(self) -> Tuple[bool, Callable]: - pass +def area_check(box, expected, tolerance=1e-4): + out = ops.box_area(box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + + +class TestBoxArea: + @pytest.mark.parametrize("dtype", [torch.int8, torch.int16, torch.int32, torch.int64]) + def test_int_boxes(self, dtype): + # Check for int boxes + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=dtype) + expected = torch.tensor([10000, 0]) + area_check(box_tensor, expected) + + # Check for float32 and float64 boxes + @pytest.mark.parametrize("dtype", [torch.float32, torch.float64]) + def test_float_boxes(self, dtype): + box_tensor = torch.tensor( + [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ], + dtype=dtype, + ) + expected = torch.tensor([604723.0806, 600965.4666, 592761.0085], dtype=torch.float64) + area_check(box_tensor, expected, tolerance=0.05) + + def test_float16_box(self): + # Check for float16 box + box_tensor = torch.tensor( + [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], + dtype=torch.float16, + ) + expected = torch.tensor([605113.875, 600495.1875, 592247.25]) + area_check(box_tensor, expected) + + def test_box_area_jit(self): + box_tensor = torch.tensor([[0, 0, 100, 100], [0, 0, 0, 0]], dtype=torch.float) + expected = ops.box_area(box_tensor) + scripted_fn = torch.jit.script(ops.box_area) + scripted_area = scripted_fn(box_tensor) + torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) - def _perform_box_operation(self, box: Tensor, run_as_script: bool = False) -> Tensor: - is_binary_fn = self._target_fn()[0] - target_fn = self._target_fn()[1] - box_operation = torch.jit.script(target_fn) if run_as_script else target_fn - return box_operation(box, box) if is_binary_fn else box_operation(box) - def _run_test(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: +class IouTestBase: + @staticmethod + def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): def assert_close(box: Tensor, expected: Tensor, tolerance): - out = self._perform_box_operation(box) + out = target_fn(box, box) torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) for dtype in dtypes: @@ -1133,74 +1167,39 @@ def assert_close(box: Tensor, expected: Tensor, tolerance): expected_box = torch.tensor(expected) assert_close(actual_box, expected_box, tolerance) - def _run_jit_test(self, test_input: List) -> None: + @staticmethod + def _run_jit_test(target_fn: Callable, test_input: List): box_tensor = torch.tensor(test_input, dtype=torch.float) - expected = self._perform_box_operation(box_tensor, True) - scripted_area = self._perform_box_operation(box_tensor, True) - torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) - - -class TestBoxArea(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (False, ops.box_area) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 0, 0]] - def _generate_int_expected() -> List[int]: - return [10000, 0] - - def _generate_float_input(index: int) -> List[List[float]]: - return [ - [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ], - [[285.25, 185.625, 1194.0, 851.5], [285.25, 188.75, 1192.0, 851.0], [279.25, 198.0, 1189.0, 849.0]], - ][index] +def _generate_int_input(): + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - def _generate_float_expected(index: int) -> List[float]: - return [[604723.0806, 600965.4666, 592761.0085], [605113.875, 600495.1875, 592247.25]][index] - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), - [torch.int8, torch.int16, torch.int32, torch.int64], - 1e-4, - _generate_int_expected(), - ), - pytest.param(_generate_float_input(0), [torch.float32, torch.float64], 0.05, _generate_float_expected(0)), - pytest.param(_generate_float_input(1), [torch.float16], 1e-4, _generate_float_expected(1)), - ], - ) - def test_box_area(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) +def _generate_float_input(): + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] - def test_box_area_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 0, 0]]) - -class TestBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.box_iou) - - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - def _generate_int_expected() -> List[List[float]]: +class TestBoxIou(IouTestBase): + def _generate_int_expected(): return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - def _generate_float_input() -> List[List[float]]: + def _generate_float_input(): return [ [285.3538, 185.5758, 1193.5110, 851.4551], [285.1472, 188.7374, 1192.4984, 851.0669], [279.2440, 197.9812, 1189.4746, 849.2019], ] - def _generate_float_expected() -> List[List[float]]: + def _generate_float_expected(): return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( @@ -1210,34 +1209,21 @@ def _generate_float_expected() -> List[List[float]]: _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() ), pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-4, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), ], ) - def test_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) - - def test_iou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) + def test_iou_jit(self): + self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) -class TestGenBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.generalized_box_iou) - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - def _generate_int_expected() -> List[List[float]]: +class TestGenBoxIou(IouTestBase): + def _generate_int_expected(): return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected() -> List[List[float]]: + def _generate_float_expected(): return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( @@ -1247,33 +1233,20 @@ def _generate_float_expected() -> List[List[float]]: _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() ), pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), ], ) - def test_gen_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) - - def test_giou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) -class TestDistanceBoxIoU(BoxTestBase): - def _target_fn(self): - return (True, ops.distance_box_iou) + def test_iou_jit(self): + self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - def _generate_int_input(): - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] +class TestDistanceBoxIoU(IouTestBase): def _generate_int_expected(): return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - def _generate_float_expected(): return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @@ -1284,34 +1257,21 @@ def _generate_float_expected(): _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() ), pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), ], ) - def test_distance_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(test_input, dtypes, tolerance, expected) - - def test_distance_iou_jit(self): - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) -class TestCompleteBoxIou(BoxTestBase): - def _target_fn(self) -> Tuple[bool, Callable]: - return (True, ops.complete_box_iou) + def test_iou_jit(self): + self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - def _generate_int_input() -> List[List[int]]: - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - def _generate_int_expected() -> List[List[float]]: +class TestCompleteBoxIou(IouTestBase): + def _generate_int_expected(): return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - def _generate_float_input() -> List[List[float]]: - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected() -> List[List[float]]: + def _generate_float_expected(): return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] @pytest.mark.parametrize( @@ -1320,15 +1280,15 @@ def _generate_float_expected() -> List[List[float]]: pytest.param( _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() ), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 0.001, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), ], ) - def test_complete_iou(self, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List) -> None: - self._run_test(test_input, dtypes, tolerance, expected) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) - def test_ciou_jit(self) -> None: - self._run_jit_test([[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + def test_iou_jit(self): + self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) class TestMasksToBoxes: From 6599ec07210fb2f795d1278d8433e0c16ebd19e9 Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 16 May 2022 18:06:00 +0530 Subject: [PATCH 10/11] Clean for a good diff to review --- test/test_ops.py | 138 +------------------------------- torchvision/ops/test_ious.py | 147 +++++++++++++++++++++++++++++++++++ 2 files changed, 148 insertions(+), 137 deletions(-) create mode 100644 torchvision/ops/test_ious.py diff --git a/test/test_ops.py b/test/test_ops.py index 78d3cf8a489..df5d397713c 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from functools import lru_cache from itertools import product -from typing import Callable, List, Tuple +from typing import Tuple import numpy as np import pytest @@ -1155,142 +1155,6 @@ def test_box_area_jit(self): torch.testing.assert_close(scripted_area, expected, rtol=0.0, atol=1e-3) -class IouTestBase: - @staticmethod - def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): - def assert_close(box: Tensor, expected: Tensor, tolerance): - out = target_fn(box, box) - torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) - - for dtype in dtypes: - actual_box = torch.tensor(test_input, dtype=dtype) - expected_box = torch.tensor(expected) - assert_close(actual_box, expected_box, tolerance) - - @staticmethod - def _run_jit_test(target_fn: Callable, test_input: List): - box_tensor = torch.tensor(test_input, dtype=torch.float) - expected = target_fn(box_tensor, box_tensor) - scripted_fn = torch.jit.script(target_fn) - scripted_out = scripted_fn(box_tensor, box_tensor) - torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) - - -def _generate_int_input(): - return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] - - -def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - -class TestBoxIou(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_input(): - return [ - [285.3538, 185.5758, 1193.5110, 851.4551], - [285.1472, 188.7374, 1192.4984, 851.0669], - [279.2440, 197.9812, 1189.4746, 849.2019], - ] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - -class TestGenBoxIou(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - -class TestDistanceBoxIoU(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - -class TestCompleteBoxIou(IouTestBase): - def _generate_int_expected(): - return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] - - def _generate_float_expected(): - return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] - - @pytest.mark.parametrize( - "test_input, dtypes, tolerance, expected", - [ - pytest.param( - _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() - ), - pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), - pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), - ], - ) - def test_iou(self, test_input, dtypes, tolerance, expected): - self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) - - def test_iou_jit(self): - self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) - - class TestMasksToBoxes: def test_masks_box(self): def masks_box_check(masks, expected, tolerance=1e-4): diff --git a/torchvision/ops/test_ious.py b/torchvision/ops/test_ious.py new file mode 100644 index 00000000000..4e87d64b477 --- /dev/null +++ b/torchvision/ops/test_ious.py @@ -0,0 +1,147 @@ +from typing import List, Callable + +import pytest +import torch +import torch.fx +from torch import Tensor +from torchvision import ops + + +class IouTestBase: + @staticmethod + def _run_test(target_fn: Callable, test_input: List, dtypes: List[torch.dtype], tolerance: float, expected: List): + def assert_close(box: Tensor, expected: Tensor, tolerance): + out = target_fn(box, box) + torch.testing.assert_close(out, expected, rtol=0.0, check_dtype=False, atol=tolerance) + + for dtype in dtypes: + actual_box = torch.tensor(test_input, dtype=dtype) + expected_box = torch.tensor(expected) + assert_close(actual_box, expected_box, tolerance) + + @staticmethod + def _run_jit_test(target_fn: Callable, test_input: List): + box_tensor = torch.tensor(test_input, dtype=torch.float) + expected = target_fn(box_tensor, box_tensor) + scripted_fn = torch.jit.script(target_fn) + scripted_out = scripted_fn(box_tensor, box_tensor) + torch.testing.assert_close(scripted_out, expected, rtol=0.0, atol=1e-3) + + +def _generate_int_input(): + return [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]] + + +def _generate_float_input(): + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + +class TestBoxIou(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_input(): + return [ + [285.3538, 185.5758, 1193.5110, 851.4551], + [285.1472, 188.7374, 1192.4984, 851.0669], + [279.2440, 197.9812, 1189.4746, 849.2019], + ] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestGenBoxIou(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, -0.7778], [0.25, 1.0, -0.8611], [-0.7778, -0.8611, 1.0]] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.generalized_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.generalized_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestDistanceBoxIoU(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.distance_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.distance_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +class TestCompleteBoxIou(IouTestBase): + def _generate_int_expected(): + return [[1.0, 0.25, 0.0], [0.25, 1.0, 0.0], [0.0, 0.0, 1.0]] + + def _generate_float_expected(): + return [[1.0, 0.9933, 0.9673], [0.9933, 1.0, 0.9737], [0.9673, 0.9737, 1.0]] + + @pytest.mark.parametrize( + "test_input, dtypes, tolerance, expected", + [ + pytest.param( + _generate_int_input(), [torch.int16, torch.int32, torch.int64], 1e-4, _generate_int_expected() + ), + pytest.param(_generate_float_input(), [torch.float16], 0.002, _generate_float_expected()), + pytest.param(_generate_float_input(), [torch.float32, torch.float64], 1e-3, _generate_float_expected()), + ], + ) + def test_iou(self, test_input, dtypes, tolerance, expected): + self._run_test(ops.complete_box_iou, test_input, dtypes, tolerance, expected) + + def test_iou_jit(self): + self._run_jit_test(ops.complete_box_iou, [[0, 0, 100, 100], [0, 0, 50, 50], [200, 200, 300, 300]]) + + +if __name__ == "__main__": + pytest.main([__file__]) From 9b6bfb17a970d41ed68d3724528167d91293ee3c Mon Sep 17 00:00:00 2001 From: Aditya Oke Date: Mon, 16 May 2022 18:07:34 +0530 Subject: [PATCH 11/11] oops --- {torchvision/ops => test}/test_ious.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename {torchvision/ops => test}/test_ious.py (100%) diff --git a/torchvision/ops/test_ious.py b/test/test_ious.py similarity index 100% rename from torchvision/ops/test_ious.py rename to test/test_ious.py