NumPy Argsort:排序不再返回元素,而是索引

Okay, here is a detailed article about NumPy's argsort function, aiming for approximately 3000 Chinese characters.


NumPy Argsort 深度解析:排序不再返回元素,而是索引的奥秘

在数据处理和科学计算的广阔领域中,排序是一项基础且至关重要的操作。无论是整理数据以便于观察,还是作为更复杂算法(如查找、合并)的预处理步骤,排序都扮演着不可或缺的角色。Python 的 NumPy 库,作为科学计算生态的核心支柱,提供了强大而高效的数组操作功能,其中自然也包括了多种排序工具。

我们通常熟悉的排序操作,比如 Python 内置的 sorted() 函数或列表的 .sort() 方法,以及 NumPy 的 numpy.sort() 函数,它们的目标都是直接返回一个包含排序后元素的新列表或数组。例如,对数组 [3, 1, 4, 2] 进行排序,我们会得到 [1, 2, 3, 4]。这种方式直观易懂,满足了许多基本的排序需求。

然而,在更复杂的场景下,我们可能不仅仅关心排序后的值本身,更关心这些值在原始数组中的位置。想象一下,你有一个记录学生姓名和对应分数的两个列表或数组,你想根据分数对学生进行排名,但最终需要的是按分数排序后的学生姓名列表。如果直接对分数数组排序,你会得到排序后的分数,但丢失了分数与姓名的对应关系。这时,一种能够返回元素排序后对应原始索引的排序方法就显得尤为重要。

这正是 NumPy 中 numpy.argsort() 函数的核心价值所在。argsort 不会返回排序后的元素数组,而是返回一个包含索引的数组,这些索引指向原始数组中的元素,并且按照这些索引对应的元素值排序。理解并熟练运用 argsort,能够解锁许多高级的数据操作技巧,极大地提升数据处理的灵活性和效率。

本文将深入探讨 numpy.argsort() 的功能、用法、参数、应用场景及其与 numpy.sort() 的区别,帮助你全面掌握这个强大的排序工具。

一、 什么是 numpy.argsort()

numpy.argsort() 的核心功能可以概括为:对输入数组进行排序,但不返回排序后的元素值,而是返回一个整数数组,该数组包含了排序后元素在原始输入数组中的索引。

换句话说,如果 indices = np.argsort(a),那么 a[indices] 将会得到一个与 np.sort(a) 相同(或在处理 NaN 等特殊值时行为一致)的排序后数组。indices 数组的第一个元素是原始数组 a 中最小元素的索引,第二个元素是第二小元素的索引,依此类推,最后一个元素是最大元素的索引。

基本示例:一维数组

让我们通过一个简单的例子来直观理解:

```python
import numpy as np

创建一个简单的一维 NumPy 数组

arr = np.array([80, 20, 50, 10, 90, 30])

使用 argsort 获取排序后的索引

indices = np.argsort(arr)

print("原始数组 (arr):", arr)

输出: 原始数组 (arr): [80 20 50 10 90 30]

print("argsort 返回的索引数组 (indices):", indices)

输出: argsort 返回的索引数组 (indices): [3 1 5 2 0 4]

使用索引数组访问原始数组,得到排序后的结果

sorted_arr_via_indices = arr[indices]
print("通过索引访问得到的排序后数组:", sorted_arr_via_indices)

输出: 通过索引访问得到的排序后数组: [10 20 30 50 80 90]

对比直接使用 np.sort() 的结果

sorted_arr_direct = np.sort(arr)
print("直接使用 np.sort() 得到的排序后数组:", sorted_arr_direct)

输出: 直接使用 np.sort() 得到的排序后数组: [10 20 30 50 80 90]

```

在这个例子中:
1. arr 是我们的原始数据 [80, 20, 50, 10, 90, 30]
2. np.argsort(arr) 返回了 [3, 1, 5, 2, 0, 4]
3. 这个索引数组告诉我们:
* 原始数组中最小的元素是 10,它位于索引 3
* 第二小的元素是 20,它位于索引 1
* 第三小的元素是 30,它位于索引 5
* 第四小的元素是 50,它位于索引 2
* 第五小的元素是 80,它位于索引 0
* 最大的元素是 90,它位于索引 4
4. 当我们使用这个 indices 数组来索引 arr 时 (arr[indices]),我们实际上是按照 [arr[3], arr[1], arr[5], arr[2], arr[0], arr[4]] 的顺序取值,这正好得到了排序后的数组 [10, 20, 30, 50, 80, 90],与 np.sort(arr) 的结果一致。

