四步训练出自己的CNN手写识别模型 | 《阿里云机器学习PAI-DSW入门指南》

点击即可参与机器学习PAI-DSW动手实验室

点击可下载完整电子书《阿里云机器学习PAI-DSW入门指南》

虽然已经 9102 年了MNIST手写数据集也早已经被各路神仙玩出了各种花样,比如其中比较秀的有用MINST训练手写日语字体的。但是目前还是很少有整体的将训练完之后的结果部署为一个可使用的服务的。大多数还是停留在最终Print出一个Accuracy。

这一次/ b 4 9 [,借助阿里云的PAI-DSW来y G T 7快速构建训练一个手写模型并且部署出一个生产可用级P h @ 8 g别的服务的教程让大家可以在/ o Hf 3 , g k ~他的产品中调用这个服f 9 K d ~务作出更加有意思的项目。

这篇文章里+ 4 V _ { ]我们先讲讲如何构建训练并导出这个手写字体识别的模型。整个教程的代码基于Snapchat的ML大佬 Aymeric Damien 的Tensorflow 入门教程F s Q X C系列。

第一步: 下载代码

首先我们可以把代码Clone到本地或者直接Clone到DSW的实例。如何Clone到) G e P G 3DSW实例的方法H = X D B U c可以参考我0 4 P % i , i的这篇文章。Clone完代码之后- F @ + g X我们还需要准备训练所需要的数据集这边可以直接从Yann Lecun的网站下载。我这边然后我们可先运行一遍看一下效果。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
我们可以看到代码Clone下来之后直接运行就已经帮我们训练出了model并且给出了现在这个Model的精+ 6 | d 8 a度。在500个batchF m | F U 5 v q之后准确率达到了95%以上而且基于GPU的DSW实例训练这5N G y V #00个Batch只需要十几秒的时间。

第二步: 修改部分代码使得可以自动导出SavedModel

这一步就是比较重要的地方了我们第一个需要关注的就是当前的这个Model里面的Input和Output.
Input还比较清楚我们直接找所有placeholder就可以了
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
Output这一块就比较复杂了,在当前的model里我们可以看到output并不是直/ G . V # $ -接定义的Y而是softmaxo X b I l _ 2之后的prediction
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
找到了这些之后就比较简单了。首先我们创建一个 Saver , 它可以帮助我们保存所有的tf变量以便之后导出模型使用

# 8 X q h H 4 H [ )'Saver' op to save and restore all the variables
saver = tf.train.Saver()+ H K 8 g  _

然后我们在模型训练的session结束的时候导出模型就好了。我们可以通过以下这段代码来导出我们训练好的模型。

import datetime
# 声明导出模型路径这边加入了时间作为路径名 这样每次训练的时候就可以保存多个8 ! t = @ ; X j b版本的模型了
er U s y zxport_path = "./model-" + datetime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S')
# 保存训练的日志文件方便如果出问题了我们可以用 tensorboard 来可视化神经网络排查问题
tf.summary.FileWriter('./graph-' + datetb r `ime.datetime.now().strftime('%Y-%m-%d_%H:%M:%S') , sess.graph)
# 构建我们的Builder
builder = tf.saved_model.builderq V I Q }.SavedModelBuilder(export_path)
# 声明各种输入这里有一个X和一个keep_prob作为输入然后
tensor_info_x = tf.saved_model.utR - ? 5 3 {ils.build_tensor_info(X)
tensor_info_keep_prob = tf.saved_b I {model.utils.build_tensor_info(keJ d d f n q & J Lep_prob)
tensor_info_y = tf.save9 z : f 4 - @d_model.V 5 z ? d z M mutils.build_tens8 J q 1 Lor_info(prediction)
prediction_signature = (
tf.saved_model.signature_def_utils.build_signature_def(
# 声明输入
inputs={
'images': tens{ o } b / gor_info_x,
'keep_prob' : tensor_info_keep_prob
},
# 声明输出
outputs={
'scores': tensor_info_y
},
method_name=tf.saved_model.signature_cons; w htants.PREDICT_METHOD_NAME
)
)
legacy_init_op = tG ` ?f.group(tf.tables_initiaM h R 6 - ; . _lizer(), name='legacy_init_op')
builder.add_meta_graph_and_variables(
sess, [tf.saved_modef . q T  3 ^ El.tag_cH ) J S ) gonstants.SERVING],
signature_def_g # - 9 map={
'predict_images':
prediction_signature,
},
legacy_init_op=legacy_init_op)
# 保存模型
builder.save()

