最有效的方法不是使用 np.setdiff1d 和 np.in1d,而是删除具有唯一值的一维数组的公共值

回答 2 浏览 6922 2022-10-14

我需要一个更快的代码来删除一个一维数组(数组长度~10-15)中与另一个一维数组(数组长度~1e5-5e5--> rarely up to 7e5)共有的值,这些数组的索引数组包含整数。数组中没有重复的,而且它们没有被排序,修改后的数值顺序必须在主数组中保持。我知道可以用这样的np.setdiff1dnp.in1d来实现(这两个都不支持无python模式下的numba jitted),其他类似的帖子(比如这个)也没有什么更有效的方法,但这里的性能很重要,因为主索引数组中的所有值都会在循环中逐渐被删除。

import numpy as np
import numba as nb

n = 500000
r = 10
arr1 = np.random.permutation(n)
arr2 = np.random.randint(0, n, r)

# @nb.jit
def setdif1d_np(a, b):
    return np.setdiff1d(a, b, assume_unique=True)


# @nb.jit
def setdif1d_in1d_np(a, b):
    return a[~np.in1d(a, b)]

还有一个相关的帖子,由norok2提出的二维数组的解决方案(使用numba的类似散列的方式)比那里描述的通常方法快15倍。如果能够为一维数组准备,这个解决方案可能是最好的。

@nb.njit
def mul_xor_hash(arr, init=65537, k=37):
    result = init
    for x in arr.view(np.uint64):
        result = (result * k) ^ x
    return result


@nb.njit
def setdiff2d_nb(arr1, arr2):
    # : build `delta` set using hashes
    delta = {mul_xor_hash(arr2[0])}
    for i in range(1, arr2.shape[0]):
        delta.add(mul_xor_hash(arr2[i]))
    # : compute the size of the result
    n = 0
    for i in range(arr1.shape[0]):
        if mul_xor_hash(arr1[i]) not in delta:
            n += 1
    # : build the result
    result = np.empty((n, arr1.shape[-1]), dtype=arr1.dtype)
    j = 0
    for i in range(arr1.shape[0]):
        if mul_xor_hash(arr1[i]) not in delta:
            result[j] = arr1[i]
            j += 1
    return result

我试图为一维数组做准备,但我有一些问题/疑惑。

  • 一开始,IDU的mul_xor_hash到底是做什么的,以及initk是否是任意选择的呢?
  • 为什么没有nb.njitmul_xor_hash就不能工作。
  File "C:/Users/Ali/Desktop/test - Copy - Copy.py", line 21, in mul_xor_hash
    result = (result * k) ^ x
TypeError: ufunc 'bitwise_xor' not supported for the input types, and the inputs could not be safely coerced to any supported types according to the casting rule ''safe''
  • 我不知道如何在一维数组上实现mul_xor_hash(如果可以的话),我想这可能会使它比二维的速度更快,所以我通过[None, :]将输入数组广播到二维,这时仅对arr2得到了以下错误。
    print(mul_xor_hash(arr2[0]))
ValueError: new type not compatible with array
  • delta是做什么的?

我正在寻找这方面的最有效方法。在没有比norok2解决方案更好的方法的情况下,如何为一维数组准备这种解决方案?

Ali_Sh 提问于2022-10-14
setdiff1dnumpy代码只是做了你的第2个版本所做的事--使用in1d。你自己看一下吧。在一个数组小得多的情况下,in1darr2的每个元素做了一个mask |= (arr1 == a)。对于一个大的尺寸差异,这被认为比排序更快。很容易读懂in1d的代码。hpaulj 2022-10-14
2 个回答
#1楼 已采纳
得票数 7

了解基于哈希值的解决方案

首先,IDU的mul_xor_hash到底是做什么的,init和k是否是任意选择的?

mul_xor_hash是一个自定义的哈希函数。众所周知,混合了xor和multiply(可能还有shifts)的函数在计算原始数据缓冲区的哈希值时相对较快。乘法倾向于洗位,而xor则用于以某种方式将结果组合/累积到一个固定大小的小数值中(即最终的哈希值)。有许多不同的散列函数。有些比其他的快,有些在给定的情况下比其他的引起更多的碰撞。一个快速的散列函数导致太多的碰撞,在实践中可能是无用的,因为它将导致一种病态的情况,所有冲突的值都需要被比较。这就是为什么快速散列函数难以实现的原因。

