自定义tf.keras.Model需要注意的点

model.save()

  • subclass Model 是不能直接save的,save成.h5,但是能够save_weights,或者save_format="tf"
  1. NotImplementedError: Saving the model to HDF5 format requires the model to be a Functional model or a Sequential model. It does not work for subclassed models, because such models are defined via the body of a Python method, which isn't safely serializable. Consider saving to the Tensorflow SavedModel format (by setting save_format="tf") or using `save_weights`.

model.trainable_variables

  • __init__若没有注册该layers,那么在后面应用梯度时会找不到model.trainable_variables。

    像下面这样是不行的:
  1. class Map_model(tf.keras.Model):
  2. def __init__(self, is_train=False):
  3. super(Map_model, self).__init__()
  4. def call(self, x):
  5. x = tf.keras.layers.Dense(10, activation='relu')
  6. return x

model.summary()

  • 需要先指定input_shape,或者你直接fit一遍它也能自动确定
  1. model.build(input_shape=(None, 448, 448, 3))
  2. print(model.summary())
  1. class Map_model(tf.keras.Model):
  2. def __init__(self, is_train=False):
  3. super(Map_model, self).__init__()
  4. self.map_f1 = tf.keras.layers.Dense(10, activation='relu', trainable=is_train)
  5. # self.map_f2 = tf.keras.layers.Dense(6, activation='relu')
  6. self.map_f3 = tf.keras.layers.Dense(3, activation='softmax', trainable=is_train)
  7. def call(self, x):
  8. x = self.map_f1(x)
  9. # x = self.map_f2(x)
  10. return self.map_f3(x)
  11. @tf.function
  12. def train_step(mmodel, label, L_label, loss_object, train_loss, train_accuracy, optimizer):
  13. with tf.GradientTape() as tape:
  14. L_label_pred = mmodel(label)
  15. loss = loss_object(L_label, L_label_pred)
  16. gradient_l = tape.gradient(loss, mmodel.trainable_variables)
  17. train_loss(loss)
  18. train_accuracy(L_label, L_label_pred)
  19. optimizer.apply_gradients(zip(gradient_l, mmodel.trainable_variables))
  20. def train():
  21. mmodel = Map_model(is_train=True)
  22. optimizer = tf.keras.optimizers.Adam(0.01)
  23. loss_object = tf.keras.losses.CategoricalCrossentropy()
  24. train_loss = tf.keras.metrics.Mean(name='train_loss')
  25. train_accuracy = tf.keras.metrics.CategoricalAccuracy(name='train_accuracy')
  26. EPOCHS = 0
  27. labels = range(1, 30) # labels = truth_label -1
  28. L_labels = [int(prpcs.map2Lclass(l)) for l in labels]
  29. labels = [l - 1 for l in labels]
  30. labels_onehot = tf.one_hot(labels, depth=29)
  31. L_labels_onehot = tf.one_hot(L_labels, depth=3)
  32. EPS = 1e-6
  33. loss_e = 0x7f7f7f
  34. while loss_e > EPS:
  35. EPOCHS += 1
  36. train_loss.reset_states()
  37. train_accuracy.reset_states()
  38. train_step(mmodel, labels_onehot, L_labels_onehot, loss_object, train_loss, train_accuracy, optimizer)
  39. template = 'Epoch {}, Loss: {}, Accuracy: {}'
  40. print(template.format(EPOCHS,
  41. train_loss.result(),
  42. train_accuracy.result() * 100))
  43. loss_e = train_loss.result()
  44. print("labels_onehot shape:", labels_onehot.shape)
  45. model_path = r'./models/'
  46. if not os.path.exists(model_path):
  47. os.makedirs(model_path)
  48. mmodel.save(os.path.join(model_path, 'map_model_{}'.format(EPS)))
  49. mmodel.save_weights(os.path.join(model_path, 'map_model_weights_{}'.format(EPS)))
  50. print("Save model!")

