From 98eda16d0fe8f21188fe951d65b4a814c0dafd5a Mon Sep 17 00:00:00 2001 From: ZZY <2450266535@qq.com> Date: Wed, 15 May 2024 20:33:56 +0800 Subject: [PATCH] =?UTF-8?q?init=20=E5=9F=BA=E7=A1=80=E8=81=8A=E5=A4=A9?= =?UTF-8?q?=E6=A1=86=E6=9E=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 0 VERSION | 1 + main.py | 9 ++ requirements.txt | 17 +++ src/__init__.py | 0 src/core/__init__.py | 0 src/core/base_asr_engine.py | 4 + src/core/base_chat_ai_engine.py | 10 ++ src/core/base_nlu_engine.py | 4 + src/core/base_tts_engine.py | 4 + src/core/core.py | 148 +++++++++++++++++++++++++++ src/core/echo.py | 19 ++++ src/plugins/__init__.py | 0 src/plugins/chat_ai_tongyiqianwen.py | 55 ++++++++++ src/plugins/defalut.yml | 20 ++++ src/plugins/plugins_conf.py | 50 +++++++++ src/plugins/remote_engine.py | 23 +++++ tests/__init__.py | 0 tests/core/__init__.py | 0 tests/core/test_core.py | 54 ++++++++++ tests/core/test_echo.py | 13 +++ tests/plugins/__init__.py | 0 utils/__init__.py | 5 + utils/config.py | 5 + utils/constants.py | 17 +++ utils/logger.py | 52 ++++++++++ utils/singleton.py | 6 ++ utils/utils.py | 0 28 files changed, 516 insertions(+) create mode 100644 README.md create mode 100644 VERSION create mode 100644 main.py create mode 100644 requirements.txt create mode 100644 src/__init__.py create mode 100644 src/core/__init__.py create mode 100644 src/core/base_asr_engine.py create mode 100644 src/core/base_chat_ai_engine.py create mode 100644 src/core/base_nlu_engine.py create mode 100644 src/core/base_tts_engine.py create mode 100644 src/core/core.py create mode 100644 src/core/echo.py create mode 100644 src/plugins/__init__.py create mode 100644 src/plugins/chat_ai_tongyiqianwen.py create mode 100644 src/plugins/defalut.yml create mode 100644 src/plugins/plugins_conf.py create mode 100644 src/plugins/remote_engine.py create mode 100644 tests/__init__.py create mode 100644 tests/core/__init__.py create mode 100644 tests/core/test_core.py create mode 100644 tests/core/test_echo.py create mode 100644 tests/plugins/__init__.py create mode 100644 utils/__init__.py create mode 100644 utils/config.py create mode 100644 utils/constants.py create mode 100644 utils/logger.py create mode 100644 utils/singleton.py create mode 100644 utils/utils.py diff --git a/README.md b/README.md new file mode 100644 index 0000000..e69de29 diff --git a/VERSION b/VERSION new file mode 100644 index 0000000..8acdd82 --- /dev/null +++ b/VERSION @@ -0,0 +1 @@ +0.0.1 diff --git a/main.py b/main.py new file mode 100644 index 0000000..0df845b --- /dev/null +++ b/main.py @@ -0,0 +1,9 @@ +if __name__ == '__main__': + from src.core.core import ExecutePipeline + from src.core.echo import EchoMiddleWare + from src.plugins.plugins_conf import PluginsConfig + engine = PluginsConfig().get_engine_class("chat_ai_engine") + pipe = ExecutePipeline(engine.to_middleware(), EchoMiddleWare(), EchoMiddleWare()) + # data = '端午是几号' + data = input('请输入: ') + pipe.execute([{'role': 'user', 'content': f'{data}'}]) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..5e7bbb1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +# if you want to use not .wav file, +# you need to install ffmpeg or libav in your system +# pydub +# simpleaudio + +# yaml +pyyaml + +# remote server connect +requests +# openai # optional + +# server optional +# tornado + +# tqdm +# rich # optional \ No newline at end of file diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/__init__.py b/src/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/core/base_asr_engine.py b/src/core/base_asr_engine.py new file mode 100644 index 0000000..cd1da1a --- /dev/null +++ b/src/core/base_asr_engine.py @@ -0,0 +1,4 @@ +from .core import BaseEngine + +class BaseASREngine(BaseEngine): + pass diff --git a/src/core/base_chat_ai_engine.py b/src/core/base_chat_ai_engine.py new file mode 100644 index 0000000..dbf0930 --- /dev/null +++ b/src/core/base_chat_ai_engine.py @@ -0,0 +1,10 @@ +from typing import Dict +from .core import BaseEngine +from logging import getLogger +logger = getLogger(__name__) + +class BaseChatAIEngine(BaseEngine): + """基础AI引擎类, 提供历史记录管理的基础框架。""" + def __init__(self, history:Dict = None, is_stream = False) -> None: + self._is_stream = is_stream + self._history = history \ No newline at end of file diff --git a/src/core/base_nlu_engine.py b/src/core/base_nlu_engine.py new file mode 100644 index 0000000..c02ff20 --- /dev/null +++ b/src/core/base_nlu_engine.py @@ -0,0 +1,4 @@ +from .core import BaseEngine + +class BaseNLUEngine(BaseEngine): + pass \ No newline at end of file diff --git a/src/core/base_tts_engine.py b/src/core/base_tts_engine.py new file mode 100644 index 0000000..e0c43cc --- /dev/null +++ b/src/core/base_tts_engine.py @@ -0,0 +1,4 @@ +from .core import BaseEngine + +class BaseTTSEngine(BaseEngine): + pass \ No newline at end of file diff --git a/src/core/core.py b/src/core/core.py new file mode 100644 index 0000000..24f7348 --- /dev/null +++ b/src/core/core.py @@ -0,0 +1,148 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, List, Optional + +from logging import getLogger +logger = getLogger(__name__) + +class MiddlewareInterface(ABC): + """中间件接口 + """ + + @abstractmethod + def process(self, data: Any) -> Any: + """处理数据并返回处理后的结果""" + pass + + @abstractmethod + def get_next(self) -> Optional[List['MiddlewareInterface']]: + """获取下一个中间件""" + pass + +class EngineInterface(ABC): + """引擎接口 + """ + @abstractmethod + def execute(self, data: Any) -> Any: + """执行引擎逻辑并返回结果""" + pass + + @abstractmethod + def transform(self, data: Any) -> Any: + """将输出数据转换为标准化的格式""" + pass + + @abstractmethod + def to_execute_middleware(self) -> MiddlewareInterface: + """将引擎执行转换为遵循MiddlewareInterface的中间件.""" + pass + + @abstractmethod + def to_transform_middleware(self) -> MiddlewareInterface: + """将引擎数据标准化转换为遵循MiddlewareInterface的中间件.""" + pass + + @abstractmethod + def to_middleware(self) -> MiddlewareInterface: + """将引擎转换为遵循MiddlewareInterface的中间件.""" + pass + +class BaseMiddleware(MiddlewareInterface): + """中间件基类 + 注意 next_middleware 参数是可选的,如果未提供,可能有两种模式 + + 1.如果process返回值为None(只有next_middleware为None返回值才为None) + 则返回数据将在Pipeline里返回,否则继续执行Pipeline的下一个中间件。 + + 2.中间件的process返回值为process_logic处理后的值 + """ + def __init__(self, process_logic: Callable[[Any], Any], + next_middleware: Optional[List[MiddlewareInterface]] = []): + self._process_logic = process_logic + self._next_middleware = next_middleware + self._data = None + + def process(self, data: Any) -> Any: + """处理数据并返回处理后的结果,支持自动调用下一个中间件 + (如果next_middleware为None则返回None)""" + try: + data = self._process_logic(data) + self._data = data + if self._next_middleware: + for middleware in self._next_middleware: + self._data = middleware.process(self._data) + return None if self._next_middleware is None else data + except Exception as e: + logger.exception(f"Error occurred during processing: {e}") + raise e + + def get_next(self) -> Optional[List[MiddlewareInterface]]: + return self._next_middleware + +class BaseEngine(EngineInterface): + """引擎基类 + """ + def transform(self, data: Any) -> Any: + raise NotImplementedError("transform method not implemented") + + def execute(self, data: Any) -> Any: + raise NotImplementedError("execute method not implemented") + + def to_transform_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: + return BaseMiddleware(self.transform, next_middleware) + + def to_execute_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: + return BaseMiddleware(self.execute, next_middleware) + + def to_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: + return BaseMiddleware(lambda data: self.transform(self.execute(data)), next_middleware) + +class ExecutePipeline: + """ + 执行流 用于按顺序执行中间件 + 示例: + pipeline = ExecutePipeline(middleware1, middleware2) + pipeline.add(middleware3) + result = pipeline.execute(data) + + 执行顺序 data -> middleware1 -> middleware2 -> middleware3 -> result + + pipeline.reorder([0,2,1]) + result = pipeline.execute(data) + + 执行顺序 data -> middleware1 -> middleware3 -> middleware2 -> result + + pipeline.reorder([0,1]) + result = pipeline.execute(data) + + 执行顺序 data -> middleware1 -> middleware2 -> result + """ + def __init__(self, *middlewares:MiddlewareInterface): + self.middlewares = [*middlewares] + + def add(self, middleware: MiddlewareInterface): + """添加中间件到执行流里""" + self.middlewares.append(middleware) + + def reorder(self, new_order: List[int]): + """重新排序中间件,按照new_order指定的索引顺序排列,注意可以实现删除功能""" + self.middlewares = [self.middlewares[i] for i in new_order] + + def execute(self, data): + """执行中间件 + + 注意中间件的next不参与主执行流 + + 注意只有中间件的next为None,返回值才不会被输入到执行流中,即停止执行后续内容 + + TODO: 1. 添加中间件执行失败时的处理机制 + """ + try: + for middleware in self.middlewares: + if data is None: + break + data = middleware.process(data) + return data + except ValueError: + return None + except Exception as e: + raise e diff --git a/src/core/echo.py b/src/core/echo.py new file mode 100644 index 0000000..4359e6d --- /dev/null +++ b/src/core/echo.py @@ -0,0 +1,19 @@ +import sys +from typing import Iterator, List, Optional +from .core import BaseMiddleware, MiddlewareInterface + +class EchoMiddleWare(BaseMiddleware): + def __init__(self, next_middleware: Optional[List[MiddlewareInterface]] = []): + super().__init__(self._process, next_middleware) + def _process(self, data): + if isinstance(data, Iterator): + ret = '' + for i in data: + ret += i + sys.stdout.write(i) + sys.stdout.flush() + sys.stdout.write('\n') + else: + ret = data + print(data) + return ret \ No newline at end of file diff --git a/src/plugins/__init__.py b/src/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/plugins/chat_ai_tongyiqianwen.py b/src/plugins/chat_ai_tongyiqianwen.py new file mode 100644 index 0000000..559b3a6 --- /dev/null +++ b/src/plugins/chat_ai_tongyiqianwen.py @@ -0,0 +1,55 @@ +# https://dashscope.console.aliyun.com/model +import json +from typing import Any +from remote_engine import BaseRemoteEngine +from plugins_conf import PluginsConfig + +from logging import getLogger +logger = getLogger(__name__) + +PluginsConfig().add_core_path() +from core.base_chat_ai_engine import BaseChatAIEngine + +class ChatAiTongYiQianWen(BaseChatAIEngine, BaseRemoteEngine): + API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" + def __init__(self, api_key : str, + model_name : str = "qwen-1.8b-chat", + is_stream : bool = False, + *args, **kwargs) -> None: + self._modual_name = model_name + BaseChatAIEngine.__init__(self, is_stream=is_stream) + BaseRemoteEngine.__init__(self, api_key=api_key, + base_url=self.API_URL) + + def _transform_raw(self, response): + return response["output"]["choices"][0]["message"]["content"] + + def _transform_iterator(self, response): + for chunk in response: + input_string = chunk.decode('utf-8') + json_string = input_string[input_string.find("data:") + 5:].strip() + yield self._transform_raw(json.loads(json_string)) + + def transform(self, data): + return self._transform_iterator(data) if self._is_stream else\ + self._transform_raw(data) + + def execute(self, data: Any) -> Any: + if self._is_stream: + self.header['X-DashScope-SSE'] = 'enable' + response = self.post_prompt({ + 'model': self._modual_name, + "input": { + "messages": data + }, + "parameters": { + "result_format": "message", + "incremental_output": self._is_stream + } + }, + is_stream=self._is_stream) + if self._is_stream: + ret = response.iter_content(chunk_size=None) + else: + ret = response.json() + return ret \ No newline at end of file diff --git a/src/plugins/defalut.yml b/src/plugins/defalut.yml new file mode 100644 index 0000000..8e716d7 --- /dev/null +++ b/src/plugins/defalut.yml @@ -0,0 +1,20 @@ +global: + plugin_path: ./ + core_path: ../ + +asr_engine: {} + +chat_ai_engine: + plugin: chat_ai_tongyiqianwen + + chat_ai_tongyiqianwen: + class_name: ChatAiTongYiQianWen + api_key: null # must be set when using this api + model_name: qwen-1.8b-chat + is_stream: true + +nlu_engine: {} + +tts_engine: {} + +engine: {} \ No newline at end of file diff --git a/src/plugins/plugins_conf.py b/src/plugins/plugins_conf.py new file mode 100644 index 0000000..6b53dee --- /dev/null +++ b/src/plugins/plugins_conf.py @@ -0,0 +1,50 @@ +from pathlib import Path +import importlib +import sys +from typing import Any, Callable, Dict, Optional, List + +import yaml +CURRENT_DIR_PATH = Path(__file__).resolve().parent + +class PluginsConfig: + def __init__(self, + base_path : Path | str = CURRENT_DIR_PATH, + conf_name : str = ".plugins.yml"): + base_path = Path(base_path) + self.base_path = base_path.resolve() + self.conf_path = base_path / conf_name + if not self.conf_path.exists(): + raise FileExistsError(f"{self.conf_path} not exists") + if not self.conf_path.is_file(): + raise TypeError(f"{self.conf_path} is not a file") + self.config:Dict = yaml.safe_load(self.conf_path.read_text("utf-8")) + global_conf = self.config.get('global') + self.core_path = self.base_path / global_conf.get('core_path') + self.plugin_path = self.base_path / global_conf.get('plugin_path') + sys.path.append(str(self.base_path)) + + def add_core_path(self): + sys.path.append(str(self.core_path.resolve())) + + def get_plugin_full_name(self, engine_name): + find_name = 'plugin' + plugin_name = self.config[engine_name][find_name] + if not plugin_name: + raise ValueError(f"{engine_name}.{find_name} is not found in {self.conf_path}") + return plugin_name + + def get_engine_modual(self, engine_name): + plugin_name = self.get_plugin_full_name(engine_name) + sys.path.append(str(self.plugin_path.resolve())) + modual = importlib.import_module(plugin_name) + return modual + + def get_engine_config(self, engine_name): + plugin_name = self.get_plugin_full_name(engine_name) + return self.config[engine_name][plugin_name] + + def get_engine_class(self, engine_name): + engine_modual = self.get_engine_modual(engine_name) + engine_config = self.get_engine_config(engine_name) + engine = getattr(engine_modual, engine_config['class_name']) + return engine(**engine_config) \ No newline at end of file diff --git a/src/plugins/remote_engine.py b/src/plugins/remote_engine.py new file mode 100644 index 0000000..72b304e --- /dev/null +++ b/src/plugins/remote_engine.py @@ -0,0 +1,23 @@ +from logging import getLogger +logger = getLogger(__name__) +import requests + +class BaseRemoteEngine(): + """基础远程服务类, 提供远程调用的基础框架。""" + def __init__(self, api_key: str, base_url: str): + self._api_key = api_key + self._base_url = base_url + self.url = self._base_url + self.header = { + 'Content-Type': 'application/json', + 'Authorization': 'Bearer ' + self._api_key + } + + def post_prompt(self, messages : list[dict[str, str]], is_stream : bool = False) -> requests.Response: + try: + response = requests.post(self.url, headers=self.header, json=messages, stream=is_stream) + except Exception as e: + logger.exception("post_prompt") + raise e + finally: + return response \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/__init__.py b/tests/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/core/test_core.py b/tests/core/test_core.py new file mode 100644 index 0000000..545c9a1 --- /dev/null +++ b/tests/core/test_core.py @@ -0,0 +1,54 @@ +import unittest +from src.core import core + +class TestBaseMiddleware(unittest.TestCase): + def test_process_base_func1(self): + middleware = core.BaseMiddleware(lambda x: x + 1) + self.assertEqual(middleware.process(1), 2) + self.assertEqual(middleware.get_next(), []) + + def test_process_base_func2(self): + mid = [core.BaseMiddleware(lambda x: x + 2)] + middleware = core.BaseMiddleware(lambda x: x + 1, next_middleware=mid) + self.assertEqual(middleware.process(1), 2) + self.assertEqual(middleware._data, 4) + self.assertIs(middleware.get_next(), mid) + +class TestBaseEngine(unittest.TestCase): + pass + +class ExecutePipeline(unittest.TestCase): + def test_base(self): + middleware = core.BaseMiddleware(lambda x: x + 1) + pipeline = core.ExecutePipeline(middleware) + self.assertEqual(pipeline.execute(1), 2) + + def test_add(self): + middleware = core.BaseMiddleware(lambda x: x + 1) + pipeline = core.ExecutePipeline(middleware) + pipeline.add(core.BaseMiddleware(lambda x: x + 2)) + self.assertEqual(pipeline.execute(1), 4) + + def test_reorder(self): + middleware = core.BaseMiddleware(lambda x: x + 1) + pipeline = core.ExecutePipeline(middleware) + pipeline.reorder([]) + pipeline.execute(1) + + def test_reorder1(self): + middleware_p1 = core.BaseMiddleware(lambda x: x + 1) + middleware_p2 = core.BaseMiddleware(lambda x: x + 2) + middleware_p3 = core.BaseMiddleware(lambda x: x + 3) + pipeline = core.ExecutePipeline(middleware_p1, middleware_p2, middleware_p3) + self.assertEqual(pipeline.execute(1), 7) + pipeline.reorder([1, 0, 2]) + self.assertEqual(pipeline.execute(1), 7) + pipeline.reorder([0, 2]) + self.assertEqual(pipeline.execute(1), 6) + pipeline.reorder([1]) + self.assertEqual(pipeline.execute(1), 4) + + def test_execute_none(self): + middleware = core.BaseMiddleware(lambda x: x + 1) + pipeline = core.ExecutePipeline(middleware) + self.assertIsNone(pipeline.execute(None)) \ No newline at end of file diff --git a/tests/core/test_echo.py b/tests/core/test_echo.py new file mode 100644 index 0000000..5981032 --- /dev/null +++ b/tests/core/test_echo.py @@ -0,0 +1,13 @@ +import unittest +from src.core import echo +from io import StringIO +import sys + +class TestEcho(unittest.TestCase): + def test_base(self): + echoMid = echo.EchoMiddleWare() + captured_out = StringIO() + sys.stdout = captured_out + echoMid.process('hello') + sys.stdout = sys.__stdout__ + self.assertEqual(captured_out.getvalue(), 'hello\n') \ No newline at end of file diff --git a/tests/plugins/__init__.py b/tests/plugins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..0c0ac3f --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,5 @@ +from .logger import setup_logger + +def init(): + from .config import file_encoding + setup_logger(file_encoding=file_encoding) \ No newline at end of file diff --git a/utils/config.py b/utils/config.py new file mode 100644 index 0000000..84aa27e --- /dev/null +++ b/utils/config.py @@ -0,0 +1,5 @@ +from .constants import * + +CONFIG_CONFIG_FILE_PATH = CONFIG_PATH / "config.yml" + +file_encoding = FILE_ENCODING \ No newline at end of file diff --git a/utils/constants.py b/utils/constants.py new file mode 100644 index 0000000..6d0c5b5 --- /dev/null +++ b/utils/constants.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +from pathlib import Path + +# 这段代码使用 pathlib 模块进行重构参考网站为 +# https://zhuanlan.zhihu.com/p/139783331 +# https://docs.python.org/zh-cn/3/library/pathlib.html + +# root directory +ROOT_PATH = Path(__file__).resolve().parent.parent + +CONFIG_PATH = ROOT_PATH / "config" +CONFIG_DEFAULT_PATH = CONFIG_PATH / "default" + +# 该路径需要自动创建 +TEMP_PATH = ROOT_PATH / ".temp" + +FILE_ENCODING = "utf-8" \ No newline at end of file diff --git a/utils/logger.py b/utils/logger.py new file mode 100644 index 0000000..1701445 --- /dev/null +++ b/utils/logger.py @@ -0,0 +1,52 @@ +import logging +import logging.config +from pathlib import Path +from typing import * +from .constants import CONFIG_DEFAULT_PATH, TEMP_PATH, FILE_ENCODING +CONFIG_LOGGER_FILE_PATH = CONFIG_DEFAULT_PATH / "logger.yml" + +# 类型别名简化复杂注解 +ConfigFunc = Callable[[str], Dict[str, Any]] + +# 默认日志配置 +def _default_config(_): + logging.basicConfig(level=logging.NOTSET) + +# 更新日志配置文件的文件路径 +def _update_handlers_filenames(config: dict[str, Any], base_filepath: Path) -> None: + for handler_config in config.get('handlers', {}).values(): + if 'filename' in handler_config: + handler_config['filename'] = str(base_filepath.joinpath(handler_config['filename'])) + +# 使用文件配置日志 +def _setup_logger(config_path : Path, + config_func : ConfigFunc = _default_config, + log_filepath: Optional[Path] = None, + file_encoding: str = 'utf-8') -> dict[str, Any]: + try: + if config_path.exists(): + config = config_func( + config_path.read_text(encoding=file_encoding)) + value = config.get('handlers', None) + if log_filepath is not None and value is not None: + _update_handlers_filenames(config, log_filepath) + logging.config.dictConfig(config) + return config + else: + config_func() + return None + except Exception as e: + raise RuntimeError(f"Failed to setup logger: {e}") + +# 默认使用yaml配置 +def yml_config(file_content : str): + import yaml + return yaml.safe_load(file_content) + +# 对外接口 +def setup_logger(config_path : Path = CONFIG_LOGGER_FILE_PATH, + config_func : Callable[[str], str| dict[str, Any]] = yml_config, + log_filepath : Path = TEMP_PATH, + file_encoding : str = FILE_ENCODING) -> bool: + return _setup_logger(config_path, config_func, log_filepath, file_encoding) is not None + diff --git a/utils/singleton.py b/utils/singleton.py new file mode 100644 index 0000000..c1b5950 --- /dev/null +++ b/utils/singleton.py @@ -0,0 +1,6 @@ +class Singleton(type): + __instances = {} + def __call__(cls, *args, **kwargs): + if cls not in cls._instances: + cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs) + return cls.__instances[cls] \ No newline at end of file diff --git a/utils/utils.py b/utils/utils.py new file mode 100644 index 0000000..e69de29