今年こそは再帰関数を理解しよう!

あけましておめでとうございます。みんなで再帰をマスターしてサイキッカーになりましょう。
当記事は主にPythonでの利用がメインなので、普段から呼吸をするように再帰を使ってる関数型言語利用者の方はブラウザバックするかマサカリを構えながら見ることをおすすめします。

再帰関数とは

再帰関数とは関数の中で自分自身の関数を呼び出す構造の関数、またはメソッドです。またその関数の呼出しについて再帰呼出しと言ったりもします。以降はひっくるめて単に再帰といいます。

具体的なコード例をみてみましょう。この手の話でよく見かける階乗の例です。ご存知の通り1以上の整数から1までの積ですね。正確な定義はWikipediaでどうぞ。計算イメージは「f(5) = 5 * 4 * 3 * 2 * 1 = 120」のような感じです。

def fact1(n):
    if n <= 1:
        return 1
    return n * fact1(n - 1)

再帰を使うと漸化式を表現しやすいというのがメリットとしてあるようです。自分はあんまり意識したことがなかったですが。

再帰において重要なのは終了条件で、この場合は「nが1以下になること」が終了条件です。これがないとプログラムは正常に終了しません。
終了条件に達すると呼び出された関数達が呼出とは逆順でreturnすることで結果を得ることができるというわけですね。

では実際に実行してみましょう。

>>> fact1(5)
120
>>> fact1(6)
720
>>> fact1(100)  # 結構でかい
93326215443944152681699238856266700490715968264381621468592963895217599993229915608941463976156518286253697920827223758251185210916864000000000000000000000000L
>>> fact1(1000)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "<stdin>", line 4, in fact1
  File "<stdin>", line 4, in fact1
  File "<stdin>", line 4, in fact1
  (長いので中略)
RecursionError: maximum recursion depth exceeded in comparison

途中まではうまく行ってるように見えましたが、1000のように比較的大きな数を入力したらRecursionErrorとなりました(Python3系)
(Python2系だと “RuntimeError: maximum recursion depth exceeded” となるようです)

Cをはじめとする高級言語では関数呼出の際にスタックフレームと呼ばれる情報をコールスタックという領域に積み(PUSH)、関数の実行が終わると取り出されます(POP)。
一般的にスタックフレームには次に実行するべき命令のアドレスやローカル変数などが含まれます。関数の呼出毎にスタックフレームが作られるため別関数で同名の変数などを宣言しても混同せずに扱ってくれるわけですね。

ただし、上記例のように連続して関数を呼び出しまくるとこのコールスタックから溢れたり、システムの制限に引っかかりエラーとなります。
CPython3の場合は再帰呼出し可能な回数がデフォルトで1000回と決まっているようです。

>>> import sys
>>> sys.getrecursionlimit()
1000

この回数を超えたので例外が起きたんですね。ではどのように対応したらよいでしょうか。

ループに直す

再帰という構造はループに変換することが可能と言われています。先程の例をループに直してみましょう。
#変数名もっといいのあるかなぁ

def fact2(n):
    total = 1
    for i in range(n, 1, -1):
        total *= i
    return total

こんな感じになりました。実行してみます。

>>> fact2(5)
120
>>> fact2(6)
720
>>> fact2(1000)
402387260077093773543702433923003985719374864210714632543799910429938512398629020592044208486969404800479988610197196058631666872994808558901323829669944590997424504087073759918823627727188732519779505950995276120874975462497043601418278094646496291056393887437886487337119181045825783647849977012476632889835955735432513185323958463075557409114262417474349347553428646576611667797396668820291207379143853719588249808126867838374559731746136085379534524221586593201928090878297308431392844403281231558611036976801357304216168747609675871348312025478589320767169132448426236131412508780208000261683151027341827977704784635868170164365024153691398281264810213092761244896359928705114964975419909342221566832572080821333186116811553615836546984046708975602900950537616475847728421889679646244945160765353408198901385442487984959953319101723355556602139450399736280750137837615307127761926849034352625200015888535147331611702103968175921510907788019393178114194545257223865541461062892187960223838971476088506276862967146674697562911234082439208160153780889893964518263243671616762179168909779911903754031274622289988005195444414282012187361745992642956581746628302955570299024324153181617210465832036786906117260158783520751516284225540265170483304226143974286933061690897968482590125458327168226458066526769958652682272807075781391858178889652208164348344825993266043367660176999612831860788386150279465955131156552036093988180612138558600301435694527224206344631797460594682573103790084024432438465657245014402821885252470935190620929023136493273497565513958720559654228749774011413346962715422845862377387538230483865688976461927383814900140767310446640259899490222221765904339901886018566526485061799702356193897017860040811889729918311021171229845901641921068884387121855646124960798722908519296819372388642614839657382291123125024186649353143970137428531926649875337218940694281434118520158014123344828015051399694290153483077644569099073152433278288269864602789864321139083506217095002597389863554277196742822248757586765752344220207573630569498825087968928162753848863396909959826280956121450994871701244516461260379029309120889086942028510640182154399457156805941872748998094254742173582401063677404595741785160829230135358081840096996372524230560855903700624271243416909004153690105933983835777939410970027753472000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000

