init 基础聊天框架

This commit is contained in:
ZZY 2024-05-15 20:33:56 +08:00
commit 98eda16d0f
28 changed files with 516 additions and 0 deletions

0
README.md Normal file
View File

1
VERSION Normal file
View File

@ -0,0 +1 @@
0.0.1

9
main.py Normal file
View File

@ -0,0 +1,9 @@
if __name__ == '__main__':
from src.core.core import ExecutePipeline
from src.core.echo import EchoMiddleWare
from src.plugins.plugins_conf import PluginsConfig
engine = PluginsConfig().get_engine_class("chat_ai_engine")
pipe = ExecutePipeline(engine.to_middleware(), EchoMiddleWare(), EchoMiddleWare())
# data = '端午是几号'
data = input('请输入: ')
pipe.execute([{'role': 'user', 'content': f'{data}'}])

17
requirements.txt Normal file
View File

@ -0,0 +1,17 @@
# if you want to use not .wav file,
# you need to install ffmpeg or libav in your system
# pydub
# simpleaudio
# yaml
pyyaml
# remote server connect
requests
# openai # optional
# server optional
# tornado
# tqdm
# rich # optional

0
src/__init__.py Normal file
View File

0
src/core/__init__.py Normal file
View File

View File

@ -0,0 +1,4 @@
from .core import BaseEngine
class BaseASREngine(BaseEngine):
pass

View File

@ -0,0 +1,10 @@
from typing import Dict
from .core import BaseEngine
from logging import getLogger
logger = getLogger(__name__)
class BaseChatAIEngine(BaseEngine):
"""基础AI引擎类, 提供历史记录管理的基础框架。"""
def __init__(self, history:Dict = None, is_stream = False) -> None:
self._is_stream = is_stream
self._history = history

View File

@ -0,0 +1,4 @@
from .core import BaseEngine
class BaseNLUEngine(BaseEngine):
pass

View File

@ -0,0 +1,4 @@
from .core import BaseEngine
class BaseTTSEngine(BaseEngine):
pass

148
src/core/core.py Normal file
View File

@ -0,0 +1,148 @@
from abc import ABC, abstractmethod
from typing import Any, Callable, List, Optional
from logging import getLogger
logger = getLogger(__name__)
class MiddlewareInterface(ABC):
"""中间件接口
"""
@abstractmethod
def process(self, data: Any) -> Any:
"""处理数据并返回处理后的结果"""
pass
@abstractmethod
def get_next(self) -> Optional[List['MiddlewareInterface']]:
"""获取下一个中间件"""
pass
class EngineInterface(ABC):
"""引擎接口
"""
@abstractmethod
def execute(self, data: Any) -> Any:
"""执行引擎逻辑并返回结果"""
pass
@abstractmethod
def transform(self, data: Any) -> Any:
"""将输出数据转换为标准化的格式"""
pass
@abstractmethod
def to_execute_middleware(self) -> MiddlewareInterface:
"""将引擎执行转换为遵循MiddlewareInterface的中间件."""
pass
@abstractmethod
def to_transform_middleware(self) -> MiddlewareInterface:
"""将引擎数据标准化转换为遵循MiddlewareInterface的中间件."""
pass
@abstractmethod
def to_middleware(self) -> MiddlewareInterface:
"""将引擎转换为遵循MiddlewareInterface的中间件."""
pass
class BaseMiddleware(MiddlewareInterface):
"""中间件基类
注意 next_middleware 参数是可选的,如果未提供,可能有两种模式
1.如果process返回值为None(只有next_middleware为None返回值才为None)
则返回数据将在Pipeline里返回,否则继续执行Pipeline的下一个中间件
2.中间件的process返回值为process_logic处理后的值
"""
def __init__(self, process_logic: Callable[[Any], Any],
next_middleware: Optional[List[MiddlewareInterface]] = []):
self._process_logic = process_logic
self._next_middleware = next_middleware
self._data = None
def process(self, data: Any) -> Any:
"""处理数据并返回处理后的结果,支持自动调用下一个中间件
(如果next_middleware为None则返回None)"""
try:
data = self._process_logic(data)
self._data = data
if self._next_middleware:
for middleware in self._next_middleware:
self._data = middleware.process(self._data)
return None if self._next_middleware is None else data
except Exception as e:
logger.exception(f"Error occurred during processing: {e}")
raise e
def get_next(self) -> Optional[List[MiddlewareInterface]]:
return self._next_middleware
class BaseEngine(EngineInterface):
"""引擎基类
"""
def transform(self, data: Any) -> Any:
raise NotImplementedError("transform method not implemented")
def execute(self, data: Any) -> Any:
raise NotImplementedError("execute method not implemented")
def to_transform_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface:
return BaseMiddleware(self.transform, next_middleware)
def to_execute_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface:
return BaseMiddleware(self.execute, next_middleware)
def to_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface:
return BaseMiddleware(lambda data: self.transform(self.execute(data)), next_middleware)
class ExecutePipeline:
"""
执行流 用于按顺序执行中间件
示例:
pipeline = ExecutePipeline(middleware1, middleware2)
pipeline.add(middleware3)
result = pipeline.execute(data)
执行顺序 data -> middleware1 -> middleware2 -> middleware3 -> result
pipeline.reorder([0,2,1])
result = pipeline.execute(data)
执行顺序 data -> middleware1 -> middleware3 -> middleware2 -> result
pipeline.reorder([0,1])
result = pipeline.execute(data)
执行顺序 data -> middleware1 -> middleware2 -> result
"""
def __init__(self, *middlewares:MiddlewareInterface):
self.middlewares = [*middlewares]
def add(self, middleware: MiddlewareInterface):
"""添加中间件到执行流里"""
self.middlewares.append(middleware)
def reorder(self, new_order: List[int]):
"""重新排序中间件,按照new_order指定的索引顺序排列,注意可以实现删除功能"""
self.middlewares = [self.middlewares[i] for i in new_order]
def execute(self, data):
"""执行中间件
注意中间件的next不参与主执行流
注意只有中间件的next为None,返回值才不会被输入到执行流中,即停止执行后续内容
TODO: 1. 添加中间件执行失败时的处理机制
"""
try:
for middleware in self.middlewares:
if data is None:
break
data = middleware.process(data)
return data
except ValueError:
return None
except Exception as e:
raise e

