最优传输和 Sinkhorn 迭代算法

引言

最优传输 (optimal transport),又称作 wasserstein distance,离散情况下也被称为 earth mover's distance (EMD), 是一种度量概率分布之间距离的一个度量。 最优传输最近几年逐渐成为机器学习领域的一个热点,特别是 WGAN [2] 的出现,引发了引发了大家对最优传输 的兴趣。 有很多近期的工作将最优传输应用在诸多任务中,比如小样本学习 [3],自监督学习 [4],信息检索 [5]。 但是目前关于最优传输的中文资料较少,特别是关于最优传输的 Sinkhorn 迭代解法,几乎找不到相关的中文材料。

本文首先简要介绍离散情况下的最优传输问题,然后介绍文献 [1] 所提出的最优传输问题的快速解法以及对应的对应梯度的计算。 为了理解的方便,本文只介绍离散情况下的最优传输问题。 值得注意的是,有很多关于最优传输的资料会用到不同的关键词, 例如 optimal transport/wasserstein distance/earth mover's distance / sinkhorn distance 等等。 离散条件下 optimal transport/wasserstein distance/earth mover's distance 可以认为是等价的,Sinkhorn iteration 是一种最优传输的快速迭代解法,后文中会介绍到。

最优传输:一个简单的例子

假设地面上有 $d_1$ 个土堆,第 $i$ 个土堆有 $\mu_i$ 单位的土; 同时有 $d_2$ 个坑,第 $j$ 个坑可以容纳的土 $\mu_j$。 我们用 $$ \begin{split} \mu\in\mathbb{R}^{d1} \\ \nu\in\mathbb{R}^{d2} \end{split} $$ 来表示这两个土堆,并 假设 $\sum_i \mu_i = \sum_j \nu_j$,也就是所有的“土”刚好够填满所有的“坑”。 这里提醒大家注意 $\mu, \nu$ (读音 mu, nu) 并不是后面要出现的 $u, v$。 定义 $M\in\mathbb{R}^{d_1\times d_2}_+$ 为距离矩阵,其中 $m_{ij}$ 表示“从第 $i$ 个土堆搬运一份土到第 $j$ 个坑的成本”。 现在要将所有的土搬运到坑中,对应的搬运成本是 $\in\mathbb{R}^{d_1\times d_2}_+$,很显然 $$ \begin{equation} \begin{cases} \sum_i T_{i,j} = \nu \\ \sum_j T_{i,j} = \mu \end{cases}\label{eq:t-constrain} \end{equation} $$ 值得一提的是,满足公式$\ref{eq:t-constrain}$的所有 $T$ 的集合都在一个高维空间的多面体内。 因为公式$\ref{eq:t-constrain}$ 的所有约束都是线性的,平面上多个线性约束条件围成的区域是多边形, 拓展到高维空间后线性约束围成的集合就是多面体(polytope)。 这里用集合 $\mathbf{U}$ 来表示所有满足条件的传输,在优化中集合 $\mathbf{U}$ 又被称为“可行域”。

最优传输 $T^*$ 定义为“使得总体运输成本 $\langle T, M\rangle$ 最小的运输方案”: $$ \begin{equation} T^* = \inf_{T\in \mathbf{U}} \langle T, M\rangle \label{eq:ot} \end{equation} $$ 其中 $\langle X, Y\rangle = \sum_{i,j}X_{i,j}Y_{i,j}$ 为矩阵内积。 此时的传输方案 $T^*$ 称为最优传输。 这个例子也是最优传输的另一个名字“earth mover's distance”的由来。

简而言之,最优传输就是要在满足公式$\ref{eq:t-constrain}$约束条件下找到使得公式$\ref{eq:ot}$最小的传输方案 $T^*$。 很显然这是一个线性规划问题,因为不论是公式$\ref{eq:t-constrain}$中的约束条件还是公式$\ref{eq:ot}$中的目标函数 都是线性的。