二、 为什么需要返回索引?argsort 的核心优势

既然 np.sort() 可以直接给出排序结果,为什么我们还需要 argsort 这种返回索引的方式呢?argsort 的优势在于它保留了原始位置信息,这在以下场景中至关重要:

  1. 同步排序多个相关数组(协同排序):
    这是 argsort 最经典的应用之一。当你有多个数组,它们之间存在一一对应的关系(例如,数组 A 存姓名,数组 B 存年龄,数组 C 存分数),而你想根据其中一个数组(比如分数 C)来对所有数组进行排序时,argsort 就是完美的解决方案。

    ```python
    import numpy as np

    names = np.array(['Alice', 'Bob', 'Charlie', 'David'])
    ages = np.array([25, 30, 22, 30])
    scores = np.array([88, 92, 85, 95])

    print("原始数据:")
    print("Names:", names)
    print("Ages:", ages)
    print("Scores:", scores)

    根据分数 (scores) 获取排序索引

    score_sort_indices = np.argsort(scores)
    print("\n按分数排序的索引:", score_sort_indices)

    输出: 按分数排序的索引: [2 0 1 3] (对应分数 85, 88, 92, 95)

    使用这些索引来同步排序所有数组

    sorted_names = names[score_sort_indices]
    sorted_ages = ages[score_sort_indices]
    sorted_scores = scores[score_sort_indices] # 等同于 np.sort(scores)

    print("\n根据分数排序后的数据:")
    print("Sorted Names:", sorted_names)

    输出: Sorted Names: ['Charlie' 'Alice' 'Bob' 'David']

    print("Sorted Ages:", sorted_ages)

    输出: Sorted Ages: [22 25 30 30]

    print("Sorted Scores:", sorted_scores)

    输出: Sorted Scores: [85 88 92 95]

    ``
    在这个例子中,我们只对
    scores应用了argsort,得到的索引[2, 0, 1, 3]被用于重新排列names,agesscores自身,从而保持了它们之间原有的对应关系,同时实现了按分数排序的整体效果。如果直接对每个数组单独使用np.sort()`,这种对应关系就会丢失。

  2. 获取 Top-K 或 Bottom-K 元素的索引(及对应值):
    有时我们不需要对整个数组排序,只关心最大或最小的 K 个元素及其原始位置。argsort 可以轻松实现这一点。

    ```python
    import numpy as np

    data = np.array([10, 50, 20, 80, 30, 90, 40, 70, 60])

    获取排序索引

    indices = np.argsort(data)

    获取最小的 3 个元素的索引

    bottom_3_indices = indices[:3]
    print("最小 3 个元素的索引:", bottom_3_indices)

    输出: 最小 3 个元素的索引: [0 2 4]

    获取最小的 3 个元素的值

    bottom_3_values = data[bottom_3_indices]
    print("最小 3 个元素的值:", bottom_3_values)

    输出: 最小 3 个元素的值: [10 20 30]

    获取最大的 3 个元素的索引 (注意索引从后往前取)

    top_3_indices = indices[-3:]

    或者可以这样写: top_3_indices = indices[len(data)-3:]

    或者更简洁,利用[::-1]反转索引再取前3:top_3_indices = indices[::-1][:3] (注意这得到的是降序索引)

    为了得到升序索引中的最大3个,用 indices[-3:] 是最直接的

    print("\n最大 3 个元素的索引:", top_3_indices)

    输出: 最大 3 个元素的索引: [7 3 5]

    获取最大的 3 个元素的值

    top_3_values = data[top_3_indices]
    print("最大 3 个元素的值:", top_3_values)

    输出: 最大 3 个元素的值: [70 80 90]

    如果想要按降序排列的最大3个元素的值及其索引

    descending_indices = np.argsort(data)[::-1] # 获取降序索引
    top_3_desc_indices = descending_indices[:3]
    top_3_desc_values = data[top_3_desc_indices]
    print("\n按降序排列的最大 3 个元素的索引:", top_3_desc_indices)

    输出: 按降序排列的最大 3 个元素的索引: [5 3 7]

    print("按降序排列的最大 3 个元素的值:", top_3_desc_values)

    输出: 按降序排列的最大 3 个元素的值: [90 80 70]

    ``
    通过对
    argsort` 返回的索引数组进行切片,我们可以高效地定位到所需范围的元素,而无需对整个数组进行排序后再次查找。

  3. 需要稳定排序(Stable Sort)的应用:
    当数组中存在重复值时,稳定排序保证了这些重复值在排序后的相对顺序与它们在原始数组中的相对顺序保持一致。argsort 通过 kind 参数支持稳定排序算法(如 'mergesort''stable')。这在协同排序等场景下尤其重要,可以确保关联数据的正确对应。例如,上面姓名、年龄、分数的例子中,Bob 和 David 年龄相同(30),如果按年龄排序,稳定排序能保证他们俩在排序后的相对位置取决于他们在原始数组中的先后顺序。

  4. 作为复杂算法的中间步骤:
    在一些更高级的算法中,比如计算排名(Rank)、实现基于邻近度的算法(如 K-近邻)等,直接操作索引往往比操作排序后的值更方便、更高效。argsort 提供的索引数组是这些算法的关键输入。

