在 Flutter 中使用 TensorFlow Lite 插件实现文字分类

文/ Amish Garg,Google Summer of Code(GSoC) 实习生,译/ Yuan,谷创字幕组,审校/ Xinlei、Lynn Wang,CFUG 社区。

Authors.

如果您希望能有一种简单、高效且灵活的方式把 TensorFlow 模型集成到 Flutter 应用里,那请您一定不要错过我们今天介绍的这个全新插件 tflite_flutter。这个插件的开发者是 Google Summer of Code(GSoC) 的一名实习生 Amish Garg,本文来自他在 Medium 上的一篇文章《在 Flutter 中使用 TensorFlow Lite 插件实现文字分类》

If you wished that there was an easy, efficient, and flexible way to integrate TensorFlow trained models with your flutter apps, I am glad to announce the release of a new plugin tflite_flutter.

tflite_flutter 插件的核心特性:

Key features of tflite_flutter

  • 它提供了与 TFLite Java 和 Swift API 相似的 Dart API,所以其灵活性和在这些平台上的效果是完全一样的

    It provides a Dart API similar to the TFLite Java and Swift APIs, thus no compromise with the flexibility offered on those platforms.

  • 通过 dart:ffi 直接与 TensorFlow Lite C API 相绑定,所以它比其它平台集成方式更加高效。

    Directly binds to the TensorFlow Lite C API using dart:ffi, making it more efficient than platform integration approaches.

  • 无需编写特定平台的代码。

    No need to write any platform-specific code.

  • 通过 NNAPI 提供加速支持,在 Android 上使用 GPU Delegate,在 iOS 上使用 Metal Delegate。

    Offers acceleration support using NNAPI, GPU delegates on Android, and Metal delegate on iOS.

本文中,我们将使用 tflite_flutter 构建一个 文字分类 Flutter 应用 带您体验 tflite_flutter 插件,首先从新建一个 Flutter 项目 text_classification_app 开始。

In this article, I will walk you through building a Text Classification Flutter App using tflite_flutter. Let’s get started by creating a new flutter project text_classification_app.

(很重要)初始化配置

(Important) Initial setup

Linux 和 Mac 用户

Linux and Mac users

install.sh 拷贝到您应用的根目录,然后在根目录执行 sh install.sh,本例中就是目录 text_classification_app/

Copy the install.sh file in the root folder of your app, and execute the command, sh install.sh in the root folder, text_classification_app/ in our case.

Windows 用户

Windows users

install.bat 文件拷贝到应用根目录,并在根目录运行批处理文件 install.bat,本例中就是目录 text_classification_app/。 

Copy the install.bat file in the root folder of your app, and execute the command, install.bat in the root folder, text_classification_app/ in our case.

它会自动从 release assets 下载最新的二进制资源,然后把它放到指定的目录下。

This will automatically download the latest binaries from release assets and place them in appropriate folders for you.

请点击到 README 文件里查看更多 关于初始配置的信息

Refer to the readme for more info on the initial setup. 

获取插件

Getting the plugin

pubspec.yaml 添加 tflite_flutter: ^<latest_version>详情)。

In pubspec.yaml include tflite_flutter: ^<latest_version> (details here).

下载模型

Downloading the model

要在移动端上运行 TensorFlow 训练模型,我们需要使用 .tflite 格式。如果需要了解如何将 TensorFlow 训练的模型转换为 .tflite 格式,请参阅官方指南。 

To use any TensorFlow trained model on mobile, we need to obtain it in .tflite format. For more information on how to convert a TensorFlow trained model to .tflite format, refer to this official guide.

这里我们准备使用 TensorFlow 官方站点上预训练的文字分类模型,可从这里下载

We are going to use the pre-trained Text Classification Model available on the TensorFlow website. Click here to download.

This pretrained model predicts if a paragraph’s sentiment is positive or negative. It was trained on Large Movie Review Dataset v1.0 from Mass et al, which consists of IMDB movie reviews labeled as either positive or negative. Find more info here. 

该预训练的模型可以预测当前段落的情感是积极还是消极。它是基于来自 Mass 等人的  Large Movie Review Dataset v1.0 数据集进行训练的。数据集由基于 IMDB 电影评论所标记的积极或消极标签组成,点击查看更多信息

