Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.
Modeling electronic health
records with recurrent
neural networks
David C. Kale,1,2
Zachary C. Lipton,3
Josh Patterson4
ST...
Outline
• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classificatio...
We need functions, brah
Various Inputs and Outputs
“Time underlies many
interesting human behaviors”
{0,1}
{A,B,C…}
captions,
email mom,
fire nuke...
But how do we produce functions?
We need a function for that…
One function-generator:
Programmers
Which are expensive
When/why does this
fail?
• Sometimes the correct function cannot be
encoded a priori — (what is spam?)
• The optimal solut...
Sometimes We Need to
Learn These Functions From
Data
One Class of Learnable Functions:
Feedforward Neural Network
Artificial Neurons
Activation Functions
• At internal nodes common choices for the activation
function are the sigmoid, tanh, and ReLU functi...
Training w Backpropagation
• Goal: calculate the rate of change of the loss
function with respect to each parameter (weigh...
Forward Pass
Backward Pass
Deep Networks
• Used to be difficult (seemed impossible) to train
nets with many layers of hidden layers
• TLDR: Turns out...
Outline
• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classificatio...
Feedforward Nets work for
Fixed-Size Data
Less Suitable for Text
We would like to capture
temporal/sequential dynamics in
the data
• Standard approaches address sequential structure:
Mark...
To Model Sequential Data:
Recurrent Neural Networks
Recurrent Net (Unfolded)
Vanishing / Exploding
Gradients
LSTM Memory Cell
(Hochreiter & Schmidhuber, 1997)
Memory Cell with Forget Gate
(Gers et al., 2000)
LSTM Forward Pass
LSTM (full network)
Large Scale Architecture
Standard
supervised
learning
Image
captioning
Sentiment
analysis
Video captioning,
Natural langua...
Outline
• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classificatio...
ICU data generated in hospital
• Patient-level info (e.g., age, gender)
• Physiologic measurements (e.g., heart rate)
– Ma...
ICU data gathered in EHR
• Patient-level info (e.g., age, gender)
• Physiologic measurements (e.g., vital signs)
– Manuall...
ICU data in our experiments
• Patient-level info (e.g., age, gender)
• Physiologic measurements (e.g., vital signs)
– Manu...
• Sparse, irregular, unaligned sampling in time, across variables
• Sample selection bias (e.g., more likely to record abn...
HR
HR
Admit
Admit
Discharge
Discharge
Challenges: alignment, variable length
• Observations begin at time of admission, no...
PhysioNet Challenge 2012
• Task: predict mortality from only first 48 hours of data
• Classic models (SAPS, Apache, PRISM)...
yt = σ(Vst + c)
st = φ(Wst-1 + Uxt + b)
PhysioNet Challenge 2012: predict in-hospital mortality from
observations x1, x2, ...
Outline
• Machine (and deep) learning
• Sequence learning with recurrent neural networks
• Clinical sequence classificatio...
PhysioNet Raw Data
• Set-a
– Directory of single files
– One file per patient
– 48 hours of ICU data
• Format
– Header Lin...
Preparing Input Data
• Input was 3D Tensor (3d Matrix)
– Mini-batch as first dimension
– Feature Columns as second dimensi...
A Single Training
Example
0 1 2 3 4 …
albumin 0.0 0.0 0.5 0.0 0.0
alp 0.0 0.1 0.0 0.0 0.2
alt 0.0 0.0 0.0 0.9 0.0
ast 0.0 ...
PhysioNet Timeseries Vectorization
@RELATION UnitTest_PhysioNet_Schema_ZUZUV
@DELIMITER ,
@MISSING_VALUE -1
@ATTRIBUTE rec...
Uneven Time Steps and Masking
0 1 2 3 4 …
albumin 0.0 0.0 0.5 0.0 0.0
alp 0.0 0.1 0.0 0.0 0.0
alt 0.0 0.0 0.0 0.9 0.0
ast ...
DL4J
• “The Hadoop of Deep Learning”
– Command line driven
– Java, Scala, and Python APIs
– ASF 2.0 Licensed
• Java implem...
RNNs in DL4J
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
.optimizationAlgo(OptimizationAlgorithm.S...
Experimental Results
• Winning entry: min(P,R) = 0.5353 (two others over 0.5)
• Trained on full set A (4K), tuned on set B...
Map sequences into fixed vector representation
• Not perfectly separable in 2D but some cluster structure related to morta...
Final comments
• We believe we could improve performance to well over 0.5
• overfitting: training min(P,R) > 0.6 (vs. test...
Questions?
Thank you for your time and attention
Gibson & Patterson. Deep Learning: A
Practitioner’s Approach. O’Reilly, Q...
Sepp Hochreiter
Father of LSTMs,* renowned beer thief
S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural Com...
Upcoming SlideShare
Loading in …5
×

Modeling Electronic Health Records with Recurrent Neural Networks

4,808 views

Published on

Time series data is increasingly ubiquitous. This trend is especially obvious in health and wellness, with both the adoption of electronic health record (EHR) systems in hospitals and clinics and the proliferation of wearable sensors. In 2009, intensive care units in the United States treated nearly 55,000 patients per day, generating digital-health databases containing millions of individual measurements, most of those forming time series. In the first quarter of 2015 alone, over 11 million health-related wearables were shipped by vendors. Recording hundreds of measurements per day per user, these devices are fueling a health time series data explosion. As a result, we will need ever more sophisticated tools to unlock the true value of this data to improve the lives of patients worldwide.

Deep learning, specifically with recurrent neural networks (RNNs), has emerged as a central tool in a variety of complex temporal-modeling problems, such as speech recognition. However, RNNs are also among the most challenging models to work with, particularly outside the domains where they are widely applied. Josh Patterson, David Kale, and Zachary Lipton bring the open source deep learning library DL4J to bear on the challenge of analyzing clinical time series using RNNs. DL4J provides a reliable, efficient implementation of many deep learning models embedded within an enterprise-ready open source data ecosystem (e.g., Hadoop and Spark), making it well suited to complex clinical data. Josh, David, and Zachary offer an overview of deep learning and RNNs and explain how they are implemented in DL4J. They then demonstrate a workflow example that uses a pipeline based on DL4J and Canova to prepare publicly available clinical data from PhysioNet and apply the DL4J RNN.

Published in: Data & Analytics
  • Be the first to comment

Modeling Electronic Health Records with Recurrent Neural Networks

  1. 1. Modeling electronic health records with recurrent neural networks David C. Kale,1,2 Zachary C. Lipton,3 Josh Patterson4 STRATA - San Jose - 2016 1 University of Southern California 2 Virtual PICU, Children’s Hospital Los Angeles 3 University of California San Diego 4 Patterson Consulting
  2. 2. Outline • Machine (and deep) learning • Sequence learning with recurrent neural networks • Clinical sequence classification using LSTM RNNs • A real world case study using DL4J • Conclusion and looking forward
  3. 3. We need functions, brah
  4. 4. Various Inputs and Outputs “Time underlies many interesting human behaviors” {0,1} {A,B,C…} captions, email mom, fire nukes, eject pop tart
  5. 5. But how do we produce functions? We need a function for that…
  6. 6. One function-generator: Programmers
  7. 7. Which are expensive
  8. 8. When/why does this fail? • Sometimes the correct function cannot be encoded a priori — (what is spam?) • The optimal solution might change over time • Programmers are expensive
  9. 9. Sometimes We Need to Learn These Functions From Data
  10. 10. One Class of Learnable Functions: Feedforward Neural Network
  11. 11. Artificial Neurons
  12. 12. Activation Functions • At internal nodes common choices for the activation function are the sigmoid, tanh, and ReLU functions. • At output, activation function could be linear (regression), sigmoid (multilabel classification) or softmax (multi-class classification)
  13. 13. Training w Backpropagation • Goal: calculate the rate of change of the loss function with respect to each parameter (weight) in the model • Update the weights by gradient following:
  14. 14. Forward Pass
  15. 15. Backward Pass
  16. 16. Deep Networks • Used to be difficult (seemed impossible) to train nets with many layers of hidden layers • TLDR: Turns out we just needed to do everything 1000x faster…
  17. 17. Outline • Machine (and deep) learning • Sequence learning with recurrent neural networks • Clinical sequence classification using LSTM RNNs • A real world case study using DL4J • Conclusion and looking forward
  18. 18. Feedforward Nets work for Fixed-Size Data
  19. 19. Less Suitable for Text
  20. 20. We would like to capture temporal/sequential dynamics in the data • Standard approaches address sequential structure: Markov models Conditional Random Fields Linear dynamical systems • Problem: We desire a system to learn representations, capture nonlinear structure, and capture long term sequential relationships.
  21. 21. To Model Sequential Data: Recurrent Neural Networks
  22. 22. Recurrent Net (Unfolded)
  23. 23. Vanishing / Exploding Gradients
  24. 24. LSTM Memory Cell (Hochreiter & Schmidhuber, 1997)
  25. 25. Memory Cell with Forget Gate (Gers et al., 2000)
  26. 26. LSTM Forward Pass
  27. 27. LSTM (full network)
  28. 28. Large Scale Architecture Standard supervised learning Image captioning Sentiment analysis Video captioning, Natural language translation Part of speech tagging Generative models for text
  29. 29. Outline • Machine (and deep) learning • Sequence learning with recurrent neural networks • Clinical sequence classification using LSTM RNNs • A real world case study using DL4J • Conclusion and looking forward
  30. 30. ICU data generated in hospital • Patient-level info (e.g., age, gender) • Physiologic measurements (e.g., heart rate) – Manually verified observations – High-frequency measurements – Waveforms • Lab results (e.g., glucose) • Clinical assessments (e.g., glasgow coma scale) • Medications and treatments • Clinical notes • Diagnoses • Outcomes • Billing codes
  31. 31. ICU data gathered in EHR • Patient-level info (e.g., age, gender) • Physiologic measurements (e.g., vital signs) – Manually verified observations – High-frequency measurements – Waveforms • Lab results (e.g., glucose) • Clinical assessments (e.g., glasgow coma scale) • Medications and treatments • Clinical notes • Diagnoses (often buried in free text notes) • Outcomes • Billing codes
  32. 32. ICU data in our experiments • Patient-level info (e.g., age, gender) • Physiologic measurements (e.g., vital signs) – Manually verified observations – High-frequency measurements – Waveforms • Lab results (e.g., glucose) • Clinical assessments (e.g., cognitive function) • One treatment: mechanical ventilation • Clinical notes • Diagnoses (often buried in free text notes) • Outcomes: in-hospital mortality • Billing codes
  33. 33. • Sparse, irregular, unaligned sampling in time, across variables • Sample selection bias (e.g., more likely to record abnormal) • Entire sequences (non-random) missing HR RR Admit Discharge Challenges: sampling rates, missingness ETCO2 Figures courtesy of Ben Marlin, UMass Amherst
  34. 34. HR HR Admit Admit Discharge Discharge Challenges: alignment, variable length • Observations begin at time of admission, not at onset of illness • Sequences vary in length from hours to weeks (or longer) • Variable dynamics across patients, even with same disease • Longterm dependencies: future state depends on earlier condition Figures courtesy of Ben Marlin, UMass Amherst
  35. 35. PhysioNet Challenge 2012 • Task: predict mortality from only first 48 hours of data • Classic models (SAPS, Apache, PRISM): experts features + regression • Useful: quantifying illness at admission, standardized performance • Not accurate enough to be used for decision support • Each record includes • patient descriptors (age, gender, weight, height, unit) • irregular sequences of ~40 vitals, labs from first 48 hours • One treatment variable: mechanical ventilation • Binary outcome: in-hospital survival or mortality (~13% mortality) • Only 4000 labeled records publicly available (“set A”) • 4000 unlabeled records (“set B”) used for tuning during competition (we didn’t use) • 4000 test examples (“set C”) not available • Very challenging task: temporal outcome, unobserved treatment effects • Winning entry score: minimum(Precision, Recall) = 0.5353 https://www.physionet.org/challenge/2012/
  36. 36. yt = σ(Vst + c) st = φ(Wst-1 + Uxt + b) PhysioNet Challenge 2012: predict in-hospital mortality from observations x1, x2, x3, …, xT during first 48 hours of ICU stay. Solution: recurrent neural network (RNN)* p(ymort = 1 | x1, x2, x3, …, xT) ≈ p(ymort = 1 | sT), with st = f(st-1, xt) • Efficient parameterization: st represents exponential # states vs. # nodes • Can encode (“remember”) longer histories • During learning, pass future info backward via backprop through time sT yT s2 y2 s1 y1 s0 x1 x2 xT * We actually use a long short-term memory network
  37. 37. Outline • Machine (and deep) learning • Sequence learning with recurrent neural networks • Clinical sequence classification using LSTM RNNs • A real world case study using DL4J • Conclusion and looking forward
  38. 38. PhysioNet Raw Data • Set-a – Directory of single files – One file per patient – 48 hours of ICU data • Format – Header Line – 6 Descriptor Values at 00:00 • Collected at Admission – 37 Irregularly sampled columns • Over 48 hours Time,Parameter,Value 00:00,RecordID,132601 00:00,Age,74 00:00,Gender,1 00:00,Height,177.8 00:00,ICUType,2 00:00,Weight,75.9 00:15,pH,7.39 00:15,PaCO2,39 00:15,PaO2,137 00:56,pH,7.39 00:56,PaCO2,37 00:56,PaO2,222 01:26,Urine,250 01:26,Urine,635 01:31,DiasABP,70 01:31,FiO2,1 01:31,HR,103 01:31,MAP,94 01:31,MechVent,1 01:31,SysABP,154 01:34,HCT,24.9 01:34,Platelets,115 01:34,WBC,16.4 01:41,DiasABP,52 01:41,HR,102 01:41,MAP,65 01:41,SysABP,95 01:56,DiasABP,64 01:56,GCS,3 01:56,HR,104 01:56,MAP,85 01:56,SysABP,132 …
  39. 39. Preparing Input Data • Input was 3D Tensor (3d Matrix) – Mini-batch as first dimension – Feature Columns as second dimension – Timesteps as third dimension • At Mini-batch size of 20, 43 columns, and 202 Timesteps – We have 173,720 values per Tensor input
  40. 40. A Single Training Example 0 1 2 3 4 … albumin 0.0 0.0 0.5 0.0 0.0 alp 0.0 0.1 0.0 0.0 0.2 alt 0.0 0.0 0.0 0.9 0.0 ast 0.0 0.0 0.0 0.0 0.4 … timesteps Vectorcolumns Values albumin 0.0 alp 1.0 alt 0.5 ast 0.0 … Vectorcolumns A single training example gets the added dimension of timesteps for each column
  41. 41. PhysioNet Timeseries Vectorization @RELATION UnitTest_PhysioNet_Schema_ZUZUV @DELIMITER , @MISSING_VALUE -1 @ATTRIBUTE recordid NOMINAL DESCRIPTOR !SKIP !ZERO @ATTRIBUTE age NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE gender NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !ZERO @ATTRIBUTE height NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE weight NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !AVG @ATTRIBUTE icutype NUMERIC DESCRIPTOR !ZEROMEAN_ZEROUNITVARIANCE !ZERO @ATTRIBUTE albumin NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE alp NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE alt NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE ast NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS @ATTRIBUTE bilirubin NUMERIC TIMESERIES !ZEROMEAN_ZEROUNITVARIANCE !PAD_TAIL_WITH_ZEROS [ more … ]
  42. 42. Uneven Time Steps and Masking 0 1 2 3 4 … albumin 0.0 0.0 0.5 0.0 0.0 alp 0.0 0.1 0.0 0.0 0.0 alt 0.0 0.0 0.0 0.9 0.0 ast 0.0 0.0 0.0 0.0 0.0 … 1.0 1.0 1.0 1.0 0.0 0.0 Single Input (columns + timesteps) Input Mask (only timesteps)
  43. 43. DL4J • “The Hadoop of Deep Learning” – Command line driven – Java, Scala, and Python APIs – ASF 2.0 Licensed • Java implementation – Parallelization (Yarn, Spark) – GPU support • Also Supports multi-GPU per host • Runtime Neutral – Local – Hadoop / YARN + Spark – AWS • https://github.com/deeplearning4j/deeplearning4j
  44. 44. RNNs in DL4J MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder() .optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1) .learningRate( learningRate ) .rmsDecay(0.95) .seed(12345) .regularization(true) .l2(0.001) .list(3) .layer(0, new GravesLSTM.Builder().nIn(iter.inputColumns()).nOut(lstmLayerSize) .updater(Updater.RMSPROP) .activation("tanh").weightInit(WeightInit.DISTRIBUTION) .dist(new UniformDistribution(-0.08, 0.08)).build()) .layer(1, new GravesLSTM.Builder().nIn(lstmLayerSize).nOut(lstmLayerSize) .updater(Updater.RMSPROP) .activation("tanh").weightInit(WeightInit.DISTRIBUTION) .dist(new UniformDistribution(-0.08, 0.08)).build()) .layer(2, new RnnOutputLayer.Builder(LossFunction.MCXENT).activation("softmax”) .updater(Updater.RMSPROP) .nIn(lstmLayerSize).nOut(nOut).weightInit(WeightInit.DISTRIBUTION) .dist(new UniformDistribution(-0.08, 0.08)).build()) .pretrain(false).backprop(true) .build(); for (int epoch = 0; epoch < max_epochs; ++epoch) net.fit(dataset_iter);
  45. 45. Experimental Results • Winning entry: min(P,R) = 0.5353 (two others over 0.5) • Trained on full set A (4K), tuned on set B (4K), tested on set C • All used extensively hand-engineered features • Our best model so far: min(P,R) = 0.4907 • 60/20/20 training/validation/test split of set A • LSTM with 2 x 300-cell layers on inputs • Different test sets so not directly comparable • Disadvantage: much smaller training set • Required no feature engineering or domain knowledge
  46. 46. Map sequences into fixed vector representation • Not perfectly separable in 2D but some cluster structure related to mortality • Can repurpose “representation” for other tasks (e.g., searching for similar patients, clustering, etc.)
  47. 47. Final comments • We believe we could improve performance to well over 0.5 • overfitting: training min(P,R) > 0.6 (vs. test: 0.49) • smaller or simpler RNN layers, adding dropout, multitask training • Flexible NN architectures well suited to complex clinical data • but likely will demand much larger data sets • may be better matched to “raw” signals (e.g., waveforms) • More general challenges • missing (or unobserved) inputs and outcomes • treatment effects confound predictive models • outcomes often have temporal components (posing as binary classification ignores that) • You can try it out: https://github.com/jpatanooga/dl4j-rnn-timeseries-examples/ See related paper to appear at ICLR 2016: http://arxiv.org/abs/1511.03677
  48. 48. Questions? Thank you for your time and attention Gibson & Patterson. Deep Learning: A Practitioner’s Approach. O’Reilly, Q2 2016. Lipton, et al. A Critical Review of RNNs. arXiv. Lipton & Kale. Learning to Diagnose with LSTM RNNs. ICLR 2016.
  49. 49. Sepp Hochreiter Father of LSTMs,* renowned beer thief S. Hochreiter and J. Schmidhuber. Long short-term memory. Neural Computation, 9 (8): 1735-1780, 1997.

×