init 基础聊天框架
This commit is contained in:
commit
98eda16d0f
9
main.py
Normal file
9
main.py
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
if __name__ == '__main__':
|
||||||
|
from src.core.core import ExecutePipeline
|
||||||
|
from src.core.echo import EchoMiddleWare
|
||||||
|
from src.plugins.plugins_conf import PluginsConfig
|
||||||
|
engine = PluginsConfig().get_engine_class("chat_ai_engine")
|
||||||
|
pipe = ExecutePipeline(engine.to_middleware(), EchoMiddleWare(), EchoMiddleWare())
|
||||||
|
# data = '端午是几号'
|
||||||
|
data = input('请输入: ')
|
||||||
|
pipe.execute([{'role': 'user', 'content': f'{data}'}])
|
17
requirements.txt
Normal file
17
requirements.txt
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# if you want to use not .wav file,
|
||||||
|
# you need to install ffmpeg or libav in your system
|
||||||
|
# pydub
|
||||||
|
# simpleaudio
|
||||||
|
|
||||||
|
# yaml
|
||||||
|
pyyaml
|
||||||
|
|
||||||
|
# remote server connect
|
||||||
|
requests
|
||||||
|
# openai # optional
|
||||||
|
|
||||||
|
# server optional
|
||||||
|
# tornado
|
||||||
|
|
||||||
|
# tqdm
|
||||||
|
# rich # optional
|
0
src/__init__.py
Normal file
0
src/__init__.py
Normal file
0
src/core/__init__.py
Normal file
0
src/core/__init__.py
Normal file
4
src/core/base_asr_engine.py
Normal file
4
src/core/base_asr_engine.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .core import BaseEngine
|
||||||
|
|
||||||
|
class BaseASREngine(BaseEngine):
|
||||||
|
pass
|
10
src/core/base_chat_ai_engine.py
Normal file
10
src/core/base_chat_ai_engine.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from typing import Dict
|
||||||
|
from .core import BaseEngine
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
class BaseChatAIEngine(BaseEngine):
|
||||||
|
"""基础AI引擎类, 提供历史记录管理的基础框架。"""
|
||||||
|
def __init__(self, history:Dict = None, is_stream = False) -> None:
|
||||||
|
self._is_stream = is_stream
|
||||||
|
self._history = history
|
4
src/core/base_nlu_engine.py
Normal file
4
src/core/base_nlu_engine.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .core import BaseEngine
|
||||||
|
|
||||||
|
class BaseNLUEngine(BaseEngine):
|
||||||
|
pass
|
4
src/core/base_tts_engine.py
Normal file
4
src/core/base_tts_engine.py
Normal file
@ -0,0 +1,4 @@
|
|||||||
|
from .core import BaseEngine
|
||||||
|
|
||||||
|
class BaseTTSEngine(BaseEngine):
|
||||||
|
pass
|
148
src/core/core.py
Normal file
148
src/core/core.py
Normal file
@ -0,0 +1,148 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Callable, List, Optional
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
class MiddlewareInterface(ABC):
|
||||||
|
"""中间件接口
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def process(self, data: Any) -> Any:
|
||||||
|
"""处理数据并返回处理后的结果"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_next(self) -> Optional[List['MiddlewareInterface']]:
|
||||||
|
"""获取下一个中间件"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class EngineInterface(ABC):
|
||||||
|
"""引擎接口
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def execute(self, data: Any) -> Any:
|
||||||
|
"""执行引擎逻辑并返回结果"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def transform(self, data: Any) -> Any:
|
||||||
|
"""将输出数据转换为标准化的格式"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_execute_middleware(self) -> MiddlewareInterface:
|
||||||
|
"""将引擎执行转换为遵循MiddlewareInterface的中间件."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_transform_middleware(self) -> MiddlewareInterface:
|
||||||
|
"""将引擎数据标准化转换为遵循MiddlewareInterface的中间件."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to_middleware(self) -> MiddlewareInterface:
|
||||||
|
"""将引擎转换为遵循MiddlewareInterface的中间件."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class BaseMiddleware(MiddlewareInterface):
|
||||||
|
"""中间件基类
|
||||||
|
注意 next_middleware 参数是可选的,如果未提供,可能有两种模式
|
||||||
|
|
||||||
|
1.如果process返回值为None(只有next_middleware为None返回值才为None)
|
||||||
|
则返回数据将在Pipeline里返回,否则继续执行Pipeline的下一个中间件。
|
||||||
|
|
||||||
|
2.中间件的process返回值为process_logic处理后的值
|
||||||
|
"""
|
||||||
|
def __init__(self, process_logic: Callable[[Any], Any],
|
||||||
|
next_middleware: Optional[List[MiddlewareInterface]] = []):
|
||||||
|
self._process_logic = process_logic
|
||||||
|
self._next_middleware = next_middleware
|
||||||
|
self._data = None
|
||||||
|
|
||||||
|
def process(self, data: Any) -> Any:
|
||||||
|
"""处理数据并返回处理后的结果,支持自动调用下一个中间件
|
||||||
|
(如果next_middleware为None则返回None)"""
|
||||||
|
try:
|
||||||
|
data = self._process_logic(data)
|
||||||
|
self._data = data
|
||||||
|
if self._next_middleware:
|
||||||
|
for middleware in self._next_middleware:
|
||||||
|
self._data = middleware.process(self._data)
|
||||||
|
return None if self._next_middleware is None else data
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Error occurred during processing: {e}")
|
||||||
|
raise e
|
||||||
|
|
||||||
|
def get_next(self) -> Optional[List[MiddlewareInterface]]:
|
||||||
|
return self._next_middleware
|
||||||
|
|
||||||
|
class BaseEngine(EngineInterface):
|
||||||
|
"""引擎基类
|
||||||
|
"""
|
||||||
|
def transform(self, data: Any) -> Any:
|
||||||
|
raise NotImplementedError("transform method not implemented")
|
||||||
|
|
||||||
|
def execute(self, data: Any) -> Any:
|
||||||
|
raise NotImplementedError("execute method not implemented")
|
||||||
|
|
||||||
|
def to_transform_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface:
|
||||||
|
return BaseMiddleware(self.transform, next_middleware)
|
||||||
|
|
||||||
|
def to_execute_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface:
|
||||||
|
return BaseMiddleware(self.execute, next_middleware)
|
||||||
|
|
||||||
|
def to_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface:
|
||||||
|
return BaseMiddleware(lambda data: self.transform(self.execute(data)), next_middleware)
|
||||||
|
|
||||||
|
class ExecutePipeline:
|
||||||
|
"""
|
||||||
|
执行流 用于按顺序执行中间件
|
||||||
|
示例:
|
||||||
|
pipeline = ExecutePipeline(middleware1, middleware2)
|
||||||
|
pipeline.add(middleware3)
|
||||||
|
result = pipeline.execute(data)
|
||||||
|
|
||||||
|
执行顺序 data -> middleware1 -> middleware2 -> middleware3 -> result
|
||||||
|
|
||||||
|
pipeline.reorder([0,2,1])
|
||||||
|
result = pipeline.execute(data)
|
||||||
|
|
||||||
|
执行顺序 data -> middleware1 -> middleware3 -> middleware2 -> result
|
||||||
|
|
||||||
|
pipeline.reorder([0,1])
|
||||||
|
result = pipeline.execute(data)
|
||||||
|
|
||||||
|
执行顺序 data -> middleware1 -> middleware2 -> result
|
||||||
|
"""
|
||||||
|
def __init__(self, *middlewares:MiddlewareInterface):
|
||||||
|
self.middlewares = [*middlewares]
|
||||||
|
|
||||||
|
def add(self, middleware: MiddlewareInterface):
|
||||||
|
"""添加中间件到执行流里"""
|
||||||
|
self.middlewares.append(middleware)
|
||||||
|
|
||||||
|
def reorder(self, new_order: List[int]):
|
||||||
|
"""重新排序中间件,按照new_order指定的索引顺序排列,注意可以实现删除功能"""
|
||||||
|
self.middlewares = [self.middlewares[i] for i in new_order]
|
||||||
|
|
||||||
|
def execute(self, data):
|
||||||
|
"""执行中间件
|
||||||
|
|
||||||
|
注意中间件的next不参与主执行流
|
||||||
|
|
||||||
|
注意只有中间件的next为None,返回值才不会被输入到执行流中,即停止执行后续内容
|
||||||
|
|
||||||
|
TODO: 1. 添加中间件执行失败时的处理机制
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
for middleware in self.middlewares:
|
||||||
|
if data is None:
|
||||||
|
break
|
||||||
|
data = middleware.process(data)
|
||||||
|
return data
|
||||||
|
except ValueError:
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
19
src/core/echo.py
Normal file
19
src/core/echo.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
import sys
|
||||||
|
from typing import Iterator, List, Optional
|
||||||
|
from .core import BaseMiddleware, MiddlewareInterface
|
||||||
|
|
||||||
|
class EchoMiddleWare(BaseMiddleware):
|
||||||
|
def __init__(self, next_middleware: Optional[List[MiddlewareInterface]] = []):
|
||||||
|
super().__init__(self._process, next_middleware)
|
||||||
|
def _process(self, data):
|
||||||
|
if isinstance(data, Iterator):
|
||||||
|
ret = ''
|
||||||
|
for i in data:
|
||||||
|
ret += i
|
||||||
|
sys.stdout.write(i)
|
||||||
|
sys.stdout.flush()
|
||||||
|
sys.stdout.write('\n')
|
||||||
|
else:
|
||||||
|
ret = data
|
||||||
|
print(data)
|
||||||
|
return ret
|
0
src/plugins/__init__.py
Normal file
0
src/plugins/__init__.py
Normal file
55
src/plugins/chat_ai_tongyiqianwen.py
Normal file
55
src/plugins/chat_ai_tongyiqianwen.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# https://dashscope.console.aliyun.com/model
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from remote_engine import BaseRemoteEngine
|
||||||
|
from plugins_conf import PluginsConfig
|
||||||
|
|
||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
|
||||||
|
PluginsConfig().add_core_path()
|
||||||
|
from core.base_chat_ai_engine import BaseChatAIEngine
|
||||||
|
|
||||||
|
class ChatAiTongYiQianWen(BaseChatAIEngine, BaseRemoteEngine):
|
||||||
|
API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||||
|
def __init__(self, api_key : str,
|
||||||
|
model_name : str = "qwen-1.8b-chat",
|
||||||
|
is_stream : bool = False,
|
||||||
|
*args, **kwargs) -> None:
|
||||||
|
self._modual_name = model_name
|
||||||
|
BaseChatAIEngine.__init__(self, is_stream=is_stream)
|
||||||
|
BaseRemoteEngine.__init__(self, api_key=api_key,
|
||||||
|
base_url=self.API_URL)
|
||||||
|
|
||||||
|
def _transform_raw(self, response):
|
||||||
|
return response["output"]["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
def _transform_iterator(self, response):
|
||||||
|
for chunk in response:
|
||||||
|
input_string = chunk.decode('utf-8')
|
||||||
|
json_string = input_string[input_string.find("data:") + 5:].strip()
|
||||||
|
yield self._transform_raw(json.loads(json_string))
|
||||||
|
|
||||||
|
def transform(self, data):
|
||||||
|
return self._transform_iterator(data) if self._is_stream else\
|
||||||
|
self._transform_raw(data)
|
||||||
|
|
||||||
|
def execute(self, data: Any) -> Any:
|
||||||
|
if self._is_stream:
|
||||||
|
self.header['X-DashScope-SSE'] = 'enable'
|
||||||
|
response = self.post_prompt({
|
||||||
|
'model': self._modual_name,
|
||||||
|
"input": {
|
||||||
|
"messages": data
|
||||||
|
},
|
||||||
|
"parameters": {
|
||||||
|
"result_format": "message",
|
||||||
|
"incremental_output": self._is_stream
|
||||||
|
}
|
||||||
|
},
|
||||||
|
is_stream=self._is_stream)
|
||||||
|
if self._is_stream:
|
||||||
|
ret = response.iter_content(chunk_size=None)
|
||||||
|
else:
|
||||||
|
ret = response.json()
|
||||||
|
return ret
|
20
src/plugins/defalut.yml
Normal file
20
src/plugins/defalut.yml
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
global:
|
||||||
|
plugin_path: ./
|
||||||
|
core_path: ../
|
||||||
|
|
||||||
|
asr_engine: {}
|
||||||
|
|
||||||
|
chat_ai_engine:
|
||||||
|
plugin: chat_ai_tongyiqianwen
|
||||||
|
|
||||||
|
chat_ai_tongyiqianwen:
|
||||||
|
class_name: ChatAiTongYiQianWen
|
||||||
|
api_key: null # must be set when using this api
|
||||||
|
model_name: qwen-1.8b-chat
|
||||||
|
is_stream: true
|
||||||
|
|
||||||
|
nlu_engine: {}
|
||||||
|
|
||||||
|
tts_engine: {}
|
||||||
|
|
||||||
|
engine: {}
|
50
src/plugins/plugins_conf.py
Normal file
50
src/plugins/plugins_conf.py
Normal file
@ -0,0 +1,50 @@
|
|||||||
|
from pathlib import Path
|
||||||
|
import importlib
|
||||||
|
import sys
|
||||||
|
from typing import Any, Callable, Dict, Optional, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
CURRENT_DIR_PATH = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
class PluginsConfig:
|
||||||
|
def __init__(self,
|
||||||
|
base_path : Path | str = CURRENT_DIR_PATH,
|
||||||
|
conf_name : str = ".plugins.yml"):
|
||||||
|
base_path = Path(base_path)
|
||||||
|
self.base_path = base_path.resolve()
|
||||||
|
self.conf_path = base_path / conf_name
|
||||||
|
if not self.conf_path.exists():
|
||||||
|
raise FileExistsError(f"{self.conf_path} not exists")
|
||||||
|
if not self.conf_path.is_file():
|
||||||
|
raise TypeError(f"{self.conf_path} is not a file")
|
||||||
|
self.config:Dict = yaml.safe_load(self.conf_path.read_text("utf-8"))
|
||||||
|
global_conf = self.config.get('global')
|
||||||
|
self.core_path = self.base_path / global_conf.get('core_path')
|
||||||
|
self.plugin_path = self.base_path / global_conf.get('plugin_path')
|
||||||
|
sys.path.append(str(self.base_path))
|
||||||
|
|
||||||
|
def add_core_path(self):
|
||||||
|
sys.path.append(str(self.core_path.resolve()))
|
||||||
|
|
||||||
|
def get_plugin_full_name(self, engine_name):
|
||||||
|
find_name = 'plugin'
|
||||||
|
plugin_name = self.config[engine_name][find_name]
|
||||||
|
if not plugin_name:
|
||||||
|
raise ValueError(f"{engine_name}.{find_name} is not found in {self.conf_path}")
|
||||||
|
return plugin_name
|
||||||
|
|
||||||
|
def get_engine_modual(self, engine_name):
|
||||||
|
plugin_name = self.get_plugin_full_name(engine_name)
|
||||||
|
sys.path.append(str(self.plugin_path.resolve()))
|
||||||
|
modual = importlib.import_module(plugin_name)
|
||||||
|
return modual
|
||||||
|
|
||||||
|
def get_engine_config(self, engine_name):
|
||||||
|
plugin_name = self.get_plugin_full_name(engine_name)
|
||||||
|
return self.config[engine_name][plugin_name]
|
||||||
|
|
||||||
|
def get_engine_class(self, engine_name):
|
||||||
|
engine_modual = self.get_engine_modual(engine_name)
|
||||||
|
engine_config = self.get_engine_config(engine_name)
|
||||||
|
engine = getattr(engine_modual, engine_config['class_name'])
|
||||||
|
return engine(**engine_config)
|
23
src/plugins/remote_engine.py
Normal file
23
src/plugins/remote_engine.py
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
from logging import getLogger
|
||||||
|
logger = getLogger(__name__)
|
||||||
|
import requests
|
||||||
|
|
||||||
|
class BaseRemoteEngine():
|
||||||
|
"""基础远程服务类, 提供远程调用的基础框架。"""
|
||||||
|
def __init__(self, api_key: str, base_url: str):
|
||||||
|
self._api_key = api_key
|
||||||
|
self._base_url = base_url
|
||||||
|
self.url = self._base_url
|
||||||
|
self.header = {
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'Authorization': 'Bearer ' + self._api_key
|
||||||
|
}
|
||||||
|
|
||||||
|
def post_prompt(self, messages : list[dict[str, str]], is_stream : bool = False) -> requests.Response:
|
||||||
|
try:
|
||||||
|
response = requests.post(self.url, headers=self.header, json=messages, stream=is_stream)
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("post_prompt")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
return response
|
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
0
tests/core/__init__.py
Normal file
0
tests/core/__init__.py
Normal file
54
tests/core/test_core.py
Normal file
54
tests/core/test_core.py
Normal file
@ -0,0 +1,54 @@
|
|||||||
|
import unittest
|
||||||
|
from src.core import core
|
||||||
|
|
||||||
|
class TestBaseMiddleware(unittest.TestCase):
|
||||||
|
def test_process_base_func1(self):
|
||||||
|
middleware = core.BaseMiddleware(lambda x: x + 1)
|
||||||
|
self.assertEqual(middleware.process(1), 2)
|
||||||
|
self.assertEqual(middleware.get_next(), [])
|
||||||
|
|
||||||
|
def test_process_base_func2(self):
|
||||||
|
mid = [core.BaseMiddleware(lambda x: x + 2)]
|
||||||
|
middleware = core.BaseMiddleware(lambda x: x + 1, next_middleware=mid)
|
||||||
|
self.assertEqual(middleware.process(1), 2)
|
||||||
|
self.assertEqual(middleware._data, 4)
|
||||||
|
self.assertIs(middleware.get_next(), mid)
|
||||||
|
|
||||||
|
class TestBaseEngine(unittest.TestCase):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ExecutePipeline(unittest.TestCase):
|
||||||
|
def test_base(self):
|
||||||
|
middleware = core.BaseMiddleware(lambda x: x + 1)
|
||||||
|
pipeline = core.ExecutePipeline(middleware)
|
||||||
|
self.assertEqual(pipeline.execute(1), 2)
|
||||||
|
|
||||||
|
def test_add(self):
|
||||||
|
middleware = core.BaseMiddleware(lambda x: x + 1)
|
||||||
|
pipeline = core.ExecutePipeline(middleware)
|
||||||
|
pipeline.add(core.BaseMiddleware(lambda x: x + 2))
|
||||||
|
self.assertEqual(pipeline.execute(1), 4)
|
||||||
|
|
||||||
|
def test_reorder(self):
|
||||||
|
middleware = core.BaseMiddleware(lambda x: x + 1)
|
||||||
|
pipeline = core.ExecutePipeline(middleware)
|
||||||
|
pipeline.reorder([])
|
||||||
|
pipeline.execute(1)
|
||||||
|
|
||||||
|
def test_reorder1(self):
|
||||||
|
middleware_p1 = core.BaseMiddleware(lambda x: x + 1)
|
||||||
|
middleware_p2 = core.BaseMiddleware(lambda x: x + 2)
|
||||||
|
middleware_p3 = core.BaseMiddleware(lambda x: x + 3)
|
||||||
|
pipeline = core.ExecutePipeline(middleware_p1, middleware_p2, middleware_p3)
|
||||||
|
self.assertEqual(pipeline.execute(1), 7)
|
||||||
|
pipeline.reorder([1, 0, 2])
|
||||||
|
self.assertEqual(pipeline.execute(1), 7)
|
||||||
|
pipeline.reorder([0, 2])
|
||||||
|
self.assertEqual(pipeline.execute(1), 6)
|
||||||
|
pipeline.reorder([1])
|
||||||
|
self.assertEqual(pipeline.execute(1), 4)
|
||||||
|
|
||||||
|
def test_execute_none(self):
|
||||||
|
middleware = core.BaseMiddleware(lambda x: x + 1)
|
||||||
|
pipeline = core.ExecutePipeline(middleware)
|
||||||
|
self.assertIsNone(pipeline.execute(None))
|
13
tests/core/test_echo.py
Normal file
13
tests/core/test_echo.py
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
import unittest
|
||||||
|
from src.core import echo
|
||||||
|
from io import StringIO
|
||||||
|
import sys
|
||||||
|
|
||||||
|
class TestEcho(unittest.TestCase):
|
||||||
|
def test_base(self):
|
||||||
|
echoMid = echo.EchoMiddleWare()
|
||||||
|
captured_out = StringIO()
|
||||||
|
sys.stdout = captured_out
|
||||||
|
echoMid.process('hello')
|
||||||
|
sys.stdout = sys.__stdout__
|
||||||
|
self.assertEqual(captured_out.getvalue(), 'hello\n')
|
0
tests/plugins/__init__.py
Normal file
0
tests/plugins/__init__.py
Normal file
5
utils/__init__.py
Normal file
5
utils/__init__.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .logger import setup_logger
|
||||||
|
|
||||||
|
def init():
|
||||||
|
from .config import file_encoding
|
||||||
|
setup_logger(file_encoding=file_encoding)
|
5
utils/config.py
Normal file
5
utils/config.py
Normal file
@ -0,0 +1,5 @@
|
|||||||
|
from .constants import *
|
||||||
|
|
||||||
|
CONFIG_CONFIG_FILE_PATH = CONFIG_PATH / "config.yml"
|
||||||
|
|
||||||
|
file_encoding = FILE_ENCODING
|
17
utils/constants.py
Normal file
17
utils/constants.py
Normal file
@ -0,0 +1,17 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# 这段代码使用 pathlib 模块进行重构参考网站为
|
||||||
|
# https://zhuanlan.zhihu.com/p/139783331
|
||||||
|
# https://docs.python.org/zh-cn/3/library/pathlib.html
|
||||||
|
|
||||||
|
# root directory
|
||||||
|
ROOT_PATH = Path(__file__).resolve().parent.parent
|
||||||
|
|
||||||
|
CONFIG_PATH = ROOT_PATH / "config"
|
||||||
|
CONFIG_DEFAULT_PATH = CONFIG_PATH / "default"
|
||||||
|
|
||||||
|
# 该路径需要自动创建
|
||||||
|
TEMP_PATH = ROOT_PATH / ".temp"
|
||||||
|
|
||||||
|
FILE_ENCODING = "utf-8"
|
52
utils/logger.py
Normal file
52
utils/logger.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
import logging
|
||||||
|
import logging.config
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import *
|
||||||
|
from .constants import CONFIG_DEFAULT_PATH, TEMP_PATH, FILE_ENCODING
|
||||||
|
CONFIG_LOGGER_FILE_PATH = CONFIG_DEFAULT_PATH / "logger.yml"
|
||||||
|
|
||||||
|
# 类型别名简化复杂注解
|
||||||
|
ConfigFunc = Callable[[str], Dict[str, Any]]
|
||||||
|
|
||||||
|
# 默认日志配置
|
||||||
|
def _default_config(_):
|
||||||
|
logging.basicConfig(level=logging.NOTSET)
|
||||||
|
|
||||||
|
# 更新日志配置文件的文件路径
|
||||||
|
def _update_handlers_filenames(config: dict[str, Any], base_filepath: Path) -> None:
|
||||||
|
for handler_config in config.get('handlers', {}).values():
|
||||||
|
if 'filename' in handler_config:
|
||||||
|
handler_config['filename'] = str(base_filepath.joinpath(handler_config['filename']))
|
||||||
|
|
||||||
|
# 使用文件配置日志
|
||||||
|
def _setup_logger(config_path : Path,
|
||||||
|
config_func : ConfigFunc = _default_config,
|
||||||
|
log_filepath: Optional[Path] = None,
|
||||||
|
file_encoding: str = 'utf-8') -> dict[str, Any]:
|
||||||
|
try:
|
||||||
|
if config_path.exists():
|
||||||
|
config = config_func(
|
||||||
|
config_path.read_text(encoding=file_encoding))
|
||||||
|
value = config.get('handlers', None)
|
||||||
|
if log_filepath is not None and value is not None:
|
||||||
|
_update_handlers_filenames(config, log_filepath)
|
||||||
|
logging.config.dictConfig(config)
|
||||||
|
return config
|
||||||
|
else:
|
||||||
|
config_func()
|
||||||
|
return None
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to setup logger: {e}")
|
||||||
|
|
||||||
|
# 默认使用yaml配置
|
||||||
|
def yml_config(file_content : str):
|
||||||
|
import yaml
|
||||||
|
return yaml.safe_load(file_content)
|
||||||
|
|
||||||
|
# 对外接口
|
||||||
|
def setup_logger(config_path : Path = CONFIG_LOGGER_FILE_PATH,
|
||||||
|
config_func : Callable[[str], str| dict[str, Any]] = yml_config,
|
||||||
|
log_filepath : Path = TEMP_PATH,
|
||||||
|
file_encoding : str = FILE_ENCODING) -> bool:
|
||||||
|
return _setup_logger(config_path, config_func, log_filepath, file_encoding) is not None
|
||||||
|
|
6
utils/singleton.py
Normal file
6
utils/singleton.py
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
class Singleton(type):
|
||||||
|
__instances = {}
|
||||||
|
def __call__(cls, *args, **kwargs):
|
||||||
|
if cls not in cls._instances:
|
||||||
|
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
||||||
|
return cls.__instances[cls]
|
0
utils/utils.py
Normal file
0
utils/utils.py
Normal file
Loading…
x
Reference in New Issue
Block a user