initk是参数肯定会导致哈希值相当平衡。这在这样的哈希函数中是很常见的。k需要足够大的乘法来洗刷比特,它通常也应该是一个素数(由于模块化的算术行为,像二的幂这样的值往往会增加碰撞)。init只对非常小的数组(例如有1个项目)起重要作用:它有助于通过对最终散列的非三常数进行异或来减少碰撞。事实上,如果arr.size = 1,那么result = (init * k) ^ arr[0],其中init * k是一个常数。众所周知,拥有一个等于arr[0]的身份哈希函数是不好的,因为它往往会导致许多碰撞(这是一个复杂的话题,但简单地说,arr[0]可以除以哈希表中的桶的数量)。因此,init应该是一个相对大的数字,init * k也应该是一个大的非琐碎的值(一个质数是一个好的目标值)。

为什么mul_xor_hash在没有nb.njit的情况下不能工作

这取决于输入的情况。输入需要是一个一维数组,其原始大小为能被8除以的字节(例如,64位的项目,2n x 32位的,4n x 16位的或8n 8位的)。下面是一些例子。

mul_xor_hash(np.random.rand(10))
mul_xor_hash(np.arange(10)) # Do not work with 9

delta是做什么的?

它是一个set,包含arr2行的哈希值,所以要比没有哈希值的比较更快地找到匹配的行。

如何为一维数组准备这种解决方案?

AFAIK,哈希值只用于避免行的比较,但这是因为输入的是二维数组。在一维中,不存在这样的问题。

这个方法有一个很大的缺陷:它只在没有哈希碰撞的情况下才会起作用。否则,实现就会错误地认为数值是相等的,即使它们不是相等的。虽然@norok在评论中明确提到了这一点。

请注意,哈希值的碰撞处理也应该被实现。


更快的执行速度

将@norok2的二维解决方案用于一维并不是一个好主意,因为哈希值的使用方式不会使其更快。事实上,一个set已经在内部使用了哈希函数。更不用说碰撞需要适当的实现(这是由set完成的)。

使用set是一个相对较好的主意,因为它使n = len(arr1)m = len(arr2)的复杂度达到了O(n + m)。也就是说,如果arr1被转换为set,那么它将会太大,无法装入L1缓存(由于你的例子中arr1的大小),导致缓慢的缓存错过。此外,set的大小不断增加将导致数值被重新洗牌,这并不高效。如果arr2被转换为set,那么许多哈希表的获取将不是很有效,因为arr2在你的情况下非常小。这就是为什么这个解决方案是次优的。

一个解决方案是arr1分块,然后根据目标分块建立一个set。然后你可以有效地检查一个值是否在这个集合中。由于体积越来越大,构建集合的效率仍然不高。这个问题是由Python本身造成的,它没有像其他语言(例如C++)那样为数据结构保留一些空间。避免这个问题的一个解决方案是简单地重新实现一个哈希表,这并不简单,而且很麻烦。实际上,Bloom过滤器可以用来加速这个过程,因为它们可以快速找到两个集合arr1arr2之间是否平均没有碰撞(尽管它们的实现并不简单)。

另一个优化是使用多个线程来并行计算这些块,因为它们是独立的。也就是说,向最终数组的追加并不容易有效地进行,特别是由于你不希望顺序被修改。一个解决方案是将拷贝从并行循环中移开,并以串行方式进行,但这很慢,而且AFAIK目前在Numba中没有简单的方法可以做到这一点(因为并行层非常有限)。考虑使用本地语言如C/C++来实现高效的并行。

最后,散列可能是相当复杂的,与带有两个嵌套循环的天真实现相比,速度可能相当小,因为arr2只有少数项目,而且现代处理器可以使用SIMD指令快速比较值(而基于散列的方法在主流处理器上几乎无法受益)。解卷可以帮助编写一个相当简单和快速的实现。同样,不幸的是,Numba内部使用的是LLVM-Jit,它似乎无法将如此简单的代码矢量化(当然是由于LLVM-Jit甚至LLVM本身的优化缺失)。因此,非矢量化的代码最终会慢一点(而不是在现代主流处理器上快4~10倍)。一个解决方案是用C/C++代码代替来做这个事情(也可能是Cython)。

这里是一个使用基本的Bloom过滤器的串行实现。

@nb.njit('uint32(int32)')
def hash_32bit_4k(value):
    return (np.uint32(value) * np.uint32(27_644_437)) & np.uint32(0x0FFF)

