Commit 6d83f890 authored by Kai Chen's avatar Kai Chen
Browse files

bug fix for retinanet with 2 classes (fg/bg)

(cherry picked from commit d1cf5e59)
parent b6561a1a
......@@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,
def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full(
(labels.size(0), label_channels), 0, dtype=torch.float32)
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
......
......@@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return torch.sum(raw * weight)[None] / avg_factor
def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduce=True):
def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
......@@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
alpha=0.25,
reduction='elementwise_mean'):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment