阅读 149

VGGnet相似图像检测

VGGnet相似图像检测

利用VGGnet的预训练模型来实现图像的检索,先用预训练模型来抽取图片的特征,然后把待检索的图像和数据库中的所有图像进行匹配,找出相似度最高的

在jupyter notebook上实现

文件路径设置:
root|____ code
       |____ images|____ img_class_1
                           |____ img_class_2
                           |____ img_class_3
                           |.... .....
                           |____ img_class_n
       |____models
       |____queryimg

  • root: 根目录
  • images: 存放各类别的图片文件夹
  • img_class_i: 存放相应类别的图片
  • database: 用于存放数据
  • queryimg: 存放待检索图片

Step 1. 构造特征提取器

这里用了Keras的应用模块(Keras.applications)提供的带有预训练权值的模型

初始化一个模型的时候,会自动下载权重到~/.keras/models/目录下

详细参考

这里用VGG16预训练模型构造一个特征提取器

import numpy as np
from keras.applications.vgg16 import VGG16
from keras.applications.vgg16 import preprocess_input as preprocess_input_vgg
from keras.preprocessing import image
from numpy import linalg as LA
class VGGNet:
    def __init__(self):
        self.input_shape = (224, 224, 3)
        self.weight = ‘imagenet‘
        self.pooling = ‘max‘
        self.model_vgg = VGG16(weights=self.weight,
                               input_shape=(self.input_shape[0], self.input_shape[1], self.input_shape[2]),
                               pooling=self.pooling, include_top=False)
    # 提取vgg16最后一层卷积特征
    def vgg_extract_feat(self, img_path):
        img = image.load_img(img_path, target_size=(self.input_shape[0], self.input_shape[1]))
        img = image.img_to_array(img)
        img = np.expand_dims(img, axis=0)
        img = preprocess_input_vgg(img)
        feat = self.model_vgg.predict(img)
        norm_feat = feat[0] / LA.norm(feat[0])
        return norm_feat
keras.applications.vgg16.VGG16()

参数设置:
include_top: 是否包括顶层的全连接层。
weights: None 代表随机初始化, ‘imagenet‘ 代表加载在 ImageNet 上预训练的权值。
input_tensor: 可选,Keras tensor 作为模型的输入(即 layers.Input() 输出的 tensor)。
input_shape: 可选,输入尺寸元组,仅当 include_top=False 时有效,否则输入形状必须是 (244, 244, 3)(对于 channels_last 数据格式),或者 (3, 244, 244)(对于 channels_first 数据格式)。它必须拥有 3 个输入通道,且宽高必须不小于 32。例如 (200, 200, 3) 是一个合法的输入尺寸。
pooling: 可选,当 include_top 为 False 时,该参数指定了特征提取时的池化方式。

  • None 代表不池化,直接输出最后一层卷积层的输出,该输出是一个四维张量。
  • ‘avg‘ 代表全局平均池化(GlobalAveragePooling2D),相当于在最后一层卷积层后面再加一层全局平均池化层,输出是一个二维张量。
  • ‘max‘ 代表全局最大池化

classes: 可选,图片分类的类别数,仅当 include_top 为 True 并且不加载预训练权值时可用。

Step 2. 保存图片数据特征

用VGGnet提取图片特征
把图片的特征向量和文件路径存到文件中

import os
import h5py
import numpy as np

root = os.path.abspath(‘..‘)
save_path = os.path.join(root,‘database‘,‘vgg_featureCNN.h5‘)

print("--------------------------------------------------")
print("         feature extraction starts")
print("--------------------------------------------------")
imgdir = os.path.join(root,‘images‘)

imgpaths = []
for subdir in os.listdir(imgdir)[:]:
    curpath = os.path.join(imgdir,subdir)
    for imgname in os.listdir(curpath):
        imgpaths += [os.path.join(curpath,imgname)]             # 添加图片路径

feats = []          # 保存图片特征向量

model = VGGNet()
for i, img_path in enumerate(imgpaths):
    norm_feat = model.vgg_extract_feat(img_path)
    feats.append(norm_feat)
    print("extracting feature from image No. %d , %d images in total" % ((i + 1), len(imgpaths)))

feats = np.array(feats)
print("--------------------------------------------------")
print("      writing feature extraction results ...")
print("--------------------------------------------------")

h5f = h5py.File(save_path, ‘w‘)
h5f.create_dataset(‘dataset_1‘, data = feats)
h5f.create_dataset(‘dataset_2‘, data = np.string_(imgpaths))
h5f.close()
print("             writing has ended.            ")

Step 3. 图片检索

把待检索图片存到queryimg中, 进行检索,输出前maxres张匹配度最高的图片


import h5py
from cv2 import cv2
import matplotlib.pyplot as plt
import numpy as np
import os

from extract_cnn_vgg16_keras import VGGNet

root = os.path.abspath(‘..‘)
save_path = os.path.join(root,‘database‘,‘vgg_featureCNN.h5‘)
h5f = h5py.File(save_path, ‘r‘)
feats = h5f[‘dataset_1‘][:]
imgpaths = h5f[‘dataset_2‘][:]
h5f.close()

querydir = os.path.join(root,‘queryimg‘)

# init VGGNet16 model
model = VGGNet()


# 待检索图片名
imgname = ‘xxx.jpg‘

print("--------------------------------------------------")
print("               searching starts")
print("--------------------------------------------------")

# 待检索图片地址
querypath = os.path.join(querydir,imgname)
queryImg = cv2.imread(querypath)
queryImg = cv2.cvtColor(queryImg, cv2.COLOR_BGR2RGB)
plt.title("Query Image")
plt.imshow(queryImg)
plt.show()


# 提取待检索图片的特征
queryVec = model.vgg_extract_feat(querypath)

# 和数据库中的每张图片的特征匹配,计算匹配分数
scores = np.dot(queryVec, feats.T)
# 按匹配分数从大到小排序
rank_ID = np.argsort(scores)[::-1]
rank_score = scores[rank_ID]

maxres = 3  # 检索出三张相似度最高的图片
imlist = []
for i, index in enumerate(rank_ID[0:maxres]):
    imlist.append(imgpaths[index])
    print("image names: " + str(imgpaths[index]) + " scores: %f" % rank_score[i])
print("top %d images in order are: " % maxres, imlist)

# 输出检索到的图片
for i, im in enumerate(imlist):
    impath = str(im)[2:-1]        # 得到的im是一个byte型的数据格式,需要转换成字符串
    print(impath)
    image = cv2.imread(impath)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    plt.title("search output %d" % (i + 1))
    plt.imshow(image)
    plt.show()

参考:

原文:https://www.cnblogs.com/kikokiko/p/14882723.html

文章分类
代码人生
文章标签
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