読者です 読者をやめる 読者になる 読者になる

ゆとりデータサイエンティストの諸々所感

データ分析会社で研究開発をしている、ゆとり世代データサイエンティストが学んだ内容や最新トピックについて諸々語る予定

Python高速化 Numba入門 その2

今回は、QuickStartを読んでいきます。

Quick Start — numba 0.15.1 documentation

とりあえず、前回の@jitデコレータだけで動くのは理解した。

from numba import jit

@jit
def sum(x, y):
    return x + y

引数と戻り値の型が陽にわかっている場合には、@jitの引数に『戻り値の型』(引数1の型, 引数2の型, ...)で指定できる。恐らく早くなるのだろう

@jit('f8(f8,f8)')
def sum(x, y):
    return x + y

numpyの配列もサポート。多次元配列はこんな感じf8[:, :, :]

@jit('f8(f8[:])')
def sum1d(array):
    sum = 0.0
    for i in range(array.shape[0]):
        sum += array[i]
    return sum

サポートされている型は以下。主要なものは大体ある。

Type Name Alias Result Type
boolean b1 uint8 (char)
bool_ b1 uint8 (char)
byte u1 unsigned char
uint8 u1 uint8 (char)
uint16 u2 uint16
uint32 u4 uint32
uint64 u8 uint64
char i1 signed char
int8 i1 int8 (char)
int16 i2 int16
int32 i4 int32
int64 i8 int64
float_ f4 float32
float32 f4 float32
double f8 float64
float64 f8 float64
complex64 c8 float complex
complex128 c16 double complex

型は文字列で指定しなくても、importして指定もできる。ハードコーディングにならないのでこっちのが嬉しい。

from numba import jit, f8

@jit(f8(f8[:]))
def sum1d(array):
    ...

型がわからない場合は、通常のPythonで処理するので異常に遅くなってしまう。それを避けるためにPythonでの処理を禁止することもできる。
どうしても型がわからない場合はエラー吐くのかな?

@jit(nopython=True)
def sum1d(array):
    ...

Numbaの型推論の結果を取得するには、inspect_typesメソッドが用意されている。これを参考にデコレータの引数を決めるのも良さそう

sum1d.inspect_types()

実際にやってみる

とりあえず、型がわからない時に、nopython=Trueするとどうなるか

from numba.decorators import jit
import numpy

class DummyClass(object):
    def __init__(self):
        hoge = 0
        huga = 2
        hogo = "hogehoge"

@jit('f8[:, :](f8[:, :], f8[:, :])', nopython=True)
def pairwise_numba3(X, D):
    M = X.shape[0]
    N = X.shape[1]
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = numpy.sqrt(d)

    return DummyClass()

結果は、やはり推論出来ない型でnopython=Trueはエラーらしい

TypingError: Failed at nopython frontend
Untyped global name 'DummyClass'

型指定したほうが速いのか?

検証のために以下を実行

# coding: utf-8
import numpy
from numba import double
from numba.decorators import jit
import time
import traceback

@jit
def pairwise_numba(X, D):
    M = X.shape[0]
    N = X.shape[1]
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = numpy.sqrt(d)

    return D

@jit('f8[:, :](f8[:, :], f8[:, :])', nopython=True)
def pairwise_numba2(X, D):
    M = X.shape[0]
    N = X.shape[1]
    for i in range(M):
        for j in range(M):
            d = 0.0
            for k in range(N):
                tmp = X[i, k] - X[j, k]
                d += tmp * tmp
            D[i, j] = numpy.sqrt(d)

    return D

if __name__ == '__main__':

    t = time.time()
    X = numpy.random.random((1000, 3))
    D = numpy.empty((1000, 1000))
    pairwise_numba(X, D)
    print "numba:", time.time() - t

    t = time.time()
    X = numpy.random.random((1000, 3))
    D = numpy.empty((1000, 1000))
    pairwise_numba2(X, D)
    print "numba2:", time.time() - t

結果は・・・

numba: 0.134999990463
numba2: 0.0120000839233

10倍高速化!

前回、普通のPythonから33倍高速になったので、型指定まですると330倍高速化したことに。

Numba結構イケるやん!

次回は、ユーザーガイドの続きか、First Steps with numbaを読む予定