Keras中LN层的使用

本文最后更新于:2021年8月15日 上午

1、BN的问题

BN是按照样本数计算归一化统计量的,当样本数很少时,比如说只有4个。这四个样本的均值和方差便不能反映全局的统计分布息,所以基于少量样本的BN的效果会变得很差。

2、LN层的使用

这张图与我们平常看到的feature maps有些不同,立方体的3个维度为别为batch/ channel/ HW,而我们常见的feature maps中,3个维度分别为channel/ H/ W,没有batch。分析上图可知:BN计算均值和标准差时,固定channel(在一个channel内),对HW和batch作平均;LN计算均值和标准差时,固定batch(在一个batch内),对HW和channel作平均,更详细的推导过程可以查阅参考文献。

可以看到与BN不同,LN没有对batch作平均,所以当batch变化时,网络的错误率不会有明显变化。论文的实验显示:LN和IN 在时间序列模型(RNN/LSTM)和生成模型(GAN)上有很好的效果,而GN在视觉模型(CNN)上表现更好。

LN与BN层非常相似,特点主要体现在两个方面:

  1. LN得到的模型更稳定;
  2. LN有正则化的作用,得到的模型更不容易过拟合。

关于Layer normalization, Keras官方API连接如下

https://keras.io/api/layers/normalization_layers/layer_normalization/

但是请注意,这个函数在Keras中可能无法正常使用,但可以通过pip正确安装

Install

1
pip install keras-layer-normalization

Usage

1
2
3
4
5
6
7
8
9
import keras
from keras_layer_normalization import LayerNormalization


input_layer = keras.layers.Input(shape=(2, 3))
norm_layer = LayerNormalization()(input_layer)
model = keras.models.Model(inputs=input_layer, outputs=norm_layer)
model.compile(optimizer='adam', loss='mse', metrics={},)
model.summary()

Reference