PYTORCH FX
AGENDA
● Execution modes
● Torch fx
● Torch FX Example
● Torch compile
EXECUTION MODES
PyTorch is an open-source machine learning library developed using Torch library for
python programs.
PyTorch supports two execution modes
1.Eager Mode
2.Graph Mode
EAGER MODE
● Eager mode (which is called as define-by-run), operators in a model are immediately
executed as they are encountered.
● That means you do not execute a pre-constructed graph with Session.
GRAPH MODE
● Graph execution (which is called as define and run) extracts tensor computations
from Python and builds an efficient graph before evaluation.
● Graph mode enables operator fusion, wherein one operator is merged with another to
reduce/localize memory reads.
TORCH FX
● Torch.FX is available toolkit as part of the PyTorch package that supports graph
mode execution.
● It can make DAG representation of pytorch pgms.
● Allows to transform/optimize pytorch code using algorithms.
● Helps build torch.compile
TORCH FX EXAMPLE
import torch
import torch.fx
from torch.fx import Graph
def add_eg(x,y):
# import pdb
# pdb.set_trace()
a = torch.sin(x)**2;
b = torch.cos(y)**2;
return torch.add(a,b)
traced = torch.fx.symbolic_trace(add_eg)
print(traced.graph)
traced.graph.print_tabular()
print(traced.code)
graph():
%x : [num_users=1] = placeholder[target=x]
%y : [num_users=1] = placeholder[target=y]
%sin : [num_users=1] = call_function[target=torch.sin](args = (%x,), kwargs = {})
%pow_1 : [num_users=1] = call_function[target=operator.pow](args = (%sin, 2),
kwargs = {})
%cos : [num_users=1] = call_function[target=torch.cos](args = (%y,), kwargs = {})
%pow_2 : [num_users=1] = call_function[target=operator.pow](args = (%cos, 2),
kwargs = {})
%add : [num_users=1] = call_function[target=torch.add](args = (%pow_1, %pow_2),
kwargs = {})
return add
x
sin
sin2x
y
cos
cos2x
ADD
OUTPUT
TORCH COMPILE
torch.compile is a method to JIT-compile PyTorch code into
optimized kernels.
EXAMPLE
def foo(x, y):
a = torch.sin(x)
b = torch.cos(y)
return a + b
opt_foo1 = torch.compile(foo)
print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
choose differnet backends :
# Reset since we are using a different mode.
import torch._dynamo
torch._dynamo.reset()
opt_model = torch.compile(foo, backend = custom_backend)
eager = foo(torch.randn((10,10)), torch.randn((10,10)))
graph_res = opt_model(torch.randn((10,10)), torch.randn((10,10)))

Pytorch fx - It is a machine learning library

  • 1.
  • 2.
    AGENDA ● Execution modes ●Torch fx ● Torch FX Example ● Torch compile
  • 3.
    EXECUTION MODES PyTorch isan open-source machine learning library developed using Torch library for python programs. PyTorch supports two execution modes 1.Eager Mode 2.Graph Mode
  • 4.
    EAGER MODE ● Eagermode (which is called as define-by-run), operators in a model are immediately executed as they are encountered. ● That means you do not execute a pre-constructed graph with Session. GRAPH MODE ● Graph execution (which is called as define and run) extracts tensor computations from Python and builds an efficient graph before evaluation. ● Graph mode enables operator fusion, wherein one operator is merged with another to reduce/localize memory reads.
  • 5.
    TORCH FX ● Torch.FXis available toolkit as part of the PyTorch package that supports graph mode execution. ● It can make DAG representation of pytorch pgms. ● Allows to transform/optimize pytorch code using algorithms. ● Helps build torch.compile
  • 6.
    TORCH FX EXAMPLE importtorch import torch.fx from torch.fx import Graph def add_eg(x,y): # import pdb # pdb.set_trace() a = torch.sin(x)**2; b = torch.cos(y)**2; return torch.add(a,b) traced = torch.fx.symbolic_trace(add_eg) print(traced.graph) traced.graph.print_tabular() print(traced.code)
  • 7.
    graph(): %x : [num_users=1]= placeholder[target=x] %y : [num_users=1] = placeholder[target=y] %sin : [num_users=1] = call_function[target=torch.sin](args = (%x,), kwargs = {}) %pow_1 : [num_users=1] = call_function[target=operator.pow](args = (%sin, 2), kwargs = {}) %cos : [num_users=1] = call_function[target=torch.cos](args = (%y,), kwargs = {}) %pow_2 : [num_users=1] = call_function[target=operator.pow](args = (%cos, 2), kwargs = {}) %add : [num_users=1] = call_function[target=torch.add](args = (%pow_1, %pow_2), kwargs = {}) return add
  • 8.
  • 9.
    TORCH COMPILE torch.compile isa method to JIT-compile PyTorch code into optimized kernels. EXAMPLE def foo(x, y): a = torch.sin(x) b = torch.cos(y) return a + b opt_foo1 = torch.compile(foo) print(opt_foo1(torch.randn(10, 10), torch.randn(10, 10)))
  • 10.
    choose differnet backends: # Reset since we are using a different mode. import torch._dynamo torch._dynamo.reset() opt_model = torch.compile(foo, backend = custom_backend) eager = foo(torch.randn((10,10)), torch.randn((10,10))) graph_res = opt_model(torch.randn((10,10)), torch.randn((10,10)))