TensorFlow XLA
「XLAとは、から、最近の利用事例について」
fpgax #11+TensorFlow User Groupハード部
「ディープラーニング専用ハードウェアについてわいわい話す会」
@ Google
作成:2019/12/29, 2019/1/6, 14, 20, 27
Slideshareにて公開 :2019/2/2
@Vengineer
ブログ (2007年~) : Vengineerの戯言
 http://blogs.yahoo.co.jp/verification_engineer
SlideShare :
 https://www.slideshare.net/ssuser479fa3
Twitter (2009年~) :
@Vengineer
ソースコード解析職人
CQ出版社:雑誌インターフェース に
「TensorFlow XLA および Lite」
に関することを寄稿しました
2017年8月号 2017年9月号 2018年2月号 2018年8月号 2019年1月号
XLA AOT XLA AOT XLA JIT Lite & XLA Lite
今日お話する内容
 
 TensorFlow r1.13ベースのお話
・XLA とは?
・XRT とは?
XLAの最近の利用事例
・JuliaでTPUを使う
・PyTorchでTPUを使う
コースコード解析ベースなので、コード多いです
XLAとは
TensorFlow XLAとは
https://www.tensorflow.org/performance/xla/
XLA(Accelerated Linear Algebra)は、TensorFlow計算を最適化する線形代数のドメ
イン固有のコンパイラです。 結果として、サーバーおよびモバイルプラットフォーム
での速度、メモリ使用率、移植性が向上します。 当初、ほとんどのユーザーはXLA
の大きなメリットは見られませんが、JIT(Just-In-Time)コンパイルや
AOT(Ahead-Of-Time)コンパイルを使用してXLAを使用することで実験を開始でき
ます。 新しいハードウェアアクセラレータをターゲットとする開発者は、XLAを試すこ
とを特にお勧めします。
原文(英語)をそのまま、Google翻訳にお願いしました。
TensorFlow XLAのソースコード
r1.0 ~ r1.11 と r1.12 ~
では、違います
Slideshareにアップしてある
TensroFlow XLA : JIT編 (r1.3版)
の内容は古いです
サンプルコードを見てみよう
def test_xla_gpu(self):
with tf.Session() as sess:
x = tf.placeholder(tf.float32, [2], name="x")
with tf.device("device:XLA_GPU:0"):
y = x * 2
result = sess.run(y, {x: [1.5, 0.5]})
Mul
Const
Feed(x)
Fetch(y)
0)、最初
Mul
_Recv
Const
_Send
Feed(x)
Fetch(y)
1)、Feed/Fetchノードの追加
Mul
_Recv
Const
_Send
cpu : Feed(x)
cpu : Fetch(y)
XLA_GPU
XLA_GPU
2)、Placement
3)、グラフの分割
_Recv
_Recv
_Send
_Send _Recv _Send
XLA_GPU
Feed(x) Fetch(y)cpu
Mul
Const
4)、XlaLauch Opに変換
XlaLaunch
_Recv
_Recv _Send
_Send _Recv _Send
XLA_GPU
Feed(x) Fetch(y)cpu
複数Opsを XlaLaunch Op に変換
XlaLaunch
MulConst
TensorFlow XLA : JITでは!
同じデバイス内で実行される Subgraph単位 の
ノードをギュギュッと1つにまとめて、
XlaLaunch Op
内で実行する
XlaLaunchは、
TensorFlow XLA専用のOpとして実装されている
Passを使ってグラフを変形してるよ
compiler/jit/jit_compilation_pass_registration.cc
REGISTER_OPTIMIZATIONマクロを使って、
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC
Passを追加
 ・MarkForCompilationPass // コンパイル可能なものにマーク
 ・EncapsulateSubgraphsPass // サブグラフを関数ノード
 ・BuildXlaLaunchOpsPass // 関数ノードを_XlaLaunchに置換
