分类目录归档:参考手册

Keras 获取参数数量,模型可视化

本文介绍如何获取Keras模型的参数数量和模型结构等。

1. 参数数量

TensorFlow 2.0 以上

import tensorflow.keras.backend as K
import numpy as np

trainable_count = np.sum([K.count_params(w) for w in model.trainable_weights])
non_trainable_count = np.sum([K.count_params(w) for w in model.non_trainable_weights])

print('Total params: {:,}'.format(trainable_count + non_trainable_count))
print('Trainable params: {:,}'.format(trainable_count))
print('Non-trainable params: {:,}'.format(non_trainable_count))

2. 网络结构可视化

注意:如需可视化模型,需要安装Graphvizpydot

函数原型请参见:Keras 可视化(中文版文档已过期)

Keras 函数原型:

tf.keras.utils.plot_model(
    model,
    to_file="model.png",
    show_shapes=False,
    show_layer_names=True,
    rankdir="TB",
    expand_nested=False,
    dpi=96,
)

参数列表

  • model: Keras 模型
  • to_file: 输出图像文件名。(注:矢量图以svg,eps等为后缀,位图以png,jpg等为后缀)
  • show_shapes: 显示Tensor尺寸(注意,如果没有定义输入Tensor尺寸,显示为‘?’)
  • show_layer_names: 是否显示每一层的名称
  • rankdirrankdir Graphviz 参数 : ‘TB’ 从上至下绘图; ‘LR’ 从左至右绘图。
  • expand_nested: 是否展开嵌套的模型。
  • dpi: 输出位图的DPI。

Tensorflow 2.0 以上版本

from tensorflow.keras.utils import plot_model

plot_model(model, to_file='model.png', show_shapes=True)

设置CMake 使用的Visual Studio 2017工具链版本

自Cmake 3.12 起使用Visual Studio 15 2017生成器时可以指定Visual Studio 2017的工具链版本。

本文使用的工具如下:

  1. CMake 3.12.0
  2. Visual Studio 2017

Visual Studio 2017工具链版本:

  1. 14.14
  2. 14.11

工具链版本设置方法如下:

  • 使用cmake命令的方法
    cmake -G "Visual Studio 15 2017" -T host=x64,version=14.xx ../
  • cmake gui使用方法
    在配置项目时增加version参数

    host=x64为指定编译器为64位版本

附加文件

CMake 3.12.0 下载页

参考资料

CMake 3.12.0 Release Note