关键点检测¶
关键点检测用于识别和定位图像中的特定兴趣点。这些关键点也称为标志点,代表物体的有意义特征,例如面部特征或物体部分。这些模型输入一张图像,并返回以下输出:
- 关键点和分数:兴趣点及其置信度分数。
- 描述符:每个关键点周围图像区域的表示,捕捉其纹理、梯度、方向等属性。
在本指南中,我们将展示如何从图像中提取关键点。
在本教程中,我们将使用 SuperPoint,这是一个用于关键点检测的基础模型。
In [ ]:
from transformers import AutoImageProcessor, SuperPointForKeypointDetection
# 初始化处理器和模型
processor = AutoImageProcessor.from_pretrained("magic-leap-community/superpoint")
model = SuperPointForKeypointDetection.from_pretrained("magic-leap-community/superpoint")
我们将在以下图像上测试模型。

In [ ]:
import torch
from PIL import Image
import requests
import cv2
# 下载图像
url_image_1 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/bee.jpg"
image_1 = Image.open(requests.get(url_image_1, stream=True).raw)
url_image_2 = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cats.png"
image_2 = Image.open(requests.get(url_image_2, stream=True).raw)
# 将图像放入列表中
images = [image_1, image_2]
# 处理输入并进行推理
inputs = processor(images, return_tensors="pt").to(model.device, model.dtype)
outputs = model(**inputs)
模型输出包括每个图像中的相对关键点、描述符、掩码和分数。掩码突出显示图像中存在关键点的区域。
In [ ]:
# 输出示例
outputs.keypoints
outputs.scores
outputs.descriptors
outputs.mask
为了在图像中标绘实际的关键点,我们需要对输出进行后处理。为此,我们需要将实际的图像尺寸传递给 post_process_keypoint_detection 函数,同时传递输出。
In [ ]:
# 获取图像尺寸
image_sizes = [(image.size[1], image.size[0]) for image in images]
# 进行后处理
outputs = processor.post_process_keypoint_detection(outputs, image_sizes)
现在,输出是一个字典列表,每个字典包含处理后的关键点、分数和描述符。
In [ ]:
# 输出示例
outputs[0]['keypoints']
outputs[0]['scores']
outputs[0]['descriptors']
我们可以使用这些数据来标绘关键点。
In [ ]:
import matplotlib.pyplot as plt
import torch
# 遍历每个图像
for i in range(len(images)):
keypoints = outputs[i]["keypoints"]
scores = outputs[i]["scores"]
descriptors = outputs[i]["descriptors"]
# 将张量转换为 NumPy 数组
keypoints = outputs[i]["keypoints"].detach().numpy()
scores = outputs[i]["scores"].detach().numpy()
image = images[i]
image_width, image_height = image.size
# 绘制图像和关键点
plt.axis('off')
plt.imshow(image)
plt.scatter(
keypoints[:, 0],
keypoints[:, 1],
s=scores * 100, # 置信度分数影响点的大小
c='cyan', # 点的颜色
alpha=0.4 # 透明度
)
plt.show()
下面是输出结果。
