交叉熵损失函数与三元组损失函数之间的关系

本文通过公式推导, 展示了交叉熵损失函数中存在着类似三元组的结构, 进而推导了交叉熵损失函数与三元组损失函数之间的关系.

交叉熵损失函数中的三元组

设某样本的特征为 $\vec{x}$, 标签为 $y$. 另设分类器权重为 $W = [\vec{w}^T_1;\vec{w}^T_2;\cdots;\vec{w}^T_C]$, 其中 $;$ 表示向量纵向连接, $C$ 为分类器的类别数.

$y$ 的 one hot 编码为 $\vec{e}_y=\left[ 1\left\{ y = 1 \right\}; \cdots ; 1\left\{ y = C \right\} \right]$, 其是一个 $C$ 维向量, $1\left\{ \cdot \right\}$ 是示性函数.

$\vec{x}$ 经过分类器且归一化后的输出为 $\vec{q} = \mathrm{softmax}(W\vec{x})$, 其也是一个 $C$ 维向量.

样本的交叉熵损失函数为:
$$\begin{equation}\begin{aligned} -\langle \vec{e}_y, \log^\circ{\vec{q}} \rangle &= -\sum_{j=1}^{C} {1\left\{ y = j \right\} \log q_j} \\ &= -\sum_{j=1}^{C} {1\left\{ y = j \right\} \log{\mathrm{softmax}_j(W\vec{x})}} \\ &= -\log \mathrm{softmax}_y(W\vec{x}) \\ &= -\log \frac{\exp(\vec{w}_{y}^T \vec{x})}{\sum_{j=1}^{C} \exp(\vec{w}_j^T \vec{x})} \\ &= \log \left( \sum_{j=1}^{C} \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x}\right) \right) \\ &= \log \left( 1 + \sum_{j=1, j\neq y}^{C} \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} \right) \right) \end{aligned}\end{equation}$$

如果把分类器权重当做标签的特征, 则当前样本特征 $\vec{x}$ 与当前样本标签特征 $\vec{w}_y$ 构成正样本对, 当前样本特征 $\vec{x}$ 与非当前标签特征 $\vec{w}_j$ ($j=1, \cdots, C$ 且 $j \neq y$) 构成负样本对. 由公式 (1) 可见, 交叉熵损失函数中隐藏着 $C-1$ 个三元组 (包含 $1$ 个正样本对, $C-1$ 个负样本对). 这就引发笔者思考: 交叉熵损失函数是不是与三元组损失函数存在更深层次的联系? 于是有了下面的推导.

交叉熵损失函数与三元组损失函数的关系推导

先给出一个结论: 当 $k > 0$ 时 $\log(1 + kx)$ 是凹函数.

证明: 令 $f(x) = \log(1 + kx)$, 则 $f'(x) = \frac{k}{1 + kx}$, $f''(x) = \frac{-k^2}{(1 + kx)^2}$. 因为 $k > 0$, 所以 $f''(x) < 0$. 根据凹函数的判定方法, 可得证.

利用上面的结论和凹函数的性质, 接着对交叉熵损失函数进行推导:
$$\begin{equation}\begin{aligned} \log \left( 1 + \sum_{j=1, j\neq y}^{C} \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} \right) \right) &= \log \left( 1 + C\sum_{j=1, j\neq y}^{C} \frac{1}{C} \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} \right) \right) \\ &\ge \sum_{j=1, j\neq y}^{C} \frac{1}{C}\log \left( 1 + C \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} \right) \right) \\ &= \frac{1}{C} \sum_{j=1, j\neq y}^{C} \log \left( 1 + \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} + d\right) \right) \end{aligned}\end{equation}$$
式中 $d = \log C$, 因为 $C > 1$, 所以 $d > 0$.

由于 $\log(1 + e^x)$ (即 softplus 函数) 是 $\max(x, 0)$ (即 ReLU 函数) 的平滑近似, 所以
$$\begin{equation}\begin{aligned} \frac{1}{C} \sum_{j=1, j\neq y}^{C} \log \left( 1 + \exp \left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} + d\right) \right) \end{aligned}\end{equation}$$

$$\begin{equation}\begin{aligned} \frac{1}{C} \sum_{j=1, j\neq y}^{C} \max\left(\vec{w}_j^T \vec{x} - \vec{w}_y^T \vec{x} + d, 0\right) \end{aligned}\end{equation}$$
的平滑近似.

公式 (4) 即是三元组损失函数 (这里用的是向量内积, 一些地方用的是向量 L2 距离, 两者本质上是等效的). 由公式 (2)(3)(4) 可见: 交叉熵损失函数可以认为是三元组损失函数的平滑近似.

更新记录

  • 20231103, 创建文档
  • 20231202, 修改文档
  • 20231228, 完善文档并发布
版权归属: 采石工
本文链接: https://quarryman.cn/article/20231228
版权声明: 除特别声明外, 文章采用《署名-非商业性使用-相同方式共享 4.0 国际》许可协议.