昊虹AI笔记网

 找回密码
 立即注册
搜索
查看: 976|回复: 0
收起左侧

详解Python_Numpy库函数take_along_axis()【由索引矩阵生成新的矩阵】

[复制链接]

239

主题

241

帖子

928

积分

管理员

Rank: 9Rank: 9Rank: 9

积分
928
昊虹君 发表于 2022-11-9 10:33 | 显示全部楼层 |阅读模式
详解Python_Numpy库函数take_along_axis()【由索引矩阵生成新的矩阵】

提问:由已有矩阵的索引生成新的矩阵为什么要用函数take_along_axis(),我用Numpy库ndarray对象的切片操作不行么?
答案是:Numpy库ndarray对象的切片操作不是万能的,比如下面的两种情况它就不能解决,而下面两种情况可以用函数take_along_axis()解决。

情况一:
我由argsort()函数得到了矩阵元素按从小到大排序的索引,接下来我想由个这个排序索引得到一个新的矩阵,这个新矩阵的元素就是按从小到大排列的。这种情况下光靠切片操作就很难实现这个功能了。不信的话诸君可以试一试,反正昊虹君是试了的,很麻烦。但是此时用函数take_along_axis()就很方便,示例如下:
  1. import numpy as np

  2. A = np.array([[10, 30, 20], [60, 40, 50]])
  3. B = np.sort(A, axis=1)
  4. index1 = np.argsort(A, axis=1)
  5. C = np.take_along_axis(A, index1, axis=1)
复制代码

运行结果如下:

从这个示例可以看出,函数take_along_axis()很方便的帮我们通过索引值矩阵index1按序取出了A中的元素形成了数组C。

情况二:
现有三维矩阵A如下:
  1. A = np.arange(2*3*4).reshape([2, 3, 4])
复制代码





现在要实现下面这个目标:
选取A的第0页的第1行和A的第1页的第2行构成一个新的三维矩阵B,B矩阵的形状为(2, 1, 4)。
这个目标用切片操作是无法实现的,昊虹君也尝试过直接用切片实现这个目标,但无奈没有成功。
不过这个目标用函数take_along_axis()就很容易实现了,实现的代码如下:
  1. # -*- coding: utf-8 -*-
  2. # 出处:昊虹AI笔记网(hhai.cc)
  3. # 用心记录计算机视觉和AI技术

  4. # 博主微信/QQ 2487872782
  5. # QQ群 271891601
  6. # 欢迎技术交流与咨询

  7. # OpenCV的版本为4.4.0

  8. import numpy as np
  9. A = np.arange(2*3*4).reshape([2, 3, 4])

  10. index1 = np.zeros([2, 1, 1]).astype('int')

  11. index1[0, 0, :] = 1
  12. index1[1, 0, :] = 2

  13. B = np.take_along_axis(A, index1, axis=1)
复制代码

运行结果如下:

                                          

                                          


具体是怎么实现的,参考博文https://blog.csdn.net/baidu_37157624/article/details/123124561
并仔细思考后得到其实现原理的精炼理解如下:
①显然,B矩阵的形状为为(2, 1, 4),又加上我们是以行为单位进行数据选取,即最小选取单位为一行,此时元素的列索引无意义,所以索引矩阵的形状为index1(2,1,1)。
②当axis=1时,有:
索引矩阵每个元素自身索引值的行索引值代表新矩阵B中的行索引值;
索引矩阵每个元素自身索引值的页索引值代表原矩阵A和新矩阵B的页索引值;
索引矩阵每个元素的值代表原矩阵A中的行索引值;
由于最小选取单位为一行,所以这里列索引值不用考虑。
基于以上认识,所以有:
index1[0, 0, :] = 1 代表取A的第0页的第1行形成B的第0页第0行;
index1[1, 0, :] = 2 代表取A的第1页的第2行形成B的第0页第0行。

补充说明:
①使用函数take_along_axis()时要注意,索引矩阵index1的维度数应该和原矩阵A的维度数相同。
②二维以下时实现上面的功能是完全可以用ndarray的切片或方法take()实现的。
关于ndarray的切片操作的详细介绍见博文 https://www.hhai.cc/thread-117-1-1.html
关于方法take()的详细介绍见博文 https://www.hhai.cc/thread-121-1-1.html
您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

QQ|Archiver|昊虹AI笔记网 ( 蜀ICP备2022024117号-1 )

GMT+8, 2024-5-7 12:23 , Processed in 0.025715 second(s), 22 queries .

Powered by Discuz! X3.4

Copyright © 2001-2021, Tencent Cloud.

快速回复 返回顶部 返回列表