三、 numpy.argsort() 详解:参数与用法

numpy.argsort() 函数的完整签名如下:

python
numpy.argsort(a, axis=-1, kind=None, order=None)

让我们逐一解析它的参数:

  1. a (array_like):
    必需参数,表示需要进行排序的输入数组。它可以是 NumPy 数组,也可以是任何可以被 NumPy 转换成数组的对象(如列表、元组)。

  2. axis (int or None, optional):
    指定沿着哪个轴进行排序。

    • axis=-1 (默认值): 沿着最后一个轴进行排序。对于二维数组,这意味着按行排序(对每一行内部的元素进行排序)。
    • axis=0: 沿着第一个轴进行排序。对于二维数组,这意味着按列排序(对每一列内部的元素进行排序,保持列的结构)。
    • axis=None: 将数组扁平化(flatten)为一维数组后进行排序。返回的索引也是对应于扁平化后数组的索引。

    示例:多维数组排序

    ```python
    import numpy as np

    arr_2d = np.array([[30, 10, 20],
    [60, 50, 40]])

    print("原始二维数组:\n", arr_2d)

    默认按行排序 (axis=-1 或 axis=1)

    indices_axis1 = np.argsort(arr_2d, axis=1)
    print("\n按行排序的索引 (axis=1):\n", indices_axis1)

    输出:

    [[1 2 0] <- 10(idx 1), 20(idx 2), 30(idx 0) in row 0

    [2 1 0]] <- 40(idx 2), 50(idx 1), 60(idx 0) in row 1

    按列排序 (axis=0)

    indices_axis0 = np.argsort(arr_2d, axis=0)
    print("\n按列排序的索引 (axis=0):\n", indices_axis0)

    输出:

    [[0 0 0] <- 30(row 0) < 60(row 1), 10(row 0) < 50(row 1), 20(row 0) < 40(row 1)

    [1 1 1]]

    扁平化后排序 (axis=None)

    indices_axisNone = np.argsort(arr_2d, axis=None)
    print("\n扁平化后排序的索引 (axis=None):\n", indices_axisNone)

    原始扁平化数组: [30, 10, 20, 60, 50, 40]

    排序后: 10, 20, 30, 40, 50, 60

    对应索引: 1, 2, 0, 5, 4, 3

    输出: [1 2 0 5 4 3]

    验证按行排序结果

    sorted_arr_axis1 = np.take_along_axis(arr_2d, indices_axis1, axis=1)
    print("\n使用 axis=1 的索引重建的按行排序数组:\n", sorted_arr_axis1)

    输出:

    [[10 20 30]

    [40 50 60]]

    验证按列排序结果

    sorted_arr_axis0 = np.take_along_axis(arr_2d, indices_axis0, axis=0)
    print("\n使用 axis=0 的索引重建的按列排序数组:\n", sorted_arr_axis0)

    输出:

    [[30 10 20]

    [60 50 40]] (注意: 每一列内部是排序的,但行的整体顺序可能打乱)

    ``np.take_along_axis()是一个非常有用的函数,可以方便地使用argsort` 返回的索引数组来重建排序后的数组,尤其在处理多维数组时。

  3. kind ({'quicksort', 'mergesort', 'heapsort', 'stable'}, optional):
    指定使用的排序算法。不同的算法在性能、稳定性、空间复杂度上有所差异。

    • 'quicksort' (快速排序): 默认值(NumPy 1.22 及之后版本可能默认为 'stable' 或基于类型的选择,早期版本默认为 'quicksort')。平均时间复杂度 O(N log N),最坏 O(N^2)。速度通常最快,但不稳定
    • 'mergesort' (归并排序): 时间复杂度稳定在 O(N log N)。稳定排序。需要额外的存储空间(通常 O(N))。
    • 'heapsort' (堆排序): 时间复杂度 O(N log N)。不稳定。空间复杂度 O(1)(原地排序,但 argsort 仍然需要 O(N) 空间存储索引)。
    • 'stable' (稳定排序): NumPy 会自动选择一种稳定的排序算法,通常是归并排序的变种(如 Timsort)。保证稳定。推荐在需要保持相等元素相对顺序时使用。

    稳定性(Stability) 指的是,如果数组中有两个或多个相等的元素,经过稳定排序后,这些相等元素在输出数组中的相对顺序与它们在输入数组中的相对顺序保持一致。这在协同排序等场景下非常重要。

    ```python
    import numpy as np

    演示稳定性差异

    结构化数组,包含一个值和原始索引信息,方便观察

    dtype = [('value', int), ('original_index', int)]
    a = np.array([(3, 0), (1, 1), (2, 2), (3, 3), (1, 4)], dtype=dtype)

    print("原始结构化数组:\n", a)

    使用不稳定排序 (quicksort) 按 'value' 排序

    注意:NumPy 1.22+ 可能默认kind行为不同,显式指定 'quicksort'

    indices_qsort = np.argsort(a, order='value', kind='quicksort')
    print("\nQuicksort (不稳定) 排序后的索引:", indices_qsort)
    print("按 Quicksort 索引重排后的数组:\n", a[indices_qsort])

    可能的输出 (两个 3 和两个 1 的相对顺序可能改变):

    按 Quicksort 索引重排后的数组:

    [(1, 4) (1, 1) (2, 2) (3, 3) (3, 0)] <-- (1, 4) 排在了 (1, 1) 前面

    使用稳定排序 (mergesort 或 stable) 按 'value' 排序

    indices_stable = np.argsort(a, order='value', kind='stable')
    print("\nStable Sort (稳定) 排序后的索引:", indices_stable)
    print("按 Stable Sort 索引重排后的数组:\n", a[indices_stable])

    稳定输出 (两个 3 和两个 1 的相对顺序保持不变):

    按 Stable Sort 索引重排后的数组:

    [(1, 1) (1, 4) (2, 2) (3, 0) (3, 3)] <-- (1, 1) 保持在 (1, 4) 前,(3, 0) 保持在 (3, 3) 前

    ```

  4. order (str or list of str, optional):
    仅当 a 是结构化数组(structured array)时使用。order 参数指定了排序时应比较的字段(列)名。可以提供一个字符串(单个字段)或一个字符串列表(按顺序比较多个字段)。

    示例:结构化数组排序

    ```python
    import numpy as np

    创建一个结构化数组

    dtype = [('name', 'S10'), ('height', float), ('age', int)]
    data = np.array([('Alice', 1.65, 25),
    ('Bob', 1.75, 30),
    ('Charlie', 1.75, 22)], dtype=dtype)

    print("原始结构化数组:\n", data)

    按 'age' 字段排序

    age_sort_indices = np.argsort(data, order='age')
    print("\n按 'age' 排序的索引:", age_sort_indices)
    print("按 'age' 排序后的数组:\n", data[age_sort_indices])

    输出:

    按 'age' 排序后的数组:

    [(b'Charlie', 1.75, 22) (b'Alice', 1.65, 25) (b'Bob', 1.75, 30)]

    先按 'height' 排序,如果 'height' 相同,再按 'age' 排序

    注意:Bob 和 Charlie 身高相同

    height_age_sort_indices = np.argsort(data, order=['height', 'age'])
    print("\n按 'height' 再按 'age' 排序的索引:", height_age_sort_indices)
    print("按 'height' 再按 'age' 排序后的数组:\n", data[height_age_sort_indices])

    输出:

    按 'height' 再按 'age' 排序后的数组:

    [(b'Alice', 1.65, 25) (b'Charlie', 1.75, 22) (b'Bob', 1.75, 30)]

    Alice 最矮排第一。Charlie 和 Bob 身高相同 (1.75),按年龄排序,Charlie (22) 在 Bob (30) 前面。

    ```