@nb.njit(['int32[:](int32[:], int32[:])', 'int32[:](int32[::1], int32[::1])'])
def setdiff1d_nb_faster(arr1, arr2):
    out = np.empty_like(arr1)
    bloomFilter = np.zeros(4096, dtype=np.uint8)
    for j in range(arr2.size):
        bloomFilter[hash_32bit_4k(arr2[j])] = True
    cur = 0
    for i in range(arr1.size):
        # If the bloom-filter value is true, we know arr1[i] is not in arr2.
        # Otherwise, there is maybe a false positive (conflict) and we need to check to be sure.
        if bloomFilter[hash_32bit_4k(arr1[i])] and arr1[i] in arr2:
            continue
        out[cur] = arr1[i]
        cur += 1
    return out[:cur]

这里有一个未经测试的变体,它应该适用于64位整数(浮点数需要内存视图,可能也需要一个质数常数)。

@nb.njit('uint64(int64)')
def hash_32bit_4k(value):
    return (np.uint64(value) * np.uint64(67_280_421_310_721)) & np.uint64(0x0FFF)

请注意,如果小数组中的所有值都包含在每个循环的主数组中,那么我们可以通过在发现值时从arr2中删除来加快arr1[i] in arr2部分的速度。也就是说,碰撞和发现应该是非常罕见的,所以我并不指望这能明显加快(更不用说它增加了一些开销和复杂性)。如果项目是分块计算的,那么最后的分块可以直接复制,不需要任何检查,但好处应该还是比较小的。请注意,这种策略对之前提到的天真的(C/C++)SIMD实现倒是很有效(可以快2倍左右)。


通用化和并行实现

本节重点讨论了关于输入大小的算法。它特别详细介绍了基于SIMD的实现,并讨论了多线程的使用。

首先,关于数值r,使用的最佳算法可能是不同的。更具体地说。

  • r为0时,最好的做法是返回未经修改的输入数组arr1(可能是一个副本,以避免与原位算法有关的问题)。
  • r为1时,我们可以使用一个基本的循环对数组进行迭代,但最好的实现可能是使用Numpy的np.where,它对此进行了高度的优化
  • r很小的时候,比如<10,那么使用基于SIMD的实现应该是特别有效的,特别是当基于arr2的循环的迭代范围在编译时是已知的,并且是解卷的时候。
  • 对于较大的r值仍然相对较小(例如r < 1000r << n),所提供的基于哈希值的解决方案应该是最好的解决方案之一。
  • 对于较大的r值和r << n,基于散列的解决方案可以通过将布尔值打包成bloomFilter中的比特,以及使用多个散列函数而不是一个,以便更好地处理碰撞,同时更有利于缓存(事实上,这就是实际的Bloom过滤器的作用);注意,当r很大和r << n时,可以使用多线程来加快查找速度。
  • r很大并且不比n小多少时,那么这个问题就很难有效解决,最好的解决办法当然是对两个数组进行排序(通常是用radix排序),并使用基于合并的方法来删除重复的内容,当rn都很大时,可能会使用多个线程(很难实现)。

我们先来看看基于SIMD的解决方案。下面是一个实现。

@nb.njit('int32[:](int32[::1], int32[::1])')
def setdiff1d_nb_simd(arr1, arr2):
    out = np.empty_like(arr1)
    limit = arr1.size // 4 * 4
    limit2 = arr2.size // 2 * 2
    cur = 0
    z32 = np.int32(0)

    # Tile (x4) based computation
    for i in range(0, limit, 4):
        f0, f1, f2, f3 = z32, z32, z32, z32
        v0, v1, v2, v3 = arr1[i], arr1[i+1], arr1[i+2], arr1[i+3]
        # Unrolled (x2) loop searching for a match in `arr2`
        for j in range(0, limit2, 2):
            val1 = arr2[j]
            val2 = arr2[j+1]
            f0 += (v0 == val1) + (v0 == val2)
            f1 += (v1 == val1) + (v1 == val2)
            f2 += (v2 == val1) + (v2 == val2)
            f3 += (v3 == val1) + (v3 == val2)
        # Remainder of the previous loop
        if limit2 != arr2.size:
            val = arr2[arr2.size-1]
            f0 += v0 == val
            f1 += v1 == val
            f2 += v2 == val
            f3 += v3 == val
        if f0 == 0: out[cur] = arr1[i+0]; cur += 1
        if f1 == 0: out[cur] = arr1[i+1]; cur += 1
        if f2 == 0: out[cur] = arr1[i+2]; cur += 1
        if f3 == 0: out[cur] = arr1[i+3]; cur += 1

    # Remainder
    for i in range(limit, arr1.size):
        if arr1[i] not in arr2:
            out[cur] = arr1[i]
            cur += 1

    return out[:cur]

