SPARK SUMMIT
EUROPE2016
SCALING FACTORIZATION
MACHINES ON APACHE SPARK
WITH PARAMETER SERVERS
Nick Pentreath
Principal Engineer, IBM
About
• About me
– @MLnick
– Principal Engineer at IBM working on machine
learning & Spark
– Apache Spark PMC
– Author of Machine Learning with Spark
Agenda
• Brief Intro to Factorization Machines
• Distributed FMs with Spark and Glint
• Results
• Challenges
• Future Work
FACTORIZATION MACHINES
Factorization Machines
Feature interactions
Factorization Machines
Linear Models
𝑤" + $ 𝑤% 𝑥%
'
%()
Not expressive
enough!
w0
Bias terms
Factorization Machines
Polynomial Regression
𝑤" + $ 𝑤% 𝑥% + $ $ 𝑤%* 𝑥% 𝑥*
'
*(%+)
'
%()
'
%()
w0
Bias terms Interaction term
O(d2)
n << d
Factorization Machines
Factorization Machine
w0
Bias terms Factorized interaction term
𝑤" + $ 𝑤% 𝑥% + $ $ 𝑣%	𝑣* 𝑥% 𝑥*
'
*(%+)
'
%()
'
%()
O(d2k)
Factorization Machines
Factorization Machine
O(d2k) O(dk)
Math hack
aka reformulation
Not convex, but efficient to train using SGD, coordinate descent, MCMC
Factorization Machines
Factorization Machine
Standard matrix
factorization
(with biases)
Contextual featuresMetadata features
Model size can still be very large! e.g. video sharing, online ads, social networks
DISTRIBUTED FM MODELS ON
SPARK
Linear Models on Spark
• Data parallel
Master
Send model to workers
Update model
Send
Update
...
Final model
Worker
Compute partial gradient
using latest full model
Compute partial gradient
using latest full model
…
Worker
Compute partial gradient
using latest full model
Compute partial gradient
using latest full model
…
Tree aggregate
Broadcast
Bottleneck!
Parameter Servers
• Model & data parallel
Master
Schedule work
...
Worker
Compute partial gradient
using latest partial model
Compute partial gradient
using latest partial model
…
Parameter Server
Update model
Update model
…
Pull partial
model
Push partial
gradient
Only push/pull
required features
Distributed FMs
• spark-libFM
– Uses old MLlib GradientDescent and LBFGS interfaces
• DiFacto
– Async SGD implementation using parameter server (ps-lite)
– Adagrad, L1 regularization, frequency-adaptive model-size
• Key is that most real-world datasets are highly sparse
(especially high-cardinality categorical data), e.g. online
ads, social network, recommender systems
• Workers only need access to a small piece of the model
GlintFM
• Procedure:
1. Construct Glint Client
2. Create distributed
parameters
3. Pre-compute required
feature indices (per
partition)
4. Iterate:
• Pull partial model
(blocking)
• Compute partial gradient
& update
• Push partial update to
parameter servers (can
be async)
5. Done!
Glint is a parameter server built using Akka
For more info, see Rolf’s talk at 17:15!
RESULTS
Data
Criteo Display Advertising Challenge Dataset
• 45m examples, 34m unique features, 48 nnz /example
0
2
4
6
8
10
12
Millions
Unique Values per Feature
0%
20%
40%
60%
80%
100%
Feature Occurence (%)
Feature Extraction
Raw Data StringIndexer OneHotEncoder VectorAssembler
OOM!
Feature Extraction
Solution? “Stringify” + CountVectorizer
Row(i1=u'1', i2=u'1', i3=u'5', i4=u'0', i5=u’1382’,... )
Row(raw=[u'i1=1', u'i2=1', u'i3=5', u'i4=0', u'i5=1382', ...)
Convert set of
String featuresinto
Seq[String]
Feature Extraction
Raw Data Stringify
Count
Vectorizer
Feature Extraction
Raw Data Stringify HashingTF
Performance
0
500
1000
1500
2000
2500
3000
3500
4000
4500
k=0 k=6 k=16 k=32
Total run time (s)*
Mllib FM Glint FM
N/A
*10 iterations, 48 partitions, fit intercept
Model size
1.9GB
Model size
4.6GB
2GB
limit
Performance
Broadcast Read
*k = 6, 10 iterations, 48 partitions, fit intercept
Compute
Gradient
computation
Aggregation
& collect
Performance
-
20
40
60
80
100
120
140
160
MLLib FM Glint FM
Median time per iteration (s)
Other
Compute
Communication
0
500
1000
1500
2000
2500
3000
3500
4000
4500
MB / iteration
Data Transfer
Mllib FM
Glint FM
*k = 6, 10 iterations, 48 partitions, fit intercept
CHALLENGES & FUTURE WORK
Challenges
• Tuning configuration
– Glint - models/server, message size, Akka frame size
– Spark - data partitioning (can be seen as “mini-batch”
size)
• Lack of server-side processing in Glint
– For L1 regularization, adaptive sparsity, Adagrad
– These result in better performance, faster execution
• Backpressure / concurrency control
Challenges
• Tuning models / server
0
500
1000
1500
2000
2500
3000
6 12 24
Models / parameter server
Total runtime
*k = 16, 10 iterations, 48 partitions, fit intercept
0
50
100
150
200
250
300
6 12 24
Models / parameter server
Median iteration time
Max iteration time
Challenges
• Index partitioning for “hot features”
0
500
1000
1500
2000
Total run time (s)*
CV features
Hashed features
– CountVectorizer for features
leads to hot spots & straggler
tasks due to sorting by
occurrence
– OneHotEncoder OOMed... but
can also face this problem
– Spreading out features is critical
(used feature hashing)
*k = 16, 10 iterations, 48 partitions, fit intercept
Future Work
• Glint enhancements
– Add features from DiFacto, i.e. L1 regularization, Adagrad &
memory-adaptive k
– Requires support for UDFs on the server
– Built-in backpressure (Akka Artery / Streams?)
– Key caching – 2x decrease in message size
• Mini-batch SGD within partitions
• Distributed solvers for ALS, MCMC, CD
• Relational data / block structure formulation
– www.vldb.org/pvldb/vol6/p337-rendle.pdf
References
• Factorization Machines
– http://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf
– https://github.com/ibayer/fastFM
– www.libfm.org
– https://github.com/zhengruifeng/spark-libFM
– https://github.com/scikit-learn-contrib/polylearn
• DiFacto
– https://github.com/dmlc/difacto
– www.cs.cmu.edu/~yuxiangw/docs/fm.pdf
• Glint / parameter servers
– https://github.com/rjagerman/glint
– https://github.com/dmlc/ps-lite
SPARK SUMMIT
EUROPE2016
THANK YOU.
@MLnick
github.com/MLnick/glint-fm
spark.tc

Spark Summit EU talk by Nick Pentreath

  • 1.
    SPARK SUMMIT EUROPE2016 SCALING FACTORIZATION MACHINESON APACHE SPARK WITH PARAMETER SERVERS Nick Pentreath Principal Engineer, IBM
  • 2.
    About • About me –@MLnick – Principal Engineer at IBM working on machine learning & Spark – Apache Spark PMC – Author of Machine Learning with Spark
  • 3.
    Agenda • Brief Introto Factorization Machines • Distributed FMs with Spark and Glint • Results • Challenges • Future Work
  • 4.
  • 5.
  • 6.
    Factorization Machines Linear Models 𝑤"+ $ 𝑤% 𝑥% ' %() Not expressive enough! w0 Bias terms
  • 7.
    Factorization Machines Polynomial Regression 𝑤"+ $ 𝑤% 𝑥% + $ $ 𝑤%* 𝑥% 𝑥* ' *(%+) ' %() ' %() w0 Bias terms Interaction term O(d2) n << d
  • 8.
    Factorization Machines Factorization Machine w0 Biasterms Factorized interaction term 𝑤" + $ 𝑤% 𝑥% + $ $ 𝑣% 𝑣* 𝑥% 𝑥* ' *(%+) ' %() ' %() O(d2k)
  • 9.
    Factorization Machines Factorization Machine O(d2k)O(dk) Math hack aka reformulation Not convex, but efficient to train using SGD, coordinate descent, MCMC
  • 10.
    Factorization Machines Factorization Machine Standardmatrix factorization (with biases) Contextual featuresMetadata features Model size can still be very large! e.g. video sharing, online ads, social networks
  • 11.
  • 12.
    Linear Models onSpark • Data parallel Master Send model to workers Update model Send Update ... Final model Worker Compute partial gradient using latest full model Compute partial gradient using latest full model … Worker Compute partial gradient using latest full model Compute partial gradient using latest full model … Tree aggregate Broadcast Bottleneck!
  • 13.
    Parameter Servers • Model& data parallel Master Schedule work ... Worker Compute partial gradient using latest partial model Compute partial gradient using latest partial model … Parameter Server Update model Update model … Pull partial model Push partial gradient Only push/pull required features
  • 14.
    Distributed FMs • spark-libFM –Uses old MLlib GradientDescent and LBFGS interfaces • DiFacto – Async SGD implementation using parameter server (ps-lite) – Adagrad, L1 regularization, frequency-adaptive model-size • Key is that most real-world datasets are highly sparse (especially high-cardinality categorical data), e.g. online ads, social network, recommender systems • Workers only need access to a small piece of the model
  • 15.
    GlintFM • Procedure: 1. ConstructGlint Client 2. Create distributed parameters 3. Pre-compute required feature indices (per partition) 4. Iterate: • Pull partial model (blocking) • Compute partial gradient & update • Push partial update to parameter servers (can be async) 5. Done! Glint is a parameter server built using Akka For more info, see Rolf’s talk at 17:15!
  • 16.
  • 17.
    Data Criteo Display AdvertisingChallenge Dataset • 45m examples, 34m unique features, 48 nnz /example 0 2 4 6 8 10 12 Millions Unique Values per Feature 0% 20% 40% 60% 80% 100% Feature Occurence (%)
  • 18.
    Feature Extraction Raw DataStringIndexer OneHotEncoder VectorAssembler OOM!
  • 19.
    Feature Extraction Solution? “Stringify”+ CountVectorizer Row(i1=u'1', i2=u'1', i3=u'5', i4=u'0', i5=u’1382’,... ) Row(raw=[u'i1=1', u'i2=1', u'i3=5', u'i4=0', u'i5=1382', ...) Convert set of String featuresinto Seq[String]
  • 20.
    Feature Extraction Raw DataStringify Count Vectorizer
  • 21.
    Feature Extraction Raw DataStringify HashingTF
  • 22.
    Performance 0 500 1000 1500 2000 2500 3000 3500 4000 4500 k=0 k=6 k=16k=32 Total run time (s)* Mllib FM Glint FM N/A *10 iterations, 48 partitions, fit intercept Model size 1.9GB Model size 4.6GB 2GB limit
  • 23.
    Performance Broadcast Read *k =6, 10 iterations, 48 partitions, fit intercept Compute Gradient computation Aggregation & collect
  • 24.
    Performance - 20 40 60 80 100 120 140 160 MLLib FM GlintFM Median time per iteration (s) Other Compute Communication 0 500 1000 1500 2000 2500 3000 3500 4000 4500 MB / iteration Data Transfer Mllib FM Glint FM *k = 6, 10 iterations, 48 partitions, fit intercept
  • 25.
  • 26.
    Challenges • Tuning configuration –Glint - models/server, message size, Akka frame size – Spark - data partitioning (can be seen as “mini-batch” size) • Lack of server-side processing in Glint – For L1 regularization, adaptive sparsity, Adagrad – These result in better performance, faster execution • Backpressure / concurrency control
  • 27.
    Challenges • Tuning models/ server 0 500 1000 1500 2000 2500 3000 6 12 24 Models / parameter server Total runtime *k = 16, 10 iterations, 48 partitions, fit intercept 0 50 100 150 200 250 300 6 12 24 Models / parameter server Median iteration time Max iteration time
  • 28.
    Challenges • Index partitioningfor “hot features” 0 500 1000 1500 2000 Total run time (s)* CV features Hashed features – CountVectorizer for features leads to hot spots & straggler tasks due to sorting by occurrence – OneHotEncoder OOMed... but can also face this problem – Spreading out features is critical (used feature hashing) *k = 16, 10 iterations, 48 partitions, fit intercept
  • 29.
    Future Work • Glintenhancements – Add features from DiFacto, i.e. L1 regularization, Adagrad & memory-adaptive k – Requires support for UDFs on the server – Built-in backpressure (Akka Artery / Streams?) – Key caching – 2x decrease in message size • Mini-batch SGD within partitions • Distributed solvers for ALS, MCMC, CD • Relational data / block structure formulation – www.vldb.org/pvldb/vol6/p337-rendle.pdf
  • 30.
    References • Factorization Machines –http://www.csie.ntu.edu.tw/~b97053/paper/Rendle2010FM.pdf – https://github.com/ibayer/fastFM – www.libfm.org – https://github.com/zhengruifeng/spark-libFM – https://github.com/scikit-learn-contrib/polylearn • DiFacto – https://github.com/dmlc/difacto – www.cs.cmu.edu/~yuxiangw/docs/fm.pdf • Glint / parameter servers – https://github.com/rjagerman/glint – https://github.com/dmlc/ps-lite
  • 31.