在神经网络的训练过程中,每一层的参数都在变化着,这就引起了该层的输出在不断变化,也就是其后续层的输入的分布在不断变化。这种变化的特性,导致人们必须仔细的选择神经网络初始化的数据并且使用较小的学习率,这就降低了神经网络的训练速度,由此可见训练神经网络并不是一项简单的工作。本文称神经网络层的这种输出不断变化的现象为内部协变量转移 (Internal Covariate Shift, ICS)。对此问题,本文提出了使用批标准化 (Batch Normalization, BN)来解决此问题。BN的使用可以让我们使用更高的学习率,在一些情况下还可以免去使用Dropout,而泛化性依旧可以保持。

文章目录

  1. 摘要
  2. 背景
  3. Batch Normalization
    1. BN 算法
    2. 在CNN中使用
    3. 使用更大的学习率
  4. 实验
    1. 使用参考
  5. 其它
  6. 引用

摘要

在神经网络的训练过程中,每一层的参数都在变化着,这就引起了该层的输出在不断变化,也就是其后续层的输入的分布在不断变化。这种变化的特性,导致人们必须仔细的选择神经网络初始化的数据并且使用较小的学习率,这就降低了神经网络的训练速度,由此可见训练神经网络并不是一项简单的工作。本文称神经网络层的这种输出不断变化的现象为内部协变量转移 (Internal Covariate Shift, ICS)。对此问题,本文提出了使用批标准化 (Batch Normalization, BN)来解决此问题。BN的使用可以让我们使用更高的学习率,在一些情况下还可以免去使用Dropout,而泛化性依旧可以保持。

Internal covariate shift/covariate shift 这里中文随手翻译的,看得懂意思就行。

背景

目前,神经网络的训练,我们一般会使用随机梯度下降的方法(或者其变体),这些方法虽然简单高效,但是它们可能需要仔细的设置超参数,特别是学习率、初始化参数。随着深度学习的发展,人们使用的神经网络深度也在不断的提升。由于神经网络后面层的输入依赖于之前所有层的输出,前面层中一些参数的微小的变化,在后面的层中可能被放大很多(特别是在一些层数比较多的神经网络中)。

神经网络层输入的分布不断变化的特性,导致后续层需要不断的调整以适应这种新的分布。一个学习系统输入分布的变化被称为协变量转移 (covariate shift),这个问题典型的解决方案就是domain adaptation。本文将协变量转移的概念由学习系统扩展到神经网络的每一个子层上去。显然如果每一个子层输出的分布不总是变来变去的,那么训练过程会变的更加稳定且容易。

针对ICS问题,本文提出了一种新的机制: Batch Normalization来减少ICS带来的影响,从而加速神经网络的训练。BN通过限制每一层网络输入的分布的均值和方差来相对稳定分布的变化;通过引入缩放参数来解决梯度大小对于输入参数大小过于敏感的问题(也即是输入参数如果过大(或过小),会导致相应梯度会过大(或过小))。

Batch Normalization

一般我们在数据预处理时,会对数据进行标准化,也就是将数值数据特征转换为均值为0,方差为1的数据来帮助稳定训练。假设我们有一组数据,其中每一条数据有$d$维的特征,每条数据可表示为$x=(x^{(1)}, x^{(2)}, \ldots, x^{(d)})$,标准化后该数据表为:

$$ \hat{x}^{(k)} = \frac{x^{(k)} - E[x^{(k)}]}{\sqrt{Var[x^{(k)}]}}, k = 1, 2, ..., d $$

由于该变换将特征值限制在0附近,并且不会离的太远(方差为1),这在某种程度上就限制了每一层网络的表达。比如,假设我们在激活函数sigmoid之前使用BN,那么其实就是将大部分输入限制在了sigmoid函数近似线性的那一段。如下图所示,我们用标准正态分布类比标准化后的输入,可以看出大部分数据都是在0附近,0附近的sigmoid函数接近线性。

sigmoid函数与标准正态分布

为了解决这个问题,作者为每一个输入值$x^{(k)}$,引入了一对参数$\gamma^{(k)}, \beta^{(k)}$,进行如下变换:

$$ y^{(k)} = \gamma^{(k)}\hat{x}^{(k)} + \beta^{(k)} $$

这两个参数并不是预先设定的超参数,而是在训练过程中学习得到的。

理想状态下,我们希望设置:

