https://nine-num-98.blogspot.com/2020/03/ai-socket-01.html
公開したプログラムの内容について解説します。
GitHub公開プログラム
https://github.com/kotetsu99/cnn_socket
前回は↑のプログラム群のうち、以下のプログラムについて解説しました。
(1)01-cnn_train.py: 画像学習プログラム
・画像をAIに学習させ、画像認識AIモデルを生成。
・AIモデルは、CIFAR-10 画像データセット向けのAIモデル(9層畳み込みネットワーク)
https://nine-num-98.blogspot.com/2020/03/ai-socket-02.html
今回は、それに続き以下のソケット通信部のサンプルコードについて解説します。
(2)02-cnn_server.py: 画像認識AIソケットサーバー
・ソケット通信サーバー。(3)のクライアントから画像ファイルを受信し保存。
・受信した画像ファイルに対し、(1)で作ったAIで画像認識させ、結果をクライアントに返す。
(3)03-socket_client.py: 画像送信ソケットクライアント
・ソケット通信クライアント。画像ファイルを(2)のサーバーにバイナリデータで送信。
・サーバーの画像認識結果を受信して、コンソール画面に表示。
イメージをつかむためのシーケンス図のほうも再掲しておきます。

画像認識AIソケットサーバー
02-cnn_server.py ソケット通信サーバーのプログラムです。
後述のクライアントプログラムから画像ファイルデータを受信し、ファイルとして保存。
画像ファイルに対し、(1)で作ったAIで画像認識させ、結果をクライアントに返します。
まずは序盤のmain関数までを見ていきます。
#!/usr/bin/env python # -*- coding: utf-8 -*- import keras from keras.preprocessing import image import numpy as np import socket import os, sys # サーバーIPアドレス定義 host = "0.0.0.0" # サーバーの待ち受けポート番号定義 port = 50001 # 入力画像リサイズ定義 img_width, img_height = 64, 107 # 学習用画像ファイル保存ディレクトリ train_data_dir = 'dataset/train' # 受信画像保存ディレクトリ sc_dir = 'dataset/sc' # 受信画像ファイル名 sc_file = 'sc_file.png' def main(): # 環境設定(ディスプレイの出力先をlocalhostにする) os.environ['DISPLAY'] = ':0' # データセットのサブディレクトリ名(クラス名)を取得 classes = [] for d in os.listdir(train_data_dir): if os.path.isdir(os.path.join(train_data_dir, d)): classes.append(d) print 'クラス名リスト = ', classes # 学習済ファイルの確認 if not len(sys.argv)==2: print('使用法: python 02-cnn_server.py 学習済ファイル名.h5') sys.exit() savefile = sys.argv[1] # モデルのロード model = keras.models.load_model(savefile) # ソケット定義(IPv4,TCPによるソケット) s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: # 次の実行に備え、ソケットをTIME-WAIT切れを待つことなく、再利用できるようにしてお>く s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # IPとPORTを指定してバインド(ソケットに紐づけ) s.bind((host,port)) # ソケット接続待受(キューの最大数を指定) s.listen(10) while True: # ソケット接続受信待ち try: print 'クライアントからの接続待ち...' # 接続が来たら対応する新しいソケットオブジェクト作成、接続先アドレスを格納 clientsock, client_address = s.accept() # 接続待ちの間に強制終了が入った時の例外処理 except KeyboardInterrupt: print 'Ctrl + C により強制終了' break # 接続待ちの間に強制終了なく、クライアントからの接続が来た場合 else: # 受信処理を行い、画像認識結果を返す recv_client_data(clientsock, model, classes) except Exception as e: print(e) finally: # ソケットを閉じる s.close()
サーバーIPアドレス、待ち受けポート番号を定義等、諸々の設定を定義した後にmain関数に入ります。
main関数では、まず01-cnn_train.py で既に作った画像認識AIモデルを読み込んだ後、ソケット通信のサーバー側処理に移行します。
ソケット通信サーバープログラムの一般的な作法は、pythonのドキュメントにある、ソケット通信のサンプルプログラムに記されています。
https://docs.python.org/ja/2.7/library/socket.html
ソケット定義、bind、listen、acceptで、クライアントからの接続待ち状態にできます。ここでは、前に実行したプログラムがソケットを一定時間占有(TIME_WAIT)せず、すぐに再利用できるような設定をいれてます。listen(10)ということで、最大で10個の接続要求を同時に受け付けることができます。
ソケット定義から後の処理は、try,except,finally という例外処理内で記述されます。これは、ネットワーク切断などの何らかのトラブルが起きたときに、異常内容を表示させ、ソケットを閉じる処理をfinallyで確実に行えるようにするためのものです。
try文の中には、さらにネストでtry,except,elseと例外処理が accept以降で入っています。これはacceptで接続待ち状態のときに、Ctrl + C コマンドで強制終了させるためのものです。
accept中にクライアントからの接続が入った場合、clientsock(ソケットオブジェクト)を作成して、else文の受信処理を行います。
受信処理は以下の関数で定義しています。
def recv_client_data(clientsock, model, classes): # 受信データ保存用変数の初期化 all_data = '' try: # ソケット接続開始後の処理 while True: # データ受信。受信バッファサイズ1024バイト data = clientsock.recv(1024) # 全データ受信完了(受信路切断)時に、ループ離脱 if not data: break # 受信データを追加し繋げていく all_data += data # 受信画像ファイル保存 with open(sc_dir + '/' + sc_file, 'wb') as f: # ファイルにデータ書込 f.write(all_data) # 受信画像ファイルに対しAIで画像認識を実行 res = cnn_recognition(model, classes) # 認識結果をクライアントに送信 clientsock.sendall(res) except Exception as e: print '受信処理エラー発生' print(e) finally: # コネクション切断 clientsock.close()
try, except, finally という例外処理がここでも実装されていますが、データ受信中にネットワークが落ちるなどのトラブルが起きた場合に備えています。トラブルが発生した場合でも、コネクションが確実に残らないよう、finallyで切断処理を書いています。
try文の中ですが、これが正常系の動作になります。クライアントから送られてくるバイナリーデータを、clientsock.recvにより1024バイト単位でループしながら受け取ります。
受け取ったデータ(data)は、all_dataという変数につなぎ合わせて保存されていきます。
クライアントから受信するデータがなくなったとき(if not data がtrue)に、ループを抜けますが、これはクライアント側の方で送信路を閉じるという処理がトリガーになります。
ループを抜けた後、all_dataを画像ファイルに書き込みます。そして保存されたファイルに対し、01-cnn_train.pyで既に作っていたAIモデルで画像認識させ、最後にクライアントに認識結果を送信します。
画像ファイルに対し、AIによる画像認識を実行しているのが、以下の関数です。
def cnn_recognition(model, classes): # 画像ファイル取得 filename = os.path.join(sc_dir, sc_file) img = image.load_img(filename, target_size=(img_width, img_height)) # 入力画像定義(画像の行, 画像の列, チャネル数の3次元テンソル) x = image.img_to_array(img) # 4次元テンソル(サンプル数, チャネル数, 画像の行数, 画像の列数)に変換 x = np.expand_dims(x, axis=0) # 学習時に正規化してるので、ここでも正規化 x = x / 255 # 画像サンプルが1枚のみなので、最初の1枚[0]の認識結果を格納 pred = model.predict(x)[0] # 予測確率が高い順番に認識結果を出力 top = 4 top_indices = pred.argsort()[-top:][::-1] result = [(classes[i], pred[i]) for i in top_indices] #print('file name is', sc_file) print '受信ファイル認識結果:' print(result) print('=======================================') return result[0][0]
kerasで定義されている関数を用いて、画像データをテンソル表現し、modelのpredict関数で認識結果を出力しています。
認識結果は、各クラス該当する確率の高い順にソートしてコンソール画面に表示します。
最後に、その中で一番確率の高いクラス名を返します。この画像認識結果が、クライアントに返されるわけです。
# 上記処理については、こちらの記事が参考になりました。
https://qiita.com/nanako_ut/items/d35d74b7d692659b1e03
画像送信ソケットクライアント
最後に、03-socket_client.py の説明です。
画像認識AIソケットサーバーに向けて、画像ファイルをバイナリデータで送信。
その後サーバーの画像認識結果を受信して、コンソール画面に表示します。
以下にコードを示します。
#!/usr/bin/env python # -*- coding: utf-8 -*- import socket import os, sys # サーバーIPアドレス定義 host = "0.0.0.0" # サーバー待ち受けポート番号定義 port = 50001 def main(): # 送信画像ファイルパス引数の取得 if not len(sys.argv)==2: print('使用法: python 03-socket_client.py 画像ファイル名') sys.exit() image_file = sys.argv[1] # ファイルをバイナリデータとして読み込み with open(image_file, 'rb') as f: binary = f.read() # ソケットクライアント作成 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) try: # 送信先サーバーに接続 s.connect((host, port)) # サーバーにバイナリデータを送る print(image_file + 'をサーバーに送信') s.sendall(binary) # データ送信完了後、送信路を閉じる s.shutdown(1) # サーバーからの応答を取得。バッファサイズ1024バイト res = s.recv(1024) # サーバー応答を表示 print('認識結果: ' + res) except Exception as e: # 例外が発生した場合、内容を表示 print(e) finally: # ソケットを閉じて終了 s.close() if __name__ == '__main__': main()
実行コマンドの引数として渡された、画像ファイルをバイナリデータとして読み込んだ後、AIサーバーとソケット通信を行います。
ソケット通信は、pythonのドキュメントにある、ソケット通信のサンプルプログラムに示されているように、一般的な作法を参考に書いています。
https://docs.python.org/ja/2.7/library/socket.html
まずソケットクライアントを作成し、サーバー接続。その後画像データをすべて送信しています。
ここで、重要なのが s.shutdown(1) という処理です。データをsendallで送り切ったあと、「もうこれ以上送るデータはないよ。」と送信終了のお知らせを、サーバー側に伝える必要があります。sendallにはそのような機能が付加されていないようなので、shutdown(1) により送信路を閉じることで、これを実現します。
一方で受信路は活かしておき、サーバー側から返ってくる画像認識結果を受け取り、printでそれを表示させます。
上記の処理が行われている間に、ネットワーク接続がきれるなど、何らかの異常が発生した場合に備えて、try,except,finallyによる例外処理を実装しています。何か問題がおきても、最終的にfinallyの処理でソケットが閉じられるようにしておきます。
以上が解説になりますが、ソケット通信でハマったのが、先述したs.shutdown(1)の箇所ですね。データの送信終了をサーバー側に知らせるやり方として、これが使えるのは中々見えてこなかったので。
もしかしたらもっとスマートな方法があるのかもしれませんが、一つの実装法として参考にしてもらえれば幸いです。