Flask + 머신러닝 모델 연동하기-Rest API 작성

2020. 9. 20. 00:24Python/Flask

반응형

앞에서 생성 된 모델 파일을 이용하여서 Rest API를 통하여서 사용자가 입력한 데이터를 받은 후 어떤 종류의 꽃인지 반환 해주는 API작성

코드작성

app.py 작성

from flask_restful import reqparse
from flask import Flask, jsonify

import numpy as np
import pickle as p
import json


app = Flask(__name__)


@app.route('/predict/', methods=['POST'])
def predict():
    parser = reqparse.RequestParser()
    parser.add_argument('petal_length')
    parser.add_argument('petal_width')
    parser.add_argument('sepal_length')
    parser.add_argument('sepal_width')

    # creates dict
    args = parser.parse_args()

    # convert input to array
    X_new = np.fromiter(args.values(), dtype=float)

    # predict - return ndarray
    prediction = model.predict([X_new])

    # result
    out = {'Prediction': get_label(prediction[0])}

    return out, 200

def get_label(label_num):
    labels = {'0' : 'iris-setosa',
              '1' : 'iris-versicolor',
              '2' : 'iris-virginica'}

    return labels.get(str(label_num))

if __name__ == '__main__':
    modelfile = 'models/iris_prediction.pickle'
    model = p.load(open(modelfile, 'rb'))
    app.run(debug=True, port=9090)

※ 참고로 프로젝트 생성 방법은 jydlove.tistory.com/21 글 참고

테스트

Postman 을 통하여 테스트 실행

앞에 모델 작성에서 step3 : 분류별 데이터 시각화 부분의 차트를 확인해보면 제대로 예측한것을 확인 할수 있다.

반응형