事实证明,在我的机器上,这个实现总是比基于哈希的实现慢,因为Numba显然为基于arr2的内部循环产生了低效率,这似乎来自与==有关的破碎优化。Numba在这个操作中根本没有使用SIMD指令(没有明显的原因)。这使得许多与SIMD相关的替代代码只要使用Numba就会变得很快。

Numba的另一个问题是np.where很慢,因为它使用的是天真的实现,而Numpy的实现已经被大量优化。由于前面的问题,Numpy中的优化很难被应用到Numba的实现中。这使得在Numba代码中使用np.where的速度无法提高。

在实践中,基于哈希的实现是相当快的,在我的机器上,复制需要大量的时间。计算部分可以使用多线程来加速。这并不容易,因为Numba的并行性模型非常有限。Numba不容易对拷贝进行优化(可以使用非时间存储,但Numba还不支持),除非计算可能是在原地进行。

为了使用多线程,一种策略是先将范围分成几块,然后:

  • 建立一个布尔数组,为arr1的每一个项目确定该项目是否在arr2中找到(完全并行)。
  • 计数器是按分块找到的项目的数量(完全并行)。
  • 计算目标块的偏移量(很难并行化,特别是用Numba,但由于有了块,所以速度很快)。
  • 将大块数据复制到目标位置,而不复制发现的项目(完全并行)。

这里有一个高效的基于哈希的并行实现。

@nb.njit('int32[:](int32[:], int32[:])', parallel=True)
def setdiff1d_nb_faster_par(arr1, arr2):
    # Pre-computation of the bloom-filter
    bloomFilter = np.zeros(4096, dtype=np.uint8)
    for j in range(arr2.size):
        bloomFilter[hash_32bit_4k(arr2[j])] = True

    chunkSize = 1024 # To tune regarding the kind of input
    chunkCount = (arr1.size + chunkSize - 1) // chunkSize

    # Find for each item of `arr1` if the value is in `arr2` (parallel)
    # and count the number of item found for each chunk on the fly.
    # Note: thanks to page fault, big parts of `found` are not even written in memory if `arr2` is small
    found = np.zeros(arr1.size, dtype=nb.bool_)
    foundCountByChunk = np.empty(chunkCount, dtype=nb.uint16)
    for i in nb.prange(chunkCount):
        start, end = i * chunkSize, min((i + 1) * chunkSize, arr1.size)
        foundCountInChunk = 0
        for j in range(start, end):
            val = arr1[j]
            if bloomFilter[hash_32bit_4k(val)] and val in arr2:
                found[j] = True
                foundCountInChunk += 1
        foundCountByChunk[i] = foundCountInChunk

    # Compute the location of the destination chunks (sequential)
    outChunkOffsets = np.empty(chunkCount, dtype=nb.uint32)
    foundCount = 0
    for i in range(chunkCount):
        outChunkOffsets[i] = i * chunkSize - foundCount
        foundCount += foundCountByChunk[i]

    # Parallel chunk-based copy
    out = np.empty(arr1.size-foundCount, dtype=arr1.dtype)
    for i in nb.prange(chunkCount):
        srcStart, srcEnd = i * chunkSize, min((i + 1) * chunkSize, arr1.size)
        cur = outChunkOffsets[i]
        # Optimization: we can copy the whole chunk if there is nothing found in it 
        if foundCountByChunk[i] == 0:
            out[cur:cur+(srcEnd-srcStart)] = arr1[srcStart:srcEnd]
        else:
            for j in range(srcStart, srcEnd):
                if not found[j]:
                    out[cur] = arr1[j]
                    cur += 1

    return out

这个实现对于我的机器上的目标输入是最快的。当n相当大,并且在目标平台上创建线程的开销相对较小(例如在PC上,但通常不是具有许多内核的计算服务器)时,它通常很快。并行实现的开销很大,所以目标机器上的核心数至少要有4个,这样实现的速度才能明显快于顺序实现。

