列表

详情


11. Triplet Loss 怎么生成那三个点

回答思路

原理

Triplet Loss是在谷歌的FaceNet论文中的提出来的,用于解决人脸识别相关的问题,原文为:《FaceNet: A Unified Embedding for Face Recognition and Clustering》,得到了广泛的应用。

Triplet三元组指的是anchor, negative, positive三个部分,每一部分都是一个embedding向量,其中

anchor指的是基准图片

positive指的是与anchor同一分类下的一张图片

negative指的是与anchor不同分类的一张图片

网络没经过学习之前,A和P的欧式距离可能很大,A和N的欧式距离可能很小,如上图左边,在网络的学习过程中,A和P的欧式距离会逐渐减小,而A和N的距离会逐渐拉大。

损失函数定义如下:

这时可以通过最小化上述损失函数,a与p之间的距离d(a,p)=0,而a与n之间的距离d(a,n)大于d(a,p)+margin。当negative example很好识别时,上述损失函数为0,否则是一个比较大的值。

也就是说,期望通过优化triplet loss,使得类内紧凑,类间远离的目标。

上面提到,该loss是为了解决人脸识别相关问题,举个例子,我们的数据库中有1000个人,每个人有20张不同的图片,那么计算该loss需要计算**(1000*20)*(999*20)**次,后面的999表示negative example,即 的复杂度。

所以triplet loss的三元组的选取也很关键。

三元组选取

基于triplet loss的定义,可以将triplet(三元组)分为三类:

easy triplets(简单三元组): triplet对应的损失为0的三元组,

hard triplets(困难三元组): negative example与anchor距离小于anchor与positive example的距离,形式化定义为
semi-hard triplets(一般三元组): negative example与anchor距离大于anchor与positive example的距离,但还不至于使得loss为0,即

在模型实际训练中,可以采用如下步骤选取以上三元组:

1.在每个mini-batch中,每个类别都有一定数量的正样本,和负样本

2.在mini-batch中挑选所有的anchor positive图像对,同时,选择最为困难的anchor negative图像对

选择的原则是,选择同类图片中最不像的,作为postive pair,选择不同类图片中最像的,作为negative pair。

代码实现(pytorch版本)

Pytorch官方已经支持了该函数:

torch.nn.TripletMarginLoss (margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction=‘mean’)

from torch import nn  class TripletLoss(nn.Module):     """Triplet loss with hard positive/negative mining.          Args:         margin (float, optional): margin for triplet. Default is 0.3.     """          def __init__(self, margin=0.3, global_feat, labels):         super(TripletLoss, self).__init__()         self.margin = margin         self.ranking_loss = nn.MarginRankingLoss(margin=margin)     def forward(self, inputs, targets):         """         Args:             inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim).             targets (torch.LongTensor): ground truth labels with shape (num_classes).         """         n = inputs.size(0)                   计算embeddings之间的距离         dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)         dist = dist + dist.t()         dist.addmm_(1, -2, inputs, inputs.t())         dist = dist.clamp(min=1e-12).sqrt()   for numerical stability                   关键,三元组的选取:为每个sample找到hardest positive and negative         mask = targets.expand(n, n).eq(targets.expand(n, n).t())         dist_ap, dist_an = [], []         for i in range(n):             dist_ap.append(dist[i][mask[i]].max().unsqueeze(0))             dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0))         dist_ap = torch.cat(dist_ap)         dist_an = torch.cat(dist_an)                   计算loss         y = torch.ones_like(dist_an)         return self.ranking_loss(dist_an, dist_ap, y)


上一题