窝牛号

Segment Anything Meta开源分割一切模型,为进军元宇宙更近一步

上期图文教程,我们分享了Segment Anything分割一切模型的原理,Segment Anything Model 是一种以最少的人工干预构建全自动可提示图像分割模型的方法。模型提供了一键分割图片的方法,当然模型也可以运行我们输入一个坐标点,一个输入框,或者输入一个对象的文本来分割输入的对象。

它是一个单一的模型,可以轻松地执行交互式分割和自动分割。该模型允许以灵活的方式使用它,只需为模型设计正确的提示(点击、分割框、文本等),就可以完成分割任务。此外,segment Anything Meta SAM在包含超过 10 亿个掩码的多样化、高质量数据集上进行训练,这使其能够泛化到新类型的对象和图像。

在 Segment Anything Meta SAM 中,该模式包含三个重要组成部分:

图像编码器。提示编码器。掩码解码器。

更多模型介绍,可以参考上期图文教程,本期教程,我们分享一下Segment Anything的代码实现过程。在运行代码前,首先需要确认一下有N卡的驱动,且成功安装了torch等第三方库

import torch import torchvision import sys !{sys.executable} -m pip install opencv-python matplotlib !{sys.executable} -m pip install &39; !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

首先我们需要使用到torch与torchvision库,并使用Facebook开源的segment-anything模型,安装相关的第三方库文件,并下载预训练模型。

import numpy as np import torch import matplotlib.pyplot as plt import cv2 def show_anns(anns): if len(anns) == 0: return sorted_anns = sorted(anns, key=(lambda x: x[&39;]), reverse=True) ax = plt.gca() ax.set_autoscale_on(False) img = np.ones((sorted_anns[0][&39;].shape[0], sorted_anns[0][&39;].shape[1], 4)) img[:,:,3] = 0 for ann in sorted_anns: m = ann[&39;] color_mask = np.concatenate([np.random.random(3), [0.35]]) img[m] = color_mask ax.imshow(img)

然后我们建立一个show函数,此函数主要用于可视化Segment Anything模型预测的结果。

image = cv2.imread(&39;) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) plt.figure(figsize=(20,20)) plt.imshow(image) plt.axis(&39;) plt.show()

然后我们读取一张本地的照片,上张照片是transformer进行对象检测的结果,我们使用原图片上传给模型。在模型读取图片前,我们需要把图片转换到RGB空间,当然这里可以show一下图片。

import sys sys.path.append(&34;) from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor sam_checkpoint = &34; model_type = &34; device = &34; sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) sam.to(device=device) mask_generator = SamAutomaticMaskGenerator(sam) masks = mask_generator.generate(image)

然后我们就可以把图片传递给模型了,这里我们使用sam_vit_h_4b8939.pth预训练模型中的vit模型,SamAutomaticMaskGenerator函数帮我们建立了一个自动分割的函数,此函数会自动分割图片中的所有检测到的分类,当然此模型可以根据自己输入的坐标信息,对象种类来单独分割某个单独的对象,此部分代码我们后期进行分享。最后我们就可以把图片传递给mask generator函数进行预测了,预测后的结果保存在mask里面。

print(len(masks)) print(masks[0].keys()) plt.figure(figsize=(20,20)) plt.imshow(image) show_anns(masks) plt.axis(&39;) plt.show()

模型预测完成后,我们可以打印出来模型预测的种类以及模型预测的结果,这里模型输出如下

segmentation : the mask area : mask区域 bbox : XYWH format 格式的边框 predicted_iou : 模型自己对掩模质量的预测 point_coords : 生成此掩码的采样输入点 stability_score : 掩码质量的附加度量 crop_box : 用于生成 XYWH 格式蒙版的图像裁剪区域 192 dict_keys([&39;, &39;, &39;, &39;, &39;, &39;, &39;])

有了以上的输出,我们可以使用前面建立的可视化函数来进行mask图片的可视化操作

本站所发布的文字与图片素材为非商业目的改编或整理,版权归原作者所有,如侵权或涉及违法,请联系我们删除

窝牛号 wwww.93ysy.com   沪ICP备2021036305号-1