分类或检索数据集清洗方法几则

最近完成了一个植物识别的训练 (训练集是自建的, 共 4,066 类, 1,306,738 张图像), 在测试集 (145,168 张图像) 上性能表现不佳 (Top1 Acc 为 0.848, Top5 Acc 为 0.959), 怀疑是训练集比较脏, 遂拟对植物识别训练集进行清洗, 想到并尝试了如下方法.

利用混淆矩阵指导分类数据集清洗

1) 用训练得到的模型在待清洗分类数据集上计算混淆矩阵, 其第 $i$ 行第 $j$ 列元素表示类别为 $i$ 的样本被预测为类别 $j$ 的数目. 对混淆矩阵按所在行的实际样本数进行归一化 (为了可以设置一个与实际样本数无关的阈值).
2) 比较 $\operatorname{recall}_i$ (混淆矩阵第 $i$ 行第 $i$ 列的元素) 和 $\operatorname{top1}_i$ (混淆矩阵第 $i$ 行的最大值). 如果两者不相等, 即认为第 $i$ 类可能存在较多噪声, 不妨称之为 dirty class.
3) 当 $\operatorname{recall}_i$ 等于 $\operatorname{top1}_i$ 时, 比较 $\operatorname{recall}_i$ 和 $\operatorname{top2}_i$ (混淆矩阵第 $i$ 行的第二大值). 如果 $\operatorname{recall}_i$ 和 $\operatorname{top1}_i$ 之差小于一定的阈值 thresh, 即认为第 $i$ 类可能存在较多噪声 (即为 dirty class).
4) 对于每个 dirty class, 计算其混淆矩阵对应行上的 top_k 个值, 在 top_k 类中排除 dirty class, 剩余的类称为该 dirty class 的干扰类 (distract classes).
5) 人工检查上面得到的 dirty class 及其 distract classes 的数据.

在植物识别训练集上, 该方法在 thresh 为 0.1 时, 得到 17 个 dirty class, 如下所示, 其中每一组的第一行是 dirty class 的召回率和标签, 第二行是该 dirty class 的 distract classes 对应列的漏检率和标签:

[1] 0.5015, 兰科_手参属 Gymnadenia nigra
    0.4715, 兰科_手参属 Gymnadenia rhellicani
[2] 0.4767, 兰科_杓兰属 Cypripedium calcicola #褐花杓兰
    0.4651, 兰科_杓兰属 Cypripedium tibeticum #西藏杓兰
[3] 0.4355, 唇形科_夏枯草属 Prunella hispida #硬毛夏枯草
    0.5242, 唇形科_夏枯草属 Prunella vulgaris #夏枯草
[4] 0.4689, 唇形科_大青属 Clerodendrum lindleyi #尖齿臭茉莉
    0.4379, 唇形科_大青属 Clerodendrum bungei #臭牡丹
[5] 0.4254, 天门冬科_玉簪属 Hosta albomarginata #紫玉簪
    0.4377, 天门冬科_玉簪属 Hosta ventricosa #紫萼
[6] 0.2006, 杜鹃花科_杜鹃花属 Rhododendron × pulchrum #锦绣杜鹃
    0.6783, 杜鹃花科_杜鹃花属 Rhododendron simsii #杜鹃
[7] 0.2059, 玄参科_醉鱼草属 Buddleja fallowiana #紫花醉鱼草
    0.5980, 玄参科_醉鱼草属 Buddleja davidii #大叶醉鱼草
[8] 0.3960, 百合科_大百合属 Cardiocrinum giganteum var. yunnanense #云南大百合
    0.4430, 百合科_大百合属 Cardiocrinum giganteum #大百合
[9] 0.4472, 睡莲科_睡莲属 Nymphaea alba #白睡莲
    0.4797, 睡莲科_睡莲属 Nymphaea
[10] 0.3261, 睡莲科_睡莲属 Nymphaea nouchali #延药睡莲
     0.6594, 睡莲科_睡莲属 Nymphaea
[11] 0.3036, 美人蕉科_美人蕉属 Canna generalis #大花美人蕉
     0.6145, 美人蕉科_美人蕉属 Canna
[12] 0.3588, 美人蕉科_美人蕉属 Canna indica var. flava #黄花美人蕉
     0.4941, 美人蕉科_美人蕉属 Canna
[13] 0.4628, 美人蕉科_美人蕉属 Canna orchioides #兰花美人蕉
     0.4662, 美人蕉科_美人蕉属 Canna
[14] 0.3429, 茄科_木曼陀罗属 Brugmansia suaveolens #大花木曼陀罗
     0.4190, 茄科_曼陀罗属 Datura stramonium #曼陀罗
[15] 0.4802, 菊科_茼蒿属 Glebionis segetum #南茼蒿
     0.4484, 菊科_茼蒿属 Glebionis coronaria #茼蒿
[16] 0.4959, 蓼科_金线草属 Antenoron filiforme var. neofiliforme #短毛金线草
     0.4008, 蓼科_金线草属 Antenoron filiforme #金线草
