np.where() 是 NumPy 中非常强大且灵活的函数,主要有两种使用模式。让我为你详细解析这两种神奇用法:
这是 np.where() 最基本的形式,类似于三元运算符:
import numpy as np
# 语法:np.where(condition, x, y)
# 当condition为True时返回x,否则返回y
arr = np.array([1, 2, 3, 4, 5])
result = np.where(arr > 3, arr, 0)
print(result) # 输出:[0 0 0 4 5]
# 实际示例:成绩等级划分
scores = np.array([85, 62, 90, 45, 78])
grades = np.where(scores >= 90, 'A',
np.where(scores >= 80, 'B',
np.where(scores >= 70, 'C',
np.where(scores >= 60, 'D', 'F'))))
print(grades) # 输出:['B' 'D' 'A' 'F' 'C']
特点:
当只传入一个参数(condition)时,np.where() 返回满足条件的元素索引:
import numpy as np
arr = np.array([1, 2, 3, 4, 5, 4, 3, 2, 1])
# 查找值为4的所有位置
indices = np.where(arr == 4)
print(indices) # 输出:(array([3, 5]),)
print(arr[indices]) # 输出:[4 4]
# 二维数组示例
matrix = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# 查找大于5的元素位置
rows, cols = np.where(matrix > 5)
print("行索引:", rows) # 输出:[1 2 2 2]
print("列索引:", cols) # 输出:[2 0 1 2]
# 获取满足条件的值
values = matrix[rows, cols]
print("值:", values) # 输出:[6 7 8 9]
高级应用示例:
import numpy as np
# 示例1:数据清洗 - 替换异常值
data = np.array([1, 2, 99, 4, 5, 99, 7])
cleaned = np.where(data > 50, np.mean(data[data <= 50]), data)
print("清洗后:", cleaned)
# 示例2:图像处理 - 阈值化
image = np.random.randint(0, 256, size=(5, 5))
thresholded = np.where(image > 128, 255, 0)
print("阈值化图像:\n", thresholded)
# 示例3:查找极值点
x = np.linspace(0, 10, 100)
y = np.sin(x)
# 查找局部最大值(简化示例)
peak_indices = np.where((y[1:-1] > y[:-2]) & (y[1:-1] > y[2:]))[0] + 1
print("极值点位置:", peak_indices)
# 示例4:多条件查找
arr = np.array([[1, 0, 3],
[0, 5, 0],
[7, 0, 9]])
# 查找非零元素
rows, cols = np.where(arr != 0)
print("非零元素位置:")
for r, c in zip(rows, cols):
print(f" ({r}, {c}) = {arr[r, c]}")
import numpy as np
import time
# 创建大型数组
arr = np.random.randn(1000000)
# 方法1:使用np.where(推荐)
start = time.time()
result1 = np.where(arr > 0, arr, 0)
time1 = time.time() - start
# 方法2:使用Python循环(不推荐)
start = time.time()
result2 = np.zeros_like(arr)
for i in range(len(arr)):
if arr[i] > 0:
result2[i] = arr[i]
time2 = time.time() - start
print(f"np.where() 耗时: {time1:.4f}秒")
print(f"Python循环耗时: {time2:.4f}秒")
print(f"加速比: {time2/time1:.1f}倍")
| 特性 | 第一种用法(三目运算) | 第二种用法(索引查找) |
|---|---|---|
| 参数数量 | 3个(condition, x, y) | 1个(condition) |
| 返回值 | 新数组(与condition同形) | 元组(索引数组) |
| 主要用途 | 条件赋值、数据转换 | 查找元素位置 |
| 性能 | 向量化操作,速度快 | 同样高效 |
| 广播支持 | ✓ | ✗ |
# 技巧1:使用np.where进行复杂条件操作
arr = np.array([1, -2, 3, -4, 5])
result = np.where(arr > 0, np.sqrt(arr), np.sqrt(-arr))
print("绝对值的平方根:", result)
# 技巧2:结合其他NumPy函数
matrix = np.array([[1, 2], [3, 4]])
# 查找最大值的所有位置
max_positions = np.where(matrix == np.max(matrix))
print("最大值位置:", max_positions)
# 技巧3:处理缺失值(配合掩码)
data = np.array([1, 2, np.nan, 4, np.nan, 6])
mask = np.isnan(data)
filled = np.where(mask, 0, data)
print("填充后:", filled)
np.where() 的核心优势在于:
掌握这两种用法,可以让你在处理数组数据时更加得心应手!