diff --git a/README.md b/README.md index 5b03c22..59e0d10 100644 --- a/README.md +++ b/README.md @@ -20,8 +20,9 @@ MegFlow 提供快速视觉应用落地流程,最快 15 分钟搭建起视频 * [build on win10](docs/how-to-build-and-run/build-on-win10.zh.md) * [generate rtsp](docs/how-to-build-and-run/generate-rtsp.zh.md) * how to use - * [add my first service](docs/how-to-add-my-service/01-single-classification-model.zh.md) - * [how to optimize and debug](docs/how-to-debug.zh.md) + * [tutorial01: image classification service](docs/how-to-add-my-service/01-single-classification-model.zh.md) + * [tutorial02: detect and classify on video stream](docs/how-to-add-my-service/02-single-det-classify.zh.md) +* [how to debug](docs/how-to-debug.zh.md) * [how to contribute](docs/how-to-contribute.zh.md) * [FAQ](docs/FAQ.zh.md) diff --git a/ci/run_pylint_check.sh b/ci/run_pylint_check.sh index 7434b65..be0de0a 100755 --- a/ci/run_pylint_check.sh +++ b/ci/run_pylint_check.sh @@ -1,5 +1,5 @@ python -m pip install pylint==2.5.2 -CHECK_DIR="flow-python/examples/simple_classification flow-python/examples/cat_finder flow-python/examples/electric_bicycle" +CHECK_DIR="flow-python/examples/simple_classification flow-python/examples/simple_det_classify flow-python/examples/cat_finder flow-python/examples/electric_bicycle" pylint $CHECK_DIR || pylint_ret=$? if [ "$pylint_ret" ]; then exit $pylint_ret diff --git a/docs/how-to-add-my-service/01-single-classification-model.zh.md b/docs/how-to-add-my-service/01-single-classification-model.zh.md index 3cfa6c0..8650a58 100644 --- a/docs/how-to-add-my-service/01-single-classification-model.zh.md +++ b/docs/how-to-add-my-service/01-single-classification-model.zh.md @@ -17,7 +17,7 @@ $ python3 dump.py $ ls -lah model.mge ... ``` -`dump.py` 正在 PR 到 MegEngine/models +`dump.py` 已经 PR 到 [MegEngine/models 分类模型目录](https://github.com/MegEngine/Models/tree/master/official/vision/classification) ```bash $ cat dump.py ... @@ -125,20 +125,49 @@ class Classify: * `__init__` 里加载模型,做个 warmup 防止首次推理太慢 * 解码成 BGR 的 data 在 `envelope.msg['data']`,推理,send 返回 json string -[更多 node 说明](appendix-B-python-plugin.zh.md) +[classify.py 各参数说明](appendix-B-python-plugin.zh.md) ## 运行测试 运行服务 ```bash $ cd flow-python/examples -$ run_with_plugins -c simple_classification/image_cpu.toml -p simple_classification # 源码/docker 编译方式用这条命令 +$ run_with_plugins -c simple_classification/image_cpu.toml -p simple_classification ``` -浏览器打开 8084 端口服务(例如 http://10.122.101.175:8084/docs ),选择一张图“try it out”即可。 +### WebUI 方式 +浏览器打开 8084 端口服务(例如 http://127.0.0.1:8084/docs ),选择一张图“try it out”即可。 -## 其他 +### 命令行方式 +```bash +$ curl http://127.0.0.1:8081/analyze/image_name -X POST --header "Content-Type:image/*" --data-binary @test.jpeg +``` + +`image_name` 是用户自定义参数,用在需要 POST 内容的场景。这里随便填即可;`test.jpeg` 是测试图片 + +### Python Client + +```Python +$ cat ${MegFlow_DIR}/flow-python/examples/simple_classification/client.py -一、http 客户端开发 +import requests +import cv2 + +def test(): + ip = 'localhost' + port = '8084' + url = 'http://{}:{}/analyze/any_content'.format(ip, port) + img = cv2.imread("./test.jpg") + _, data = cv2.imencode(".jpg", img) + data = data.tobytes() + + headers = {'Content-Length': '%d' % len(data), 'Content-Type': 'image/*'} + res = requests.post(url, data=data, headers=headers) + print(res.content) + +if __name__ == "__main__": + test() +``` -rweb/Swagger 提供了 http RESTful API 描述文件,例如在 http://10.122.101.175:8084/openapi.json 。`swagger_codegen` 可用描述文件生成各种语言的调用代码。更多教程见 [swagger codegen tutorial ](https://swagger.io/tools/swagger-codegen/)。 +### 其他语言 +rweb/Swagger 提供了 http RESTful API 描述文件,例如在 http://127.0.0.1:8084/openapi.json 。`swagger_codegen` 可用描述文件生成 java/go 等语言的调用代码。更多教程见 [swagger codegen tutorial ](https://swagger.io/tools/swagger-codegen/)。 diff --git a/docs/how-to-add-my-service/02-single-det-classify.zh.md b/docs/how-to-add-my-service/02-single-det-classify.zh.md new file mode 100644 index 0000000..ec5b3a2 --- /dev/null +++ b/docs/how-to-add-my-service/02-single-det-classify.zh.md @@ -0,0 +1,147 @@ +# 串连检测和分类 + +本文将在 [tutorial01](01-single-classification-model.zh.md) 的基础上扩展计算图:先检测、再扣图分类。对外提供视频解析服务。完整的代码在 [simple_det_classify](../../flow-python/examples/simple_det_classify) 。 + +## 移除分类预处理 + +之前提到过:MegEngine 除了不需要转模型,还能消除预处理。我们修改 `dump.py` 把预处理从 SDK/业务代码提到模型内。这样的好处是:**划清工程和算法的边界**,预处理本来就应该由 scientist 维护,每次只需要 release mge 文件,减少交接内容 + +```bash +$ cat ${MegFlow}/flow-python/examples/simple_det_classify/dump.py +... + data = mge.Tensor(np.ones(shape, dtype=np.uint8)) + + @jit.trace(capture_as_const=True) + def pred_func(data): + out = data.astype(np.float32) + # resnet18 预处理 + output_h, output_w = 224, 224 + # resize + M = mge.tensor(np.array([[1,0,0], [0,1,0], [0,0,1]], dtype=np.float32).reshape((1,3,3))) + out = F.vision.warp_perspective(out, M, (output_h, output_w), format='NHWC') + # mean + _mean = mge.Tensor(np.array([103.530, 116.280, 123.675], dtype=np.float32)) + out = F.sub(out, _mean) + # div + _div = mge.Tensor(np.array([57.375, 57.120, 58.395], dtype=np.float32)) + out = F.div(out, _div) + # dimshuffile + out = F.transpose(out, (0,3,1,2)) + + outputs = model(out) + return outputs +... +``` +具体实现是在 trace inference 里增加预处理动作,fuse opr 优化加速的事情交给 MegEngine 即可。更多 cv 操作参照 [MegEngine API 文档](https://megengine.org.cn/doc/stable/zh/reference/api/megengine.functional.vision.warp_perspective.html?highlight=warp_perspective)。 + +因为推理输入变成了 BGR,所以 dump 模型的时候参数也应该跟着变 +```bash +$ python3 dump.py -a resnet18 -s 1 224 224 3 +``` + +## 准备检测模型 +这里直接用现成的 YOLOX mge 模型。复用 [cat_finder 的检测](../../flow-python/examples/cat_finder/det.py) 或者从 [YOLOX 官网](https://github.com/Megvii-BaseDetection/YOLOX/tree/main/demo/MegEngine/python) 下载最新版。 + +## 配置计算图 +`flow-python/examples` 增加 `simple_det_classify/video_cpu.toml` + +```bash +$ cat flow-python/examples/simple_det_classify/video_cpu.toml + +main = "tutorial_02" + +# 重资源结点要先声明 +[[nodes]] +name = "det" +ty = "Detect" +model = "yolox-s" +conf = 0.25 +nms = 0.45 +tsize = 640 +path = "models/simple_det_classify_models/yolox_s.mge" +interval = 5 +visualize = 1 +device = "cpu" +device_id = 0 + +[[nodes]] +name = "classify" +ty = "Classify" +path = "models/simple_det_classify_models/resnet18_preproc_inside.mge" +device = "cpu" +device_id = 0 + +[[graphs]] +name = "subgraph" +inputs = [{ name = "inp", cap = 16, ports = ["det:inp"] }] +outputs = [{ name = "out", cap = 16, ports = ["classify:out"] }] +# 描述连接关系 +connections = [ + { cap = 16, ports = ["det:out", "classify:inp"] }, +] + +... +# ty 改成 VdieoServer + [[graphs.nodes]] + name = "source" + ty = "VideoServer" + port = 8085 + +... +``` +想对上一期的配置,需要关注 3 点: +* 视频流中的重资源结点,需要声明在 `[[graphs]]` 之外,因为多路视频需要复用这个结点。如果每一路都要启一个 det 结点,资源会爆掉 +* `connections` 不再是空白,因为两个结点要描述连接关系 +* Server 类型改成 `VideoServer`,告诉 UI 是要处理视频的 + +## 实现细节 +* 可以看到此时 [resnet18 的 lite.py](../../flow-python/examples/simple_det_classify/lite.py) 已经删除了 preprocess 函数 +* det.py 可以直接用 `cat_finder` 的 + +## 运行测试 + +运行服务 +```bash +$ cd flow-python/examples +$ run_with_plugins -c simple_det_classify/video_cpu.toml -p simple_det_classify +``` + +### WebUI 方式 +浏览器打开 8085 端口服务(例如 http://127.0.0.1:8085/docs ) + +* 参照 [如何生成 rtsp](../how-to-build-and-run/generate-rtsp.zh.md),提供一个 rtsp 流地址 +* 或者干脆给 .mp4 文件的绝对路径(文件和 8085 服务在同一台机器上) + +### 命令行方式 +```bash +$ curl -X POST 'http://127.0.0.1:8085/start/rtsp%3A%2F%2F127.0.0.1%3A8554%2Ftest1.ts' # start rtsp://127.0.0.1:8554/test1.ts +start stream whose id is 2% +$ curl 'http://127.0.0.1:8085/list' # list all stream +[{"id":1,"url":"rtsp://10.122.101.175:8554/test1.ts"},{"id":0,"url":"rtsp://10.122.101.175:8554/test1.ts"}]% +``` +路径中的 `%2F`、`%3A` 是 [URL](https://www.ietf.org/rfc/rfc1738.txt) 的转义字符 + +### Python Client + +```Python +$ cat ${MegFlow_DIR}/flow-python/examples/simple_det_classify/client.py + +import requests +import urllib + + +def test(): + ip = 'localhost' + port = '8085' + video_path = 'rtsp://127.0.0.1:8554/vehicle.ts' + video_path = urllib.parse.quote(video_path, safe='') + url = 'http://{}:{}/start/{}'.format(ip, port, video_path) + + res = requests.post(url) + ret = res.content + print(ret) + + +if __name__ == "__main__": + test() +``` diff --git a/docs/how-to-build-and-run/build-from-source.zh.md b/docs/how-to-build-and-run/build-from-source.zh.md index 4f65608..36f2926 100644 --- a/docs/how-to-build-and-run/build-from-source.zh.md +++ b/docs/how-to-build-and-run/build-from-source.zh.md @@ -80,7 +80,7 @@ P.S. 默认 ffmpeg 依赖自动从 github 上拉取源码构建,这会使得 ```bash $ cd examples $ cargo build --example run_with_plugins --release # 编译出 megflow bin -$ ln -s ../../target/example/run_with_plugins +$ ln -s ../../target/release/examples/run_with_plugins $ ./run_with_plugins -p logical_test ``` `logical_test` 是 examples 下最基础的计算图测试用例,运行能正常结束表示 MegFlow 编译成功、基本语义无问题。 diff --git a/docs/how-to-debug.zh.md b/docs/how-to-debug.zh.md index 4e768b5..4a291df 100644 --- a/docs/how-to-debug.zh.md +++ b/docs/how-to-debug.zh.md @@ -1 +1,12 @@ -# \ No newline at end of file +# 如何 Debug 常见问题 + +一、`run_with_plugins` 无法启动服务,直接 core dump 报错退出 + +如果“Python 开机自检”的 `run_with_plugins -p logical_test` 能够正常结束,排查方向应该是 Python import error。调试方法举例 +```bash +$ gdb --args ./run_with_plugins -c electric_bicycle/electric_bicycle_cpu.toml -p electric_bicycle +... +illegal instruction +... +``` +可以看到 crash 发生在哪个 import \ No newline at end of file diff --git a/flow-python/examples/cat_finder/README.md b/flow-python/examples/cat_finder/README.md index eb189b9..3349bba 100644 --- a/flow-python/examples/cat_finder/README.md +++ b/flow-python/examples/cat_finder/README.md @@ -1,7 +1,7 @@ # 猫猫围栏 ## 一、功能概述 -注册的猫猫离开围栏,会收到一条告警信息。未注册的不会报警。 +注册的猫猫离开围栏,会收到一条告警信息。未注册的不会报警。 CPU 配置已提供,没有 GPU 也可以运行。 ## 二、模型和自测数据下载 diff --git a/flow-python/examples/simple_classification/client.py b/flow-python/examples/simple_classification/client.py new file mode 100644 index 0000000..4b3394e --- /dev/null +++ b/flow-python/examples/simple_classification/client.py @@ -0,0 +1,30 @@ +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 +import cv2 +import requests + + +def test(): + ip = 'localhost' + port = '8084' + user_define_string = 'content' + url = f'http://{ip}:{port}/analyze/{user_define_string}' + img = cv2.imread("./test.jpg") + _, data = cv2.imencode(".jpg", img) + data = data.tobytes() + + headers = {'Content-Length': f'{len(data)}', 'Content-Type': 'image/*'} + res = requests.post(url, data=data, headers=headers) + print(res.content) + + +if __name__ == "__main__": + test() diff --git a/flow-python/examples/simple_classification/lite.py b/flow-python/examples/simple_classification/lite.py index 034a7c4..68b27c9 100644 --- a/flow-python/examples/simple_classification/lite.py +++ b/flow-python/examples/simple_classification/lite.py @@ -1,6 +1,13 @@ -#!/usr/bin/env python3 -# -*- coding:utf-8 -*- -# Copyright (c) Megvii, Inc. and its affiliates. +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 import argparse import cv2 diff --git a/flow-python/examples/simple_det_classify/__init__.py b/flow-python/examples/simple_det_classify/__init__.py new file mode 100644 index 0000000..f57558f --- /dev/null +++ b/flow-python/examples/simple_det_classify/__init__.py @@ -0,0 +1,10 @@ +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 diff --git a/flow-python/examples/simple_det_classify/classify.py b/flow-python/examples/simple_det_classify/classify.py new file mode 100644 index 0000000..1cfc848 --- /dev/null +++ b/flow-python/examples/simple_det_classify/classify.py @@ -0,0 +1,72 @@ +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 + +import json +import numpy as np +from loguru import logger +from megflow import register + +from .lite import PredictorLite + + +@register(inputs=['inp'], outputs=['out']) +class Classify: + def __init__(self, name, arg): + logger.info("loading Resnet18 Classification...") + self.name = name + + # load ReID model and warmup + self._model = PredictorLite(path=arg['path'], + device=arg['device'], + device_id=arg['device_id']) + warmup_data = np.zeros((224, 224, 3), dtype=np.uint8) + self._model.inference(warmup_data) + logger.info("Resnet18 loaded.") + + def expand(self, box, max_w, max_h, ratio): + l = box[0] + r = box[2] + t = box[1] + b = box[3] + center_x = (l + r) / 2 + center_y = (t + b) / 2 + w_side = (r - l) * ratio / 2 + h_side = (b - t) * ratio / 2 + + l = center_x - w_side + r = center_x + w_side + t = center_y - h_side + b = center_y + h_side + l = max(0, l) + t = max(0, t) + r = min(max_w, r) + b = min(max_h, b) + return int(l), int(t), int(r), int(b) + + def exec(self): + envelope = self.inp.recv() + if envelope is None: + return + + data = envelope.msg['data'] + items = envelope.msg['items'] + results = [] + for _, item in enumerate(items): + assert 'bbox' in item + bbox = item['bbox'] + l, t, r, b = self.expand(bbox, data.shape[1], data.shape[0], 1.1) + _type = self._model.inference(data[t:b, l:r]) + results.append({ + "type": str(_type), + "frame_id": str(envelope.partial_id) + }) + + self.out.send(envelope.repack(json.dumps(results))) diff --git a/flow-python/examples/simple_det_classify/client.py b/flow-python/examples/simple_det_classify/client.py new file mode 100644 index 0000000..22ba281 --- /dev/null +++ b/flow-python/examples/simple_det_classify/client.py @@ -0,0 +1,28 @@ +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 +import urllib +import requests + + +def test(): + ip = 'localhost' + port = '8085' + video_path = 'rtsp://127.0.0.1:8554/vehicle.ts' + video_path = urllib.parse.quote(video_path, safe='') + url = 'http://{}:{}/start/{}'.format(ip, port, video_path) + + res = requests.post(url) + ret = res.content + print(ret) + + +if __name__ == "__main__": + test() diff --git a/flow-python/examples/simple_det_classify/det.py b/flow-python/examples/simple_det_classify/det.py new file mode 100644 index 0000000..60a1d38 --- /dev/null +++ b/flow-python/examples/simple_det_classify/det.py @@ -0,0 +1,78 @@ +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 + +import numpy as np +from loguru import logger +from megflow import register +from warehouse.detection_yolox import PredictorLite + + +@register(inputs=['inp'], outputs=['out']) +class Detect: + def __init__(self, name, args): + logger.info("loading YOLOX detection...") + self._tsize = args['tsize'] + self._interval = args['interval'] + self._visualize = args['visualize'] + self.name = name + + # load detect model and warmup + self._predictor = PredictorLite(path=args['path'], + confthre=args['conf'], + nmsthre=args['nms'], + test_size=(self._tsize, self._tsize), + device=args['device'], + device_id=args['device_id']) + warmup_data = np.zeros((224, 224, 3), dtype=np.uint8) + self._predictor.inference(warmup_data) + logger.info(" YOLOX loaded.") + + @staticmethod + def restrict(val, min, max): + assert min < max + if val < min: + val = min + if val > max: + val = max + return val + + def exec(self): + envelope = self.inp.recv() + if envelope is None: + return + + msg = envelope.msg + msg['items'] = [] + + process = envelope.partial_id % self._interval == 0 + if process: + data = msg['data'] + outputs = self._predictor.inference(data) + # skip if detect nothing + if outputs is not None: + items = [] + + for i in range(outputs.shape[0]): + output = outputs[i] + item = dict() + item["bbox"] = output[0:4] + item["det_score"] = output[4] * output[5] + items.append(item) + msg['items'] = items + + # import cv2 + # x = self._predictor.visual(outputs, data) + # name = 'frame{0:07d}.jpg'.format(envelope.partial_id) + # cv2.imwrite(name, x) + + if self._visualize == 1: + msg['data'] = self._predictor.visual(outputs, data) + self.out.send(envelope) diff --git a/flow-python/examples/simple_det_classify/dump.py b/flow-python/examples/simple_det_classify/dump.py new file mode 100644 index 0000000..acae30f --- /dev/null +++ b/flow-python/examples/simple_det_classify/dump.py @@ -0,0 +1,97 @@ +# -*- coding: utf-8 -*- +# MegEngine is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +# pylint: skip-file +import argparse +import sys + +# pylint: disable=import-error +import resnet.model as resnet_model +# pylint: disable=import-error +import shufflenet.model as snet_model + +import numpy as np + +import megengine as mge +import megengine.functional as F +from megengine import jit + + +def dump_static_graph(model, graph_name, shape): + model.eval() + + data = mge.Tensor(np.ones(shape, dtype=np.uint8)) + + @jit.trace(capture_as_const=True) + def pred_func(data): + out = data.astype(np.float32) + + output_h, output_w = 224, 224 + # resize + M = mge.tensor(np.array([[1,0,0], [0,1,0], [0,0,1]], dtype=np.float32).reshape((1,3,3))) + out = F.vision.warp_perspective(out, M, (output_h, output_w), format='NHWC') + # mean + _mean = mge.Tensor(np.array([103.530, 116.280, 123.675], dtype=np.float32)) + out = F.sub(out, _mean) + # div + _div = mge.Tensor(np.array([57.375, 57.120, 58.395], dtype=np.float32)) + out = F.div(out, _div) + # dimshuffile + out = F.transpose(out, (0,3,1,2)) + + outputs = model(out) + return outputs + + pred_func(data) + pred_func.dump( + graph_name, + arg_names=["data"], + optimize_for_inference=True, + enable_fuse_conv_bias_nonlinearity=True, + ) + + +def main(): + parser = argparse.ArgumentParser(description="MegEngine Classification Dump .mge") + parser.add_argument( + "-a", + "--arch", + default="resnet18", + help="model architecture (default: resnet18)", + ) + parser.add_argument( + "-s", + "--shape", + type=int, + nargs='+', + default="1 3 224 224", + help="input shape (default: 1 3 224 224)" + ) + parser.add_argument( + "-o", + "--output", + type=str, + default="model.mge", + help="output filename" + ) + + args = parser.parse_args() + + if 'resnet' in args.arch: + model = getattr(resnet_model, args.arch)(pretrained=True) + elif 'shufflenet' in args.arch: + model = getattr(snet_model, args.arch)(pretrained=True) + else: + print('unavailable arch {}'.format(args.arch)) + sys.exit() + dump_static_graph(model, args.output, tuple(args.shape)) + + +if __name__ == "__main__": + main() diff --git a/flow-python/examples/simple_det_classify/lite.py b/flow-python/examples/simple_det_classify/lite.py new file mode 100644 index 0000000..d8ec894 --- /dev/null +++ b/flow-python/examples/simple_det_classify/lite.py @@ -0,0 +1,82 @@ +# MegFlow is Licensed under the Apache License, Version 2.0 (the "License") +# +# Copyright (c) 2019-2021 Megvii Inc. All rights reserved. +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + +#!/usr/bin/env python +# coding=utf-8 + +import argparse +import time +import cv2 +import numpy as np +import megenginelite as mgelite +from loguru import logger + + +class PredictorLite: + def __init__( + self, + path, + device="gpu", + device_id=0, + ): + + if "gpu" in device.lower(): + device_type = mgelite.LiteDeviceType.LITE_CUDA + else: + device_type = mgelite.LiteDeviceType.LITE_CPU + + net_config = mgelite.LiteConfig(device_type=device_type) + ios = mgelite.LiteNetworkIO() + ios.add_input(mgelite.LiteIO("data", is_host=True)) + + net = mgelite.LiteNetwork(config=net_config, io=ios) + net.device_id = device_id + net.load(path) + + self.net = net + + def inference(self, img): + t0 = time.time() + + img = cv2.resize(img, (224, 224)) + # build input tensor + inp_data = self.net.get_io_tensor("data") + inp_data.set_data_by_share(img) + + # forward + self.net.forward() + self.net.wait() + + # postprocess + output_keys = self.net.get_all_output_name() + output = self.net.get_io_tensor(output_keys[0]).to_numpy() + logger.debug("resnet18 infer time: {:.4f}s".format(time.time() - t0)) + + return np.argmax(output[0]) + + +def make_parser(): + parser = argparse.ArgumentParser("Classification Demo!") + parser.add_argument("--path", + default="./test.png", + help="path to images or video") + parser.add_argument("--model", + default=None, + type=str, + help=".mge for eval") + return parser + + +if __name__ == "__main__": + args = make_parser().parse_args() + predictor = PredictorLite(args.model) + image = cv2.imread(args.path) + if image is None: + logger.error(f"open {args.path} failed") + out = predictor.inference(image) + logger.info(f'{out}') diff --git a/flow-python/examples/simple_det_classify/video_cpu.toml b/flow-python/examples/simple_det_classify/video_cpu.toml new file mode 100644 index 0000000..f63a662 --- /dev/null +++ b/flow-python/examples/simple_det_classify/video_cpu.toml @@ -0,0 +1,45 @@ +main = "tutorial_02" + +[[nodes]] +name = "det" +ty = "Detect" +model = "yolox-s" +conf = 0.25 +nms = 0.45 +tsize = 640 +path = "models/simple_det_classify_models/yolox_s.mge" +interval = 5 +visualize = 1 +device = "cpu" +device_id = 0 + +[[nodes]] +name = "classify" +ty = "Classify" +path = "models/simple_det_classify_models/resnet18_preproc_inside.mge" +device = "cpu" +device_id = 0 + +[[graphs]] +name = "subgraph" +inputs = [{ name = "inp", cap = 16, ports = ["det:inp"] }] +outputs = [{ name = "out", cap = 16, ports = ["classify:out"] }] +connections = [ + { cap = 16, ports = ["det:out", "classify:inp"] }, +] + +[[graphs]] +name = "tutorial_02" +connections = [ + { cap = 16, ports = ["source:out", "destination:inp"] }, + { cap = 16, ports = ["source:inp", "destination:out"] } +] + + [[graphs.nodes]] + name = "source" + ty = "VideoServer" + port = 8085 + + [[graphs.nodes]] + name = "destination" + ty = "subgraph"