tkm2261's blog

研究員(OR屋) → データ分析官 → MLエンジニア → ニートがデータ分析諸々書いてます

Python高速化 Numba入門 その1

みなさん、こんにちは
今日からPython高速化 Numbaに入門したいと思います。

入門資料を探しに来た皆様すみませんが、 本記事は私がこれから入門する内容になります。

結果として入門資料に慣れば幸いですが、過度な期待は御無用でお願いします。

基本的には以下を読み進めて行きます。

http://numba.pydata.org/

Numbaとは

JIT(just-in-time)コンパイラを使ってPythonを高速化しよう!』というPythonモジュールです。

LLVMコンパイラを使っており、これはJuliaが高速な理由でもあるので期待大です。

学生時代はCythonを使って高速化をよくしていましたが、以下の理由により今回はNumbaを学びます

候補 今回諦めた理由
Cython cdefとか結構手を入れるのでPythonに戻すのが面倒。pyximportも面倒
C拡張 C言語は極力触りたくない
PyPy numpyサポートが怪しい。いつか触るかも
Julia なるべくPythonでやりたい。これでダメなら入門するかも。

Numbaはどうやらデコレータ一発で一応動くらしい。Cythonよりは使いやすいことを期待したい。

とりあえず通常Pythonと比較

以下のサイトのコードを参考に速度を計測

Numba vs Cython

ちょっと手を加えて、実行したのが以下のコードで比較

# coding: utf-8
import numpy
from numba import double
from numba.decorators import jit
import time

@jit
def pairwise_numba(X, D):
    M = X.shape[0]
    N = X.shape[1]
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = numpy.sqrt(d)


def pairwise_python(X, D):
    M = X.shape[0]
    N = X.shape[1]
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = numpy.sqrt(d)


if __name__ == '__main__':
    t = time.time()
    X = numpy.random.random((1000, 3))
    D = numpy.empty((1000, 1000))
    pairwise_python(X, D)
    print "python:", time.time() - t

    t = time.time()
    X = numpy.random.random((1000, 3))
    D = numpy.empty((1000, 1000))
    pairwise_numba(X, D)
    print "numba:", time.time() - t

結果は

python: 4.64800000191
numba: 0.137000083923

34倍の高速化!イケるやん!

次回はQuick Startを読む予定