上から順番に実行される
xla.compile
ポイント1 : xla.compile (Python API)
compiler/xla/g3doc/tutorials/xla_compile.ipynb
def build_mnist_model(x, y_):
y = tf.keras.layers.Dense(NUM_CLASSES).apply(x)
cross_entropy = tf.losses.sparse_softmax_cross_entropy(labels=y_, logits=y)
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
return y, train_step
[y] = xla.compile(build_mnist_model, inputs=[images, labels])
Passを使ってグラフを変形してるよ
compiler/jit/jit_compilation_pass_registration.cc
REGISTER_OPTIMIZATIONマクロを使って、
OptimizationPassRegistry::PRE_PLACEMENT
Passを追加 
// EncapsulateXlaComputationsPass rewrites computations generated by the
// xla.compile() Python code into XlaLaunch nodes.
REGISTER_OPTIMIZATION(
OptimizationPassRegistry::PRE_PLACEMENT, 26,
EncapsulateXlaComputationsPass); // r1.12 にて導入
// Pythonコードで xla.compile() を実行するとここが呼ばれる
グラフから XlaLauch Op に変換
compiler/jit/encapsulate_xla_computations_pass.cc
Status EncapsulateXlaComputationsPass::Run(
const GraphOptimizationPassOptions& options) {
Encapsulate(options.graph, options.flib_def);
BuildXlaLaunchOps(options.graph->get());
return Status::OK();
}
XlaLaunch => XlaCompile + XlaRun
- XlaCompile
TF function を LocalExecutable にコンパイルする
- XlaRun
XlaCompile でコンパイルした LocalExecutable を実行する
Split XlaLaunch into XlaCompile and XlaRun; NFC , 21 Sep 2018
TensorFlow r1.12にて導入された
Passを使ってグラフを変形してるよ
compiler/jit/jit_compilation_pass_registration.cc
REGISTER_OPTIMIZATIONマクロを使って、
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC
Passを追加
 ・MarkForCompilationPass
 ・IncreaseDynamismForAutoJitPass (r1.12にて導入)
 ・PartiallyDeclusterPass (r1.11にて導入)
 ・EncapsulateSubgraphsPass // サブグラフを関数ノード
 ・BuildXlaOpsPass // Xla Ops をコンパイル置換
