Successfully reported this slideshow.
We use your LinkedIn profile and activity data to personalize ads and to show you more relevant ads. You can change your ad preferences anytime.

TCO in Python via bytecode manipulation.

1,026 views

Published on

TCO in Python via bytecode manipulation.

Published in: Software
  • Be the first to comment

TCO in Python via bytecode manipulation.

  1. 1. Optimizing tail recursion in Python using bytecode manipulations. Allison Kaptur Paul Tagliamonte Liuda Nikolaeva (all errors are my own)
  2. 2. Problem: Python has a limit on recursion depth: def factorial(n, accum): if n <= 1: return accum else: return factorial(n-1, accum*n) >>> tail-factorial(1000) RuntimeError: maximum recursion depth exceeded
  3. 3. Challenge: • Optimize recursive function calls so that they don’t create new frames, thus avoiding stack overflow. • What we want: eliminate the recursive call; instead, reset the variables and jump to the beginning of the function.
  4. 4. Problem: How do you change the insides of a function?
  5. 5. Bytecode! Solution: (obviously)
  6. 6. Quick intro to bytecode. def f(n, accum): if n <= 1: return accum else: return f(n-1, accum*n) >>> f.__code__.co_code '|x00x00dx01x00kx01x00rx10x00|x01x00Stx00x00|x00 x00dx01x00x18|x01x00|x00x00x14x83x02x00Sdx00x00 S‘ >>> print [ord(b) for b in f.__code__.co_code] [124, 0, 0, 100, 1, 0, 107, 1, 0, 114, 16, 0, 124, 1, 0, 83, 116, 0, 0, 124, 0, 0, 100, 1, 0, 24, 124, 1, 0, 124, 0, 0, 20, 131, 2, 0, 83, 100, 0, 0, 83]
  7. 7. def f(n, accum): if n <= 1: return accum else: return f(n-1, accum*n) >>> import dis >>> dis.dis(f) 2 0 LOAD_FAST 0 (n) 3 LOAD_CONST 1 (1) 6 COMPARE_OP 1 (<=) 9 POP_JUMP_IF_FALSE 16 3 12 LOAD_FAST 1 (accum) 15 RETURN_VALUE 5 >> 16 LOAD_GLOBAL 0 (f) 19 LOAD_FAST 0 (n) 22 LOAD_CONST 1 (1) 25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum) 29 LOAD_FAST 0 (n) 32 BINARY_MULTIPLY 33 CALL_FUNCTION 2 36 RETURN_VALUE 37 LOAD_CONST 0 (None) 40 RETURN_VALUE
  8. 8. def f(n, accum): if n <= 1: return accum else: return f(n-1, accum*n) >>> import dis >>> dis.dis(f) 2 0 LOAD_FAST 0 (n) 3 LOAD_CONST 1 (1) 6 COMPARE_OP 1 (<=) 9 POP_JUMP_IF_FALSE 16 3 12 LOAD_FAST 1 (accum) 15 RETURN_VALUE 5 >> 16 LOAD_GLOBAL 0 (f) 19 LOAD_FAST 0 (n) 22 LOAD_CONST 1 (1) 25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum) 29 LOAD_FAST 0 (n) 32 BINARY_MULTIPLY 33 CALL_FUNCTION 2 36 RETURN_VALUE 37 LOAD_CONST 0 (None) 40 RETURN_VALUE
  9. 9. def f(n, accum): if n <= 1: return accum else: return f(n-1, accum*n) >>> import dis >>> dis.dis(f) 2 0 LOAD_FAST 0 (n) 3 LOAD_CONST 1 (1) 6 COMPARE_OP 1 (<=) 9 POP_JUMP_IF_FALSE 16 3 12 LOAD_FAST 1 (accum) 15 RETURN_VALUE 5 >> 16 LOAD_GLOBAL 0 (f) 19 LOAD_FAST 0 (n) 22 LOAD_CONST 1 (1) 25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum) 29 LOAD_FAST 0 (n) 32 BINARY_MULTIPLY 33 CALL_FUNCTION 2 36 RETURN_VALUE 37 LOAD_CONST 0 (None) 40 RETURN_VALUE
  10. 10. Before optimization: 0 LOAD_FAST 0 (n) 3 LOAD_CONST 1 (1) 6 COMPARE_OP 1 (<=) 9 POP_JUMP_IF_FALSE 16 12 LOAD_FAST 1 (accum) 15 RETURN_VALUE >> 16 LOAD_GLOBAL 0 (f) 19 LOAD_FAST 0 (n) 22 LOAD_CONST 1 (1) 25 BINARY_SUBTRACT 26 LOAD_FAST 1 (accum) 29 LOAD_FAST 0 (n) 32 BINARY_MULTIPLY 33 CALL_FUNCTION 2 36 RETURN_VALUE After optimization: >> 0 LOAD_FAST 0 (n) 3 LOAD_CONST 1 (1) 6 COMPARE_OP 1 (<=) 9 POP_JUMP_IF_FALSE 16 12 LOAD_FAST 1 (accum) 15 RETURN_VALUE >> 16 LOAD_FAST 0 (n) 19 LOAD_CONST 1 (1) 22 BINARY_SUBTRACT 23 LOAD_FAST 1 (accum) 26 LOAD_FAST 0 (n) 39 BINARY_MULTIPLY 30 STORE_FAST 1 (accum) 33 STORE_FAST 0 (n) 36 JUMP_ABSOLUTE 0 39 RETURN_VALUE
  11. 11. Simplified algorithm. def recursion_optimizer(f): new_bytecode = ‘’ for byte in f.__code__.co_code: if instruction[byte] == ‘LOAD_GLOBAL f’: get rid of this instruction elif instruction[byte] == ‘CALL_FUNCTION’: #replace it with resetting variables and jumping to 0 for arg in *args: new_bytecode.add_instr(store_new_val(arg)) new_bytecode.add_instr(jump_to_0) else: #regular byte new_bytecode.add(byte) f.__code__.co_code = new_bytecode return f
  12. 12. Not only does it work, it works FASTER than the original function: • Timed 10000 calls to fact(450). Original fact: 1.7009999752 Optimized fact: 1.6970000267 • And faster than other ways of optimizing this.
  13. 13. Here is the most interesting so far:
  14. 14. If our function calls another function… def sq(x): return x*x @tailbytes_v1 def sum_squares(n, accum): if n < 1: return accum else: return sum_squares(n-1, accum+sq(n)) • Our initial algorithm was removing all calls to a function, not only the recursive calls, so this would break.
  15. 15. How do you battle this? • We need to keep track of function calls and remove only the recursive calls. • Unfortunately, bytecode doesn’t know which function it’s calling: it just calls whatever is on the stack: 29 CALL_FUNCTION 2
  16. 16. So we just need to keep track of the stack… • When we hit ‘LOAD_GLOBAL self’, we start keeping track of the stack size (stack_size = 0). • Now, with every byte, we update the stack size. • Once we hit stack_size = 0, it means this byte was the recursive call, so we remove it. • It allows us to not get rid of calls to other functions (e.g., identity).
  17. 17. Road ahead: • Make it harder to break. • Translate “normal” (non-tail) recursion into tail-recursion (possibly with ASTs) • Handle mutual recursion …And some crasy ideas:
  18. 18. https://github.com/lohmataja/recursion Or: http://tinyurl.com/tailbytes Liuda Nikolaeva

×