Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.

20180728 halide-study

1,031 views

Published on

Halide勉強会 @フィックスターズの資料です

Published in: Software
  • Be the first to comment

20180728 halide-study

  1. 1. Tensor Comprehensions にみる Halide IRの汎用性 Fixstars Solutions, Inc. Takuro Iizuka
  2. 2. Takuro Iizuka / @iitaku 北米子会社のFixstars Solutions, Inc. にて HalideのFPGAバックエンドおよびツールチェイン”GENESIS”の開発やってます
  3. 3. もくじ  TC: Tensor Comprehensions 概要  TC言語  Inside TC  まとめ
  4. 4. TC: Tensor Comprehensionsとは?  テンソル計算の記述言語および 最適化コンパイラフレームワーク  2018.2.14にFacebook AI Researchからリリース  TC言語でアルゴリズムを書くと ライブラリがいい感じに最適化してくれる  PyTorchとシームレスに統合できる
  5. 5. TC: Tensor Comprehensionsとは?  テンソル計算の記述言語および 最適化コンパイラフレームワーク  2018.2にFacebook AI Researchからリリース  TC言語でアルゴリズムを書くと ライブラリがいい感じに最適化してくれる  PyTorchとシームレスに統合できる コンパイラの中間表現としてHalide IRを採用
  6. 6. TCアーキテクチャ https://research.fb.com/announcing-tensor-comprehensions/
  7. 7. TCベンチマーク結果 MLP: Multi-Layer Perceptron TMM: Transposed Matrix Multiplication TBMM: Transposed Batched Matrix Multiplication GCOV: Grouped Convolutions https://research.fb.com/announcing-tensor-comprehensions/
  8. 8. TC in PyTorch で書いて で動せる
  9. 9. TC in PyTorch $ conda create –y –name pytorch python=3.6 $ conda activate pytorch $ conda install -y -c pytorch -c tensorcomp tensor_comprehensions $ python ./matmul.py Variable containing: -2.4028 2.8492 7.6141 3.3159 3.7171 1.3839 0.6650 -1.7253 0.7447 1.3988 0.1396 -0.0661 -1.0574 0.2163 0.1711 [torch.cuda.FloatTensor of size 3x5 (GPU 0)] import tensor_comprehensions as tc import torch mm = """ def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } """ matmul = tc.define(mm, name="matmul") A, B = torch.randn(3, 4).cuda(), torch.randn(4, 5).cuda() C = matmul(A, B, options=tc.Options("naive")) print(C) 1. TCをセットアップ 2. TC言語でカスタム レイヤを書く 3. 実行する
  10. 10. TC言語
  11. 11. num ::= <number literal with C syntax> id ::= [_a-zA-Z0-9]*[_a-zA-Z][_a-zA-Z0-9]* exp ::= num | ( '-' | '!' | ... ) exp | exp ( [+-*/%] | '==' | '!=' | '<=' | ... ) exp | exp '?' exp ':' exp | id '.' num # range of num-th dimension of id | id '(' exp_list ')' # builtin call or tensor access reduction ::= <associative reduction operator> | '+=' | '*=' | 'min=' | 'max=' | '+=!' | '*=!' | 'min=!' | 'max=!' range_constraint ::= id 'in' exp ':' exp stmt ::= id '(' id_list ')' [ '=' | reduction ] exp [ 'where' range_constraint_list ] | id_list = id '('id_list ')' # TC function call arg ::= type id return ::= id # inferred return type and range scalar_type ::= 'double' | 'float' | 'half' | 'int32' | 'byte' | 'uint32' | ... type ::= scalar_type [ '(' id_list ')' ] func ::= # TC function definition 'def' id '(' arg_list ')' '->' '(' return_list ')' '{' stmt_list '}' id_list ::= <comma separated id list> exp_list ::= <comma separated exp list> arg_list ::= <comma separated arg list> stmt_list ::= <whitespace separated stmt list> return_list ::= <comma separated return list> range_constraint_list ::= <non-empty comma separated range_constraint list>
  12. 12. TC言語の特徴  テンソル計算記述特化言語  Halide言語よりさらに簡素な言語体型  超ミニマルなプリミティブ型  畳み込み演算用の特殊なオペレータ群  where節でレンジ制約を記述
  13. 13. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } def 関数名 {} で関数定義
  14. 14. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } type(X,Y,…)で入力引数に制約 (要素型制約およびレンジ制約)を付与
  15. 15. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } レンジ制約に同一シンボルを使用することで、 異なる引数間の制約を表現できる
  16. 16. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 出力の制約はコンパイラによって自動推論される
  17. 17. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 初期化付き畳み込み用記述のオペレータ
  18. 18. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 左辺で定義した誘導変数を右辺で使用すると ループが形成される for m := 0, M for n := 0, N
  19. 19. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } ループのレンジは入力制約から決定される e.g. m := [0. M)
  20. 20. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } 左辺で定義していない誘導変数を右辺で使用した場合、 既存ループの最内に新たにループが形成される for m := 0, M for n := 0, N for r_k := 0, K
  21. 21. def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } その場合も入力の制約をもとに制約チェックが行われる
  22. 22. Halide言語と比較してみる def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } TC言語 Func matmul; ImageParam A{float, 2}; ImageParam B{float, 2}; Var m, n; Rdom r_k{0, K}; matmul(m, n) = sum(A(m, r_k) * B(r_k, n)); matmul.realize({M, N}); Halide言語
  23. 23. Halide言語と比較してみる def matmul(float(M,K) A, float(K,N) B) -> (output) { output(m, n) +=! A(m, r_k) * B(r_k, n) } Func matmul; ImageParam A{float, 2}; ImageParam B{float, 2}; Var m, n; Rdom r_k{0, K}; matmul(m, n) = sum(A(m, r_k) * B(r_k, n)); matmul.realize({M, N}); レンジ制約は前向きに推論 レンジ制約は後ろ向きに推論
  24. 24. Inside TC
  25. 25. TC GPU Backendコンパイルフロー TC言語 AST Halide IR パース/AST構築: tc/core/compiler.cc, parse(…) Halide IR変換: tc/core/tc2halide.cc, translate(…) Halide IR レンジ推論/ループ形成/簡約化: tc/core/tc2halide.cc, translate(…) Polyhedral IR CUDA C Polyhedral IR変換: tc/core/polyhedral/scop.cc, makeScop(…) Polyhedral IR ループ変形/スレッディング: tc/core/polyhedral/mapped_scop.cc, makeWithOuterBlockInnerThreadStrategy(…) コード生成: tc/core/polyhedral/mapped_scop.cc, codegen(…)
  26. 26. Halide IR オリジナルのHalide実装では、 中間表現の変換過程においてHalide IRを2つに大別できる Halide IR Halide IR Halide IR Halide IR Halide IR Call/Provide系 Load/Store系 storage_flattening以前/以後
  27. 27. Call/Provide系 抽象度高い Load/Store系 抽象度低い for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Provide(out, {x, y}) = Call(in, {x, y}) } } for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Store(out, y*out.stride.1+x, Load(in, y*in.stride.1.x)) } }
  28. 28. TCが取り扱うHalide IRは 抽象度の高いCall/Provide系のみ for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Provide(out, {x, y}) = Call(in, {x, y}) } } for (y, 0, out.extent.1) { for (x, 0, out.extent.0) { Store(out, y*out.stride.1+x, Load(in, y*in.stride.1.x)) } } ターゲットコードに近い中間表現は Halide IRではなくPolyhedral IRを使う
  29. 29. Halide IR上でのLowering  コンパイラインフラストラクチャとしてのHalide IR – Halide中にはHalide IR (Halide::Stmt/Expr)を操作する 関数が多数実装済み • 簡約器 • ソルバー • 範囲演算 • CSE • などなど – IRMutatorやIRVisitorクラスを使用すれば Halide IRに対する変換や解析を独自に実装できる  TCのHalide IR Loweringは以下を行っている – レンジ推論 – ループ形成 – 簡約化
  30. 30. レンジ推論で使用されるHalide API  solve_for_inner_interval(c, v) – 条件式cを必ず満たす変数vの最大範囲を計算する  and_condition_over_domain(c, varying) – 変数範囲varyingの仮定のもとで条件式cを簡約化する  simplify(e) – 式eを簡約化する これらを組み合わせてレンジ推論を行い、後段で行われる Polyhedral最適化に必要な条件を満たすかをテストしておく
  31. 31. ループ形成で使用されるHalide API  realization_order(…) – Provide/Call間の依存グラフを トポロジカルソートによって順序付けをする  schedule_functions(…) – 出力が依存するすべての計算ループを形成し、 Halide IR (Halide::Stmt) を返す 後の解析や変換のベースとなるループ構築を行う
  32. 32. ループ形成結果 後の解析や変換のベースとなるループ構築を行う produce output { let output.s0.n.loop_max = output.s0.n.max let output.s0.n.loop_min = output.s0.n.min let output.s0.n.loop_extent = ((output.s0.n.max + 1) - output.s0.n.min) let output.s0.m.loop_max = output.s0.m.max let output.s0.m.loop_min = output.s0.m.min let output.s0.m.loop_extent = ((output.s0.m.max + 1) - output.s0.m.min) for (output.s0.m, output.s0.m.loop_min, output.s0.m.loop_extent) { for (output.s0.n, output.s0.n.loop_min, output.s0.n.loop_extent) { output(output.s0.m, output.s0.n) = 0.000000f } } let output.s1.r_k.loop_extent = ((output.s1.r_k.max - output.s1.r_k.min) + 1) let output.s1.r_k.loop_max = output.s1.r_k.max let output.s1.r_k.loop_min = output.s1.r_k.min let output.s1.n.loop_max = output.s1.n.max let output.s1.n.loop_min = output.s1.n.min let output.s1.n.loop_extent = ((output.s1.n.max + 1) - output.s1.n.min) let output.s1.m.loop_max = output.s1.m.max let output.s1.m.loop_min = output.s1.m.min let output.s1.m.loop_extent = ((output.s1.m.max + 1) - output.s1.m.min) for (output.s1.m, output.s1.m.loop_min, output.s1.m.loop_extent) { for (output.s1.n, output.s1.n.loop_min, output.s1.n.loop_extent) { for (output.s1.r_k, output.s1.r_k.loop_min, output.s1.r_k.loop_extent) { output(output.s1.m, output.s1.n) = ReductionUpdate((output(output.s1.m, output.s1.n) + (A(output.s1.m, output.s1.r_k)* B(output.s1.r_k, output.s1.n)))) } } } } 初期化 計算
  33. 33. 簡約化で使用されるHalide API  LetStmt::make(n, e, s) – 文s中で式eをシンボルnに束縛する  simplify(s) – 文sを簡約化する レンジ制約を適用しループ範囲の簡約化を行う
  34. 34. 簡約化結果 後段のPolyhedral Transformationで 解析・変形可能なループ構造に簡約できた for (output.s0.m, 0, M) { for (output.s0.n, 0, K) { output(output.s0.m, output.s0.n) = 0.000000f } } for (output.s1.m, 0, M) { for (output.s1.n, 0, K) { for (output.s1.r_k, 0, K) { output(output.s1.m, output.s1.n) = ReductionUpdate((output(output.s1.m, output.s1.n) + (A(output.s1.m, output.s1.r_k)* B(output.s1.r_k, output.s1.n)))) } } } ループ範囲がTC言語のレンジ制約と対応している
  35. 35.  tc::polyhedral::Scop – ISL (Integer Set Library) /Polyhedral Compilation Libraryを用いて計算されたスケジューリング – RAW, WAR, WAW 依存関係 – メモリ配置 – Halide IR関連 • パラメータ • 入出力 • Stmt Polyhedral IR Polyhedral IR Halide IR パラメータ/ 入出力 Polyhedral IR変換 Stmt of Provide A Stmt of Provide B パラメータ/ 入出力 スケジューリング 依存関係 メモリ配置
  36. 36. Polyhedral IR 上でのLowering 1. ループ融合 2. タイリング – パラメータ: タイリング戦略 3. スレッドマッピング – パラメータ: スレッドサイズ 4. ブロックマッピング – パラメータ: ブロックサイズ 5. メモリマッピング – パラメータ: 有効/無効、マッピング先、マッピング量 Polyhedral Transformationとは、 ループ構造をPolytope=多面体と見立ててアフィン変換を施すことで Legalなループ変形を行う最適化手法
  37. 37. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } ----------------------------------------------------------------------- filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- band(n(1) permutable(0) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- 初期状態のスケジューリングツリー 雑な解説: band=ループ、 filter=ステートメント
  38. 38. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) ループ融合後
  39. 39. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : 0 <= output_s0_m < M and 0 <= output_s0_n < N } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : 0 <= output_s1_m < M and 0 <= output_s1_n < N and 0 <= output_s1_r_k < K }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- band(n(2) permutable(1) coincident(1, 0) unroll(0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) タイリング後
  40. 40. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 }) context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- mapping_filter(ids(t0, ) [K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 } [K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 }) band(n(1) permutable(1) coincident(1) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- thread_specific() band(n(1) permutable(1) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) スレッドマッピング後
  41. 41. domain( [K, M, N] -> { S_0[output_s0_m, output_s0_n] : K = 4 and M = 3 and N = 5 and 0 <= output_s0_m <= 2 and 0 <= output_s0_n <= 4 } [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : K = 4 and M = 3 and N = 5 and 0 <= output_s1_m <= 2 and 0 <= output_s1_n <= 4 and 0 <= output_s1_r_k <= 3 }) context([K, M, N, t1, t0, t2, b2, b1, b0] -> { [] : t1 = 0 and t2 = 0 and b2 = 0 and b1 = 0 and 0 <= t0 <= 127 and 0 <= b0 <= 127 }) mapping_filter(ids(b0, ) [K, M, N, b0] -> { S_0[output_s0_m, output_s0_n] : (-b0 + output_s0_m) mod 128 = 0 and 0 <= b0 <= 127 } [K, M, N, b0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-b0 + output_s1_m) mod 128 = 0 and 0 <= b0 <= 127 }) band(n(3) permutable(1) coincident(1, 1, 0) unroll(0, 0, 0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_m)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_m)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(0)] } ----------------------------------------------------------------------- mapping_filter(ids(t0, ) [K, M, N, t0] -> { S_0[output_s0_m, output_s0_n] : (-t0 + output_s0_n) mod 128 = 0 and 0 <= t0 <= 127 } [K, M, N, t0] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] : (-t0 + output_s1_n) mod 128 = 0 and 0 <= t0 <= 127 }) band(n(1) permutable(1) coincident(1) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(output_s0_n)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_n)] } ----------------------------------------------------------------------- thread_specific() band(n(1) permutable(1) coincident(0) unroll(0) ----------------------------------------------------------------------- | [K, M, N] -> { S_0[output_s0_m, output_s0_n] -> [(0)] } | [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] -> [(output_s1_r_k)] } ----------------------------------------------------------------------- sequence() filter( [K, M, N] -> { S_0[output_s0_m, output_s0_n] }) filter( [K, M, N] -> { S_1[output_s1_m, output_s1_n, output_s1_r_k] }) ブロックマッピング後
  42. 42. template<typename T> inline __device__ T floord(T n, T d) { return n < 0 ? - (-n + d - 1)/d : n / d; } #define if_then_else(cond,a,b) ((cond) ? (a) : (b)) // Can't include system dependencies with NVRTC // Can't include cuda_fp16.h with NVRTC due to transitive system dependencies // #include <cuda_fp16.h> // Halide type handling typedef char int8; typedef short int16; typedef int int32; typedef long int64; typedef unsigned char uint8; typedef unsigned short uint16; typedef unsigned int uint32; typedef unsigned long uint64; // typedef half float16; typedef float float32; typedef double float64; #define inff __int_as_float(0x7f800000) #define inf __longlong_as_double(0x7ff0000000000000LL) // Before CUDA 9, syncwarp is a noop since warps are always synchronized. #if __CUDACC_VER_MAJOR__ < 9 __device__ void __syncwarp(unsigned mask = 0xFFFFFFFF) {} #endif extern "C" { __global__ void matmul_4_3_5(int32 K, int32 M, int32 N, float32* poutput, const float32* pA, const float32* pB) { int b0 = blockIdx.x; int b1 = blockIdx.y; int b2 = blockIdx.z; int t0 = threadIdx.x; int t1 = threadIdx.y; int t2 = threadIdx.z; float32 (*output)[5] = reinterpret_cast<float32 (*)[5]>(poutput); const float32 (*A)[4] = reinterpret_cast<const float32 (*)[4]>(pA); const float32 (*B)[5] = reinterpret_cast<const float32 (*)[5]>(pB); output[b0][t0] = 0.000000f; for (int c4 = 0; c4 <= 3; c4 += 1) { output[b0][t0] = (output[b0][t0] + (A[b0][c4]*B[c4][t0])); } } } コード生成
  43. 43. パラメータのオートチューニング  遺伝的アルゴリズムを用いて より良いパラメータを探索する https://research.fb.com/announcing-tensor-comprehensions/
  44. 44. まとめ
  45. 45. Halide IR良くできてる – 制約は正義 • Polyhedral Transformation等応用的な最適化手法を適用可能 • 解析時の計算量爆発がおきにくい – IRを操作する関数の実装が揃ってる • 自前の解析や変形を行う場合でも多くの機能を転用可能 – IRの変換が書きやすい • IRMutator/IRVisitorのクラスが割とシンプルで書きやすい – TC/TVM/Tiramisu等、 Halide IRを再利用する取り組みも出てきた High Level Compiler IRとしてのHalideに今後も要注目!

×