本文介绍如何获取Keras模型的参数数量和模型结构等。
1. 参数数量
TensorFlow 2.0 以上
import tensorflow.keras.backend as K import numpy as np trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights]) non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights]) print('Total params: {:,}'.format(trainable_count + non_trainable_count)) print('Trainable params: {:,}'.format(trainable_count)) print('Non-trainable params: {:,}'.format(non_trainable_count))
2. 网络结构可视化
函数原型请参见:Keras 可视化(中文版文档已过期)
Keras 函数原型:
tf.keras.utils.plot_model( model, to_file="model.png", show_shapes=False, show_layer_names=True, rankdir="TB", expand_nested=False, dpi=96, )
参数列表
- model: Keras 模型
- to_file: 输出图像文件名。(注:矢量图以svg,eps等为后缀,位图以png,jpg等为后缀)
- show_shapes: 显示Tensor尺寸(注意,如果没有定义输入Tensor尺寸,显示为‘?’)
- show_layer_names: 是否显示每一层的名称
- rankdir:
rankdir
Graphviz 参数 : ‘TB’ 从上至下绘图; ‘LR’ 从左至右绘图。 - expand_nested: 是否展开嵌套的模型。
- dpi: 输出位图的DPI。
Tensorflow 2.0 以上版本
from tensorflow.keras.utils import plot_model plot_model(model, to_file='model.png', show_shapes=True)