defiou(self, a, b): # this is just the usual way to IoU from bounding boxes inter = self.intersection(a, b) area_a = self.area(a).unsqueeze(2).expand_as(inter) area_b = self.area(b).unsqueeze(1).expand_as(inter) return inter / (area_a + area_b - inter + 1e-12)
classPiecewiseLin(nn.Module): def__init__(self, n): super().__init__() self.n = n self.weight = nn.Parameter(torch.ones(n + 1)) # the first weight here is always 0 with a 0 gradient self.weight.data[0] = 0
defforward(self, x): # all weights are positive -> function is monotonically increasing w = self.weight.abs() # make weights sum to one -> f(1) = 1 w = w / w.sum() w = w.view([self.n + 1] + [1] * x.dim()) # keep cumulative sum for O(1) time complexity csum = w.cumsum(dim=0) csum = csum.expand((self.n + 1,) + tuple(x.size())) w = w.expand_as(csum)
# figure out which part of the function the input lies on y = self.n * x.unsqueeze(0) idx = Variable(y.long().data) f = y.frac()
# contribution of the linear parts left of the input x = csum.gather(0, idx.clamp(max=self.n)) # contribution within the linear segment the input falls into x = x + f * w.gather(0, (idx + 1).clamp(max=self.n)) return x.squeeze(0)
defto_one_hot(self, scores): """ Turn a bunch of non-negative scalar values into a one-hot encoding. E.g. with self.objects = 3, 0 -> [1 0 0 0], 2.75 -> [0 0 0.25 0.75]. """ # sanity check, I don't think this ever does anything (it certainly shouldn't) scores = scores.clamp(min=0, max=self.objects) # compute only on the support i = scores.long().data f = scores.frac() # target_l is the one-hot if the score is rounded down # target_r is the one-hot if the score is rounded up target_l = scores.data.new(i.size(0), self.objects + 1).fill_(0) target_r = scores.data.new(i.size(0), self.objects + 1).fill_(0)
target_l.scatter_(dim=1, index=i.clamp(max=self.objects), value=1) target_r.scatter_(dim=1, index=(i + 1).clamp(max=self.objects), value=1) # interpolate between these with the fractional part of the score return (1 - f) * Variable(target_l) + f * Variable(target_r)