Real Time Machine Learning Visualization with Spark
1. Real Time Machine Learning
Visualization with Spark
Chester Chen, Ph.D
Sr. Manager, Data Science & Engineering
GoPro, Inc.
Hadoop Summit, San Jose 2016
2. Who am I ?
• Sr. Manager of Data Science & Engineering at GoPro
• Founder and Organizer of SF Big Analytics Meetup (4500+ members)
• Previous Employment:
– Alpine Data, Tinga, Clearwell/Symantec, AltaVista, Ascent Media, ClearStory Systems,
WebWare.
• Experience with Spark
– Exposed to Spark since Spark 0.6
– Architect for Alpine Spark Integration on Spark 1.1, 1.3 and 1.5.x
• Hadoop Distribution
– CDH, HDP and MapR
8. What is K-Means ?
• Given a set of observations (x1, x2, …, xn), where each observation is a d-
dimensional real vector,
• k-means clustering aims to partition the n observations into k (≤ n) sets
S = {S1, S2, …, Sk}
• The clusters are determined by minimizing the inter-cluster sum of squares (ICSS)
(sum of distance functions of each point in the cluster to the K center). In other
words, the objective is to find
• where μi is the mean of points in Si.
• https://en.wikipedia.org/wiki/K-means_clustering
10. Real Time ML Visualization
• Use Cases
– Use visualization to determine whether to end the
training early
• Need a way to visualize the training process
including the convergence, clustering or residual
plots, etc.
• Need a way to stop the training and save current
model
• Need a way to disable or enable the visualization
12. How to Enable Real Time ML Visualization ?
• A callback interface for Spark Machine Learning Algorithm to send
messages
– Algorithms decide when and what message to send
– Algorithms don’t care how the message is delivered
• A task channel to handle the message delivery from Spark Driver to
Spark Client
– It doesn’t care about the content of the message or who sent the message
• The message is delivered from Spark Client to Browser
– We use HTML5 Server-Sent Events ( SSE) and HTTP Chunked Response
(PUSH)
– Pull is possible, but requires a message Queue
• Visualization using JavaScript Frameworks Plot.ly and D3
13. Spark Job in Yarn-Cluster mode
Spark
Client
Hadoop Cluster
Yarn-Container
Spark Driver
Spark Job
Spark Context
Spark ML
algorithm
Command Line
Rest API
Servlet
Application Host
14. Spark Job in Yarn-Cluster mode
Spark
Client
Hadoop Cluster
Command Line
Rest API
Servlet
Application Host
Spark Job
App Context Spark ML
Algorithms
ML Listener
Message
Logger
16. Enable Real Time ML Visualization
SSE
Plotly
D3
Browser
Rest
API
Server
Web Server
Spark
Client
Hadoop Cluster
Spark Job
App Context
Message
Logger
Task Channel
Spark ML
Algorithms
ML Listener
Akka
Chunked
Response
Akka
17. Enable Real Time ML Visualization
SSE
Plotly
D3
Browser
Rest
API
Server
Web Server
Spark
Client
Hadoop Cluster
Spark Job
App Context
Message
Logger
Task Channel
Spark ML
Algorithms
ML Listener
Akka
Chunked
Response
Akka
20. Callback Interface: MLListenerSupport
trait MLListenerSupport {
// rest of code
def sendMessage(message: => Any): Unit = {
if (enableListener) {
listeners.foreach(l => l.onMessage(message))
}
}
21. KMeansEx: KMeans with MLListener
class KMeansExt private (…) extends Serializable
with Logging
with MLListenerSupport {
...
}
22. KMeansEx: KMeans with MLListener
case class KMeansCoreStats (iteration: Int, centers: Array[Vector], cost: Double )
private def runAlgorithm(data: RDD[VectorWithNorm]): KMeansModel = {
...
while (!stopIteration &&
iteration < maxIterations && !activeRuns.isEmpty) {
...
if (listenerEnabled()) {
sendMessage(KMeansCoreStats(…))
}
...
}
}
23. KMeans ML Listener
class KMeansListener(columnNames: List[String],
data : RDD[Vector],
logger : MessageLogger) extends MLListener{
var sampleDataOpt : Option[Array[Vector]]= None
override def onMessage(message : => Any): Unit = {
message match {
case coreStats :KMeansCoreStats =>
if (sampleDataOpt.isEmpty)
sampleDataOpt = Some(data.takeSample(withReplacement = false, num=100))
//use the KMeans model of the current iteration to predict sample cluster indexes
val kMeansModel = new KMeansModel(coreStats.centers)
val cluster=sampleDataOpt.get.map(vector => (vector.toArray, kMeansModel.predict(vector)))
val msg = KMeansStats(…)
logger.sendBroadCastMessage(MLConstants.KMEANS_CENTER, msg)
case _ =>
println(" message lost")
}
24. KMeans Spark Job Setup
Val appCtxOpt : Option[ApplicationContext] = …
val kMeans = new KMeansExt().setK(numClusters)
.setEpsilon(epsilon)
.setMaxIterations(maxIterations)
.enableListener(enableVisualization)
.addListener(
new KMeansListener(...))
appCtxOpt.foreach(_.addTaskObserver(new MLTaskObserver(kMeans,logger)))
kMeans.run(vectors)
25. ML Task Observer
• Receives command from User to update running Spark Job
• Once receives UpdateTask Command from notify call, it preforms the
necessary update operation
trait TaskObserver {
def notify (task: UpdateTaskCmd)
}
class MLTaskObserver(support: MLListenerSupport, logger: MessageLogger )
extends TaskObserver {
//implement notify
}
33. HTTP Chunked Response and SSE
SSE
Plotly
D3
Browser
Rest
API
Server
Web Server
Spark
Client
Hadoop Cluster
Spark Job
App Context
Message
Logger
Task Channel
Spark ML
Algorithms
ML Listener
Akka
Chunked
Response
Akka
34. HTML5 Server-Sent Events (SSE)
• Server-sent Events (SSE) is one-way messaging
– An event is when a web page automatically get update from Server
• Register an event source (JavaScript)
var source = new EventSource(url);
• The Callback onMessage(data)
source.onmessage = function(message){...}
• Data Format:
data: { n
data: “key” : “value”, nn
data: } nn
36. Push vs. Pull
Push
• Pros
– The data is streamed (pushed) to browser via chunked response
– There is no need for data queue, but the data can be lost if not consumed
– Multiple pages can be pushed at the same time, which allows multiple visualization
views
• Cons
– For slow network, slow browser and fast data iterations, the data might all show-up in
browser at once, rather showing a nice iteration-by-iteration display
– If you control the data chunked response by Network Acknowledgement, the
visualization may not show-up at all as the data is not pushed due to slow network
acknowledgement
37. Push vs. Pull
Pull
• Pros
– Message does not get lost, since it can be temporarily stored in the message
queue
– The visualization will render in an even pace
• Cons
– Need to periodically send server request for update,
– We will need a message queue before the message is consumed
– Hard to support multiple pages rendering with simple message queue
38. Visualization: Plot.ly + D3
Cost vs. IterationCost vs. Iteration
ArrTime vs. DistanceArrTime vs. DepTime
Alpine Workflow
39. Use Plot.ly to render graph
function showCost(dataParsed) {
var costTrace = { … };
var data = [ costTrace ];
var costLayout = {
xaxis: {…},
yaxis: {…},
title: …
};
Plotly.newPlot('cost', data, costLayout);
}
40. Real Time ML Visualization: Summary
• Training machine learning model involves a lot of experimentation,
we need a way to visualize the training process.
• We presented a system to enable real time machine learning
visualization with Spark:
– Gives visibility into the training of a model
– Allows us monitor the convergence of the algorithms during training
– Can stop the iterations when convergence is good enough.
Here’s what we saw…
- Business was indeed growing, the product line was expanding in number and sophistication, BUT we were becoming more than a camera company.
- We had a growing ecosystem of software and services
- We had a rich media side of the business that was growing and in social and various media distribution channels
- We’re moving now into advanced capture
- And with drones, entirely new categories
- This all lends and leads to the current Big Data landscape that we have today.
So, we brought together the a team of bad assess for companies like LinkedIn, Apple, Oracle, and Splice Machine to tackle the problem
Thus formed the Data Science and Engineering team at GoPro
Steps :
Choose centers
Compute and min d = distance to centroid, choose new center
Convergence when centroid is not changed
Once we define the MLListener Support, we can gather stats at initial, iteration and final step and call:
sendMessage(gatherKMeansStats(/*…*/))