Keras 获取参数数量,模型可视化

本文介绍如何获取Keras模型的参数数量和模型结构等。

1. 参数数量

TensorFlow 2.0 以上

import tensorflow.keras.backend as K
import numpy as np

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

2. 网络结构可视化

注意:如需可视化模型,需要安装Graphvizpydot

函数原型请参见:Keras 可视化(中文版文档已过期)

Keras 函数原型:

tf.keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
)

参数列表

  • model: Keras 模型
  • to_file: 输出图像文件名。(注:矢量图以svg,eps等为后缀,位图以png,jpg等为后缀)
  • show_shapes: 显示Tensor尺寸(注意,如果没有定义输入Tensor尺寸,显示为‘?’)
  • show_layer_names: 是否显示每一层的名称
  • rankdirrankdir Graphviz 参数 : ‘TB’ 从上至下绘图; ‘LR’ 从左至右绘图。
  • expand_nested: 是否展开嵌套的模型。
  • dpi: 输出位图的DPI。

Tensorflow 2.0 以上版本

from tensorflow.keras.utils import plot_model

plot_model(model, to_file='model.png', show_shapes=True)