tensorflow 2.0 技巧 | 自定义tf.keras.Model的坑的更多相关文章

  1. tf.keras遇见的坑:Output tensors to a Model must be the output of a TensorFlow `Layer`

    经过网上查找,找到了问题所在:在使用keras编程模式是,中间插入了tf.reshape()方法便遇到此问题. 解决办法:对于遇到相同问题的任何人,可以使用keras的Lambda层来包装张量流操作, ...

  2. [TensorFlow 2.0] Keras 简介

    Keras 是一个用于构建和训练深度学习模型的高阶 API.它可用于快速设计原型.高级研究和生产. keras的3个优点: 方便用户使用.模块化和可组合.易于扩展 简单点说就是,简单.好用.快(构建) ...

  3. 三分钟快速上手TensorFlow 2.0 (上)——前置基础、模型建立与可视化

    本文学习笔记参照来源:https://tf.wiki/zh/basic/basic.html 学习笔记类似提纲,具体细节参照上文链接 一些前置的基础 随机数 tf.random uniform(sha ...

  4. python 3.7 安装 sklearn keras(tf.keras)

    # 1   sklearn  一般方法 网上有很多教程,不再赘述. 注意顺序是 numpy+mkl     ,然后 scipy的环境,scipy,然后 sklearn # 2 anoconda ana ...

  5. 【tf.keras】TensorFlow 1.x 到 2.0 的 API 变化

    TensorFlow 2.0 版本将 keras 作为高级 API,对于 keras boy/girl 来说,这就很友好了.tf.keras 从 1.x 版本迁移到 2.0 版本,需要修改几个地方. ...

  6. 一文上手Tensorflow2.0之tf.keras(三)

    系列文章目录: Tensorflow2.0 介绍 Tensorflow 常见基本概念 从1.x 到2.0 的变化 Tensorflow2.0 的架构 Tensorflow2.0 的安装(CPU和GPU ...

  7. TensorFlow2.0(11):tf.keras建模三部曲

    .caret, .dropup > .btn > .caret { border-top-color: #000 !important; } .label { border: 1px so ...

  8. 【tf.keras】tf.keras使用tensorflow中定义的optimizer

    Update:2019/09/21 使用 tf.keras 时,请使用 tf.keras.optimizers 里面的优化器,不要使用 tf.train 里面的优化器,不然学习率衰减会出现问题. 使用 ...

  9. 【tf.keras】Resource exhausted: OOM when allocating tensor with shape [9216,4096] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc

    运行以下类似代码: while True: inputs, outputs = get_AlexNet() model = tf.keras.Model(inputs=inputs, outputs= ...

随机推荐

  1. php的异步非阻塞swoole模块使用(一)实现简易tcp服务器--服务端

    绑定tcp服务器的地址 $swserver = new swoole_server("127.0.0.1",9501); 设置tcp服务器装机容量(太危言耸听了-其实就是设置属性) ...

  2. qt5-信号和槽

    信号函数: connect(btn,&QPushButton::clicked,this,&QWidget::close); //参数1 信号发送者://参数2 信号:---& ...

  3. JAVA笔记9-多态(动态绑定、池绑定)

    1.动态绑定:执行期间(而非编译期间)判断所引用对象的实际类型,根据实际的类型调用相应方法. 2.多态存在的三个必要条件(同时):继承.重写.父类引用指向子类对象. 这三个条件满足后,当调用父类中被重 ...

  4. IE大文件断点续传

    IE的自带下载功能中没有断点续传功能,要实现断点续传功能,需要用到HTTP协议中鲜为人知的几个响应头和请求头. 一. 两个必要响应头Accept-Ranges.ETag 客户端每次提交下载请求时,服务 ...

  5. 《python cookbook》学习笔记

    2016.5.3 第8章  类与对象 8.1 改变对象的字符串显示 __str__ 和 __repr__   %s 和 %r,提到了eval,我没有用过 8.2 自定义字符串的格式化  __forma ...

  6. 4.JSP内置对象

    JSP内置对象,JSP提供了由容器实现和管理的内置对象,也可以称之为隐含对象,这些内置对象不需要通过 JSP页面编写来实例化,在所有的JSP页面中都可以直接使用,它起到了简化页面的作用. 在JSP中一 ...

  7. AbpUser 扩展

    AbpUser表存放的信息比较少,现扩展一下信息 1.在Core层添加UserExtend 类,继承 AbpUser<User>,写入以上各项属性字段,并添加Discriminator 字 ...

  8. [CSP-S模拟测试]:字符交换(贪心+模拟)

    题目传送门(内部题136) 输入格式 输入文件第一行为两个正整数$n,k$,第二行为一个长度为$n$的小写字母字符串$s$. 输出格式 输出一个整数,为对字符串$s$进行至多$k$次交换相邻字符的操作 ...

  9. localhost与127.0.0.1的区别是什么?

    localhost与127.0.0.1的区别是什么?都代表本地服务器 相信有人会说是本地ip,曾有人说,用127.0.0.1比localhost好,可以减少一次解析. 看来这个问题还有人不清楚,其实这 ...

  10. C++入门经典-例4.7-变量的作用域

    1:代码如下: // 4.7.cpp : 定义控制台应用程序的入口点. // #include "stdafx.h" #include <iostream> using ...