Python高速化 Numba入門 その2
今回は、QuickStartを読んでいきます。
Quick Start — numba 0.15.1 documentation
とりあえず、前回の@jitデコレータだけで動くのは理解した。
from numba import jit @jit def sum(x, y): return x + y
引数と戻り値の型が陽にわかっている場合には、@jitの引数に『戻り値の型』(引数1の型, 引数2の型, ...)で指定できる。恐らく早くなるのだろう
@jit('f8(f8,f8)') def sum(x, y): return x + y
numpyの配列もサポート。多次元配列はこんな感じf8[:, :, :]
@jit('f8(f8[:])') def sum1d(array): sum = 0.0 for i in range(array.shape[0]): sum += array[i] return sum
サポートされている型は以下。主要なものは大体ある。
Type Name | Alias | Result Type |
---|---|---|
boolean | b1 | uint8 (char) |
bool_ | b1 | uint8 (char) |
byte | u1 | unsigned char |
uint8 | u1 | uint8 (char) |
uint16 | u2 | uint16 |
uint32 | u4 | uint32 |
uint64 | u8 | uint64 |
char | i1 | signed char |
int8 | i1 | int8 (char) |
int16 | i2 | int16 |
int32 | i4 | int32 |
int64 | i8 | int64 |
float_ | f4 | float32 |
float32 | f4 | float32 |
double | f8 | float64 |
float64 | f8 | float64 |
complex64 | c8 | float complex |
complex128 | c16 | double complex |
型は文字列で指定しなくても、importして指定もできる。ハードコーディングにならないのでこっちのが嬉しい。
from numba import jit, f8 @jit(f8(f8[:])) def sum1d(array): ...
型がわからない場合は、通常のPythonで処理するので異常に遅くなってしまう。それを避けるためにPythonでの処理を禁止することもできる。
どうしても型がわからない場合はエラー吐くのかな?
@jit(nopython=True) def sum1d(array): ...
Numbaの型推論の結果を取得するには、inspect_typesメソッドが用意されている。これを参考にデコレータの引数を決めるのも良さそう
sum1d.inspect_types()
実際にやってみる
とりあえず、型がわからない時に、nopython=Trueするとどうなるか
from numba.decorators import jit import numpy class DummyClass(object): def __init__(self): hoge = 0 huga = 2 hogo = "hogehoge" @jit('f8[:, :](f8[:, :], f8[:, :])', nopython=True) def pairwise_numba3(X, D): M = X.shape[0] N = X.shape[1] for i in range(M): for j in range(M): d = 0.0 for k in range(N): tmp = X[i, k] - X[j, k] d += tmp * tmp D[i, j] = numpy.sqrt(d) return DummyClass()
結果は、やはり推論出来ない型でnopython=Trueはエラーらしい
TypingError: Failed at nopython frontend Untyped global name 'DummyClass'
型指定したほうが速いのか?
検証のために以下を実行
# coding: utf-8 import numpy from numba import double from numba.decorators import jit import time import traceback @jit def pairwise_numba(X, D): M = X.shape[0] N = X.shape[1] for i in range(M): for j in range(M): d = 0.0 for k in range(N): tmp = X[i, k] - X[j, k] d += tmp * tmp D[i, j] = numpy.sqrt(d) return D @jit('f8[:, :](f8[:, :], f8[:, :])', nopython=True) def pairwise_numba2(X, D): M = X.shape[0] N = X.shape[1] for i in range(M): for j in range(M): d = 0.0 for k in range(N): tmp = X[i, k] - X[j, k] d += tmp * tmp D[i, j] = numpy.sqrt(d) return D if __name__ == '__main__': t = time.time() X = numpy.random.random((1000, 3)) D = numpy.empty((1000, 1000)) pairwise_numba(X, D) print "numba:", time.time() - t t = time.time() X = numpy.random.random((1000, 3)) D = numpy.empty((1000, 1000)) pairwise_numba2(X, D) print "numba2:", time.time() - t
結果は・・・
numba: 0.134999990463 numba2: 0.0120000839233
約10倍高速化!
前回、普通のPythonから33倍高速になったので、型指定まですると330倍高速化したことに。
Numba結構イケるやん!
次回は、ユーザーガイドの続きか、First Steps with numbaを読む予定