text_classification.tflitetext_classification_vocab.txt 文件拷贝到 text_classification_app/assets/ 目录下。

Place text_classification.tflite and text_classification_vocab.txt in the text_classification_app/assets/ directory.

pubspec.yaml 文件中添加 assets/

Include assets/ in pubspec.yaml .

assets:    
  - assets/

现在万事俱备,我们可以开始写代码了。 🚀

Now, we are all set, to begin with coding. 🚀

实现分类器

Coding the classifier

预处理

Pre-processing

正如 文字分类模型页面 里所提到的。可以按照下面的步骤使用模型对段落进行分类:

As mentioned on the text_classification model’s page, Here are the steps to classify a paragraph with the model:

  1. 对段落文本进行分词,然后使用预定义的词汇集将它转换为一组词汇 ID;

    Tokenize the paragraph and convert it to a list of word ids using a predefined vocabulary.

  2. 将生成的这组词汇 ID 输入 TensorFlow Lite 模型里;

    Feed the list to the TensorFlow Lite model.

  3. 从模型的输出里获取当前段落是积极或者是消极的概率值。

    Get the probability of the paragraph being positive or negative from the model outputs.

我们首先写一个方法对原始字符串进行分词,其中使用 text_classification_vocab.txt 作为词汇集。

We will first write a method to tokenize the raw string using text_classification_vocab.txt as vocabulary.

在 lib/ 文件夹下创建一个新文件 classifier.dart。 

Create a new file classifier.dart under the lib/ folder.

这里先写代码加载 text_classification_vocab.txt 到字典里。

Let’s first write code to load text_classification_vocab.txt to a dictionary.

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';
  
  Map<String, int> _dict;

  Classifier() {
    _loadDictionary();
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }
  
}

加载字典

Loading Dictionary

现在我们来编写一个函数对原始字符串进行分词。

Now, we will write a function to tokenize the raw string.

import 'package:flutter/services.dart';

class Classifier {
  final _vocabFile = 'text_classification_vocab.txt';

  // 单句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;
  
  List<List<double>> tokenizeInputText(String text) {
    
    // 使用空格进行分词
    final toks = text.split(' ');
    
    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 的对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子里的每个单词在 dict 里找到相应的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
    return [vec];
  }
}


分词

Tokenization

使用 tflite_flutter 进行分析

Inference using tflite_flutter

这是本文的主体部分,这里我们会讨论 tflite_flutter 插件的用途。

This is the main section of this blog, as here we are going to discuss the usage of the tflite_flutter plugin.

这里的分析是指基于输入数据在设备上使用 TensorFlow Lite 模型的处理过程。要使用 TensorFlow Lite 模型进行分析,需要通过 解释器 来运行它,了解更多

The term inference refers to the process of executing a TensorFlow Lite model on-device in order to make predictions based on input data. To perform an inference with a TensorFlow Lite model, you must run it through an interpreter. Learn more. 

创建解释器,加载模型

Creating the interpreter, loading the model

tflite_flutter 提供了一个方法直接通过资源创建解释器。

tflite_flutter provides a method to create the interpreter direct from assets.

static Future<Interpreter> fromAsset(String assetName, {InterpreterOptions options})

由于我们的模型在 assets/ 文件夹下,需要使用上面的方法来创建解析器。对于 InterpreterOptions 的相关说明,请 参考这里

As our model is in assets/ directory we will just use the above method to create the interpreter. For info on InterpreterOptions refer to this.

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化以后加载模型
    _loadModel();
  }

  void _loadModel() async {
    
    // 使用 Interpreter.fromAsset 创建解释器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

}

创建解释器的代码

Code to create Interpreter

如果您不希望将模型放在 assets/ 目录下,tflite_flutter 还提供了工厂构造函数创建解释器,更多信息

If you don’t want to put your model in assets/ directory then tflite_flutter provides factory constructors to create interpreter as well, refer readme.

我们开始进行分析!

Let’s perform Inference!

现在用下面方法启动分析:

We are going to use this method for inference,