BuildXlaOpsPass::Run
compiler/jit/build_xla_ops_pass.cc
tatus BuildXlaOpsPass::Run(const GraphOptimizationPassOptions& options) {
Graph* graph = options.graph->get();
….
for (Node* n : xla_compiled_kernels) {
// ここで、Xla Ops をコンパイルした後に、実行しています。
TF_RETURN_IF_ERROR(ReplaceNodeWithXlaCompileAndXlaRun(
*options.flib_def, lazy_compilation_enabled, graph, n));
}
…
ReplaceNodeWithXlaCompileAndXlaRun
compiler/jit/build_xla_ops_pass.cc
XlaClusterInfo cluster_info;
TF_RETURN_IF_ERROR(GetXlaClusterInfo(n, &cluster_info));
ops::_XlaCompile xla_compile(root.WithOpName("xla_compile"),
/*constants=*/cluster_info.constant_inputs,
/*args=*/cluster_info.non_constant_inputs,
/*resources=*/cluster_info.resource_inputs,
/*must_compile=*/requires_compilation,
cluster_info.function);
REGISTER_OP("_XlaCompile")
.Input("constants: Tconstants")
.Attr("Tconstants: list(type) >= 0")
.Attr("must_compile: bool")
.Input("args: Targs")
.Attr("Targs: list(type) >= 0")
.Input("resources: Nresources * resource")
.Attr("Nresources: int >= 0")
.Output("key: string") // コンパイルがOKの時の key
.Output("compilation_successful: bool") // コンパイルがOK/NG
.Attr("function: func")
// The compilation cache is stateful.
.SetIsStateful()
.Doc(R"(XLA Compile Op. For use by the XLA JIT only.
void XlaCompileOp::Compute(OpKernelContext* ctx) {
xla::LocalClient* client;
const XlaCompiler::CompilationResult* kernel;
xla::LocalExecutable* executable;
std::map<int, OptionalTensor> variables;
…
Status status = CompileToLocalExecutable (
ctx, function_, platform_info_, resources_, constants_,
/*lazy=*/!must_compile_, &client, &variables, &kernel, &executable);
if (must_compile_ || status.code() != error::UNIMPLEMENTED) {
OP_REQUIRES_OK(ctx, status);
}
// Each execution of an XlaCompile op creates a new XlaExecutableClosure, even
// if it didn't have to compile the cluster because of a compilation-cache
// hit. This is because we at least need new snapshots of the resource
// variables.
XlaExecutableClosureStore::KeyT key =
XlaExecutableClosureStore::Global()->Produce(XlaExecutableClosure(
client, executable, kernel, std::move(variables), constants_.size()));
Tensor compilation_key(cpu_allocator, DT_STRING, TensorShape({}));
compilation_key.flat<string>()(0) = key;
Tensor compilation_successful(cpu_allocator, DT_BOOL, TensorShape({}));
compilation_successful.flat<bool>()(0) = true;
ctx->set_output(0, compilation_key) ;
ctx->set_output(1, compilation_successful) ;
}
static Status CompileToLocalExecutable (
…
XlaCompilationCache* cache ;
TF_RETURN_IF_ERROR(rm->LookupOrCreate<XlaCompilationCache>(
rm->default_container(), "xla_cache", &cache,
[&](XlaCompilationCache** cache) {
return BuildCompilationCache(ctx, platform_info, cache);
}));
...
std::vector<XlaCompiler::Argument> args;
TF_RETURN_IF_ERROR(XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_args, *variables, ctx, &args));
return cache->Compile(options, function, args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
: XlaCompilationCache::CompileMode::kStrict,
kernel, executable);
ReplaceNodeWithXlaCompileAndXlaRun
compiler/jit/build_xla_ops_pass.cc
std::vector<Output> xla_run_args = cluster_info.non_constant_inputs;
absl::c_copy(cluster_info.resource_inputs,
std::back_inserter(xla_run_args));
ops::_XlaRun xla_run(root.WithOpName("xla_run"), xla_run_args,
xla_compile.key, n->output_types());
MoveOutgoingEdges(g, /*old_node=*/n,
/*new_node=*/xla_run.operation.node());
REGISTER_OP("_XlaRun")
.Input("args: Targs") // 引数
.Attr("Targs: list(type) >= 0")
.Output("results: Tresults")
.Attr("Tresults: list(type) >= 0")
.Input("key: string") // _XlaCompile の結果の key
// XLA random-number generation ops are stateful.
// TODO(phawkins): create stateful and non-stateful variants of _XlaRun.
.SetIsStateful()
.Doc(R"(XLA Run Op. For use by the XLA JIT only.
void XlaRunOp::Compute(OpKernelContext* ctx) {
Tensor key_tensor = ctx->input(ctx->num_inputs() - 1) ;
const XlaExecutableClosureStore::KeyT& key = key_tensor.flat<string>()(0) ;
XlaExecutableClosure closure =
XlaExecutableClosureStore::Global()->Consume( key);
…
Env* env = Env::Default();
auto start_time = env->NowMicros();
auto run_result =
closure.executable()->Run (launch_context.arguments(), run_options);
OP_REQUIRES(ctx, run_result.ok(), run_result.status());
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time in computation: " << elapsed << "us";
ポイント2:XlaCompile + XlaRun
XlaCompile XlaRun
XlaLaunch~ r1.11
r1.12にて
追加
LocalExecutable生成 実行
Eager Modeでは?
The XLA compile API
# xla.compile() doesn't work with Keras model.fit() API or
TF eager mode yet.
しかし、
ソースコードは語る
EagerLocalExecute
// If we are running a function on explicitly requested TPU,
// compile it with XLA.
// Note that it is not ideal, but currently ok, to set this
// attribute after computing the kernel cache key above.
bool compile_with_xla = false;
if (op->is_function() && device != nullptr &&
(device->device_type() == " TPU" || device->device_type() == "XLA_GPU" ||
device->device_type() == "XLA_CPU")) {
op->MutableAttrs()->Set(kXlaCompileAttr, true);
compile_with_xla = true;
}
kXlaCompileAttr を true にすると、
MarkForCompilationPass::Run にて、XLA化の準備をする
Passを使ってグラフを変形してるよ
compiler/jit/jit_compilation_pass_registration.cc
REGISTER_OPTIMIZATIONマクロを使って、
OptimizationPassRegistry::POST_REWRITE_FOR_EXEC
Passを追加
 ・MarkForCompilationPass
 ・IncreaseDynamismForAutoJitPass (r1.12にて導入)
 ・PartiallyDeclusterPass (r1.11にて導入)
 ・EncapsulateSubgraphsPass // サブグラフを関数ノード
 ・BuildXlaOpsPass // Xla Ops をコンパイル置換
XRTとは?
r1.11で導入された
XRT
TensorFlow以外からXLAを
利用するための仕組み?
XRTのテストコードを見てみよう
xla::XlaComputation AddAndTuple () {
xla::XlaBuilder builder("AddAndTuple");
auto p0 = xla::Parameter(&builder, 0,
xla::ShapeUtil::MakeShape(xla::F32, {2}), "P0");
auto p1 = xla::Parameter(&builder, 1,
xla::ShapeUtil::MakeShape(xla::F32, {2}), "P1");
auto sum = xla::Add(p0, p1);
xla::Tuple(&builder, {sum});
return builder.Build().ValueOrDie();
}
sum = xla::Add( p0, p1 )
TEST(RawApiTest, CompileAndExecuteReturnTuple ) {
xrt::XLAAllocation p0;
p0.set_device_ordinal(0);
*p0.mutable_value() = FloatVector({1.0f, 2.0f});
xrt::XLAAllocation p1;
p1.set_device_ordinal(0);
*p1.mutable_value() = FloatVector({8.0f, 5.0f});
xrt::XLAComputation c;
auto config = c.mutable_config();
auto shapes = config->mutable_program_shape();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->add_parameters() =
xla::ShapeUtil::MakeShape(xla::F32, {2}).ToProto();
*shapes->mutable_result() =
xla::ShapeUtil::MakeTupleShape({xla::ShapeUtil::MakeShape(xla::F32, {2})})
.ToProto();
StoreComputationSnapshot( AddAndTuple(), c.mutable_hlo_snapshot());
xrt::XRTExecutionConfig e;
e.set_release_input_handles(true);
e.set_release_compilation_handle(true);
Scope root = Scope::NewRootScope().WithDevice(DeviceFromFlag());
auto e_config =
ops::Const(root.WithDevice("/device:CPU:0"), e.SerializeAsString());
auto computation =
ops::Const(root.WithDevice("/device:CPU:0"), c.SerializeAsString());
auto c_handle = ops::XRTCompile(root, computation);
auto p0_value =
ops::Const(root.WithDevice("/device:CPU:0"), p0.SerializeAsString());
auto p0_handle = ops::XRTAllocate(root, p0_value);
auto p1_value =
ops::Const(root.WithDevice("/device:CPU:0"), p1.SerializeAsString());
auto p1_handle = ops::XRTAllocate(root, p1_value);
auto result = ops::XRTExecute(root, c_handle, e_config,
{Output(p0_handle), Output(p1_handle)});
auto read_back = ops::XRTReadLiteralAndRelease (root, result);
TF_ASSERT_OK(root.status());
ClientSession session(root);
std::vector<Tensor> outputs;
TF_EXPECT_OK(session.Run({ read_back}, &outputs));
xla::LiteralProto response;
EXPECT_TRUE(response.ParseFromString(outputs[0].scalar<string>()()));
auto sum = xla::LiteralUtil::CreateR1<float>({9.0f, 7.0f});
auto expected = xla::LiteralUtil::MakeTuple({&sum});
EXPECT_TRUE(CompareLiteralToLiteralProto(expected, response));
}
XRTCompile
XRTExecute
XRTReadLiteralAndRelease
xrt::XLAComputation
xrt::XRTExecutionConfig
xrt::XLAAllocation (入力データ)
read_back
XRTAllocate
XRTAllocateConst
Const
Const
Const
REGISTER_OP("XRTAllocate")
 ・入力
allocation: string
 ・出力
handle: int64
void XRTAllocate::Compute(OpKernelContext* ctx) override {
const Tensor& allocation_info = ctx->input(0); // 入力:0
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(allocation_info.shape()),
errors::Internal("allocation input should be a string scalar"));
xrt::XLAAllocation allocation_proto;
OP_REQUIRES(
ctx,
allocation_proto.ParseFromString(allocation_info.scalar<string>()()),
errors::InvalidArgument(
"Unable to parse allocation input to XLAAllocation"));
xla::Literal literal;
OP_REQUIRES_OK(
ctx, XRTStateHelpers::MakeLiteral(allocation_proto.value(), &literal));
....
XRTTupleAllocation* allocation;
OP_REQUIRES_OK(ctx, XRTTupleAllocation::CreateAndTransfer(
literal, device_ref.backend(),
device_ref.device_ordinal(), &allocation));
// Intern takes ownership of our reference to allocation.
int64 key;
OP_REQUIRES_OK(ctx, allocation->Intern(rm, &key));
Tensor output(DT_INT64, TensorShape({}));
output.scalar<int64>()() = key;
ctx->set_output(0, output); // 出力:0
REGISTER_OP("XRTCompile")
 ・入力
computation: string
 ・出力
handle: int64
program_shape: string
void XRTCompileOp::Compute(OpKernelContext* ctx) {
....
const Tensor& computation_input = ctx->input(0);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(computation_input.shape()),
errors::Internal("computation input should be a string scalar"));
xrt::XLAComputation computation_proto ;
OP_REQUIRES(
ctx,
computation_proto.ParseFromString(computation_input.scalar<string>()() ),
errors::InvalidArgument(
"Unable to parse computation input to XLAComputation"));
....
int64 uid;
OP_REQUIRES_OK(
ctx, cache->CompileIfKeyAbsent(
key, &uid, [&](std::unique_ptr<xla::LocalExecutable>* program) {
VLOG(1) << "Compiling XLA executable";
return Compile(ctx, computation_proto, program);
}));
std::unique_ptr<XRTCompilationCacheEntryRef> entry;
OP_REQUIRES_OK(ctx, cache->Lookup(uid, &entry));
status XRTCompileOp::Compile(OpKernelContext* ctx,
const xrt::XLAComputation& computation_proto ,
std::unique_ptr<xla::LocalExecutable>* program) {
....
TF_ASSIGN_OR_RETURN( xla::XlaComputation computation ,
client->LoadSnapshot( computation_proto.hlo_snapshot ()));
....
auto compile_result =
client->Compile(computation, argument_layout_ptrs, build_options);
if (!compile_result.ok()) {
return compile_result.status();
}
*program = std::move(compile_result.ValueOrDie());
return Status::OK();
}
....
Tensor handle_output(DT_INT64, TensorShape({}));
handle_output.scalar<int64>()() = uid;
ctx->set_output(0, handle_output); // 出力:0
xla::LocalExecutable* executable = entry->get().get_executable();
xla::ProgramShapeProto program_shape = executable->executable()
->module()
.config()
.entry_computation_layout()
.ComputeProgramShape()
.ToProto();
Tensor program_shape_output(DT_STRING, TensorShape({1}));
program_shape_output.vec<string>()(0) = program_shape.SerializeAsString();
ctx->set_output(1, program_shape_output); // 出力:1
}
REGISTER_OP("XRTExecute")
.Attr("Ninputs: int >= 0")
.Input("computation_handle: int64")
.Input("execution_config: string")
.Input("input_handles: Ninputs * int64")
.Output("output_handle: int64")
 ・属性
Ninputs: int >= 0
 ・入力
computation_handle: int64
execution_config: string
input_handles: Ninputs * int64
 ・出力
output_handle: int64
void XRTExecuteOp::ComputeAsync(OpKernelContext* context, DoneCallback done) {
// Schedule onto the default queue, for unbounded concurrency. See b/73520706
Env::Default()->SchedClosure([this, context, done]() {
OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
done();
});
}
Status XRTExecuteOp::DoWork (OpKernelContext* context) {
…
const Tensor& execution_input = context->input(0); // 入力:0
TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_input.shape()));
int64 compilation_handle = execution_input.scalar<int64>()();
const Tensor& execution_config = context->input(1); // 入力:1
TF_RET_CHECK(TensorShapeUtils::IsScalar(execution_config.shape()));
xrt::XRTExecutionConfig config_proto;
TF_RET_CHECK(
config_proto.ParseFromString(execution_config.scalar<string>()()));
…
std::vector<xla::ShapedBuffer> input_allocations;
std::vector<xla::ShapedBuffer*> input_pointers;
TF_RETURN_IF_ERROR( GetComputationInputs (context, rm, release_inputs,
&input_tuples, &input_allocations,
&input_pointers));
…
Status GetComputationInputs (OpKernelContext* context, ResourceMgr* rm,
bool release_inputs,
std::vector<XRTTupleAllocation*>* input_tuples,
std::vector<xla::ShapedBuffer>* input_allocations,
std::vector<xla::ShapedBuffer*>* input_pointers) {
std::vector<int64> input_uids ;
OpInputList arg_list;
TF_RETURN_IF_ERROR( context->input_list("input_handles", &arg_list ));
Env* env = Env::Default();
auto start_time = env->NowMicros();
xla::LocalExecutable* executable = entry->get().get_executable();
auto run_result = executable->Run(input_pointers, run_options) ;
if (!run_result.ok()) {
return run_result.status();
}
auto elapsed = env->NowMicros() - start_time;
VLOG(2) << "Elapsed time: " << elapsed << "us";
Tensor* output_tensor;
TF_RETURN_IF_ERROR(
context->allocate_output(0, TensorShape({}), &output_tensor) );
int64 key;
TF_RETURN_IF_ERROR(output_tuple->Intern(rm, &key));
output_tensor->scalar<int64>()() = key ;
REGISTER_OP("XRTReadLiteralAndRelease")
 ・入力
handle: int64
 ・出力
literal: string
void XRTReadLiteralOp::Compute(OpKernelContext* ctx) override {
const Tensor& allocation_handle = ctx->input(0) ; // 入力:0
OP_REQUIRES(
ctx, TensorShapeUtils::IsScalar(handle_tensor.shape()),
errors::Internal("computation input should be an int64 scalar"));
int64 allocation_handle = handle_tensor.scalar<int64>()();
…
xla::Literal literal;
OP_REQUIRES_OK(
ctx, allocation->ToLiteral(device_ref.backend(),
device_ref.device_ordinal(), &literal));
xla::LiteralProto literal_proto = literal.ToProto();
Tensor output(DT_STRING, TensorShape({}));
literal_proto.SerializeToString(&output.scalar<string>()());
ctx->set_output(0, output) ; // 出力:0
}
ポイント3:XRTCompile + XRTExecute
XRTReadLiteralA
ndRelease
XRTAllocate XRTExecute
XRTCompile
JuliaでTPUを使う
Automatic Full Compilation of Julia
Programs and ML Models to Cloud
TPUs
https://arxiv.org/abs/1810.09868
Qiita : XLA.jl を試してみた
Qiita : JuliaからCloud TPUを使う論文の、ざっくりまとめ
引用:https://kiszk.github.io/2018/12/19/TensorFlow-Julia-TPU-XLA/
PyTorchでTPUを使う
Introducing PyTorch across Google
Cloud , 2018.10.3
https://cloud.google.com/blog/products/ai-machine-learning/introducing-p
ytorch-across-google-cloud
Today, we’re pleased to announce that engineers on Google’s TPU team
are actively collaborating with core PyTorch developers to connect
PyTorch to Cloud TPUs. The long-term goal is to enable everyone to enjoy
the simplicity and flexibility of PyTorch while benefiting from the
performance, scalability, and cost-efficiency of Cloud TPUs.
As a starting point, the engineers involved have produced a prototype that
connects PyTorch to Cloud TPUs via XLA, an open source linear algebra
compiler.
This prototype has successfully enabled us to train a PyTorch
implementation of ResNet-50 on a Cloud TPU, and we’re planning to open
source the prototype and then expand it in collaboration with the PyTorch
community.
Please email us at pytorch-tpu@googlegroups.com to tell us what types of
PyTorch workloads you would be most interested in accelerating with Cloud
TPUs!
PyTorch For TPU : 2018.11.13公開
https://github.com/pytorch/xla
PyTorch + XLA のソースコードが公開された:2018.11.16のブログ
利用ケース:
 1)、CPUのXLAにて実行する場合。
 2)、XRT経由でXLAを利用して、CPUのXLAにて実行する場合、
 3)、XRT経由でXLAを利用して、TPUのXLAにて実行する場合