从概率上来讲, 如果我们限制 $\sum_i \mu_i = \sum_j \nu_j = 1$,那么 $\mu, \nu$ 可以看作是两个概率分布, 而任意满足条件的传输 $T$ 可以认为是边缘概率分别为 $\mu, \nu$ 的一个联合概率分布。 因此最优传输可以看作是两个概率分布之间距离的一个度量。

Why optimal transport?

那么我们为什么要 optimal transport 呢?如果为了都两个分布之间的距离,有很多现成的度量可以用, 比如非常简单的 KL 散度: $$ \text{KL}(p, q) = \sum_i p_i \cdot\log\frac{p_i}{q_i}. $$ 除了“无法处理两个分布的支撑集不相交的情况”以及“不满足对称性”等原因之外,一个重要的原因就是这种 逐点计算的度量没有考虑分布内的结构信息。 所谓的结构信息,就是分布内不同事件之间的相关性。

就以我们常见的分类任务为例,分类任务通常用交叉熵损失来度量模型预测和样本标签之间的距离, 交叉熵损失实际上就是在计算 onehot 化的标签和模型预测之间的 KL 散度。 这种逐点计算的损失函数(不论是交叉熵还是 L2)都无法考虑分布内不同事件的相关性。 例如将“汽车”误分类成“卡车”显然没有把“汽车”误分类成“斑马”严重。 但是用 KL 散度来度量的话,这两种错误的损失是一样的。

假设$\mu \in \mathbb{R}^{d1}_+, \nu \in \mathbb{R}^{d2}_+$是两个离散概率分布, 满足 $\mu^T\mathbf{1}_{d1} = 1, \nu^T\mathbf{1}_{d2} = 1$,其中$\mathbf{1}_{d1},\mathbf{1}_{d1}$ 表示全1列向量。

假设传输矩阵 $T\in\mathbb{R}^{d_1\times d_2}$ 为: $$ \begin{equation} f_T(x) = \sum_{n=-\infty}^{+\infty} c_n e^{in\omega_0 x} \label{fourier-series-final} \end{equation} $$

Optimal Transport with Entropic Constraints

定义联合分布 $T$ 的熵 (entropy) 为: $$ \begin{equation} H(T) = -\sum_{i,j} T_{i,j}\log(T_{i,j}) \end{equation} $$

熵正则 (entropy regularized)后的传输 cost 为: $$ \begin{equation} D_{M, \lambda}(\mu,\nu) = \inf_{T\in \mathbf{U}} \langle T, M\rangle + \lambda\cdot H(T) \end{equation} $$

基于 sinkhorn 迭代的快速解法。

假设有向量 $u \in \mathbb{R}^{d1}, v\in \mathbb{R}^{d2}$ ,这里注意区分 $u, v$ 和前面的 $\mu, \nu$ (读音 mu, nu)。 令 $u, v$ 的初始值为 $$ \begin{split} u &= \mathbf{1}_{d1} / d1 \\ v &= \mathbf{1}_{d2} / d2 \end{split} $$ 然后用 sinkhorn 迭代算法求解: $$ \begin{equation} \begin{split} v_j = \nu_j / (K^T\mu)_j \\ u_i = \mu_i / (K\nu)_i \end{split}\label{eq:sinkhorn} \end{equation} $$ 迭代收敛后最优传输矩阵 $T^*$ 和对应的传输消耗 $C^*$ 可以由以下公式给出: $$ \begin{equation}\begin{split} T^* &= \text{diag}(u)K \text{diag}(v) \\ C^* &= \langle T^*, M\rangle \end{split}\end{equation} $$

引用

如果本文的内容对你撰写学术论文有帮助,希望能考虑引用:
@misc{zhao2020fourier,
title   = {最优传输与其 sinkhorn 迭代快速解法},
author  = {Kai Zhao},
year    = 2020,
note    = {\url{http://kaizhao.net/posts/optimal-transport}}
}

(The comment system is provided by Disqus that is blocked by the GFW. For users from mainland China you may need a VPN.)