对目标输入的chunkSize变量进行调整可能是有用的。如果r << n,最好使用一个相当大的chunkSize。也就是说,chunk的数量需要足够大,以便多个线程对许多chunk进行操作。因此,chunkSize应该明显小于n / numberOfThreads

在我的机器上,大部分时间(65-70%)都花在最后的拷贝上,而最后的拷贝大部分都是占用内存的,很难用Numba来进一步优化。


结果是什么

以下是在我的基于i5-9600KF的机器上的结果(有6个核心)。

setdif1d_np:               2.65 ms
setdif1d_in1d_np:          2.61 ms
setdiff1d_nb:              2.33 ms
setdiff1d_nb_simd:         1.85 ms
setdiff1d_nb_faster:       0.73 ms
setdiff1d_nb_faster_par:   0.49 ms

所提供的最佳实施方案比其他方案快了4~5倍。

Jérôme Richard 提问于2022-10-16
Jérôme Richard 修改于2022-10-23
4096是相当谨慎的选择。1.它需要与hash_32bit_4k函数中的掩码相匹配。2.它需要是2的幂(所以掩码可以等同于一个快速模数)。3.它需要适合于L1高速缓存,也可能适合于少数内存页。4. 它需要足够大,以避免与arr2中的项目数量发生冲突。例如,当我在我的机器上测试时,8192也很好用。代码中考虑到了碰撞,所以结果应该总是正确的(至少对于这种输入数据类型)。Jérôme Richard 2022-10-16
我编辑了帖子,增加了关于64位整数和关于arr2包含在arr1中的信息。我同意这个概括。事实上,我在想这样的策略是否可以直接在np.setdiff1d(Numpy的实现,而不是Numba)中实现,尽管这可能是一个重要的工作。Jérôme Richard 2022-10-16
那么,对于64位来说,4096不需要改变(setdiff1d_nb_faster不需要任何改变,只是签名),也不需要删除arr2的值(对性能的影响可忽略不计)?似乎你考虑了两种不同的签名方式;如果我们知道它们的类型,例如C --> 只是[::1],是否会比只使用一种签名方式有负面作用?just对于r=1来说,解决方案比32位数组的numpy慢一点(非常接近);但是,对于64位(我的情况),它至少快1.5倍,即使对于小的r值,例如1Ali_Sh 2022-10-16
4096可以不动,是的。当然,对于64位的方法,签名需要适应输入类型。::1意味着轴是连续的,所以访问速度一般比较快,但是你不能向接受连续数组的Numba方法提供非连续数组(而相反的情况是可以的)。Jérôme Richard 2022-10-16
Colab上的结果通常不是很稳定,也没有很好的可重复性(因为AFAIK机器是与其他用户共享的)。此外,结果还取决于平台。目标处理器可能会在这样的平台上使核心RAM带宽达到饱和。如果是这样,没有代码会更快。如果不进行分析,就很难判断。请注意,你可能需要调整参数。Jérôme Richard 2022-10-16
#2楼
得票数 4

我发现散列并没有帮助。它只是二维情况下的技巧,将一维数组转换为单个数字,并将它们作为一个集合。

下面是我转换为一维数组的norok2方法(并添加了注释以加快编译速度)。 请注意,这只比你已有的方法稍快(20-30%)。当然,在第二次函数调用后,由于编译的原因,第一次会稍微慢一些。

@nb.njit('int32[:](int32[:], int32[:])')
def setdiff1d_nb(arr1, arr2):
    delta = set(arr2)

    # : build the result
    result = np.empty(len(arr1), dtype=arr1.dtype)
    j = 0
    for i in range(arr1.shape[0]):
        if arr1[i] not in delta:
            result[j] = arr1[i]
            j += 1
    return result[:j]
dankal444 提问于2022-10-14
dankal444 修改于2022-10-17
我检查了一些指定数组大小的随机值,你的第一个解决方案比这个快? 是吗?我使用了64位整数布局C。Ali_Sh 2022-10-17
@Ali_Sh 我还没有测试编辑后的速度。做了norok2建议的两个修改(他删除了他的评论)。嗯,也许最后一行result[:j]不存在,导致了速度下降。dankal444 2022-10-17
澄清一下:Norok2的建议是正确的,必须考虑到它们。我的原始版本没有考虑到arr2中的数字可能重复的事实(在你的情况下很少,但仍然......)。dankal444 2022-10-17