Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.

Machine learning in production with scikit-learn

2,103 views

Published on

Presented at PyOhio 2017: https://pyohio.org/schedule/presentation/284/

The Python data ecosystem provides amazing tools to quickly get up and running with machine learning models, but the path to stably serving them in production is not so clear. We'll discuss details of wrapping a minimal REST API around scikit-learn, training and persisting models in batch, and logging decisions, then compare to some other common approaches to productionizing models.

Published in: Technology
  • Be the first to comment

Machine learning in production with scikit-learn

  1. 1. Machine Learning in Production with scikit-learn Jeff Klukas - Data Engineer at Simple 1
  2. 2. 2
  3. 3. 3 • What’s the problem we’re solving? • Why machine learning? • Walkthrough of developing the model • ✨ Live demo ✨ • Complications of moving this workflow to production • Other potential approaches Overview
  4. 4. 4
  5. 5. 5 Categorizing chats # SELECT subject, body, category FROM chats; subject | body | category --------------+---------------------------+---------------- Check deposit | Hi how are you? I was… | education Lost Card | Can you send me a new… | urgent my transfer | My transfer of $10 isn’t… | education Mail deposits | I have a large check… | education urgent, customer education, new product, incidents, other
  6. 6. 6
  7. 7. 7
  8. 8. 8 ✨ ✨ ✨ ✨ ✨ ✨ ✨ ✨ 💖 💖 💖 Machine Learning ✨ 💖 ✨
  9. 9. 9
  10. 10. 10
  11. 11. 11 sklearn.pipeline from sklearn.pipeline import Pipeline from sklearn.feature_extraction.text import ( CountVectorizer, TfidfTransformer) from xgboost import XGBClassifier stopwords, lemmatizer = … pipeline = Pipeline([ ('preprocess', MessagePreprocessor(subject_weight=2)), ('text', TextProcessor(stopwords, lemmatizer)), ('vect', CountVectorizer()), ('tfidf', TfidfTransformer()), ('clf', XGBClassifier(objective='multi:softmax')), ])
  12. 12. 12 Training the model import pandas as pd data_frame = pd.read_sql(redshift_connection, "SELECT category, subject, body FROM chats;") X = data_frame[['subject', 'body']] y = data_frame['category'] X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=0) pipeline.fit(X_train, y_train)
  13. 13. 13 Overfitting https://en.wikipedia.org/wiki/Overfitting
  14. 14. 14 Testing the model from sklearn.metrics import classification_report y_predicted = pipeline.predict(X_test) print(classification_report(y_test, y_predicted)) precision recall f1-score support class 0 0.67 1.00 0.80 2 class 1 0.00 0.00 0.00 1 class 2 1.00 0.50 0.67 2 avg / total 0.67 0.60 0.59 5
  15. 15. 15 Serving the model in Flask from flask import route, jsonify, request @route('/chat-classification-api/messages', methods=['POST']) def classify_messages(): """Classify given chat messages""" messages = request.get_json() y = pipeline.predict(messages) # join class labels back with identifiers predictions = [{"chat_id": message["chat_id"], "class_label": label} for message, label in zip(messages, y)] return jsonify(predictions)
  16. 16. 16 Live Demo
  17. 17. 17 How do we take this to production?
  18. 18. 18 How do we take this to production?
  19. 19. Step 1 Separate training and serving 19
  20. 20. 20 Model Persistence import pickle import boto3 def write_to_s3(pipeline, key, bucket): s3_client = boto3.client("s3") kms_client = boto3.client("kms") pkl = pickle.dumps(pipeline) enc_pkl = my_encrypt_function(pkl, kms_client) s3_client.put_object(Bucket=s3_bucket, Key=key, Body=enc_pkl, ServerSideEncryption="AES256")
  21. 21. 21 Model Persistence import pickle import boto3 from flask import current_app def load_message_classifier(app): conf = app.config["MESSAGE_CLASSIFIER"] s3_client = boto3.client("s3") kms_client = boto3.client("kms") resp = s3_client.get_object(Bucket=conf[“bucket"], Key=conf["path"]) untrusted_bytes = resp["Body"].read() pkl = decrypt(untrusted_bytes, kms_client) with app.app_context(): current_app._message_classifier = pickle.loads(pkl)
  22. 22. Step 2 Provide an environment for batch training and evaluation 22
  23. 23. 23 Optimizing Parameter Values from sklearn.model_selection import GridSearchCV params = { 'preprocess__subject_weight': (1, 2, 3, 4, 5), 'text__stopwords': ([], IGNORE, PUNCTUATION), 'vect__max_df': (0.5, 0.75, 1.0), 'vect__ngram_range': ((1, 1), (1, 2)), 'tfidf__use_idf': (True, False), 'tfidf__norm': ('l1', 'l2'), } search = GridSearchCV(pipeline, params) search.fit(X_train, y_train)
  24. 24. Step 3 Monitor performance, adapt to production load, degrade gracefully 24
  25. 25. Other Approaches 25
  26. 26. 26 • How big is your team? • How large of a problem space do you need to cover? • What is your existing stack? Considerations
  27. 27. 27 Off-the-Shelf
  28. 28. 28 Off-the-Shelf
  29. 29. 29 Off-the-Shelf
  30. 30. 30 • Train and test in a batch environment • Output serialized model and classification report • sklearn.pipeline is convenient for storing code+params • Serve on-demand predictions separately • Treat this like any production service Recap
  31. 31. Thank You 31
  32. 32. 32 Questions ✨ 💖
  33. 33. Machine Learning in Production with scikit-learn Jeff Klukas - Data Engineer at Simple 33

×