$$ \begin{align} \gamma^{(k)} &= \sqrt{Var[x^{(k)}]}\\ \beta^{(k)} &= E[x^{(k)}] \end{align} $$

计算过程是针对所有训练样本的。但是在实际训练时,对于每一个BN层计算数据集中的所有数据效率是极低的。因此,我们一般在训练时都会使mini-batch的方法(所以本文称作batch normalization吧?)。

BN 算法

设我们有一个包含m条数据的mini-batch$\mathcal{B} = {x_{1\ldots m}}$,那么$y_i = BN_{\gamma, \beta}(x_i)$($\gamma, \beta$为参数)的计算步骤为:

$$ \begin{align} \mu_{\mathcal{B}} & \leftarrow\frac{1}{m}\sum_{i=1}^m x_i\\ \sigma_{\mathcal{B}}^2 & \leftarrow\frac{1}{m}\sum_{i=1}^m(x_i - \mu_{\mathcal{B}})^2\\ \hat{x}_i & \leftarrow \frac{x_i - \mu_{\mathcal{B}}}{\sqrt{\sigma_{\mathcal{B}}^2 + \epsilon}}\\ y_i & \leftarrow \gamma \hat{x}_i + \beta \equiv BN_{\gamma, \beta}(x_i) \end{align} $$

其中$\epsilon$为一个常量,用于防止除零。

BN的每一步都是可导的,因此很容易集成到深度学习网络中去。应用BN之后,假设原先每一层的输入为$x$,那么之后就变成了$BN(x)$。

按照原文的意思,此处,BN应该是应用到了激活函数之后。网上也有许多讨论,就是应该在激活函数之前还是之后使用BN。一种说法是如果使用ReLU,那么BN在前,如果使用s形激活函数比如sigmoid/tanh,BN在后。论文下文在讨论将BN用到CNN中时,将BN用到了激活函数之前。据作者在Google的同事网上发帖说,作者在代码中肯定是将BN用在了激活函数之后的(不知道说的哪一部分)。大家在实际使用时,建议都试试。讨论可以参考[2]。

使用BN进行训练时,显然我们需要batch size > 1,使用较大的batch size在一定程度上可以加快训练。但是在训练完成以后,我们进行测试或前向推理的时候,显然我们不希望或者无法对于每一层的激活值使用BN操作。一来, 我们可能仅计算一个样本,方差为0;二来,显然我们不希望样本的前向推理依赖于其它样本(BN的计算显然是依赖于输入的所有样本的)。所以我们需要得到一个新的均值$E[x]$和方差$Var[x]$用于在推理的时候替代$\mu_{\mathcal{B}}$和$\sigma_{\mathcal{B}}^2$,以便得到:

$$ \hat{x} = \frac{x - E[x]}{\sqrt{Var[x] + \epsilon}} $$

为了得到$E[x]$以及$Var[x]$的无偏估计,显然我们应该使用所有训练样本来得到这两个值,以保证准确性($\gamma, \beta$训练得到,直接使用即可)。我们以某一个BN层为例, BN的输入为$x$,具体步骤为:

  1. 设计一个神经网络$N_{\Theta}$, 其中$\Theta$为网络的参数
  2. 在需要的地方插入BN层,这样得到训练用网络$N_{BN}$
  3. 正常训练$N_{BN}$
  4. 训练完成后,对于每一个 batch $\mathcal{B_i}, i = 1, 2, ..., m$ (假设我们有 m 个 batch), 计算得到 BN的输入$x$的均值集合$\{\mu_{\mathcal{B_i}}\}$以及方差集合$\{ \sigma_{\mathcal{B_i}}^2 \}$
  5. 计算$E[x]$和$Var[x]$:
$$ \begin{align} E[x] & = \frac{1}{m} \sum_{i=1}^{m}\mu_{\mathcal{B_i}}\\ Var[x]& = \frac{m}{m-1}(\frac{1}{m}\sum_{i=1}^m \sigma_{\mathcal{B}_i}^2) \end{align} $$

有了$E[x]$和$Var[x]$之后,在推理阶段,我们就可直接使用单个样本或多个样本进行推理,而样本间无依赖关系。

在CNN中使用

为了保持CNN的网络特性,BN在CNN中的使用方法与全连接网络中的略有不同。