四、argsort vs sort: 何时选择?

总结一下 np.argsort()np.sort() 的关键区别和适用场景:

特性 np.argsort(a) np.sort(a) a.sort() (原地排序方法)
返回值 排序后的索引数组 排序后的的新数组 None (直接修改原数组 a)
原始数组 不改变 不改变 改变
核心用途 需要保留原始位置信息;协同排序;获取 Top-K 索引;稳定排序需求;算法中间步骤 只需要排序后的值;简单排序展示 只需要排序后的值;希望节省内存(原地修改)
空间 需要额外空间存储索引数组 (O(N)) 需要额外空间存储排序后的新数组 (O(N)) 通常空间效率更高(取决于算法)
参数 axis, kind, order axis, kind, order kind, order (没有 axis)

选择指南:

  • 当你需要根据一个数组的顺序来排列其他相关联的数据时,使用 argsort
  • 当你需要知道排序后的元素在原始数组中的位置(索引)时,使用 argsort
  • 当你需要获取最大/最小的 K 个元素的索引时,使用 argsort
  • 当稳定性很重要,并且你需要确保相等元素的原始顺序得以保留时,使用 argsort 并指定 kind='stable''mergesort'np.sorta.sort() 也支持 kind 参数)。
  • 当你仅仅需要得到一个排序好的数组副本,并且不关心原始索引时,使用 np.sort()
  • 当你希望直接在原数组上进行排序以节省内存,并且不需要保留原始数组时,使用数组自身的 .sort() 方法。