19
src/core/echo.py Normal file
View File

@ -0,0 +1,19 @@
import sys
from typing import Iterator, List, Optional
from .core import BaseMiddleware, MiddlewareInterface
class EchoMiddleWare(BaseMiddleware):
def __init__(self, next_middleware: Optional[List[MiddlewareInterface]] = []):
super().__init__(self._process, next_middleware)
def _process(self, data):
if isinstance(data, Iterator):
ret = ''
for i in data:
ret += i
sys.stdout.write(i)
sys.stdout.flush()
sys.stdout.write('\n')
else:
ret = data
print(data)
return ret

0
src/plugins/__init__.py Normal file
View File

View File

@ -0,0 +1,55 @@
# https://dashscope.console.aliyun.com/model
import json
from typing import Any
from remote_engine import BaseRemoteEngine
from plugins_conf import PluginsConfig
from logging import getLogger
logger = getLogger(__name__)
PluginsConfig().add_core_path()
from core.base_chat_ai_engine import BaseChatAIEngine
class ChatAiTongYiQianWen(BaseChatAIEngine, BaseRemoteEngine):
API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
def __init__(self, api_key : str,
model_name : str = "qwen-1.8b-chat",
is_stream : bool = False,
*args, **kwargs) -> None:
self._modual_name = model_name
BaseChatAIEngine.__init__(self, is_stream=is_stream)
BaseRemoteEngine.__init__(self, api_key=api_key,
base_url=self.API_URL)
def _transform_raw(self, response):
return response["output"]["choices"][0]["message"]["content"]
def _transform_iterator(self, response):
for chunk in response:
input_string = chunk.decode('utf-8')
json_string = input_string[input_string.find("data:") + 5:].strip()
yield self._transform_raw(json.loads(json_string))
def transform(self, data):
return self._transform_iterator(data) if self._is_stream else\
self._transform_raw(data)
def execute(self, data: Any) -> Any:
if self._is_stream:
self.header['X-DashScope-SSE'] = 'enable'
response = self.post_prompt({
'model': self._modual_name,
"input": {
"messages": data
},
"parameters": {
"result_format": "message",
"incremental_output": self._is_stream
}
},
is_stream=self._is_stream)
if self._is_stream:
ret = response.iter_content(chunk_size=None)
else:
ret = response.json()
return ret

20
src/plugins/defalut.yml Normal file
View File

@ -0,0 +1,20 @@
global:
plugin_path: ./
core_path: ../
asr_engine: {}
chat_ai_engine:
plugin: chat_ai_tongyiqianwen
chat_ai_tongyiqianwen:
class_name: ChatAiTongYiQianWen
api_key: null # must be set when using this api
model_name: qwen-1.8b-chat
is_stream: true
nlu_engine: {}
tts_engine: {}
engine: {}