関数呼び出しをなくしたらエラーが発生しなくなりましたね。

末尾再帰にする

末尾再帰とは再帰の呼出が関数の(returnをのぞく)最終ステップで実行されている再帰関数という説明が多かったです。
少し補足すると呼出元関数が呼出先関数の返却値を必要としない関数かな。

末尾再帰にすると処理系が内部的にループへ変換してくれるためスタックが消費されず、いくら呼び出しても時間はかかるがスタックオーバーフローなどのエラーが発生しないというメリットがあります。

実装例を見てもらおうと思うんですが、Pythonでは言語レベルの末尾再帰最適化は行われません。
Rubyでは設定をいじると最適化してくれるみたいなので先にそっちをやってみましょう。

使っているRubyのバージョンは2.3.1です。

RubyVM::InstructionSequence.compile_option = {
  tailcall_optimization: true,
  trace_instruction: false
}
 
def fact1(n)
    return 1 if n <= 1
    return n * fact1(n - 1)
end
 
def fact2(n, total=1)
    return total if n <= 1
    return fact2(n - 1, total * n)
end

先頭4行が有効にする設定です。多分上しか必要ないんだろうけどコピペなので許してください。最後のreturnはなくても大丈夫です。
上(fact1)が普通の再帰、下(fact2)が末尾再帰です。
上の関数も呼び出しは末尾にあるため末尾再帰と誤解しがちですが、return時にfact1の実行結果をnとかける必要がある、つまり返却値を待つ必要があるため末尾再帰ではありません。

末尾再帰構造にする上で重要なポイントは最終的に返却すべき値をどこで計算しているか、どのように渡っているかです。
fact1では値の返却時に計算を行うのに対し、fact2では呼び出し時に引数(total)に渡して引き継いでいきます。これを継続渡し形式(CPS)というらしいです。BPSは関係ない。

呼出フェーズの流れ(往路)、返却フェーズの流れ(復路)に分けて表現すると以下のようになります。

非末尾再帰(fact1) 末尾再帰(fact2)
往路
(呼出)
  • 5を渡して呼び出す(基底)
  • 4を渡して呼び出す
  • 3を渡して呼び出す
  • 2を渡して呼び出す
  • 1を渡して呼び出す
  • 5と1を渡して呼び出す
  • 4と5( 1*5 )を渡して呼び出す
  • 3と20( (1*5)*4 )を渡して呼び出す
  • 2と60( ((1*5)*4)*3 )を渡して呼び出す
  • 1と120( (((1*5)*4)*3)*2 )を渡して呼び出す
復路
(返却)
  • 1を返却する
  • 2( 2*1 )を返却する
  • 6( 3*(2*1) )を返却する
  • 24( 4*(3*(2*1)) )を返却する
  • 120( 5*(4*(3*(2*1))) )を返却する
  • 120を返却する
  • 120を返却する
  • 120を返却する
  • 120を返却する
  • 120を返却する

末尾再帰では往路の時点で返却値が求まっています。復路が不要な再帰と言い換えてもよいかもしれません。

実際に最適化されているか確かめてみましょう。

irb(main):015:0> fact1(5)
=> 120
# 非末尾再帰関数はエラー
irb(main):016:0> fact1(20000)
SystemStackError: stack level too deep
	from (irb):8:in 'fact1'
	from (irb):8:in 'fact1'
... 10034 levels...
	from (irb):8:in 'fact1'
	from /usr/bin/irb:11:in '<main>'
# 末尾再帰関数はエラーにならなかった。結果が大きすぎたので省略
irb(main):017:0> fact2(20000)
=> 18192063202303451348276417568664587660716099014787526489180622186345694610385575344538360958277587...

