Cloud ML Engineでモデルの作成からGAEからのアクセスまで

TensorFlowで作成したモデルをGoogle Cloud PlatformのML Engineにデプロイし,Google App Engineにたてたサーバからアクセスしてみます.

全体の流れ

  1. TensorFlowのSavedModelを作成
  2. SavedModelをCloud Storageに保存
  3. ML Engineにモデル,バージョンを作成しデプロイ
  4. 正常にデプロイできているかcurlから確認
  5. GAEからML Engineにアクセス

1. SavedModelの作成

import os
import tensorflow as tf
from sklearn import datasets
from sklearn.model_selection import train_test_split


def inference(x_data):
    W = tf.Variable(tf.zeros([4, 3]))
    b = tf.Variable(tf.zeros([3]))
    y = tf.nn.softmax(tf.matmul(x_data, W) + b)

    return y


def loss_func(y_target, y):
    cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_target, logits=y))
    return cross_entropy


def training(loss):
    train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
    return train_step


def main():
    output_dir = "model"
    if os.path.exists(output_dir):
        print(f"{output_dir} is already exists.")
        return

    # data
    iris = datasets.load_iris()
    X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target)

    # placeholders
    train_labels = tf.placeholder(tf.int64, [None])
    train_one_hot = tf.one_hot(train_labels, depth=3, dtype=tf.float32)
    test_labels = tf.placeholder(tf.int64, [None])
    test_one_hot = tf.one_hot(test_labels, depth=3, dtype=tf.float32)

    x_data = tf.placeholder(tf.float32, [None, 4])
    y_target = tf.placeholder(tf.float32, [None, 3])

    with tf.Session() as sess:
        model_output = inference(x_data)
        loss = loss_func(y_target, model_output)
        training_op = training(loss)

        init = tf.global_variables_initializer()
        sess.run(init)

        # learn
        xtest = sess.run(train_one_hot, feed_dict={train_labels: y_train})
        for step in range(1000):
            sess.run(training_op, feed_dict={x_data: X_train, y_target: xtest})

        # test
        ytest = sess.run(test_one_hot, feed_dict={test_labels: y_test})
        predictions_correct = tf.cast(tf.equal(tf.argmax(model_output, 1), tf.argmax(y_target, 1)), tf.float32)
        accuracy = tf.reduce_mean(predictions_correct)
        res = sess.run(accuracy, feed_dict={x_data: X_test, y_target: ytest})
        print("{0:.2f}".format(res))

        # save
        input_signatures = {
            "input": tf.saved_model.utils.build_tensor_info(x_data),
        }
        output_signatures = {
            "output": tf.saved_model.utils.build_tensor_info(model_output)
        }
        predict_signature_def = tf.saved_model.signature_def_utils.build_signature_def(
            input_signatures,
            output_signatures,
            tf.saved_model.signature_constants.PREDICT_METHOD_NAME
        )

        builder = tf.saved_model.builder.SavedModelBuilder(output_dir)
        builder.add_meta_graph_and_variables(sess=sess,
                                             tags=[tf.saved_model.tag_constants.SERVING],
                                             signature_def_map={tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: predict_signature_def},
                                             assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS))

        builder.save()


if __name__ == "__main__":
    main()

2. ML Engineにモデル,バージョンを作成しデプロイ

# upload
gsutil cp -r ./model gs://$BUCKET_NAME/

# make model
gcloud ml-engine models create $MODEL_NAME --region ASIA-NORTHEAST1 --eneble-logging

# deploy
gcloud ml-engine versions create $VERSION_NAME --model $MODEL_NAME --origin $MODEL_DIR --runtime-version=1.12 --framework TENSORFLOW --python-version=3.5

4. curlから確認

curl -X POST -H "Content-Type: application/json" -d '{"instances": [{"input": [6.8,  2.8,  4.8,  1.4]}, {"input": [6.0,  3.4,  4.5,  1.6]}]}' -H "Authorization: Bearer `gcloud auth print-access-token`" "https://ml.googleapis.com/v1/projects/$PROJECT/models/$MODEL_NAME/versions/$VERSION_NAME:predict"

結果

{"predictions": [{"output": [0.00039858807576820254, 0.9749729633331299, 0.02462848834693432]}, {"output": [0.0073119597509503365, 0.9235738515853882, 0.06911419332027435]}]}

5. Goで書いたAPIサーバをGAEにデプロイする

とりあえず予測結果はログにだすようにします.

package main

import (
    "context"
    "encoding/json"
    "net/http"
    "strings"
    "time"

    "golang.org/x/oauth2/google"
    "google.golang.org/api/ml/v1"
    "google.golang.org/appengine"
    "google.golang.org/appengine/log"
)

type Instance struct {
    Input []float32 `json:"input"`
}

type MlRequest struct {
    Instances []Instance `json:"instances"`
}

type Prediction struct {
    Output []float32 `json:"output"`
}

type MlResponse struct {
    Predictions []Prediction `json:"predictions"`
}

func callPredict(instances []Instance, project string, model string, version string, ctx context.Context) ([]Prediction, error) {
    request := &MlRequest{
        Instances: instances,
    }

    ctxWithTimeOut, _ := context.WithTimeout(ctx, 5*time.Second)
    client, err := google.DefaultClient(ctxWithTimeOut, ml.CloudPlatformScope)
    if err != nil {
        return []Prediction{}, err
    }

    data, _ := json.Marshal(request)

    mlService, _ := ml.New(client)
    predictReq := &ml.GoogleCloudMlV1__PredictRequest{
        HttpBody: &ml.GoogleApi__HttpBody{
            Data: string(data),
        },
    }

    start := time.Now()
    call := mlService.Projects.Predict("projects/"+project+"/models/"+model+"/versions/"+version, predictReq)
    googleBody, err := call.Do()
    if err != nil {
        return []Prediction{}, err
    }
    end := time.Now()
    log.Infof(ctx, "%v[ms]", float64(end.Sub(start).Nanoseconds())/1000000.0)

    if len(googleBody.Data) == 0 {
        return []Prediction{}, err
    }

    decoder := json.NewDecoder(strings.NewReader(googleBody.Data))
    var resp MlResponse
    if err := decoder.Decode(&resp); err != nil {
        return []Prediction{}, err
    }

    return resp.Predictions, nil
}

func handle(w http.ResponseWriter, r *http.Request) {
    ctx := appengine.NewContext(r)

    instances := []Instance{
        {Input: []float32{6.8, 2.8, 4.8, 1.4}},
        {Input: []float32{6.0, 3.4, 4.5, 1.6}},
    }

    prediction, err := callPredict(instances, "PROJECT", "MODEL", "VERSION", ctx)
    if err != nil {
        log.Errorf(ctx, "%v", err)
    }

    for _, p := range prediction {
        log.Infof(ctx, "%v", p.Output)
    }

    json.NewEncoder(w).Encode("ok")
}

func main() {
    http.HandleFunc("/", handle)
    appengine.Main()
}

6. サーバにデプロイ

gcloud app deploy --project ${PROJECT}

7. logを確認

結果は取得できたけどすごい遅い・・・

2019-03-13 11:53:32.222 JST 174.349599[ms]
2019-03-13 11:53:32.222 JST [0.00039858808 0.97497296 0.024628488]
2019-03-13 11:53:32.222 JST [0.0073119598 0.92357385 0.06911419]