2. 이영빈
AI 교육 @모두의연구소 아이펠
JAX/Flax LAB랩짱 @모두의연구소
오거나이저 @GDG SongDo
진행자 소개
JAX Document 한국어 번역
https://jax-kr.readthedocs.io/ko/
3. 목차
● JAX를 알아보자!
● Flax에 대해 알아보자!
● JAX User가 바라본 Keras-core
● Flax와 Keras 비교하기
4. JAX를 알아보자!
“JAX is Autograd and XLA, brought together for high-performance numerical computing.”
JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산작업을 위해 만든 프레임워크
JAX란?
5. JAX를 알아보자!
“JAX is Autograd and XLA, brought together for high-performance numerical computing.”
JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산작업을 위해 만든 프레임워크
JAX란?
8. JAX를 알아보자!
“JAX is Autograd and XLA, brought together for high-performance numerical computing.”
JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산작업을 위해 만든 프레임워크
JAX란?
9. JAX를 알아보자!
“JAX is Autograd and XLA, brought together for high-performance numerical computing.”
JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산작업을 위해 만든 프레임워크
XLA(Accelerated Linear Algebra)는 CPU, GPU 및 맞춤형 액셀러레이터(예: Google의 TPU)와 같은 기기에 대해 JIT 컴파일 기법을 사용하여 런타임
에 사용자가 생성한 TensorFlow 그래프를 분석하고 실제 런타임 차원과 유형에 맞게 최적화하며, 여러 연산을 함께 합성하고 이에 대한 효율적인 네이
티브 기계어 코드를 내보냅니다.
JAX란?
14. JAX를 알아보자!
OpenXLA 아키텍쳐 (New Version)
HLO (High Level Operator) => 컴파일러 입력 언어
Compilers => IREE (MLIR 기반 컴파일러) + XLA HLO
PJRT => 다양한 HW를 받을 수 있게 만드는 언어
15. JAX를 알아보자!
“JAX is Autograd and XLA, brought together for high-performance numerical computing.”
JAX는 XLA와 Autograd를 이용해 머신러닝 연구와 고성능 연산작업을 위해 만든 프레임워크
XLA(Accelerated Linear Algebra)는 CPU, GPU 및 맞춤형 액셀러레이터(예: Google의 TPU)와 같은 기기에 대해 JIT 컴파일 기법을 사용하여 런타임
에 사용자가 생성한 TensorFlow 그래프를 분석하고 실제 런타임 차원과 유형에 맞게 최적화하며, 여러 연산을 함께 합성하고 이에 대한 효율적인 네이
티브 기계어 코드를 내보냅니다.
그래서… 정리하면?
19. JAX를 알아보자!
JAX의 함수형 프로그래밍
Pure function
- 같은 입력값이 주어졌을 때 언제나 같은 결과값을 리턴
- 부수효과(side effect)를 만들지 않음
Stateless, Immutability
- 데이터의 변경이 필요한 경우, 원본 데이터 구조를 변경하지 않고
그 데이터의 복사본을 만들어서 그 일부를 변경하고, 변경한 복사본을 사용해 작업을 진행
문제점
- 거의 모든 딥러닝 프레임워크는 상태가 변하는 걸 가정하고 있음
- 모델 파라미터
- 옵티마이저
20. Flax를 알아보자!
Flax란?
Google DeepMind에서 개발해서 사용하고 있는 High level API
=>Flexibility라는 이름에서 나옴
HuggingFace에서도 Flax community week를 만들어서 변환하고 있음
현재 Google DeepMind에서 나온 대부분의 논문 구현은 Flax로 이루어져 있음
- ViT
- PaLM
- Imagen
- JaxNeRF
21. Flax를 알아보자!
(구)DeepMind JAX Ecosystem
DeepMind의 경우에도 JAX를 적극적으로 도입하고 있음
DeepMind에서 사용하고 있는 JAX 기반 대표적인 Framework
- Haiku : Neural Network => Flax와 통합!
- Optax : Optimizer만 따로 사용
- RLax : Reinforcement Learning
- Chex : 테스트 환경
- Jraph : GNNs
https://www.deepmind.com/blog/using-jax-to-accelerate-our-research
22. JAX 유저가 바라본 Keras-core
Keras-core
Keras-core 발표에 대한 딥러닝 프레임워크 유저 반응 (주관적인 느낌)
- TensorFlow User
- 그래서 뭐가 바뀌었는가…? (우린 똑같음)
- PyTorch User
- Keras에 PyTorch가 붙는구나… (노관심)
- JAX User
- 와우!! 드디어 JAX를 편하게 쓸 수 있게 됐어… (감격)
- Flax 어떻게 되는거지…? 망하는건가..? ㅠㅠ
24. Flax를 알아보자!
Flax vs JAX backend Keras
Flax
- 장점
- 비교적 예제가 다양함 (절대 많은건 아님…)
- TrainState를 정의해서 사용하다보니 직관적임
- JAX를 100% 활용하기 위해 나온 Tool (뼛속까지 함수형 프로그래밍)
- Hugging Face가 쓰고 있는 JAX 프레임워크 표준
- 단점
- Research팀이 담당하다보니 라이브러리를 전부 쪼개놓고 있음 (Optax, orbax)
- 유저 눈치를 보지 않는 업데이트 (다음 달에는 이 메소드는 deprecated됩니다)
25. Flax를 알아보자!
Flax vs JAX backend Keras
JAX backend Keras
- 장점
- 개발과정에서 Keras-core 파트마다 메인테이너를 지정해서 관리
- 커스터마이징을 하지 않는 선에서는 Keras가 편의성에서는 압도적임
- 단점
- State 정의가 model안에 숨겨져 있다보니 헷갈림
- 빼놓아도 collection을 사용해야 함 (불편함 증대)
- JAX를 100% 활용하고 있지 않는 상황 => PRNG
- 예제가 없음 (숄레가 만든것 말고는 존재하지 않음)
26. 결론
Flax vs JAX backend Keras
JAX 생태계를 최대한도로 활용하고 싶다 => Flax
JAX를 명확하게 잘 사용하고 싶다 => Flax
커스터마이징이 필요 없는 간단한 모델을 돌릴 예정이다 => JAX backend Keras
딥러닝 프레임워크 코드에 기여하고 싶다 => JAX backend Keras
27. 결론
Keras가 딥러닝 프레임워크계의 황제로 등극하려면?
TensorFlow로 모델을 만들고
.keras로 저장하고 나중에 JAX backend로 변환해서 사용하는
것!