五、 性能考量

argsortsort 的性能主要取决于选择的排序算法 (kind 参数) 和输入数据的大小及特性。
* 大多数情况下,快速排序 ('quicksort') 提供了最好的平均性能(O(N log N)),但其最坏情况性能为 O(N^2),且不稳定。
* 归并排序 ('mergesort', 'stable') 保证 O(N log N) 的时间复杂度和稳定性,但通常比快速排序慢一点,且需要额外的 O(N) 空间。
* 堆排序 ('heapsort') 保证 O(N log N) 的时间复杂度,空间复杂度优秀(原地排序特性),但不稳定。

argsort 本身需要创建一个与输入数组大小相同的整数索引数组,这会带来额外的内存开销(O(N))。对于非常大的数组,这可能是一个需要考虑的因素。然而,其提供的灵活性往往超过了这点内存开销带来的影响。

六、 进阶技巧与应用

  1. 获取排名 (Rank): 可以通过两次 argsort 来计算数组元素的排名(最小元素的排名为 0,次小的为 1,依此类推)。
    python
    import numpy as np
    a = np.array([30, 10, 40, 20])
    temp = np.argsort(a) # 得到排序索引 [1, 3, 0, 2] (对应 10, 20, 30, 40)
    ranks = np.empty_like(temp)
    ranks[temp] = np.arange(len(a)) # 在 temp 指定的位置填入 0, 1, 2, ...
    # temp[0]=1, 所以 ranks[1]=0
    # temp[1]=3, 所以 ranks[3]=1
    # temp[2]=0, 所以 ranks[0]=2
    # temp[3]=2, 所以 ranks[2]=3
    print("Array:", a)
    print("Ranks:", ranks) # 输出: Ranks: [2 0 3 1]
    # 30 的排名是 2,10 的排名是 0,40 的排名是 3,20 的排名是 1

  2. 结合布尔索引: argsort 返回的索引可以用于更复杂的基于排序的选择。

  3. 配合 np.take_along_axis: 如前文多维数组示例所示,take_along_axis 是使用 argsort 索引来构造排序后数组的规范方式,尤其对于多维情况。

七、 总结

numpy.argsort() 是 NumPy 库中一个极其有用的函数,它颠覆了传统排序函数直接返回值的方式,巧妙地通过返回索引来保留原始数据的位置信息。这一特性使得 argsort 在处理需要保持数据关联性的场景(如协同排序)、提取特定范围(Top-K/Bottom-K)元素、实现稳定排序以及作为更复杂算法构建块时,展现出无与伦比的价值。

通过理解 argsort 的核心机制,掌握其 axis, kind, order 等关键参数的用法,并了解它与 np.sort().sort() 方法的区别,数据科学家和工程师能够更加灵活、高效地驾驭 NumPy 进行数据整理、分析和预处理工作。虽然初看起来可能不如 np.sort() 直观,但一旦掌握,argsort 将成为你数据处理工具箱中一把不可或缺的瑞士军刀,助你解决更多样、更复杂的排序相关问题。记住,当你需要的不仅仅是排序后的值,而是排序后的“秩序”本身时,argsort 就是你的答案。


THE END