View File

@ -0,0 +1,50 @@
from pathlib import Path
import importlib
import sys
from typing import Any, Callable, Dict, Optional, List
import yaml
CURRENT_DIR_PATH = Path(__file__).resolve().parent
class PluginsConfig:
def __init__(self,
base_path : Path | str = CURRENT_DIR_PATH,
conf_name : str = ".plugins.yml"):
base_path = Path(base_path)
self.base_path = base_path.resolve()
self.conf_path = base_path / conf_name
if not self.conf_path.exists():
raise FileExistsError(f"{self.conf_path} not exists")
if not self.conf_path.is_file():
raise TypeError(f"{self.conf_path} is not a file")
self.config:Dict = yaml.safe_load(self.conf_path.read_text("utf-8"))
global_conf = self.config.get('global')
self.core_path = self.base_path / global_conf.get('core_path')
self.plugin_path = self.base_path / global_conf.get('plugin_path')
sys.path.append(str(self.base_path))
def add_core_path(self):
sys.path.append(str(self.core_path.resolve()))
def get_plugin_full_name(self, engine_name):
find_name = 'plugin'
plugin_name = self.config[engine_name][find_name]
if not plugin_name:
raise ValueError(f"{engine_name}.{find_name} is not found in {self.conf_path}")
return plugin_name
def get_engine_modual(self, engine_name):
plugin_name = self.get_plugin_full_name(engine_name)
sys.path.append(str(self.plugin_path.resolve()))
modual = importlib.import_module(plugin_name)
return modual
def get_engine_config(self, engine_name):
plugin_name = self.get_plugin_full_name(engine_name)
return self.config[engine_name][plugin_name]
def get_engine_class(self, engine_name):
engine_modual = self.get_engine_modual(engine_name)
engine_config = self.get_engine_config(engine_name)
engine = getattr(engine_modual, engine_config['class_name'])
return engine(**engine_config)

View File

@ -0,0 +1,23 @@
from logging import getLogger
logger = getLogger(__name__)
import requests
class BaseRemoteEngine():
"""基础远程服务类, 提供远程调用的基础框架。"""
def __init__(self, api_key: str, base_url: str):
self._api_key = api_key
self._base_url = base_url
self.url = self._base_url
self.header = {
'Content-Type': 'application/json',
'Authorization': 'Bearer ' + self._api_key
}
def post_prompt(self, messages : list[dict[str, str]], is_stream : bool = False) -> requests.Response:
try:
response = requests.post(self.url, headers=self.header, json=messages, stream=is_stream)
except Exception as e:
logger.exception("post_prompt")
raise e
finally:
return response

0
tests/__init__.py Normal file
View File

0
tests/core/__init__.py Normal file
View File

54
tests/core/test_core.py Normal file
View File

@ -0,0 +1,54 @@
import unittest
from src.core import core
class TestBaseMiddleware(unittest.TestCase):
def test_process_base_func1(self):
middleware = core.BaseMiddleware(lambda x: x + 1)
self.assertEqual(middleware.process(1), 2)
self.assertEqual(middleware.get_next(), [])
def test_process_base_func2(self):
mid = [core.BaseMiddleware(lambda x: x + 2)]
middleware = core.BaseMiddleware(lambda x: x + 1, next_middleware=mid)
self.assertEqual(middleware.process(1), 2)
self.assertEqual(middleware._data, 4)
self.assertIs(middleware.get_next(), mid)
class TestBaseEngine(unittest.TestCase):
pass
class ExecutePipeline(unittest.TestCase):
def test_base(self):
middleware = core.BaseMiddleware(lambda x: x + 1)
pipeline = core.ExecutePipeline(middleware)
self.assertEqual(pipeline.execute(1), 2)
def test_add(self):
middleware = core.BaseMiddleware(lambda x: x + 1)
pipeline = core.ExecutePipeline(middleware)
pipeline.add(core.BaseMiddleware(lambda x: x + 2))
self.assertEqual(pipeline.execute(1), 4)
def test_reorder(self):
middleware = core.BaseMiddleware(lambda x: x + 1)
pipeline = core.ExecutePipeline(middleware)
pipeline.reorder([])
pipeline.execute(1)
def test_reorder1(self):
middleware_p1 = core.BaseMiddleware(lambda x: x + 1)
middleware_p2 = core.BaseMiddleware(lambda x: x + 2)
middleware_p3 = core.BaseMiddleware(lambda x: x + 3)
pipeline = core.ExecutePipeline(middleware_p1, middleware_p2, middleware_p3)
self.assertEqual(pipeline.execute(1), 7)
pipeline.reorder([1, 0, 2])
self.assertEqual(pipeline.execute(1), 7)
pipeline.reorder([0, 2])
self.assertEqual(pipeline.execute(1), 6)
pipeline.reorder([1])
self.assertEqual(pipeline.execute(1), 4)
def test_execute_none(self):
middleware = core.BaseMiddleware(lambda x: x + 1)
pipeline = core.ExecutePipeline(middleware)
self.assertIsNone(pipeline.execute(None))