最新コードでは、2) と 3) のみ
テストコード:test/test_train_imagenet.py
…
import torch_xla_py.xla_model as xm # xla_model をインポート
…
def train_imagenet():
…
model = torchvision.models.resnet50 () # モデル (resnet50)
cross_entropy_loss = nn.CrossEntropyLoss()
devices = [':{}'.format(n) for n in range(0, FLAGS.num_cores)]
inputs = torch.zeros(FLAGS.batch_size, 3, 224, 224)
target = torch.zeros(FLAGS.batch_size, dtype=torch.int64)
xla_model = xm.XlaModel( # XlaModel にてモデルコンパイル
model, [inputs], loss_fn=cross_entropy_loss,
target=target, num_cores=FLAGS.num_cores, devices=devices)
…
optimizer = optim.SGD(
xla_model.parameters_list(), lr=lr, momentum=momentum, weight_decay=5e-4)
log_fn = xm.get_log_fn(logdir=FLAGS.logdir)
for epoch in range(1, FLAGS.num_epochs + 1):
xla_model.train( # xla_model.train にて学習
train_loader,
optimizer,
FLAGS.batch_size,
log_interval=log_interval,
metrics_debug=FLAGS.metrics_debug,
log_fn=log_fn)
accuracy = xla_model.test( # xla_mode.test にて推論
test_loader,
_cross_entropy_loss_eval_fn(cross_entropy_loss),
FLAGS.batch_size,
log_fn=log_fn)
xm.update_optimizer_state(optimizer, 'lr', lambda x: x / 1.025)
return accuracy
PyTorch で TPU を使うときのポイント
1)、xla_model をインポート
import torch_xla_py.xla_model as xm
2)、XlaModel にてモデルコンパイル
xla_model = xm.XlaModel( model, [inputs], …
3)、train にて学習
for epoch in range(1, FLAGS.num_epochs + 1):
xla_model.train(...)
4)、test にて推論
accuracy = xla_model.test(...)
def train(self,
samples_loader,
optimizer,
batch_size,
log_interval=1,
log_fn=print,
metrics_debug=False):
wloader = LoaderWrapper(
samples_loader,
self._loader_prefetch,
batch_size,
num_cores=self._num_cores,
devices=self._devices,
fused_mode=True)
wloader_cleaner = xu.Cleaner(wloader.close)
loss = None
start_time = time.time()
self._epoch += 1
for batch_number, (inputs, targets) in wloader:
self._step += 1
optimizer.zero_grad()
xla_outputs = xla_run_model(self._xla_model, inputs, devices=self._devices)
xla_run_grad(self._xla_model, self._get_backward_grads(xla_outputs),
devices=self._devices)
optimizer.step()
if (log_fn is not None and log_interval is not None and
batch_number % log_interval == 0):
if metrics_debug:
log_fn(torch_xla._XLAC._xla_metrics_report())
loss = self._compute_loss(xla_outputs)
log_fn(
TrainStepMetrics(self._epoch, self._num_cores, batch_number,
len(samples_loader), batch_size, loss,
time.time() - start_time, self._step))
return loss
def test(self, samples_loader, eval_fn, batch_size, log_fn=print):
wloader = LoaderWrapper(
samples_loader,
self._loader_prefetch,
batch_size,
num_cores=self._num_cores,
devices=self._devices,
fused_mode=True)
wloader_cleaner = xu.Cleaner(wloader.close)
test_loss = 0
count = 0
correct = 0
start_time = time.time()
with torch.no_grad():
for batch_number, (inputs, targets) in wloader:
xla_outputs = xla_run_model(self._xla_model, inputs, devices=self._devices)
for i, replica_xla_outputs in enumerate(xla_outputs):
output = replica_xla_outputs[1].to_tensor()
closs, ccorrect = eval_fn(output, inputs[i][1].to_tensor())
test_loss += closs
correct += ccorrect
count += batch_size
test_loss /= count
accuracy = 100.0 * correct / count
if log_fn is not None:
log_fn(
TestStepMetrics(test_loss, correct, count,
time.time() - start_time, self._step))
return accuracy
# Run an XLA model with the given tensors.
def xla_run_model(xla_model, inputs, devices=None):
return xla_model(*convert_to_xla_tensors(inputs, devices=devices))
…
 C++側のコード (Pybind11を利用)