できてましたね。

自分の最適化のイメージです。

最適化前 最適化後
(
 ((5)        * # 5
  ((5-1)     * # 4
   ((4-1)    * # 3
    ((3-1)   * # 2
     ((2-1))   # 1
    )
   )
  )
 )
)
(
 (5)   * # 5
 (5-1) * # 4
 (4-1) * # 3
 (3-1) * # 2
 (2-1)   # 1
)

すべての再帰関数が末尾再帰に変換できるわけではありません。
複数回関数を呼び出していたり、呼び出した結果を参照して何かする必要があったりと末尾化できない理由はいろいろ考えられます。無理に変換してわかりづらくなるくらいならループで書いたほうが無難でしょう。

Pythonで末尾再帰最適化

Pythonでは末尾再帰最適化はできないのかということで調べたらそんな感じのデコレータを実装してくれた方がいるみたいです。

Python3でも動くようにいじったのが以下。まぁ例外のキャッチくらいしか変えてないのだけど。

import sys
 
 
class TailRecurseException(BaseException):
  def __init__(self, args, kwargs):
    self.args = args
    self.kwargs = kwargs
 
 
def tail_call_optimized(g):
  """
  This function decorates a function with tail call
  optimization. It does this by throwing an exception
  if it is it's own grandparent, and catching such
  exceptions to fake the tail call optimization.
 
  This function fails if the decorated
  function recurses in a non-tail context.
  """
  def func(*args, **kwargs):
    f = sys._getframe()
    if f.f_back and f.f_back.f_back \
        and f.f_back.f_back.f_code == f.f_code:
      raise TailRecurseException(args, kwargs)
    else:
      while 1:
        try:
          return g(*args, **kwargs)
        except TailRecurseException as e:
          args = e.args
          kwargs = e.kwargs
  func.__doc__ = g.__doc__
  return func

このデコレータを使った関数を作って実行しましょう。
実行箇所がわかりやすいようにトレースバックも表示してみます。

>>> import traceback
 
>>> def fact1(n):
...     print('-----', n, '-----')
...     for i, line in enumerate(traceback.format_stack()):
...         print(i, line.strip())
...     if n <= 1: return 1
...     return n * fact1(n - 1)
...
>>> @tail_call_optimized
... def fact2(n):
...     print('-----', n, '-----')
...     for i, line in enumerate(traceback.format_stack()):
...         print(i, line.strip())
...     if n <= 1: return 1
...     return n * fact2(n - 1)
...
>>> @tail_call_optimized
... def fact3(n, total=1):
...     print('-----', n, '-----')
...     for i, line in enumerate(traceback.format_stack()):
...         print(i, line.strip())
...     if n <= 1: return total
...     return fact3(n - 1, total * n)
...
 
>>> fact1(5)
----- 5 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 3, in fact1
----- 4 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 6, in fact1
2 File "<stdin>", line 3, in fact1
----- 3 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 6, in fact1
2 File "<stdin>", line 6, in fact1
3 File "<stdin>", line 3, in fact1
----- 2 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 6, in fact1
2 File "<stdin>", line 6, in fact1
3 File "<stdin>", line 6, in fact1
4 File "<stdin>", line 3, in fact1
----- 1 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 6, in fact1
2 File "<stdin>", line 6, in fact1
3 File "<stdin>", line 6, in fact1
4 File "<stdin>", line 6, in fact1
5 File "<stdin>", line 3, in fact1
120
 
>>> fact2(5)
----- 5 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact2
----- 4 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact2
----- 3 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact2
----- 2 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact2
----- 1 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact2
1
 
>>> fact3(5)
----- 5 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact3
----- 4 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact3
----- 3 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact3
----- 2 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact3
----- 1 -----
0 File "<stdin>", line 1, in <module>
1 File "<stdin>", line 19, in func
2 File "<stdin>", line 4, in fact3
120

デコレータを使っていない関数(fact1)は段々と呼出の階層が深くなっていくのに対し、デコレータを使った再帰関数(fact2, fact3)では呼出の階層が常に一定なのがわかるでしょうか。
そして、正しい結果が得られるのは末尾再帰であるfact3だけです。このデコレータは対象関数が末尾再帰になっているかどうかの判定まではしてくれないので非末尾再帰も強制的に最適化してしまいます。結果としてfact2では終了条件の返却値(1)が返ってきてしまいました。
というわけでこれを使うときは正しい末尾再帰関数を書いてあげる必要があります。

