Numba加速Python代码

Numba 简介与使用

Numba 是一个即时编译器(JIT),它能够将 Python 代码编译成高效的机器码,从而显著提高数值计算的性能。Numba 特别适合用于加速那些对性能要求较高的数值和科学计算任务,比如 NumPy 数组操作、循环等。本文将介绍如何安装 Numba 以及它的基本用法。

安装 Numba

Numba 可以通过 condapip 来安装。推荐使用 conda,因为它能更好地处理依赖关系。

使用 Conda 安装

1
conda install numba

使用 Pip 安装

如果你更喜欢使用 pip,可以运行以下命令:

1
pip install numba

基本概念

JIT 编译

Numba 的核心功能是其 JIT 编译能力。当你在函数定义前加上 @jit 装饰器时,Numba 会在第一次调用该函数时将其编译为机器码。之后每次调用这个函数都会执行编译后的版本,而不是解释执行 Python 代码。

函数签名

你可以选择是否给定函数签名。如果提供了函数签名,那么 Numba 将只接受指定类型的参数,并且可以在编译时进行更多的优化。如果不提供签名,Numba 会根据实际传入的参数类型来决定如何编译。

缓存

Numba 支持缓存已编译的函数,这样即使你重新启动 Python 解释器,也不需要再次编译相同的函数。这可以通过将 cache=True 作为装饰器的一个参数来启用。

基本用法

简单的例子

下面是一个简单的例子,展示了如何使用 Numba 加速一个求平方根总和的函数。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from numba import jit
import numpy as np

# 不带函数签名的 @jit 装饰器
@jit(nopython=True) # nopython 模式可以实现最佳性能
def sum_of_sqrts(n):
total = 0.0
for i in range(n):
total += np.sqrt(i)
return total

# 测试函数
n = 10000000
print("Sum of square roots:", sum_of_sqrts(n))

提供函数签名

如果你想为函数提供签名,可以像下面这样写:

1
2
3
4
5
6
7
8
9
10
11
12
13
from numba import jit, float64, int32

# 带有函数签名的 @jit 装饰器
@jit(float64(int32), nopython=True, cache=True)
def sum_of_sqrts_with_signature(n):
total = 0.0
for i in range(n):
total += np.sqrt(i)
return total

# 测试带有签名的函数
n = 10000000
print("Sum of square roots with signature:", sum_of_sqrts_with_signature(n))

并行化

对于支持并行化的循环,Numba 提供了 parallel=True 参数。当设置了这个参数后,Numba 会尝试自动并行化循环中的迭代。

1
2
3
4
5
6
7
8
9
10
11
12
from numba import prange

@jit(nopython=True, parallel=True)
def sum_of_sqrts_parallel(n):
total = 0.0
for i in prange(n): # 使用 prange 替代 range
total += np.sqrt(i)
return total

# 测试并行化的函数
n = 10000000
print("Parallel sum of square roots:", sum_of_sqrts_parallel(n))

高级特性

GPU 加速

Numba 还支持 CUDA GPU 编程,允许你编写可以直接在 NVIDIA GPU 上运行的 Python 代码。要使用这个功能,你需要安装 numba[cuda] 和相应的 CUDA 工具包。

1
conda install numba cudatoolkit

然后你可以使用 @cuda.jit 装饰器来编译 CUDA 内核。

其他装饰器

除了 @jit 之外,Numba 还提供了其他几个有用的装饰器,例如:

  • @vectorize:用于创建通用函数(ufuncs),这些函数可以应用于整个数组。
  • @guvectorize:用于创建广义通用函数,可以接受多个输入数组并产生一个输出数组。
  • @stencil:用于定义卷积或模板操作,通常用于图像处理等领域。

性能提示

  • 避免全局变量:Numba 在 nopython 模式下无法访问全局变量。应该将所有必要的数据作为参数传递给函数。
  • 最小化 Python API 调用:尽量减少对 Python C API 的调用,因为它们可能会导致性能下降。
  • 使用合适的数据类型:明确指定数据类型可以帮助 Numba 更好地优化代码。
  • 考虑并行化:对于独立的循环迭代,开启并行化可以带来显著的性能提升。