.def("__call__",
[](XlaModule& xla_module , py::args args) -> py::object {
auto inputs = XlaCreateTensorList(args);
XlaModule::TensorBatchVector outputs;
{
NoGilSection nogil;
outputs = xla_module.forward(inputs);
}
return XlaPackTensorList(outputs);
})
# Runs the backward pass for the given XLA model and the gradient outputs.
def xla_run_grad(xla_model, grad_outputs, devices=None):
# Trace and symbolically differentiate
grads_output_xla = convert_to_xla_tensors(grad_outputs, devices=devices)
xla_model.backward (*grads_output_xla)
…
 C++側のコード (Pybind11を利用)
.def("backward",
[](XlaModule& xla_module , py::args args) {
auto inputs = XlaCreateTensorList(args);
NoGilSection nogil;
xla_module.backward (inputs);
})
torch_xla/csrc/translator.cpp この部分は、2019.1.20に追記
at::aten
add、div、sub、mul、gt、type_as、convolution、
thnn_conv2d_forward、thnn_conv2d_backward、t、addmm、mm、max_pool2d_with_indices、
max_pool2d_with_indices_backward、
avg_pool2d、avg_pool2d_backward、
adaptive_avg_pool2d、adaptive_avg_pool2d_backward、
sqrt、rsqrt、neg、tanh、sigmoid、relu、threshold、threshold_backward、
log_softmax、_log_softmax_backward_data、
reshape、view、expand、stack、cat、chunk、
native_batch_norm、batch_norm、native_batch_norm_backward、
sum、nll_loss、nll_loss_backward、size
at::prim
Constant、Undefined、SumToSize、ListConstruct
今日のまとめ
・XLA は、r1.12で変わった
  ポイント1 : xla.compile (Python API)
  ポイント2 : XlaCompile + XlaRun
・XRT が、r1.11 にて導入され
・Julia や PyTorch でTPUで使える
  ポイント3 : XRTCompile + XRTExecute
おまけ
まだあります、XLA を利用したもの
・Google/jax : JAX: Autograd and XLA
https://github.com/google/jax
現在、頻繁に更新されています!(‘xla’ or ‘xrt’, TPUはまだの模様)
ブログ:今週の月曜日(2019.1.28)から金曜日(2019.2.1)まで
・LeFlow : XLA => FPGA
https://github.com/danielholanda/LeFlow
XLA => LLVM => (LegUp) => Verilog HDL
(TensorFlow r1.6ベース、リリース後更新無し?)
あたしは、
ディープラーニング職人 ではありません
コンピュータエンジニア です
ありがとうございました
@Vengineer
ソースコード解析職人

TensorFlow XLA 「XLAとは、から、最近の利用事例について」