Model reproducibility is becoming the next frontier for successful AI models building and deployments for both Research and Production scenarios. In this talk we will show you how to build reproducible AI models and workflows using PyTorch and MLflow that can be shared across your teams, with traceability and speed up collaboration for AI projects.
4. ~1,619C O N T R I B U T O R S
50%+Y O Y G R O W T H
34K+P Y T O R C H F O R U M U S E R S
5. G R O W I N G U S A G E I N O P E N S O U R C E
Source: https://paperswithcode.com/trends
6. G R O W T H O F D A T A I N M L P I P E L I N E S @ F A C E B O O K
FB DATA USED IN AN ML
PIPELINE IN 2018
FB DATA USED IN AN ML
PIPELINE TODAY
DATA WAREHOUSE
GROWTH SINCE 2018
ML DATA GROWTH
SINCE 2018
30% 50% 3X2X
7. G R O W T H O F M L T R A I N I N G @ F A C E B O O K
WORKFLOWSUNIQUE USERS COMPUTE CONSUMED
5X
INCREASE
2X
INCREASE
8X
INCREASE
9. TRADITIONAL SOFTWARE VS MACHINE LEARNING
• Continuous, Iterative process, Optimize for metric
• Quality depends on data and tuning parameters
• Experiment tracking is difficult
• Over time data changes, model drift
• Compare + combine many libraries and models
• Diverse deployment environments
10. REPRODUCIBILITY CHALLENGE
• Difficult to reproduce results of a paper,
• Missing data, Model weights, scripts
R E S E A R C H
• Hyper parameters, Features, Data,
Vocabulary and other artifacts lost
• People leaving company
P R O D U C T I O N
12. REPRODUCIBILITY CHECKLIST
• Dependencies — does a repository have information on
dependencies or instructions on how to set up the environment?
• Training scripts — does a repository contain a way to train/fit
the model(s) described in the paper?
• Evaluation scripts — does a repository contain a script to
calculate the performance of the trained model(s) or run
experiments on models?
• Pretrained models — does a repository provide free access to
pretrained model weights?
• Results — does a repository contain a table/plot of main results
and a script to reproduce those results?
13. ARXIV + PWC —> REPRODUCIBLE RESEARCH
https://medium.com/paperswithcode/papers-with-code-partners-with-arxiv-ecc362883167
15. Model
Registry
Store, annotate
and manage
models in a central
repository
Projects
Package data science
code in a format that
enables reproducible
runs on many
platform
Models
Deploy machine
learning models in
diverse serving
environments
Tracking
Record and query
experiments:
code, data, config,
and results
PyTorch auto logging PyTorch examples w/
MLProjects
TorchScripted models,
Save/Load artifacts
MLflow TorchServe
Deployment Plugin
MLFLOW + PYTORCH FOR REPRODUCIBILITY
16. M L F L O W A U T O L O G G I N G
• PyTorch auto logging with Lightning training
loop
• Model hyper-parameters like LR, model
summary, optimizer name, Min delta, Best
Score
• Early stopping and other callbacks
• Log every N iterations
• User defined metrics like F1 score, test
accuracy
• ….
from mlflow.pytorch.pytorch_autolog import autolog
parser =
LightningMNISTClassifier.add_model_specific_args(parent_par
ser=parser)
autolog() #just add this and your autologging should work!
mlflow.set_tracking_uri(dict_args['tracking_uri'])
model = LightningMNISTClassifier(**dict_args)
early_stopping = EarlyStopping(monitor="val_loss",
mode="min", verbose=True)
checkpoint_callback = ModelCheckpoint(
filepath=os.getcwd(), save_top_k=1, verbose=True,
monitor="val_loss", mode="min", prefix="",
)
lr_logger = LearningRateLogger()
trainer = pl.Trainer.from_argparse_args(
args,
callbacks=[lr_logger],
early_stop_callback=early_stopping,
checkpoint_callback=checkpoint_callback,
train_percent_check=0.1,
)
trainer.fit(model)
trainer.test()
18. S A V E A R T I F A C T S • Additional artifacts for model reproducibility
• For Example: vocabulary files for NLP models,
requirements.txt and other extra files for torchserve deployment
mlflow.pytorch.save_model(
model,
path=args.model_save_path,
requirements_file="requirements.txt",
extra_files=["class_mapping.json", "bert_base_uncased_vocab.txt"],
)
:param requirements_file: An (optional) string containing the path to requirements file.
If ``None``, no requirements file is added to the model.
:param extra_files: An (optional) list containing the paths to corresponding extra files.
For example, consider the following ``extra_files`` list::
extra_files = ["s3://my-bucket/path/to/my_file1",
"s3://my-bucket/path/to/my_file2"]
In this case, the ``"my_file1 & my_file2"`` extra file is downloaded from S3.
If ``None``, no extra files are added to the model.
19. T O R C H S C R I P T E D M O D E L
• Log TorchScripted model
• Serialize and Optimize models for python-free
process
• Recommended for production inference
mlflow.set_tracking_uri(dict_args["tracking_uri"])
model = LightningMNISTClassifier(**dict_args)
# Convert to TorchScripted model
scripted_model = torch.jit.script(model)
mlflow.start_run()
# Log the scripted model using log_model
mlflow.pytorch.log_model(scripted_model, "scripted_model")
# If you need to reload the model just call load_model
uri_path = mlflow.get_artifact_uri()
scripted_loaded_model =
mlflow.pytorch.load_model(os.path.join(uri_path,
"scripted_model"))
mlflow.end_run()
20. TORCHSERVE
• Default handlers for common use cases (e.g., image segmentation, text classification) along with custom handlers support
for other use cases and a Model Zoo
• Multi-model serving, Model versioning and ability to roll back to an earlier version
• Automatic batching of individual inferences across HTTP requests
• Logging including common metrics, and the ability to incorporate custom metrics
• Robust HTTP APIS - Management and Inference
21. D E P L O Y M E N T P L U G I N
New TorchServe Deployment Plugin
Test models during development cycle, pull
models from MLflow Model repository and run
• CLI
• Run with Local vs remote TorchServe
• Python API
mlflow deployments predict --name mnist_test --target
torchserve --input_path sample.json --output_path
output.json
import os
import matplotlib.pyplot as plt
from torchvision import transforms
from mlflow.deployments import get_deploy_client
img = plt.imread(os.path.join(os.getcwd(), "test_data/one.png"))
mnist_transforms = transforms.Compose([
transforms.ToTensor()
])
image = mnist_transforms(img)
plugin = get_deploy_client("torchserve")
config = {
'MODEL_FILE': "mnist_model.py",
'HANDLER_FILE': 'mnist_handler.py'
}
plugin.create_deployment(name="mnist_test",
model_uri="mnist_cnn.pt", config=config)
prediction = plugin.predict("mnist_test", image)
23. PYTEXT
PARAMETER SWEEPING
EVALUATION
TRAINING
MODEL AUTHORING
NEW IDEA / PAPER
PYTORCH
MODEL
PYTHON
SERVICE
SMALL-SCALE
METRICS
PYTEXT
PERFORMANCE TUNING
EXPORT VALIDATION
EXPORT TO TORCHSCRIPT
PYTORCH
TORCHSCRIPT
C++
INFERENCE
SERVICE
RESEARCH TO PRODUCTION CYCLE @ FACEBOOK