在CNN中,每一个filter会对整个"图像"进行操作,每一个filter根据输入图像的大小,会输出很多值。一个卷积层的filter对应的输出我们称之为一个feature map, 假设这个feature map的大小为$p\times q$。在全连接网络中,我们对于每一个激活值都计算了一个对应的BN值。而在CNN中,为了保持CNN的特性,自然很合理的可以想到对于每个filter设置一个BN层,这个BN层在filter对应的所有feature map间是共享的。

换句话说,每一个filter的feature map用的都是同样的均值和方差,原先在全连接网络中我们一个batch对应了一个$\mu_{\mathcal{B}}$,而在卷积网络中我们一个batch 对应了$p\times q$个$\mu_{\mathcal{B}}$ ($p\times q$为 feature map大小),因此我们有$m^{\prime} = m\times p\times q$个数据参与求均值与方差,均值为:

$$ \begin{align} E[x] & = \frac{1}{m^{\prime}} \sum_{i=1}^{m^{\prime}}\mu_{\mathcal{B_i}}\\ & = \frac{1}{m\times p \times q} \sum_{i=1}^{m\times p \times q}\mu_{\mathcal{B_i}} \end{align} $$

$Var(x)$求法也一样,是对$m^{\prime}$个数据求的。

使用更大的学习率

在传统深度神经网络中,使用大的学习率可能会导致梯度爆炸或者梯度消失问题,或者陷入局部最优。使用BN可以帮助缓解这些问题。激活值在经过BN之后,范围被限制在了0附近的的一个合理范围内(在使用$\gamma,\beta$进行线性变换之前),因此激活值过大或者过小都会被重新映射到另一个区间。这就可以让激活函数(比如sigmoid)的激活区域保持在合理的区域(non-saturated regimes),在这些区域中,激活函数的梯度不会消失(比如sigmoid函数在0附近的梯度是最大的,BN应该可以帮助变换后的值在0附近,而不会导致过大或过小,从而导致激活函数的梯度趋于0)。

在不使用BN的情况下,如果使用大的学习率,可能会导致每一层的参数在更新后数值的数量级发生变化(变大),这就很容易引起梯度爆炸问题。使用BN之后,这个问题就不存在了。因为BN首先将激活值全部都变为均值为0,方差为1了,也就是:

$$ BN((aW)u) = BN(Wu) $$

其中,标量$a$可以看作是对参数W缩放,$W$为参数,$u$为输入。因此:

$$ \frac{\partial BN((aW)u)}{\partial u} = \frac{\partial BN(Wu)}{\partial u} $$

所以,参数$W$即使被放大较多,也不会导致梯度爆炸问题。

因此,使用BN可以让用户使用更大的学习率。

实验

我们简单看下实验,下图展示了使用和不使用BN的情况下,训练MNIST数据集使用steps(应该是使用训练mini-batch的个数吧)与准确率变化的对比。

准确率 VS 训练 steps

可以看出使用BN后收敛速度加快了很多。

下图展示了BN在ImageNet上的表现。

Image Net 数据集上的表现

显然,使用BN以后,模型的收敛速度得到了极大的提升。其中$x5/x30$表示训练过程中使用BN后,调整学习率为原先的5倍以及30倍。其它实验结果可参见原论文。

使用参考

在训练ImageNet分类器时,对于使用BN的模型,作者对于网络进行了如下修改,可以作为参考:

  1. 提高了学习率
  2. 去除dropout
  3. 对于训练数据进行更加彻底的打乱 (为了得到更具代表性的batch数据)
  4. 减少 $L_2$正则系数
  5. 提高学习率衰减系数
  6. 移除*Local Response Normalization*
  7. 减少在图片上的光学上的调整,让网络多看到更多”真实“的图片

其它

作者在文中说BN可与你帮助解决ICS问题,后面一些研究[3]指出实际情况可能并不是这样的,BN和ICS可能并没有什么联系,在一些情况下还可能会增大ICS。以后有时间我们再分享。

引用

[1] Ioffe, Sergey, and Christian Szegedy. "Batch normalization: Accelerating deep network training by reducing internal covariate shift." International conference on machine learning. PMLR, 2015.

[2] Batch Normalization before or after ReLU? https://www.reddit.com/r/MachineLearning/comments/67gonq/d_batch_normalization_before_or_after_relu/

[3] Santurkar, Shibani, et al. "How does batch normalization help optimization?." arXiv preprint arXiv:1805.11604 (2018).


[本]通信工程@河海大学 & [硕]CS@清华大学
这个人很懒,他什么也没有写!

1
964
0

More Recommendations


Nov. 30, 2022