void run(Object input, Object output);

注意这里的方法和 Java API 中的是一样的。

Notice that this method is the same as the one provided by Java API.

Object inputObject output 必须是和 Input Tensor 与 Output Tensor 维度相同的列表。

The Object input and Object output must be multi-dimensional lists having the same shape as Input Tensor, and Output Tensor.

要查看  input tensors 和 output tensors 的维度,可以使用如下代码:

To view, the shapes and sizes of input tensors, output tensors you can do,

_interpreter.allocateTensors();
// 打印 input tensor 列表
print(_interpreter.getInputTensors());
// 打印 output tensor 列表
print(_interpreter.getOutputTensors());

在本例中 text_classification 模型的输出如下: 

In the case of our text_classification model,

InputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf280, name: embedding_input, type: TfLiteType.float32, shape: [1, 256], data:  1024]
OutputTensorList:
[Tensor{_tensor: Pointer<TfLiteTensor>: address=0xbffcf140, name: dense_1/Softmax, type: TfLiteType.float32, shape: [1, 2], data:  8]

现在,我们实现分类方法,该方法返回值为 1 表示积极,返回值为 0 表示消极。

Now, lets, write the classify method which returns 1 for positive, and 0 for negative.

int classify(String rawText) {
    
    //  tokenizeInputText 返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);
   
    // [1,2] 形状的输出
    var output = List<double>(2).reshape([1, 2]);
    
    // run 方法会运行分析并且存储输出的值
    _interpreter.run(input, output);

    var result = 0;
    // 如果输出中第一个元素的值比第二个大,那么句子就是消极的
    
    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

用于分析的代码

Code for Inference

在 tflite_flutter 的 extension ListShape on List 下面定义了一些使用的扩展:

There are some useful extensions defined under extension ListShape on List in tflite_flutter,

// 将提供的列表进行矩阵变形,输入参数为元素总数 // 保持相等 
// 用法:List(400).reshape([2,10,20]) 
// 返回  List<dynamic>

List reshape(List<int> shape)
// 返回列表的形状
List<int> get shape
// 返回列表任意形状的元素数量
int get computeNumElements

最终的 classifier.dart 应该是这样的:

The final classifier.dart should look like this,

import 'package:flutter/services.dart';

// 引入 tflite_flutter
import 'package:tflite_flutter/tflite_flutter.dart';

class Classifier {
  // 模型文件的名称
  final _modelFile = 'text_classification.tflite';
  final _vocabFile = 'text_classification_vocab.txt';

  // 语句的最大长度
  final int _sentenceLen = 256;

  final String start = '<START>';
  final String pad = '<PAD>';
  final String unk = '<UNKNOWN>';

  Map<String, int> _dict;

  // TensorFlow Lite 解释器对象
  Interpreter _interpreter;

  Classifier() {
    // 当分类器初始化的时候加载模型
    _loadModel();
    _loadDictionary();
  }

  void _loadModel() async {
    // 使用 Intepreter.fromAsset 创建解析器
    _interpreter = await Interpreter.fromAsset(_modelFile);
    print('Interpreter loaded successfully');
  }

  void _loadDictionary() async {
    final vocab = await rootBundle.loadString('assets/$_vocabFile');
    var dict = <String, int>{};
    final vocabList = vocab.split('\n');
    for (var i = 0; i < vocabList.length; i++) {
      var entry = vocabList[i].trim().split(' ');
      dict[entry[0]] = int.parse(entry[1]);
    }
    _dict = dict;
    print('Dictionary loaded successfully');
  }

  int classify(String rawText) {
    // tokenizeInputText  返回形状为 [1, 256] 的 List<List<double>>
    List<List<double>> input = tokenizeInputText(rawText);

    //输出形状为 [1, 2] 的矩阵
    var output = List<double>(2).reshape([1, 2]);

    // run 方法会运行分析并且将结果存储在 output 中。
    _interpreter.run(input, output);

    var result = 0;
    // 如果第一个元素的输出比第二个大,那么当前语句是消极的

    if ((output[0][0] as double) > (output[0][1] as double)) {
      result = 0;
    } else {
      result = 1;
    }
    return result;
  }

  List<List<double>> tokenizeInputText(String text) {
    // 用空格分词
    final toks = text.split(' ');

    // 创建一个列表,它的长度等于 _sentenceLen,并且使用 <pad> 对应的字典值来填充
    var vec = List<double>.filled(_sentenceLen, _dict[pad].toDouble());

    var index = 0;
    if (_dict.containsKey(start)) {
      vec[index++] = _dict[start].toDouble();
    }

    // 对于句子中的每个单词,在 dict 中找到相应的 index 值
    for (var tok in toks) {
      if (index > _sentenceLen) {
        break;
      }
      vec[index++] = _dict.containsKey(tok)
          ? _dict[tok].toDouble()
          : _dict[unk].toDouble();
    }

    // 按照我们的解释器输入 tensor 所需的形状 [1,256] 返回 List<List<double>>
    return [vec];
  }
}

现在,可以根据您的喜好实现 UI 的代码,分类器的用法比较简单。

Now, it’s up to you to code the desired UI for this, the usage of classifier would be simple,

// 创建 Classifier 对象
Classifer _classifier = Classifier();
// 将目标语句作为参数,调用 classify 方法
_classifier.classify("I liked the movie");
// 返回 1 (积极的)
_classifier.classify("I didn't liked the movie");
// 返回 0 (消极的)

请在这里查阅完整代码:Text Classification Example app with UI

Check out the complete Text Classification Example app with UI.

Text Classification Example App

文字分类示例应用

Text Classification Example App

了解更多关于 tflite_flutter 插件的信息,请访问 GitHub repo: am15h/tflite_flutter_plugin

Visit the repository am15h/tflite_flutter_plugin on Github to learn more about the tflite_flutter plugin.

答疑

FAQs

问:tflite_flutter 和 tflite v1.0.5 有哪些区别?
Q. How is this plugin tflite_flutter different from tflite v1.0.5

tflite v1.0.5 侧重于为特定用途的应用场景提供高级特性,比如图片分类、物体检测等等。而新的 tflite_flutter 则提供了与 Java API 相同的特性和灵活性,而且可以用于任何 tflite 模型中,它还支持 delegate。

While tflite v1.0.5 focuses on offering some high-level features to build apps with specific use cases like Image Classification, Object Detection, etc…, the new, tflite_flutter offers the same flexibility and features as the Java API and can be used with any tflite model. It also offers support for delegates.

由于使用 dart:ffi (dart ↔️ (ffi) ↔️ C),tflite_flutter 非常快 (拥有低延时)。而 tflite 使用平台集成 (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C)。

tflite_flutter is fast (has low latency) as it uses dart:ffi (dart ↔️ (ffi) ↔️ C) while tflite uses platform integration (dart ↔️ platform-channel ↔️ (Java/Swift) ↔️ JNI ↔️ C).

问:如何使用 tflite_flutter 创建图片分类应用?有没有类似 TensorFlow Lite Android Support Library 的依赖包?
Q. How to create an Image Classification app using tflite_flutter, is there any package similar to TensorFlow Lite Android Support Library?

更新(07/01/2020): TFLite Flutter Helper 开发库已发布。

Update (07/01/2020): TFLite Flutter Helper library is released.

TensorFlow Lite Flutter Helper Library 为处理和控制输入及输出的 TFLite 模型提供了易用的架构。它的 API 设计和文档与 TensorFlow Lite Android Support Library 是一样的。更多信息请 参考这里

TensorFlow Lite Flutter Helper Library provides a simple architecture for processing and manipulating input and output of TFLite Models. Its API design and documentation are identical to the TensorFlow Lite Android Support Library. More info here.

以上是本文的全部内容,欢迎大家对 tflite_flutter 插件进行反馈,请在这里 上报 bug 或提出功能需求

That’s all for this blog, I would love to hear your feedback on tflite_flutter plugin. Feel free to file an issue to report bugs or for feature requests.

谢谢关注。

Thanks for reading.

感谢 Michael Thomsen。

Thanks to Michael Thomsen. 

延展阅读

如果需要关注更多 TensorFlow 和 Google AI 相关内容,请查阅下面资料