阅读 181

React 实现图片识别App

(首先声明,我并不是这方面的学习者,我也不懂这个什么神经网络的学习,写这一篇和做这一个demo 完全是因为觉得好玩。所以里面的代码,除了 react 我懂,其他的,我就。。。)

在这里插入图片描述

先把效果图给大家放上来 在这里插入图片描述 在这里插入图片描述

在这里插入图片描述 个人觉得效果还行。识别不太准确是因为这个 app学习图片的时间太短(电脑太卡)。


(笔者是 window10) 安装运行环境:

  1. npm install --global windows-build-tools(这个时间很漫长。。。)

  2. npm install @tensorflow/tfjs-node(这个时间很漫长。。。)

项目目录如下 在这里插入图片描述

train文件夹 index.js(入口文件)

const tf = require('@tensorflow/tfjs-node') const getData = require('./data') const TRAIN_DIR = '../垃圾分类/train' const OUTPUT_DIR = '../outputDir' const MOBILENET_URL = 'http://ai-sample.oss-cn-hangzhou.aliyuncs.com/pipcook/models/mobilenet/web_model/model.json' const main = async () => {   // 加载数据   const { ds, classes} = await getData(TRAIN_DIR, OUTPUT_DIR)   // 定义模型   const mobilenet = await tf.loadLayersModel(MOBILENET_URL)   mobilenet.summary()   // console.log(mobilenet.layers.map((l, i) => [l.name, i]))   const model = tf.sequential()   for (let i = 0; i <= 86; i += 1) {     const layer = mobilenet.layers[i]     layer.trainable = false     model.add(layer)   }   model.add(tf.layers.flatten())   model.add(tf.layers.dense({     units: 10,     activation: 'relu'   }))   model.add(tf.layers.dense({     units: classes.length,     activation: 'softmax'   }))   // 训练模型   model.compile({     loss: 'sparseCategoricalCrossentropy',     optimizer: tf.train.adam(),     metrics: ['acc']   })   await model.fitDataset(ds, { epochs: 20 })   await model.save(`file://${process.cwd()}/${OUTPUT_DIR}`) } main() 复制代码

data.js(处理数据)

const fs = require('fs') const tf = require('@tensorflow/tfjs-node') const img2x = (imgPath) => {   const buffer = fs.readFileSync(imgPath)   return tf.tidy(() => {     const imgTs = tf.node.decodeImage(new Uint8Array(buffer))     const imgTsResized = tf.image.resizeBilinear(imgTs, [224, 224])     return imgTsResized.toFloat().sub(255/2).div(255/2).reshape([1, 224, 224, 3])   }) } const getData = async (trainDir, outputDir) => {   const classes = fs.readdirSync(trainDir)   fs.writeFileSync(`${outputDir}/classes.json`, JSON.stringify(classes))   const data = []   classes.forEach((dir, dirIndex) => {     fs.readdirSync(`${trainDir}/${dir}`)       .filter(n => n.match(/jpg$/))       .slice(0, 10)       .forEach(filename => {         console.log('读取', dir, filename)         const imgPath = `${trainDir}/${dir}/${filename}`         data.push({ imgPath, dirIndex })       })   })   tf.util.shuffle(data)   const ds = tf.data.generator(function* () {     const count = data.length     const batchSize = 32     for (let start = 0; start < count; start += batchSize) {       const end = Math.min(start + batchSize, count)       yield tf.tidy(() => {         const inputs = []         const labels = []         for (let j = start; j < end; j += 1) {           const { imgPath, dirIndex } = data[j]           const x = img2x(imgPath)           inputs.push(x)           labels.push(dirIndex)         }         const xs = tf.concat(inputs)         const ys = tf.tensor(labels)         return { xs, ys }       })     }   })   return {     ds,     classes   } } module.exports = getData 复制代码

安装一些运行项目需要的插件 在这里插入图片描述

app 文件夹

import React, { PureComponent } from 'react' import { Button, Progress, Spin, Empty } from 'antd' import 'antd/dist/antd.css' import * as tf from '@tensorflow/tfjs' import { file2img, img2x } from './utils' import intro from './intro' const DATA_URL = 'http://127.0.0.1:8080/' class App extends PureComponent {   state = {}   async componentDidMount() {     this.model = await tf.loadLayersModel(DATA_URL + '/model.json')     // this.model.summary()     this.CLASSES = await fetch(DATA_URL + '/classes.json').then(res => res.json())   }   predict = async (file) => {     const img = await file2img(file)     this.setState({       imgSrc: img.src,       isLoading: true     })     setTimeout(() => {       const pred = tf.tidy(() => {         const x = img2x(img)         return this.model.predict(x)       })       const results = pred.arraySync()[0]         .map((score, i) => ({score, label: this.CLASSES[i]}))         .sort((a, b) => b.score - a.score)       this.setState({         results,         isLoading: false       })     }, 0)   }   renderResult = (item) => {     const finalScore = Math.round(item.score * 100)     return (       <tr key={item.label}>         <td style={{ width: 80, padding: '5px 0' }}>{item.label}</td>         <td>           <Progress percent={finalScore} status={finalScore === 100 ? 'success' : 'normal'} />         </td>       </tr>     )   }   render() {     const { imgSrc, results, isLoading } = this.state     const finalItem = results && {...results[0], ...intro[results[0].label]}     return (       <div style={{padding: 20}}>         <span           style={{ color: '#cccccc', textAlign: 'center', fontSize: 12, display: 'block' }}         >识别可能不准确</span>         <Button           type="primary"           size="large"           style={{width: '100%'}}           onClick={() => this.upload.click()}         >           选择图片识别         </Button>         <input           type="file"           onChange={e => this.predict(e.target.files[0])}           ref={el => {this.upload = el}}           style={{ display: 'none' }}         />         {           !results && !imgSrc && <Empty style={{ marginTop: 40 }} />         }         {imgSrc && <div style={{ marginTop: 20, textAlign: 'center' }}>           <img src={imgSrc} style={{ maxWidth: '100%' }} />         </div>}         {finalItem && <div style={{marginTop: 20}}>识别结果: </div>}         {finalItem && <div style={{display: 'flex', alignItems: 'flex-start', marginTop: 20}}>           <img             src={finalItem.icon}             width={120}           />           <div>             <h2 style={{color: finalItem.color}}>               {finalItem.label}             </h2>             <div style={{color: finalItem.color}}>               {finalItem.intro}             </div>           </div>         </div>}         {           isLoading && <Spin size="large" style={{display: 'flex', justifyContent: 'center', alignItems: 'center', marginTop: 40 }} />         }         {results && <div style={{ marginTop: 20 }}>           <table style={{width: '100%'}}>             <tbody>               <tr>                 <td>类别</td>                 <td>匹配度</td>               </tr>               {results.map(this.renderResult)}             </tbody>           </table>         </div>}       </div>     )   } } export default App 复制代码

index.html

<!DOCTYPE html> <html>   <head>     <title>垃圾分类</title>     <meta name="viewport" content="width=device-width, inital-scale=1">   </head>   <body>     <div id="app"></div>     <script src="./index.js"></script>   </body> </html> 复制代码

index.js

import React from 'react' import ReactDOM from 'react-dom' import App from './App' ReactDOM.render(<App />, document.querySelector('#app')) 复制代码

intro.js

export default {   '可回收物': {     icon: 'https://lajifenleiapp.com/static/svg/1_3F6BA8.svg',     color: '#3f6ba8',     intro: '是指在日常生活中或者为日常生活提供服务的活动中产生的,已经失去原有全部或者部分使用价值,回收后经过再加工可以成为生产原料或者经过整理可以再利用的物品,包括废纸类、塑料类、玻璃类、金属类、织物类等。'   },   '有害垃圾': {     icon: 'https://lajifenleiapp.com/static/svg/2v_B43953.svg',     color: '#b43953',     intro: '是指生活垃圾中对人体健康或者自然环境造成直接或者潜在危害的物质,包括废充电电池、废扣式电池、废灯管、弃置药品、废杀虫剂(容器)、废油漆(容器)、废日用化学品、废水银产品、废旧电器以及电子产品等。'   },   '厨余垃圾': {     icon: 'https://lajifenleiapp.com/static/svg/3v_48925B.svg',     color: '#48925b',     intro: '是指居民日常生活中产生的有机易腐垃圾,包括菜叶、剩菜、剩饭、果皮、蛋壳、茶渣、骨头等。'   },   '其他垃圾': {     icon: 'https://lajifenleiapp.com/static/svg/4_89918B.svg',     color: '#89918b',     intro: '是指除可回收物、有害垃圾和厨余垃圾之外的,混杂、污染、难分类的其他生活垃圾。'   } } 复制代码

utils.js

import * as tf from '@tensorflow/tfjs' export const file2img = async (f) => {   return new Promise(reslove => {     const reader = new FileReader()     reader.readAsDataURL(f)     reader.onload = (e) => {       const img = document.createElement('img')       img.src = e.target.result       img.width = 224       img.height = 224       img.onload = () => { reslove(img) }     }   }) } export function img2x(imgEl) {   return tf.tidy(() => {     return tf.browser.fromPixels(imgEl)         .toFloat().sub(255/2).div(255/2)         .reshape([1, 224, 224, 3])   }) } 复制代码

运行项目代码之前,我们需要先在 train 目录下运行,node index.js,生成 model.json 以供识别系统使用。之后需要在根目录下运行 hs outputDir --cors, 使得生成的 model.json 运行在 http 环境下,之后才可以运行 npm start ,不然项目是会报错的。

主要的代码就是上面这些。前面笔者也说了。自己对这方面完全不懂,所以也无法解说其中的代码。各位感兴趣就自己研究一下。代码地址奉上。


作者:淘淘是只狗
链接:https://juejin.cn/post/7048472221798367262


文章分类
代码人生
版权声明:本站是系统测试站点,无实际运营。本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容, 请发送邮件至 XXXXXXo@163.com 举报,一经查实,本站将立刻删除。
相关推荐