回答思路
原理
Triplet Loss是在谷歌的FaceNet论文中的提出来的,用于解决人脸识别相关的问题,原文为:《FaceNet: A Unified Embedding for Face Recognition and Clustering》,得到了广泛的应用。
Triplet三元组指的是anchor, negative, positive三个部分,每一部分都是一个embedding向量,其中
anchor指的是基准图片
positive指的是与anchor同一分类下的一张图片
网络没经过学习之前,A和P的欧式距离可能很大,A和N的欧式距离可能很小,如上图左边,在网络的学习过程中,A和P的欧式距离会逐渐减小,而A和N的距离会逐渐拉大。
这时可以通过最小化上述损失函数,a与p之间的距离d(a,p)=0,而a与n之间的距离d(a,n)大于d(a,p)+margin。当negative example很好识别时,上述损失函数为0,否则是一个比较大的值。
也就是说,期望通过优化triplet loss,使得类内紧凑,类间远离的目标。
上面提到,该loss是为了解决人脸识别相关问题,举个例子,我们的数据库中有1000个人,每个人有20张不同的图片,那么计算该loss需要计算**(1000*20)*(999*20)**次,后面的999表示negative example,即 的复杂度。
所以triplet loss的三元组的选取也很关键。
三元组选取
基于triplet loss的定义,可以将triplet(三元组)分为三类:
在模型实际训练中,可以采用如下步骤选取以上三元组:
1.在每个mini-batch中,每个类别都有一定数量的正样本,和负样本
2.在mini-batch中挑选所有的anchor positive图像对,同时,选择最为困难的anchor negative图像对
选择的原则是,选择同类图片中最不像的,作为postive pair,选择不同类图片中最像的,作为negative pair。
代码实现(pytorch版本)
Pytorch官方已经支持了该函数:
torch.nn.TripletMarginLoss (margin=1.0, p=2.0, eps=1e-06, swap=False, size_average=None, reduce=None, reduction=‘mean’)
from torch import nn class TripletLoss(nn.Module): """Triplet loss with hard positive/negative mining. Args: margin (float, optional): margin for triplet. Default is 0.3. """ def __init__(self, margin=0.3, global_feat, labels): super(TripletLoss, self).__init__() self.margin = margin self.ranking_loss = nn.MarginRankingLoss(margin=margin) def forward(self, inputs, targets): """ Args: inputs (torch.Tensor): feature matrix with shape (batch_size, feat_dim). targets (torch.LongTensor): ground truth labels with shape (num_classes). """ n = inputs.size(0) 计算embeddings之间的距离 dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = dist + dist.t() dist.addmm_(1, -2, inputs, inputs.t()) dist = dist.clamp(min=1e-12).sqrt() for numerical stability 关键,三元组的选取:为每个sample找到hardest positive and negative mask = targets.expand(n, n).eq(targets.expand(n, n).t()) dist_ap, dist_an = [], [] for i in range(n): dist_ap.append(dist[i][mask[i]].max().unsqueeze(0)) dist_an.append(dist[i][mask[i] == 0].min().unsqueeze(0)) dist_ap = torch.cat(dist_ap) dist_an = torch.cat(dist_an) 计算loss y = torch.ones_like(dist_an) return self.ranking_loss(dist_an, dist_ap, y)