[17] 0.4352, 蔷薇科_绣线菊属 Spiraea × bumalda 'Goalden Mound' #金山绣线菊
     0.3472, 蔷薇科_绣线菊属 Spiraea japonica #粉花绣线菊

由上面结果可见植物识别训练集的 dirty class 和其 distract classes 很多都是同属的 (这很符合直觉, 有的同属植物外观相似性大, 本身而言和对人而言都是不易区别的).

利用分类层权重指导该分类器训练集的数据清洗

设分类层权重矩阵 $W$ 的尺寸为 (embedding_size, num_classes).

1) 对分类层权重矩阵按列 (即按类别) 进行 L2 归一化, 得到归一化分类层权重矩阵 $\hat{W}$. 计算 $\hat{W}^T\hat{W}$, 即相当于计算 $\hat{W}$ 的两两列向量之间的内积, 或者说相当于计算 $W$ 的两两列向量之间的余弦相似度, 所以不妨称为相似度矩阵, 记为 $S$, 其尺寸为 (num_classes, num_classes). 因为 $W$ 的列向量可以视作类的模板特征, 所以 $S$ 可以用来表征两两类别之间的相似度.
2) 在相似度矩阵 $S$ 上, 对于每一行, 找到第二大的值及索引, 即相当于对于每一类, 找到除自身外与该类最相似的类. 如果该值大于一定的阈值 thresh, 即认为该行对应的类可能存在较多噪声, 不妨称之为 dirty class.
3) 对于每个 dirty class, 计算其相似度矩阵 $S$ 对应行上的 top_k 个值, 在 top_k 类中排除 dirty class, 剩余的类称为该 dirty class 的干扰类 (distract classes).
4) 人工检查上面得到的 dirty class 及其 distract classes 的数据.

在植物识别训练集上, 该方法在 thresh 为 0.64 时, 得到 10 个 dirty class, 如下所示, 其中每一组的第一行是 dirty class 与其自身的相似度 (恒为 1.0) 和 dirty class 的标签, 第二行是 dirty class 与其 distract classes 的相似度和 distract classes 的标签:

[1] 1.0000, 夹竹桃科_钉头果属 Gomphocarpus fruticosus #钉头果
    0.6692, 夹竹桃科_钉头果属 Gomphocarpus physocarpus #气球果
[2] 1.0000, 夹竹桃科_钉头果属 Gomphocarpus physocarpus #气球果
    0.6692, 夹竹桃科_钉头果属 Gomphocarpus fruticosus #钉头果
[3] 1.0000, 百合科_大百合属 Cardiocrinum giganteum #大百合
    0.6431, 百合科_大百合属 Cardiocrinum giganteum var. yunnanense #云南大百合
[4] 1.0000, 百合科_大百合属 Cardiocrinum giganteum var. yunnanense #云南大百合
    0.6431, 百合科_大百合属 Cardiocrinum giganteum #大百合
[5] 1.0000, 葫芦科_Marah Marah fabacea
    0.6683, 葫芦科_Marah Marah macrocarpa
[6] 1.0000, 葫芦科_Marah Marah macrocarpa
    0.6683, 葫芦科_Marah Marah fabacea
[7] 1.0000, 蓼科_金线草属 Antenoron filiforme #金线草
    0.6587, 蓼科_金线草属 Antenoron filiforme var. neofiliforme #短毛金线草
[8] 1.0000, 蓼科_金线草属 Antenoron filiforme var. neofiliforme #短毛金线草
    0.6587, 蓼科_金线草属 Antenoron filiforme #金线草
[9] 1.0000, 鸢尾科_庭菖蒲属 Sisyrinchium albidum
    0.6781, 鸢尾科_庭菖蒲属 Sisyrinchium campestre
[10] 1.0000, 鸢尾科_庭菖蒲属 Sisyrinchium campestre
     0.6781, 鸢尾科_庭菖蒲属 Sisyrinchium albidum

注意到: 各个 dirty class 和其 distract classes 大都也为其同属类.

利用类平均特征矩阵指导分类或检索数据集的清洗

方法类似 "利用分类层权重指导该分类器训练集的数据清洗", 只是将分类层权重矩阵换成类的平均特征矩阵 (特征一般取分类层前一层的输出). 该方法可以用于人脸识别, 行人重识别中的类级别的数据清洗.

方法比较

上面提到的三种数据清洗方法都是类级别的, 不同之处:

  • 利用分类层权重指导该分类器训练集的数据清洗: 只能对分类器的训练集进行清洗.
  • 利用混淆矩阵指导分类数据集清洗: 只能对出现在训练集类别中的分类数据集 (包括分类器的训练集和测试集) 进行清洗.
  • 利用类平均特征矩阵指导分类或检索数据集的清洗: 可以对任意 (同域) 的分类或检索数据集进行清洗.

更新记录

  • 20211021, 发布
  • 20211022, 添加一些内容

版权声明

自由转载-非商用-非衍生-保持署名 (创意共享3.0许可证)

版权归属: 采石工
本文链接: https://quarryman.cn/article/20211021
版权声明: 除特别声明外, 文章采用《署名-非商业性使用-相同方式共享 4.0 国际》许可协议.