Python3 左バイナリ法の実装

こんにちは、kisseです。

大学で、「Pythonで左バイナリ法の実装してもらうから、Python勉強してきてね。」っていわれたので、左バイナリ法の実装までしちゃおうと思います。笑

左バイナリ法とは??

wikipediaさんに書いてありました。笑
左バイナリ法とは、指数を2進表記して上位ビットから順に評価を行っていく方式です。




実装した

Githubに同じコード置いたので、ダウンロードする際はそっちからのが便利かなと思います。
(このとき、どっちが左か右か分からなかったので、Github上ではupper_binary()とlower_binary()っていう関数名でごまかしています。笑笑
左バイナリ法は指数の上位ビットから順に評価しているのでupper_binary()としました。)
以下のような、コードを記述しました。

def left_binary(a, k):
    
    # 特殊な場合はあらかじめ省いておこう。
    if not isinstance(k, int):
        raise Exception("exponent must be int.")
    
    if k == 0:
        return 1
    elif k < 0:
        raise Exception("left_binary() does'nt work support when exponent smaller than zero.")
    
    # 最上位ビットが1のマスクを作成
    mask = 1 << k.bit_length() - 1
    
    ans = 1
    
    # ビットマスクが移動し終わるまで、繰り返す。
    while mask:
        ans *= ans
        
        if k & mask:
            ans *= a
        
        # ビットマスクを移動
        mask >>= 1
        
    return ans

関数内のはじめの方は、引数のバリデーションを行っています。
遠慮なく例外投げていきます。笑

指数を2進表記した際のビット列長をbit_length()で取得しました。
1を(指数ビット列長 – 1)だけ左シフトすると、最上位ビットのみが1のビットマスクが出来上がりますね。

答えの初期値を1として、処理をはじめます。
ビットマスクを右にシフトしていくのですが、最終的にビットマスクが0になったら処理を終了するループに入ります。

ループ内では、現時点での答えを2乗して、ビットマスクと指数のAND演算が0でないときには、底をかけてあげる処理を行います。
この処理を行うことによって指数計算を高速に行うことができるそうです。

左バイナリ法の計算の妥当性を雑に説明

拙い説明なので読み飛ばしてもらっても構いません。知らなくても実装できます!(実際僕も理解したのは、実装してからしばらく経ってからです。笑)

例えば、3の5乗などの計算について考えましょう。
(当ブログでは数式を表示する機能は導入してません。ごめんなさい。m(._.)m)

5を2進表記すると、0b101となります。
では、左バイナリ法の手順に沿って、計算を行います。

5の最上位のビットは1ですね。
最初に現在の答えの2乗を行います。
答えの初期値は1なので、1 * 1 = 1です。
次に1 * 3 = 3となります。

指数の次のビットは0です。
ここでは、答えの2乗のみ行います。
3 * 3 = 9です。

最下位のビットは1です。
ここでは、答えの2乗を行ったあとに、底を掛け合わせます。
9 * 9 = 81。
81 * 3 = 243です。

ちゃんと答えに辿りつきました。

しかし、なぜこの手順で答えが出るのでしょうか??

指数の性質を思いだしながら確認します。
指数を2倍するのは、答えを2乗するのと等しいです。
例えば(3の1乗 * 3の1乗) = 3の2乗ですよね。

ところで、2進表記の数字を1つ左にシフトするというのは、どのような処理でしょうか?
たとえば、0b001を1ビット左シフトして、0b010にすると値は2倍になります。(1 -> 2)
2進数を1ビット左シフトすることは、値を2倍することに等しいのです。

上記の2点を合わせると、指数を1ビット左シフトすると、結果はその2乗になるということになります。

また、指数を1増やすと、結果はそれまでの結果に底を掛けたものになります。

以上の点を踏まえてもう一度3の5乗について考えてみましょう。

初期状態は3の0乗です。
現在の指数0を1ビット左シフトすると同時に、その結果を2乗します。(指数も結果も変わらない。)
次に、5の最上位ビット1を指数に加えると同時に、結果を3倍します。
すると、現在の指数は1、結果は3となります。

次に、現在の指数1(0b001)を1ビット左シフトすると同時に、その結果を2乗します。
すると指数は2(0b010)、結果は9となります。
5の上位から2番目のビットは0なので、ここでは何もしません。

もう一度、現在の指数2(0b010)を1ビット左シフトすると同時に、その結果を2乗します。
すると、指数は4(0b100)、結果は81となります。
5の最下位ビットは1なので、指数に1を加えて結果を3倍します。
すると、指数は5(0b101)、結果は243となります。

ちゃんと計算できました!
指数を1ビット左シフトすると結果は2乗になる点と、指数を1増加させると底を1回掛けることを利用して、左バイナリ法は実装されてたんですね。

剰余計算ありバージョン

ほとんど変更はないのですが、乗算処理の際にその乗算結果に対して、剰余計算を行うと最終結果も剰余計算を行ったものと等しくなります。

# ベキ乗の剰余計算
def upper_binary_mod(base, exponent, mod):
    # 特殊な場合はあらかじめ省いておこう。
    if not isinstance(exponent, int):
        raise Exception("exponent must be int.")
    
    if exponent == 0:
        return 1
    elif exponent < 0:
        raise Exception("upper_binary() does'nt work support when exponent smaller than zero.")
        
    if not isinstance(mod, int):
        raise Exception("mod must be int.")
        
    if mod <= 0:
        raise Exception("mod must be larger than zero.")
        
    # 最上位ビットが1のマスクを作成
    mask = 1 << exponent.bit_length() - 1
    
    ans = 1
    
    # ビットマスクが移動し終わるまで、繰り返す。
    while mask:
        ans = (ans * ans) % mod
        
        if exponent & mask:
            ans = (ans * base) % mod
        
        # ビットマスクを移動
        mask >>= 1
        
    return ans

これとかは、RSA暗号っていう暗号でよく使われるやつらしいです。




おわり

雑な説明ですが、左バイナリ法の実装と説明を行いました。
特に、説明の部分では雰囲気掴んでもらえたら嬉しいです。
(実際、雰囲気でやってる。)

最後まで読んでいただきありがとうございます!

あわせて読みたい