ここまで書いておいてなんですが、なんでもかんでも再帰で表現しようとするのは考えものです。
特に末尾再帰でかけるような単方向な繰り返しをわざわざ再帰で表すメリットはあまり大きくないはずです。たとえ最適化してくれるとしてもです。ご利用は計画的に。

でも再帰を使ったほうが便利なケースというのは確かに存在します。少しだけ見てみましょう。

応用

ツリー構造の走査

入れ子になったリストの各要素を走査する関数を考えてみましょう。

schema = [
  'a',
  [
    'b',
    [
      ['c'],
      'd',
      ['e', 'f'],
    ],
    'g',
    ['h'],
    'i',
  ],
  [
    'j',
    ['k', 'l'],
    'm',
    ['n', 'o'],
  ],
  'p'
]

ツリー構造で表すと以下のようになります。
nested_list_graph

ループと再帰、それぞれに対して「深さ優先」「幅優先」の探索視点で計4パターン作ってみました。
ツリーの子要素はforで回し、ツリーを更に掘り進めるためにwhileもしくは再帰による繰り返しを使うのが実装のポイントです。
ちなみに深さ優先とは階層を掘り進めて行き止まりにぶつかったら戻るという方法を繰り返す走査方法です。期待する出力順は「a,b,c,d,e,f,g,h,i,j,k,l,m,n,o,p」です。
対する幅優先とは階層の浅い順に評価する走査方法です。期待する出力順は「a,p,b,g,i,j,m,d,h,k,l,n,o,c,e,f」です

再帰 ループ
深さ優先探索
def walk1(schema):
    for e in schema:
        if isinstance(e, list):
            walk1(e)
        else:
            print(e)
def walk2(schema):
    stack = []
    it = iter(schema)
    while True:
        for e in it:
            if isinstance(e, list):
                stack.append(it)
                it = iter(e)
                break
            else:
                print(e)
        else:
            if stack:
                it = stack.pop()
            else:
                break
>>> walk1(schema)
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
>>> walk2(schema)
a
b
c
d
e
f
g
h
i
j
k
l
m
n
o
p
幅優先探索
def walk3(schema):
    queue = []
    def _walk3(schema):
        for e in schema:
            if isinstance(e, list):
                queue.append(e)
            else:
                print(e)
        while queue:
            _walk3(queue.pop(0))
    return _walk3(schema)
def walk4(schema):
    queue = [iter(schema)]
    while queue:
        it = queue.pop(0)
        for e in it:
            if isinstance(e, list):
                queue.append(iter(e))
            else:
                print(e)
>>> walk3(schema)
a
p
b
g
i
j
m
d
h
k
l
n
o
c
e
f
>>> walk4(schema)
a
p
b
g
i
j
m
d
h
k
l
n
o
c
e
f

だれがどう見ても再帰を使って深さ優先探索した時(walk1)がもっとも簡潔ですよね。

深さ優先探索で最も考慮すべきなのは終端に突き当たった時に前の親に戻るという動作です。
これを実現するためには自前でスタック的(LIFO)な変数を用意する必要があります。実際に上の例(ループ)ではイテレータというオブジェクトをスタックにPUSH/POPするのを手動で制御しています。
(再帰)関数を使うとコールスタックの仕組みによりローカル変数が呼び出し時の状態で自動的に退避されるためうまく活用することで簡潔なコードを書くことができるわけです。
ただ、幅優先で探索しようとすると途中結果を別途保持しないといけません。上記の実装例ではクロージャとして再帰を定義することにしました(walk3)。

ハノイの塔

というものがあります。
説明はWikipediaさまにおまかせします。

以下のルールに従ってすべての円盤を右端の杭に移動させられれば完成。

  • 3本の杭と、中央に穴の開いた大きさの異なる複数の円盤から構成される。
  • 最初はすべての円盤が左端の杭に小さいものが上になるように順に積み重ねられている。
  • 円盤を一回に一枚ずつどれかの杭に移動させることができるが、小さな円盤の上に大きな円盤を乗せることはできない。

n枚の円盤すべてを移動させるには最低 2^n – 1 回の手数がかかる[1]。
解法に再帰的アルゴリズムが有効な問題として有名であり、プログラミングにおける再帰的呼出しの例題としてもよく用いられる。
tower_of_hanoi_4