我们可以把这段代码插在这里这样训练完成的时候就会自动导出了。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
导出之后应该会有如下的文件结构3 ` 8 1 6 ` C我们也可以在左边的文件管理器中查看。

./model-2019-05-20_13:50:26
├─h = [ s % ) ?  .─ sav ` 0ed_model.pb
└── variables
├── variables.data-00000-of-00001
└── variables.index
1 directory, 3 files

第三步: 部署我们的模型

终于到了可以部署的阶段了。但是在部署之前先别那s L g O H P么着急建议用 tensorboard 把训练日志下载到本地之后看一下。

这一步除了可以可视化的解释我q ` @们的模型之外还可以7 ^ e ~ j a h . -帮助我^ l f | a w n X们理清我们的模型的输入和输出分别是什么。

这边我先在有日志文件的路径打开一个tensorboard 通过这个命令

tensorboard --logdir ./

然后我们在游览器里输入默认地址 localhost:6006 就可以看到了。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
从这个图里也可以看到我们的这个Model里有2个输入源分别叫做images和keep_prob。并且点击它们之后我们还能看到对应的数据格式应该是什么样的。不过没有办法使用 Tensorboard 的同学也不用担心因为EAS这个产品也为我们提供了构造请求的方Z v B式。这一次部署我们先使用WEB界面来部署我们的服务这一步也可以通过EASCMD来实现之后我会再写一篇如何用好EASCMD的文章。

我们可b / D c以把模型文件下载完k c S ~ ! ! w j q之后用zipa u = 0 9 ; $ ] d打包然后到PAI产品的控制台点击EAS-模型在线服务。
ZIP打包可以用这个命令如果你是Unix的用户的话

zip% ? : z -r model.zip path/to/model_files

进入EAS4 z 1 K , K之后我们点击模型V j = v部署上传
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
然后继续配置我们的processor这一次因为我们是用tensorflow训练的所以选择Tensorflow
然后资源选择CPU有需要的同学可以考虑GPU然后上传我们的模型文件。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
点击下一步我H i i们选新建服务然后给我们的服务起个名字,并且配置资源数量。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
然后最后确认一下就可以点击部署了。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》

第四步: 调P = y }试我们的模型J 4 Z j x b D v }

