昊虹君 发表于 2022-11-9 10:33

详解Python_Numpy库函数take_along_axis()【由索引矩阵生成新的矩阵】

详解Python_Numpy库函数take_along_axis()【由索引矩阵生成新的矩阵】

提问:由已有矩阵的索引生成新的矩阵为什么要用函数take_along_axis(),我用Numpy库ndarray对象的切片操作不行么?
答案是:Numpy库ndarray对象的切片操作不是万能的,比如下面的两种情况它就不能解决,而下面两种情况可以用函数take_along_axis()解决。

情况一:
我由argsort()函数得到了矩阵元素按从小到大排序的索引,接下来我想由个这个排序索引得到一个新的矩阵,这个新矩阵的元素就是按从小到大排列的。这种情况下光靠切片操作就很难实现这个功能了。不信的话诸君可以试一试,反正昊虹君是试了的,很麻烦。但是此时用函数take_along_axis()就很方便,示例如下:
import numpy as np

A = np.array([, ])
B = np.sort(A, axis=1)
index1 = np.argsort(A, axis=1)
C = np.take_along_axis(A, index1, axis=1)
运行结果如下:
http://pic1.hhai.cc/pic1/2022/2022-11/003/01.png
从这个示例可以看出,函数take_along_axis()很方便的帮我们通过索引值矩阵index1按序取出了A中的元素形成了数组C。

情况二:
现有三维矩阵A如下:
A = np.arange(2*3*4).reshape()
http://pic1.hhai.cc/pic1/2022/2022-11/003/02.png

http://pic1.hhai.cc/pic1/2022/2022-11/003/03.png

现在要实现下面这个目标:
选取A的第0页的第1行和A的第1页的第2行构成一个新的三维矩阵B,B矩阵的形状为(2, 1, 4)。
这个目标用切片操作是无法实现的,昊虹君也尝试过直接用切片实现这个目标,但无奈没有成功。
不过这个目标用函数take_along_axis()就很容易实现了,实现的代码如下:
# -*- coding: utf-8 -*-
# 出处:昊虹AI笔记网(hhai.cc)
# 用心记录计算机视觉和AI技术

# 博主微信/QQ 2487872782
# QQ群 271891601
# 欢迎技术交流与咨询

# OpenCV的版本为4.4.0

import numpy as np
A = np.arange(2*3*4).reshape()

index1 = np.zeros().astype('int')

index1 = 1
index1 = 2

B = np.take_along_axis(A, index1, axis=1)
运行结果如下:
http://pic1.hhai.cc/pic1/2022/2022-11/003/04.png
                                          
http://pic1.hhai.cc/pic1/2022/2022-11/003/05.png
                                          
http://pic1.hhai.cc/pic1/2022/2022-11/003/06.png

具体是怎么实现的,参考博文https://blog.csdn.net/baidu_37157624/article/details/123124561,
并仔细思考后得到其实现原理的精炼理解如下:
①显然,B矩阵的形状为为(2, 1, 4),又加上我们是以行为单位进行数据选取,即最小选取单位为一行,此时元素的列索引无意义,所以索引矩阵的形状为index1(2,1,1)。
②当axis=1时,有:
索引矩阵每个元素自身索引值的行索引值代表新矩阵B中的行索引值;
索引矩阵每个元素自身索引值的页索引值代表原矩阵A和新矩阵B的页索引值;
索引矩阵每个元素的值代表原矩阵A中的行索引值;
由于最小选取单位为一行,所以这里列索引值不用考虑。
基于以上认识,所以有:
index1 = 1 代表取A的第0页的第1行形成B的第0页第0行;
index1 = 2 代表取A的第1页的第2行形成B的第0页第0行。

补充说明:
①使用函数take_along_axis()时要注意,索引矩阵index1的维度数应该和原矩阵A的维度数相同。
②二维以下时实现上面的功能是完全可以用ndarray的切片或方法take()实现的。
关于ndarray的切片操作的详细介绍见博文 https://www.hhai.cc/thread-117-1-1.html
关于方法take()的详细介绍见博文 https://www.hhai.cc/thread-121-1-1.html
页: [1]
查看完整版本: 详解Python_Numpy库函数take_along_axis()【由索引矩阵生成新的矩阵】