最有效的方法不是使用 np.setdiff1d 和 np.in1d,而是删除具有唯一值的一维数组的公共值
我需要一个更快的代码来删除一个一维数组(数组长度~10-15)中与另一个一维数组(数组长度~1e5-5e5--> rarely up to 7e5)共有的值,这些数组的索引数组包含整数。数组中没有重复的,而且它们没有被排序,修改后的数值顺序必须在主数组中保持。我知道可以用这样的np.setdiff1d
或np.in1d
来实现(这两个都不支持无python模式下的numba jitted),其他类似的帖子(比如这个)也没有什么更有效的方法,但这里的性能很重要,因为主索引数组中的所有值都会在循环中逐渐被删除。
import numpy as np
import numba as nb
n = 500000
r = 10
arr1 = np.random.permutation(n)
arr2 = np.random.randint(0, n, r)
# @nb.jit
def setdif1d_np(a, b):
return np.setdiff1d(a, b, assume_unique=True)
# @nb.jit
def setdif1d_in1d_np(a, b):
return a[~np.in1d(a, b)]
还有一个相关的帖子,由norok2提出的二维数组的解决方案(使用numba的类似散列的方式)比那里描述的通常方法快15倍。如果能够为一维数组准备,这个解决方案可能是最好的。
@nb.njit
def mul_xor_hash(arr, init=65537, k=37):
result = init
for x in arr.view(np.uint64):
result = (result * k) ^ x
return result
@nb.njit
def setdiff2d_nb(arr1, arr2):
# : build `delta` set using hashes
delta = {mul_xor_hash(arr2[0])}
for i in range(1, arr2.shape[0]):
delta.add(mul_xor_hash(arr2[i]))
# : compute the size of the result
n = 0
for i in range(arr1.shape[0]):
if mul_xor_hash(arr1[i]) not in delta:
n += 1
# : build the result
result = np.empty((n, arr1.shape[-1]), dtype=arr1.dtype)
j = 0
for i in range(arr1.shape[0]):
if mul_xor_hash(arr1[i]) not in delta:
result[j] = arr1[i]
j += 1
return result
我试图为一维数组做准备,但我有一些问题/疑惑。
- 一开始,IDU的
mul_xor_hash
到底是做什么的,以及init
和k
是否是任意选择的呢? - 为什么没有
nb.njit
,mul_xor_hash
就不能工作。
File "C:/Users/Ali/Desktop/test - Copy - Copy.py", line 21, in mul_xor_hash
result = (result * k) ^ x
TypeError: ufunc 'bitwise_xor' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
- 我不知道如何在一维数组上实现
mul_xor_hash
(如果可以的话),我想这可能会使它比二维的速度更快,所以我通过[None, :]
将输入数组广播到二维,这时仅对arr2
得到了以下错误。
print(mul_xor_hash(arr2[0]))
ValueError: new type not compatible with array
- 而
delta
是做什么的?
我正在寻找这方面的最有效方法。在没有比norok2解决方案更好的方法的情况下,如何为一维数组准备这种解决方案?
setdiff1d
的numpy
代码只是做了你的第2个版本所做的事--使用in1d
。你自己看一下吧。在一个数组小得多的情况下,in1d
为arr2
的每个元素做了一个mask |= (arr1 == a)
。对于一个大的尺寸差异,这被认为比排序更快。很容易读懂in1d
的代码。
- hpaulj 2022-10-14
了解基于哈希值的解决方案
首先,IDU的mul_xor_hash到底是做什么的,init和k是否是任意选择的?
mul_xor_hash
是一个自定义的哈希函数。众所周知,混合了xor和multiply(可能还有shifts)的函数在计算原始数据缓冲区的哈希值时相对较快。乘法倾向于洗位,而xor则用于以某种方式将结果组合/累积到一个固定大小的小数值中(即最终的哈希值)。有许多不同的散列函数。有些比其他的快,有些在给定的情况下比其他的引起更多的碰撞。一个快速的散列函数导致太多的碰撞,在实践中可能是无用的,因为它将导致一种病态的情况,所有冲突的值都需要被比较。这就是为什么快速散列函数难以实现的原因。
init
和k
是参数肯定会导致哈希值相当平衡。这在这样的哈希函数中是很常见的。k
需要足够大的乘法来洗刷比特,它通常也应该是一个素数(由于模块化的算术行为,像二的幂这样的值往往会增加碰撞)。init
只对非常小的数组(例如有1个项目)起重要作用:它有助于通过对最终散列的非三常数进行异或来减少碰撞。事实上,如果arr.size = 1
,那么result = (init * k) ^ arr[0]
,其中init * k
是一个常数。众所周知,拥有一个等于arr[0]
的身份哈希函数是不好的,因为它往往会导致许多碰撞(这是一个复杂的话题,但简单地说,arr[0]
可以除以哈希表中的桶的数量)。因此,init
应该是一个相对大的数字,init * k
也应该是一个大的非琐碎的值(一个质数是一个好的目标值)。
为什么mul_xor_hash在没有nb.njit的情况下不能工作
这取决于输入的情况。输入需要是一个一维数组,其原始大小为能被8除以的字节(例如,64位的项目,2n x 32位的,4n x 16位的或8n 8位的)。下面是一些例子。
mul_xor_hash(np.random.rand(10))
mul_xor_hash(np.arange(10)) # Do not work with 9
delta是做什么的?
它是一个set
,包含arr2
行的哈希值,所以要比没有哈希值的比较更快地找到匹配的行。
如何为一维数组准备这种解决方案?
AFAIK,哈希值只用于避免行的比较,但这是因为输入的是二维数组。在一维中,不存在这样的问题。
这个方法有一个很大的缺陷:它只在没有哈希碰撞的情况下才会起作用。否则,实现就会错误地认为数值是相等的,即使它们不是相等的。虽然@norok在评论中明确提到了这一点。
请注意,哈希值的碰撞处理也应该被实现。
更快的执行速度
将@norok2的二维解决方案用于一维并不是一个好主意,因为哈希值的使用方式不会使其更快。事实上,一个set
已经在内部使用了哈希函数。更不用说碰撞需要适当的实现(这是由set
完成的)。
使用set
是一个相对较好的主意,因为它使n = len(arr1)
和m = len(arr2)
的复杂度达到了O(n + m)
。也就是说,如果arr1
被转换为set
,那么它将会太大,无法装入L1缓存(由于你的例子中arr1
的大小),导致缓慢的缓存错过。此外,set
的大小不断增加将导致数值被重新洗牌,这并不高效。如果arr2
被转换为set
,那么许多哈希表的获取将不是很有效,因为arr2
在你的情况下非常小。这就是为什么这个解决方案是次优的。
一个解决方案是将arr1
分块,然后根据目标分块建立一个set
。然后你可以有效地检查一个值是否在这个集合中。由于体积越来越大,构建集合的效率仍然不高。这个问题是由Python本身造成的,它没有像其他语言(例如C++)那样为数据结构保留一些空间。避免这个问题的一个解决方案是简单地重新实现一个哈希表,这并不简单,而且很麻烦。实际上,Bloom过滤器可以用来加速这个过程,因为它们可以快速找到两个集合arr1
和arr2
之间是否平均没有碰撞(尽管它们的实现并不简单)。
另一个优化是使用多个线程来并行计算这些块,因为它们是独立的。也就是说,向最终数组的追加并不容易有效地进行,特别是由于你不希望顺序被修改。一个解决方案是将拷贝从并行循环中移开,并以串行方式进行,但这很慢,而且AFAIK目前在Numba中没有简单的方法可以做到这一点(因为并行层非常有限)。考虑使用本地语言如C/C++来实现高效的并行。
最后,散列可能是相当复杂的,与带有两个嵌套循环的天真实现相比,速度可能相当小,因为arr2
只有少数项目,而且现代处理器可以使用SIMD指令快速比较值(而基于散列的方法在主流处理器上几乎无法受益)。解卷可以帮助编写一个相当简单和快速的实现。同样,不幸的是,Numba内部使用的是LLVM-Jit,它似乎无法将如此简单的代码矢量化(当然是由于LLVM-Jit甚至LLVM本身的优化缺失)。因此,非矢量化的代码最终会慢一点(而不是在现代主流处理器上快4~10倍)。一个解决方案是用C/C++代码代替来做这个事情(也可能是Cython)。
这里是一个使用基本的Bloom过滤器的串行实现。
@nb.njit('uint32(int32)')
def hash_32bit_4k(value):
return (np.uint32(value) * np.uint32(27_644_437)) & np.uint32(0x0FFF)
@nb.njit(['int32[:](int32[:], int32[:])', 'int32[:](int32[::1], int32[::1])'])
def setdiff1d_nb_faster(arr1, arr2):
out = np.empty_like(arr1)
bloomFilter = np.zeros(4096, dtype=np.uint8)
for j in range(arr2.size):
bloomFilter[hash_32bit_4k(arr2[j])] = True
cur = 0
for i in range(arr1.size):
# If the bloom-filter value is true, we know arr1[i] is not in arr2.
# Otherwise, there is maybe a false positive (conflict) and we need to check to be sure.
if bloomFilter[hash_32bit_4k(arr1[i])] and arr1[i] in arr2:
continue
out[cur] = arr1[i]
cur += 1
return out[:cur]
这里有一个未经测试的变体,它应该适用于64位整数(浮点数需要内存视图,可能也需要一个质数常数)。
@nb.njit('uint64(int64)')
def hash_32bit_4k(value):
return (np.uint64(value) * np.uint64(67_280_421_310_721)) & np.uint64(0x0FFF)
请注意,如果小数组中的所有值都包含在每个循环的主数组中,那么我们可以通过在发现值时从arr2
中删除来加快arr1[i] in arr2
部分的速度。也就是说,碰撞和发现应该是非常罕见的,所以我并不指望这能明显加快(更不用说它增加了一些开销和复杂性)。如果项目是分块计算的,那么最后的分块可以直接复制,不需要任何检查,但好处应该还是比较小的。请注意,这种策略对之前提到的天真的(C/C++)SIMD实现倒是很有效(可以快2倍左右)。
通用化和并行实现
本节重点讨论了关于输入大小的算法。它特别详细介绍了基于SIMD的实现,并讨论了多线程的使用。
首先,关于数值r
,使用的最佳算法可能是不同的。更具体地说。
- 当
r
为0时,最好的做法是返回未经修改的输入数组arr1
(可能是一个副本,以避免与原位算法有关的问题)。 - 当
r
为1时,我们可以使用一个基本的循环对数组进行迭代,但最好的实现可能是使用Numpy的np.where
,它对此进行了高度的优化 - 当
r
很小的时候,比如<10,那么使用基于SIMD的实现应该是特别有效的,特别是当基于arr2
的循环的迭代范围在编译时是已知的,并且是解卷的时候。 - 对于较大的
r
值仍然相对较小(例如r < 1000
和r << n
),所提供的基于哈希值的解决方案应该是最好的解决方案之一。 - 对于较大的
r
值和r << n
,基于散列的解决方案可以通过将布尔值打包成bloomFilter
中的比特,以及使用多个散列函数而不是一个,以便更好地处理碰撞,同时更有利于缓存(事实上,这就是实际的Bloom过滤器的作用);注意,当r
很大和r << n
时,可以使用多线程来加快查找速度。 - 当
r
很大并且不比n
小多少时,那么这个问题就很难有效解决,最好的解决办法当然是对两个数组进行排序(通常是用radix排序),并使用基于合并的方法来删除重复的内容,当r
和n
都很大时,可能会使用多个线程(很难实现)。
我们先来看看基于SIMD的解决方案。下面是一个实现。
@nb.njit('int32[:](int32[::1], int32[::1])')
def setdiff1d_nb_simd(arr1, arr2):
out = np.empty_like(arr1)
limit = arr1.size // 4 * 4
limit2 = arr2.size // 2 * 2
cur = 0
z32 = np.int32(0)
# Tile (x4) based computation
for i in range(0, limit, 4):
f0, f1, f2, f3 = z32, z32, z32, z32
v0, v1, v2, v3 = arr1[i], arr1[i+1], arr1[i+2], arr1[i+3]
# Unrolled (x2) loop searching for a match in `arr2`
for j in range(0, limit2, 2):
val1 = arr2[j]
val2 = arr2[j+1]
f0 += (v0 == val1) + (v0 == val2)
f1 += (v1 == val1) + (v1 == val2)
f2 += (v2 == val1) + (v2 == val2)
f3 += (v3 == val1) + (v3 == val2)
# Remainder of the previous loop
if limit2 != arr2.size:
val = arr2[arr2.size-1]
f0 += v0 == val
f1 += v1 == val
f2 += v2 == val
f3 += v3 == val
if f0 == 0: out[cur] = arr1[i+0]; cur += 1
if f1 == 0: out[cur] = arr1[i+1]; cur += 1
if f2 == 0: out[cur] = arr1[i+2]; cur += 1
if f3 == 0: out[cur] = arr1[i+3]; cur += 1
# Remainder
for i in range(limit, arr1.size):
if arr1[i] not in arr2:
out[cur] = arr1[i]
cur += 1
return out[:cur]
事实证明,在我的机器上,这个实现总是比基于哈希的实现慢,因为Numba显然为基于arr2
的内部循环产生了低效率,这似乎来自与==
有关的破碎优化。Numba在这个操作中根本没有使用SIMD指令(没有明显的原因)。这使得许多与SIMD相关的替代代码只要使用Numba就会变得很快。
Numba的另一个问题是np.where
很慢,因为它使用的是天真的实现,而Numpy的实现已经被大量优化。由于前面的问题,Numpy中的优化很难被应用到Numba的实现中。这使得在Numba代码中使用np.where
的速度无法提高。
在实践中,基于哈希的实现是相当快的,在我的机器上,复制需要大量的时间。计算部分可以使用多线程来加速。这并不容易,因为Numba的并行性模型非常有限。Numba不容易对拷贝进行优化(可以使用非时间存储,但Numba还不支持),除非计算可能是在原地进行。
为了使用多线程,一种策略是先将范围分成几块,然后:
- 建立一个布尔数组,为
arr1
的每一个项目确定该项目是否在arr2
中找到(完全并行)。 - 计数器是按分块找到的项目的数量(完全并行)。
- 计算目标块的偏移量(很难并行化,特别是用Numba,但由于有了块,所以速度很快)。
- 将大块数据复制到目标位置,而不复制发现的项目(完全并行)。
这里有一个高效的基于哈希的并行实现。
@nb.njit('int32[:](int32[:], int32[:])', parallel=True)
def setdiff1d_nb_faster_par(arr1, arr2):
# Pre-computation of the bloom-filter
bloomFilter = np.zeros(4096, dtype=np.uint8)
for j in range(arr2.size):
bloomFilter[hash_32bit_4k(arr2[j])] = True
chunkSize = 1024 # To tune regarding the kind of input
chunkCount = (arr1.size + chunkSize - 1) // chunkSize
# Find for each item of `arr1` if the value is in `arr2` (parallel)
# and count the number of item found for each chunk on the fly.
# Note: thanks to page fault, big parts of `found` are not even written in memory if `arr2` is small
found = np.zeros(arr1.size, dtype=nb.bool_)
foundCountByChunk = np.empty(chunkCount, dtype=nb.uint16)
for i in nb.prange(chunkCount):
start, end = i * chunkSize, min((i + 1) * chunkSize, arr1.size)
foundCountInChunk = 0
for j in range(start, end):
val = arr1[j]
if bloomFilter[hash_32bit_4k(val)] and val in arr2:
found[j] = True
foundCountInChunk += 1
foundCountByChunk[i] = foundCountInChunk
# Compute the location of the destination chunks (sequential)
outChunkOffsets = np.empty(chunkCount, dtype=nb.uint32)
foundCount = 0
for i in range(chunkCount):
outChunkOffsets[i] = i * chunkSize - foundCount
foundCount += foundCountByChunk[i]
# Parallel chunk-based copy
out = np.empty(arr1.size-foundCount, dtype=arr1.dtype)
for i in nb.prange(chunkCount):
srcStart, srcEnd = i * chunkSize, min((i + 1) * chunkSize, arr1.size)
cur = outChunkOffsets[i]
# Optimization: we can copy the whole chunk if there is nothing found in it
if foundCountByChunk[i] == 0:
out[cur:cur+(srcEnd-srcStart)] = arr1[srcStart:srcEnd]
else:
for j in range(srcStart, srcEnd):
if not found[j]:
out[cur] = arr1[j]
cur += 1
return out
这个实现对于我的机器上的目标输入是最快的。当n
相当大,并且在目标平台上创建线程的开销相对较小(例如在PC上,但通常不是具有许多内核的计算服务器)时,它通常很快。并行实现的开销很大,所以目标机器上的核心数至少要有4个,这样实现的速度才能明显快于顺序实现。
对目标输入的chunkSize
变量进行调整可能是有用的。如果r << n
,最好使用一个相当大的chunkSize。也就是说,chunk的数量需要足够大,以便多个线程对许多chunk进行操作。因此,chunkSize
应该明显小于n / numberOfThreads
。
在我的机器上,大部分时间(65-70%)都花在最后的拷贝上,而最后的拷贝大部分都是占用内存的,很难用Numba来进一步优化。
结果是什么
以下是在我的基于i5-9600KF的机器上的结果(有6个核心)。
setdif1d_np: 2.65 ms
setdif1d_in1d_np: 2.61 ms
setdiff1d_nb: 2.33 ms
setdiff1d_nb_simd: 1.85 ms
setdiff1d_nb_faster: 0.73 ms
setdiff1d_nb_faster_par: 0.49 ms
所提供的最佳实施方案比其他方案快了4~5倍。
hash_32bit_4k
函数中的掩码相匹配。2.它需要是2的幂(所以掩码可以等同于一个快速模数)。3.它需要适合于L1高速缓存,也可能适合于少数内存页。4. 它需要足够大,以避免与arr2
中的项目数量发生冲突。例如,当我在我的机器上测试时,8192也很好用。代码中考虑到了碰撞,所以结果应该总是正确的(至少对于这种输入数据类型)。
- Jérôme Richard 2022-10-16
np.setdiff1d
(Numpy的实现,而不是Numba)中实现,尽管这可能是一个重要的工作。
- Jérôme Richard 2022-10-16
setdiff1d_nb_faster
不需要任何改变,只是签名),也不需要删除arr2
的值(对性能的影响可忽略不计)?似乎你考虑了两种不同的签名方式;如果我们知道它们的类型,例如C
--> 只是[::1]
,是否会比只使用一种签名方式有负面作用?just对于r=1
来说,解决方案比32位数组的numpy慢一点(非常接近);但是,对于64位(我的情况),它至少快1.5倍,即使对于小的r
值,例如1
。
- Ali_Sh 2022-10-16
::1
意味着轴是连续的,所以访问速度一般比较快,但是你不能向接受连续数组的Numba方法提供非连续数组(而相反的情况是可以的)。
- Jérôme Richard 2022-10-16
我发现散列并没有帮助。它只是二维情况下的技巧,将一维数组转换为单个数字,并将它们作为一个集合。
下面是我转换为一维数组的norok2方法(并添加了注释以加快编译速度)。 请注意,这只比你已有的方法稍快(20-30%)。当然,在第二次函数调用后,由于编译的原因,第一次会稍微慢一些。
@nb.njit('int32[:](int32[:], int32[:])')
def setdiff1d_nb(arr1, arr2):
delta = set(arr2)
# : build the result
result = np.empty(len(arr1), dtype=arr1.dtype)
j = 0
for i in range(arr1.shape[0]):
if arr1[i] not in delta:
result[j] = arr1[i]
j += 1
return result[:j]
result[:j]
不存在,导致了速度下降。
- dankal444 2022-10-17