实践:Numba加速冒泡排序

冒泡排序(Bubble Sort)是一种简单的排序算法。它重复地遍历要排序的数列,一次比较两个元素,如果它们的顺序错误就交换它们的位置。这个过程会持续进行,直到没有需要再交换的元素为止。本文将介绍如何使用纯 Python 实现冒泡排序,并提供代码示例和性能分析。

算法原理

  1. 比较相邻元素:从列表的第一个元素开始,依次比较每对相邻元素。
  2. 交换位置:如果前一个元素大于后一个元素,则交换它们的位置。
  3. 一轮遍历:对于每一趟遍历,最大的元素会被移动到列表的最后面。
  4. 重复过程:重复上述步骤,对于每次遍历,不考虑已经被放置在正确位置的元素。
  5. 优化:如果某一轮遍历中没有任何交换发生,那么列表已经排序完成,可以提前结束算法。

纯Python实现

以下是用纯 Python 编写的冒泡排序函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
def bubble_sort(arr):
n = len(arr)
# 遍历所有数组元素
for i in range(n):
# 标记是否发生了交换
swapped = False
# 最后i个元素已经是排好序的
for j in range(0, n-i-1):
# 交换如果元素找到的顺序是错误的
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
swapped = True
# 如果没有发生交换,说明数组已经排好序了
if not swapped:
break
return arr

使用Numba加速排序过程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
@jit
def bubble_sort_numba(arr):
n = len(arr)
# 遍历所有数组元素
for i in range(n):
# 标记是否发生了交换
swapped = False
# 最后i个元素已经是排好序的
for j in range(0, n-i-1):
# 交换如果元素找到的顺序是错误的
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
swapped = True
# 如果没有发生交换,说明数组已经排好序了
if not swapped:
break
return arr

完整代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
from time import time
from numba import jit
from random import uniform
def bubble_sort(arr):
n = len(arr)
# 遍历所有数组元素
for i in range(n):
# 标记是否发生了交换
swapped = False
# 最后i个元素已经是排好序的
for j in range(0, n-i-1):
# 交换如果元素找到的顺序是错误的
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
swapped = True
# 如果没有发生交换,说明数组已经排好序了
if not swapped:
break
return arr

@jit
def bubble_sort_numba(arr):
n = len(arr)
# 遍历所有数组元素
for i in range(n):
# 标记是否发生了交换
swapped = False
# 最后i个元素已经是排好序的
for j in range(0, n-i-1):
# 交换如果元素找到的顺序是错误的
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
swapped = True
# 如果没有发生交换,说明数组已经排好序了
if not swapped:
break
return arr

if __name__=="__main__":
length=10000

numbers1=[uniform(-2*length,2*length) for i in range(length)]
numbers2=numbers1.copy()

a=time()
temp=bubble_sort(numbers1)
b=time()
print(f"Pure python costs {b-a}s")

a=time()
temp=bubble_sort_numba(numbers2)
b=time()
print(f"Python with numba costs {b-a}s")

性能分析

  • 时间复杂度:最坏情况下(当列表完全逆序时),冒泡排序的时间复杂度为 (O(n^2)),其中 (n) 是列表的长度。最好情况下(当列表已经有序时),由于我们添加了一个 swapped 标志来检测是否进行了交换,所以时间复杂度可以达到 (O(n))。
  • 空间复杂度:冒泡排序是一个原地排序算法,它的空间复杂度为 (O(1)),因为它只需要一个额外的空间用于临时变量。
  • 稳定性:冒泡排序是一个稳定的排序算法,即相等元素的相对顺序不会改变。

Numba加速Python代码
http://example.com/2024/12/05/Numba加速Python代码/
作者
Morningstars
发布于
2024年12月5日
许可协议