13
tests/core/test_echo.py Normal file
View File

@ -0,0 +1,13 @@
import unittest
from src.core import echo
from io import StringIO
import sys
class TestEcho(unittest.TestCase):
def test_base(self):
echoMid = echo.EchoMiddleWare()
captured_out = StringIO()
sys.stdout = captured_out
echoMid.process('hello')
sys.stdout = sys.__stdout__
self.assertEqual(captured_out.getvalue(), 'hello\n')

View File

5
utils/__init__.py Normal file
View File

@ -0,0 +1,5 @@
from .logger import setup_logger
def init():
from .config import file_encoding
setup_logger(file_encoding=file_encoding)

5
utils/config.py Normal file
View File

@ -0,0 +1,5 @@
from .constants import *
CONFIG_CONFIG_FILE_PATH = CONFIG_PATH / "config.yml"
file_encoding = FILE_ENCODING

17
utils/constants.py Normal file
View File

@ -0,0 +1,17 @@
# -*- coding: utf-8 -*-
from pathlib import Path
# 这段代码使用 pathlib 模块进行重构参考网站为
# https://zhuanlan.zhihu.com/p/139783331
# https://docs.python.org/zh-cn/3/library/pathlib.html
# root directory
ROOT_PATH = Path(__file__).resolve().parent.parent
CONFIG_PATH = ROOT_PATH / "config"
CONFIG_DEFAULT_PATH = CONFIG_PATH / "default"
# 该路径需要自动创建
TEMP_PATH = ROOT_PATH / ".temp"
FILE_ENCODING = "utf-8"

52
utils/logger.py Normal file
View File

@ -0,0 +1,52 @@
import logging
import logging.config
from pathlib import Path
from typing import *
from .constants import CONFIG_DEFAULT_PATH, TEMP_PATH, FILE_ENCODING
CONFIG_LOGGER_FILE_PATH = CONFIG_DEFAULT_PATH / "logger.yml"
# 类型别名简化复杂注解
ConfigFunc = Callable[[str], Dict[str, Any]]
# 默认日志配置
def _default_config(_):
logging.basicConfig(level=logging.NOTSET)
# 更新日志配置文件的文件路径
def _update_handlers_filenames(config: dict[str, Any], base_filepath: Path) -> None:
for handler_config in config.get('handlers', {}).values():
if 'filename' in handler_config:
handler_config['filename'] = str(base_filepath.joinpath(handler_config['filename']))
# 使用文件配置日志
def _setup_logger(config_path : Path,
config_func : ConfigFunc = _default_config,
log_filepath: Optional[Path] = None,
file_encoding: str = 'utf-8') -> dict[str, Any]:
try:
if config_path.exists():
config = config_func(
config_path.read_text(encoding=file_encoding))
value = config.get('handlers', None)
if log_filepath is not None and value is not None:
_update_handlers_filenames(config, log_filepath)
logging.config.dictConfig(config)
return config
else:
config_func()
return None
except Exception as e:
raise RuntimeError(f"Failed to setup logger: {e}")
# 默认使用yaml配置
def yml_config(file_content : str):
import yaml
return yaml.safe_load(file_content)
# 对外接口
def setup_logger(config_path : Path = CONFIG_LOGGER_FILE_PATH,
config_func : Callable[[str], str| dict[str, Any]] = yml_config,
log_filepath : Path = TEMP_PATH,
file_encoding : str = FILE_ENCODING) -> bool:
return _setup_logger(config_path, config_func, log_filepath, file_encoding) is not None

6
utils/singleton.py Normal file
View File

@ -0,0 +1,6 @@
class Singleton(type):
__instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
return cls.__instances[cls]

0
utils/utils.py Normal file
View File