过拟合和欠拟合

简单来说过拟合就是模型训练集精度高,测试集训练精度低;欠拟合则是模型训练集和测试集训练精度都低。

官方文档地址为 https://tensorflow.google.cn/tutorials/keras/overfit_and_underfit

过拟合和欠拟合

以IMDB dataset为例,对于过拟合和欠拟合,不同模型的测试集和验证集损失函数图如下:

baseline模型结构为:10000-16-16-1

smaller_model模型结构为:10000-4-4-1

bigger_model模型结构为:10000-512-512-1

造成过拟合的原因通常是参数过多或者数据较少,欠拟合往往是训练次数不够。

解决方法

正则化

正则化简单来说就是稀疏化参数,使得模型参数较少。类似于降维。

正则化参考: https://blog.csdn.net/jinping_shi/article/details/52433975

tf.keras通常在损失函数后添加正则项,l1正则化和l2正则化。

l2_model = keras.models.Sequential([
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
                       activation=tf.nn.relu, input_shape=(10000,)),
    keras.layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),#权重l2正则化
                       activation=tf.nn.relu),
    keras.layers.Dense(1, activation=tf.nn.sigmoid)
])

l2_model.compile(optimizer='adam',
                 loss='binary_crossentropy',
                 metrics=['accuracy', 'binary_crossentropy'])

l2_model_history = l2_model.fit(train_data, train_labels,
                                epochs=20,
                                batch_size=512,
                                validation_data=(test_data, test_labels),
                                verbose=2)

dropout

Dropout将在训练过程中每次更新参数时按一定概率(rate)随机断开输入神经元,使得比例为rate的神经元不被训练。

具体见: https://yq.aliyun.com/articles/68901

dpt_model = keras.models.Sequential([
    keras.layers.Dense(16, activation=tf.nn.relu, input_shape=(10000,)),
    keras.layers.Dropout(0.3), #百分之30的神经元失效
    keras.layers.Dense(16, activation=tf.nn.relu),
    keras.layers.Dropout(0.7), #百分之70的神经元失效
    keras.layers.Dense(1, activation=tf.nn.sigmoid)
])

dpt_model.compile(optimizer='adam',
                  loss='binary_crossentropy',
                  metrics=['accuracy','binary_crossentropy'])

dpt_model_history = dpt_model.fit(train_data, train_labels,
                                  epochs=20,
                                  batch_size=512,
                                  validation_data=(test_data, test_labels),
                                  verbose=2)

总结

常用防止过拟合的方法有:

  1. 增加数据量
  2. 减少网络结构参数
  3. 正则化
  4. dropout
  5. 数据扩增data-augmentation
  6. 批标准化