T l / n [ b Q y |到EAS的控制台我们可以看到我们的服务正在被构建中。等到状态显示Running的时候我们就可以开始调试了。
我们可以先点击在线调试。
四步训练出自己的CNN手写识别模型  | 《阿里云机器学习PAI-DSW入门指南》
会让我们跳转到一个Debug 接口的页面。什么都不需B 3 m y g要填直接点击提交我们就可以看到服务的u L o ;数据格式了。
然后我们用一段pyt@ = X G lhon2的代码来调试这个刚刚部署完的服务。c h 1 a p P Spython3的SDK暂时还在研发中。注意要把下面的
app_key, app_secret, url 换成我们刚刚部署好的内容。点击模型名字就可以看见了。
其中测试图片的数据大家可以在这下载到。

#!/usr/bin/env pytht m Bon
# -*- coding: utf-8 -*-
import json
from urlparse impor } Y A h & 1t urlparse
from com.aliyun# $ 5 S H j L.api.gateway.sdk import client
from com.aliyun.api.gateway.sdk.http import req8 ~ O O ! Puest
from com.aliyun.api.gateway.sdk.common import constant
from pF 6 ! * y Y [ 0ai_tf_pred% C | x T Yict_proto import tf_predict_I e N w [ 3 9 +pb2
import cv2
import nump7 1 J G %y as np
with open('4 Z ( Y H9.jpg', 'rb') as infile:
buf = infile.read()
# 使用numpy将字节流转换成array
x = np.fromstring(buf, dtype='uint8K F [ b r ` 3 8 H')
# 将读取到的array进行图片解码获& d 3 } Y得28  28的矩阵
img = cv2.imdecode(x, cv2.IMREAD_UNCHANGED)
# 由于预测服务APII Z e @  S需要长度为784的一维向量将矩阵reshape成784
img = np.reZ @ = 4 ( T 0shape(F - &img, 784)
def predict(url, app_key, app_secret, request_data):
cli = client.DefaultClient(app_key=app_key, app_secret=r 8 R * 7 B 0 Bapp_secret)
body = request_data
url_el6 _ a - J 2e = urlparse(url)
host = 'http://' + url_ele.hostname
path = url_ele.path
reG x 4 k :q_post = request.Request(host=host, protocol=constant.HTTP, url=path, method="POST", tim: C J g F d /e_@ # a ) @ nout=6000)
req_post.set_body(body)
req_post.set_c7 N _ 9  ? vontent_type(constant.CONTENT_TYPE_STREAM)
stat,header, content = cli.execute(req_post)
return stat, dict(header) if header is not None eP u 7 Y j 0 S Vlse {}, content
def d4 / ^ kemox ` A 8():
# 输入模型信息,点击模型名字就可以获取到了
app_keT O K V + jy = 'YOUR_APP_KEY'
app_secret = 'YOU- H S = )R_APP_SECRET'
url = 'YOUR_APP_URL'
# 构造服务
requn { 2 a L F S - |est% O T l r | T = tf_pr9 l o [ : ^ 0 vedict_pb2.PredictRequest()
request.signature_name = 'predicd : 2 3 Z @ O = Nt_images'
request.inputs['images'X r N t b 2].dH J m A , @type = tf_predict_pbi U . u r2.DT_FLOAT  # images 参数类型
request.inputs['images'].array_shape.dim.extend([1, 784])  # images参数的形状
requn & I  yest.inputs['images'].float* s % N A D Y c X_val.extend(img)  # 数据
ri V Q y [ ^ ] n Oequest.inputh O s['keep_prob'].dtype = tf_predict_pb2.DT@ p w ;_FLOAT  # keep_pJ e * ] a o ! A Erob 参数的类型
requestH c u v ) P.inputs['keep_prob'].float_val.extend([0.75])  # 默认填写一个
# †pb—–string› “
request_daf a ~ta = request.SerializeToString()
st~ O gat, header, content = predict(url, app_keyZ T B M V & r % ., app_secret, request_data)
if stat !# v G j ?= 200:
print 'Http status code: ', stat
print 'Error msg in header: ', head4 [ K 7er['x-ca-err: ^ & F 8 $or-message'] if 'x-ca-error-message' in header else ''
print 'Error msg in body: ', conteK E v P R ; lnt
else:
response = tf_predict_pb2.PredictResponse()
response.ParseFromString(content)
print(response)
if __name__ == '__main__':
demo()

运行这个python代码然后我们会得到

outputs {
key: "scoO o [ }res"
value {
dtype: DT_FLc i ; . E cOAT0 | A i x - u a 
array_shape {
dim:- M } 1
dim: 10
}
float_val: 0.0
float_val: 0.0
float_val: 0.0
fl` 4 V 2 | x 5 qoat_val: 0.0
float_vJ ; Ial: 0.0
float_val: 0.0
float_val: 0.0
float_val: 0.0
float_vS D | m 7 2 N . 6al: 0.0
float_val: 1.0
}
}

我们可以看到从0开始数的最后一个也就1 O S $ # ) C ? N是第9个的结果A / g C是1 其他都是0 说明我们的结果是9和我们输入的一样。这样我们就简单轻松的构建了一个在线服务能够将用@ n - ; 6户的图片中手写数字识别出来。配合其他Web框架或者更多的东西我们就可以作 9 W k I v 7 S出更好玩的玩具啦。