ということで説明にもある通り、再帰を使うことで簡潔に表現できるケースの一つです。

def hanoi(towers, start='a', end='b', work='c'):
    def _hanoi(n, start, end, work):
        if not n:
            return
        _hanoi(n - 1, start, work, end)  # call 1
        towers[end].append(towers[start].pop())  # move
        _hanoi(n - 1, work, end, start)  # call 2
    _hanoi(len(towers[start]), start, end, work)
    return towers
 
# 「a(開始)」にある円盤3つ(3, 2, 1)を「b(終了)」に移動する。退避場所として「c」を使う。
hanoi({
    'a': [3, 2, 1],
    'b': [],
    'c': [],
})

プログラムとしては上記のような感じにかけるとのこと。ポイントは以下だと思ってます。

  • 常に上2つの円盤の移動について考える
  • 複数ある場合は「一番下」と「それ以外」の2つについて考え、「それ以外」が最上部の1つになるまで細かくしてから移動する(ここを再帰で表現する)
  • 仮に2つの円盤を移動すると以下のような流れになる
    1. 上の円盤をc(退避)に移動する(コメントのcall 1)
    2. 下の円盤をb(終了)に移動する(コメントのmove)
    3. 上の円盤(cに移動した)をb(終了)に移動する(コメントのcall 2)
  • 上記のプログラムでは「下の円盤の移動」は同じ関数内のappend()によって、「上の円盤の移動」は再帰関数(移譲先の関数)によって処理される
  • 下の円盤が一番上を指したら折り返す(再帰の終了条件)

このアルゴリズムを最初に思いついた人は天才ですね。私は凡人なのでちゃんと理解するためにループに直してみます。

と、その前に上のプログラムでは「n」という変数で下の円板の位置を「天辺からの相対位置」を表していましたが、
プログラムを簡潔にするため「底辺からの相対位置」を「limit」という変数に入れ、終了条件として使うことにします。
終了条件だけ直して以下のようになります。注)まだループじゃないです

def hanoi2(towers, start='a', end='b', work='c'):
    limit = len(towers[start])
    def _hanoi(n, start, end, work):
        if n == limit:
            return
        _hanoi(n + 1, start, work, end)  # call 1
        towers[end].append(towers[start].pop())  # move
        # 変数の動きが気になる人は下のコメントアウトを外してみよう!
        # print('towers:', towers)
        _hanoi(n + 1, work, end, start)  # call 2
    _hanoi(0, start, end, work)
    return towers
 
hanoi2({
  'a': [3, 2, 1],
  'b': [],
  'c': [],
})

このhanoi2で3つの円盤を移動する際の変数をトレースすると以下のような流れになります。
hanoi_variable_transition

上記の流れをループに組み込みもうと思ったときに重要なのはスタックフレームという考え方です。
再帰を使ったプログラムでは「start」「end」「work」という変数は関数呼び出しごとに独立しており、プログラムがどこまで実行されたかもスタックフレームが覚えてくれていました。(だから再帰を使うとこんなにシンプルに書けるんですね)

でもループではこれらを自前で管理する必要が出てきます。例えば以下のようにかけます。

def hanoi3(towers, start='a', end='b', work='c'):
    limit = len(towers[start])
    stack = [{'start': start, 'end': end, 'work': work}]
    while stack:
        vars = stack[-1]
        index = vars.get('index', 0)
        if index == 3 or len(stack) > limit:
            stack.pop()
        elif index == 0:
            stack.append({'start': vars['start'], 'end': vars['work'], 'work': vars['end']})
        elif index == 1:
            towers[vars['end']].append(towers[vars['start']].pop())
            # 変数の動きが気になる人は下のコメントアウトを外してみよう!
            # print('towers:', towers)
        elif index == 2:
            stack.append({'start': vars['work'], 'end': vars['end'], 'work': vars['start']})
        vars['index'] = index + 1
    return towers
 
hanoi3({
  'a': [3, 2, 1],
  'b': [],
  'c': [],
})

今回の実装では関数の呼び出しごとに独立していた変数は「stack」というリストに独立した変数空間を表す辞書を積んでいきます。これは円盤の位置と一致します。
実行中の箇所についてはツリー構造の走査ではイテレータを利用していましたが、今回は「index」という変数で管理して分岐処理しています。

以上です。

参考リンク: