diff --git a/main.py b/main.py index 6d8404b..ae0cbbb 100644 --- a/main.py +++ b/main.py @@ -1,28 +1,50 @@ +# from src import Plugins from src import ExecutePipeline -from src import Plugins -from src import EchoEngine, TeeEngine +from src import EchoProcessor +from viztracer import VizTracer -def main(): - from plugins.example import MyEngine +async def main(): from pathlib import Path - plug_conf = Plugins(Path(__file__).resolve().parent) - pipe = ExecutePipeline( - # plug_conf.load_engine("asr_engine"), - # MyEngine(), - plug_conf.load_engine("chat_ai_engine"), - TeeEngine(), - plug_conf.load_engine("tts_engine"), - plug_conf.load_engine("sounds_play_engine"), - EchoEngine() - ) + # plug_conf = Plugins(Path(__file__).resolve().parent) + from src.offical.requests.tongyiqianwen import ChatAiTongYiQianWen + pipe = ExecutePipeline([ + ChatAiTongYiQianWen(api_key='sk-ab4a3e0e29d54ebaad560c1472933d41', use_stream_api=True), + EchoProcessor() + ]).start() + # # plug_conf.load_engine("asr_engine"), + # # MyEngine(), + # plug_conf.load_engine("chat_ai_engine"), + # TeeProcessor(), + # plug_conf.load_engine("tts_engine"), + # plug_conf.load_engine("sounds_play_engine"), + # EchoEngine() # exe_input = './tests/offical/sounds/asr_example.wav' # exe_input = input('input: ') - exe_input = '你好' - res = pipe.execute_engines(exe_input) - print(res) + exe_input = '给我一个100字的文章' + await pipe.write(exe_input) + await asyncio.sleep(0) + # loop.run_in_executor(None, inputdata) + loop = asyncio.get_running_loop() + def inputdata(pipe, loop): + while True: + try: + data = input('input :') + asyncio.run_coroutine_threadsafe(pipe.write(data), loop) + except KeyboardInterrupt: + pipe.cancel() + break + with VizTracer(log_async=True): + thread = threading.Thread(target=inputdata, args=(pipe,loop,)) + thread.start() + await pipe.process() + thread.join() + # asyncio.gather([i for i in pipe._tasks]) + # await asyncio.to_thread(inputdata, pipe) if __name__ == '__main__': from dotenv import load_dotenv load_dotenv() - main() + import asyncio + import threading + asyncio.run(main()) diff --git a/src/__init__.py b/src/__init__.py index 9800aac..be81001 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -1,23 +1,21 @@ -from .core.interface import MiddlewareInterface as MiddlewareInterface -from .core.interface import EngineInterface as EngineInterface +# from .core.interface import MiddlewareInterface as MiddlewareInterface +# from .core.interface import EngineInterface as EngineInterface -from .core.core import Middleware as Middleware -from .core.core import Engine as Engine +from .core.core import ExecuteDataStream as ExecuteDataStream +from .core.core import ExecuteProcessor as ExecuteProcessor from .core.core import ExecutePipeline as ExecutePipeline -from .core.echo import EchoMiddleware as EchoMiddleware -from .core.echo import EchoEngine as EchoEngine -from .core.tee import TeeEngine as TeeEngine +from .core.tools import EchoProcessor as EchoProcessor +from .core.tools import TeeProcessor as TeeProcessor -from .engine.stream_engine import StreamEngine as StreamEngine -from .engine.asr_engine import ASREngine as ASREngine -from .engine.tts_engine import TTSEngine as TTSEngine -from .engine.nlu_engine import NLUEngine as NLUEngine -from .engine.chat_ai_engine import ChatAIEngine as ChatAIEngine +# from .engine.stream_engine import StreamEngine as StreamEngine +# from .engine.asr_engine import ASREngine as ASREngine +# from .engine.tts_engine import TTSEngine as TTSEngine +# from .engine.nlu_engine import NLUEngine as NLUEngine +# from .engine.chat_ai_engine import ChatAIEngine as ChatAIEngine - -from .plugins.dynamic_package_import import dynamic_package_import as dynamic_package_import -from .plugins.plugins_conf import PluginsConfig as PluginsConfig -from .plugins.plugins import Plugins as Plugins +# from .plugins.dynamic_package_import import dynamic_package_import as dynamic_package_import +# from .plugins.plugins_conf import PluginsConfig as PluginsConfig +# from .plugins.plugins import Plugins as Plugins # from .utils.logger import setup_logger as setup_logger import logging diff --git a/src/core/__init__.py b/src/core/__init__.py deleted file mode 100644 index 0947334..0000000 --- a/src/core/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .interface import MiddlewareInterface as MiddlewareInterface -from .interface import EngineInterface as EngineInterface - -from .core import Middleware as Middleware -from .core import Engine as Engine \ No newline at end of file diff --git a/src/core/core.py b/src/core/core.py index dd50b5d..67aaf50 100644 --- a/src/core/core.py +++ b/src/core/core.py @@ -1,168 +1,483 @@ -from .interface import EngineInterface, MiddlewareInterface -from typing import Any, Callable, List, Optional, Union - +from abc import ABC, abstractmethod +import asyncio +from typing import Any, AsyncGenerator, Coroutine from logging import getLogger logger = getLogger(__name__) -class Middleware(MiddlewareInterface): +class ExecuteDataStream: """ - 中间件基类 - 注意 next_middleware 参数是可选的,如果未提供,可能有两种模式 + Represents a stream of data with associated metadata. - 1.如果process返回值为None(只有next_middleware为None返回值才为None) - 则返回数据将在Pipeline里返回,否则继续执行Pipeline的下一个中间件。 + This class provides methods to manage a stream of data, including reading and writing data + asynchronously. It uses an asyncio Queue to store data items and manages metadata such as + the name of the stream and its operational status. - 2.中间件的process返回值为process_logic处理后的值 + Attributes: + - __metadata: A dictionary containing metadata about the data stream. + - __data_queue: An asyncio Queue for storing data items. + + 表示了含有元数据的数据流。 + + 该类提供了用于管理数据流的方法,包括异步读写数据。它使用 asyncio 队列来存储数据项, + 并管理诸如数据流名称和操作状态等元数据。 + + 属性: + - __metadata: 包含关于数据流元数据的字典。 + - __data_queue: 用于存储数据项的 asyncio 队列。 """ - def __init__(self, process_logic: Callable[[Any], Any], - next_middleware: Optional[List[MiddlewareInterface]] = []): - self._process_logic = process_logic - self._next_middleware = next_middleware - self._data = None + def __init__(self, + name: str = "ExecuteDataStream", + maxsize: int = 8): + """ + Initialize the ExecuteDataStream with an empty data buffer and metadata dictionary. + + Parameters: + - name: str, the name of the ExecuteDataStream, defaults to "ExecuteDataStream". + - maxsize: int, the maximum size of the data queue, defaults to 8. + + 初始化 ExecuteDataStream 类的一个实例,创建一个空的数据缓冲区和元数据字典。 + + 参数: + - name: str, 数据流的名字,默认为 "ExecuteDataStream"。 + - maxsize: int, 数据队列的最大容量,默认为 8。 + """ + self.__metadata = { + "__name__": name, + "__running__": True, + "__open_life__": True, + } + self.__data_queue = asyncio.Queue(maxsize) + + async def iter_read(self, timeout: float | None = None) -> AsyncGenerator: + """ + Asynchronously iterate over data in the stream, yielding each item as it becomes available. + + Parameters: + - timeout: float | None, the time to wait for data before timing out. + + Returns: + - AsyncGenerator: An asynchronous generator that yields data items as they become available. + + 异步迭代数据流中的数据,当数据可用时生成每个数据项(阻塞等待数据)。 + + 参数: + - timeout: float | None, 等待数据的时间,超时则停止等待。 + + 返回: + - AsyncGenerator: 当数据可用时生成数据项的异步生成器。 + """ + while True: + res = await self.read(timeout) + if res is None: + break + yield res + + async def read(self, timeout: float | None = None) -> Any | None: + """ + Asynchronously read data from the data stream. + + Parameters: + - timeout: float | None, the time to wait for data before timing out. + + Returns: + - Any | None: The data item or None if no data is available within the timeout period. + + 异步从数据流中读取数据。 + + 参数: + - timeout: float | None, 等待数据的时间,超时则停止等待。 + + 返回: + - Any | None: 数据项或 None ,如果在超时期间内没有数据可用。 + """ + if not self.__metadata.get('__running__', False): + return None - def process(self, data: Any) -> Any: - """处理数据并返回处理后的结果,支持自动调用下一个中间件 - (如果next_middleware为None则返回None)""" try: - data = self._process_logic(data) - self._data = data - if self._next_middleware: - for middleware in self._next_middleware: - self._data = middleware.process(self._data) - return None if self._next_middleware is None else data + data = await asyncio.wait_for(self.__data_queue.get(), timeout) + return data.get("data", None) + except asyncio.TimeoutError: + return None + + async def write(self, + data: Any, + timeout: float | None = None, + life_time: int = -1) -> bool: + """ + Asynchronously write data to the data stream. + + #### Note: Please do not write None as it will be considered an error and thus discarded. + + Parameters: + - data: Any, the data to be written. + - timeout: float | None, the time to wait for writing before timing out. + - life_time: int, the number of iterations the data should remain in the stream. + + Returns: + - bool: True if the data was successfully written, False otherwise. + + 异步向数据流中写入数据。 + + #### 注意请不要写入 None 因为 None 会被认为是 Error ,所以会被抛弃 。 + + 参数: + - data: Any, 要写入的数据。 + - timeout: float | None, 写入等待时间,超时则停止等待。 + - life_time: int, 数据在数据流中的存活周期数。 + + 返回: + - bool: 如果数据成功写入则返回 True ,否则返回 False。 + """ + if not self.__metadata.get('__running__', False): + return False + try: + if life_time == 0 or data is None: + return False + if self.__metadata.get('__open_life__', False): + life_time -= 1 + await asyncio.wait_for(self.__data_queue.put({ + 'data': data, + 'life': life_time, + }), timeout) + return True + except asyncio.TimeoutError: + return False + + +class ExecuteProcessor(ABC): + """ + An abstract base class for processing data streams. + + This class represents a processor that can handle incoming data from one stream and send processed data to another stream. + + Usage: + - Inherit from this class and implement the `execute` method. + - Call the `set_stream` method to set `read_stream` and `write_stream`. + - Call the `process` method to start the processor (if there are custom requirements, you can refactor this function, but the refactoring must comply with the interface standard). + + Attributes: + - _read_stream: (Should not be used directly) An instance of `ExecuteDataStream` for reading data. + - _write_stream: (Should not be used directly) An instance of `ExecuteDataStream` for writing data. + + 一个抽象基类,用于处理数据流。 + + 该类表示一个处理器,可以从一个数据流读取数据,并将处理后的数据发送到另一个数据流。 + + 使用方法: + - 继承该类并实现 `execute` 方法。 + - 调用 `set_stream` 方法设置 `read_stream` 和 `write_stream`。 + - 调用 `process` 方法启动处理器(如果有定制化要求可重构该函数,但重构必须符合接口标准)。 + + 属性: + - _read_stream: (不应当直接使用)用于读取数据的 `ExecuteDataStream` 实例。 + - _write_stream: (不应当直接使用)用于写入数据的 `ExecuteDataStream` 实例。 + """ + def __init__(self): + """ + Initialize the `ExecuteProcessor` with default read and write streams set to `None`. + + 初始化 `ExecuteProcessor` 类的一个实例,将默认的读取和写入数据流设置为 `None`。 + """ + self._read_stream: ExecuteDataStream | None = None + self._write_stream: ExecuteDataStream | None = None + + def set_stream(self, + read_stream: ExecuteDataStream | None = None, + write_stream: ExecuteDataStream | None = None): + """ + Set the read and write streams for this processor. + + #### Note: Once `read_stream` and `write_stream` are set, this instance should not accept new settings. + + Parameters: + - read_stream: An instance of `ExecuteDataStream` for reading data. + - write_stream: An instance of `ExecuteDataStream` for writing data. + + 设置此处理器的读取和写入数据流。 + + #### 注意一旦设置了 `read_stream` 和 `write_stream` ,则该实例将不应该接受新的设置。 + + 参数: + - read_stream: 用于读取数据的 `ExecuteDataStream` 实例。 + - write_stream: 用于写入数据的 `ExecuteDataStream` 实例。 + """ + self._read_stream = read_stream or self._read_stream + self._write_stream = write_stream or self._write_stream + + async def _iter_process(self): + """ + Iterate over the data in the read stream and process it, then write the results to the write stream. + + #### Note: If the `process` method is refactored in the future, this function may serve as a scaffold. + + 迭代读取数据流中的数据并对其进行处理,然后将结果写入写入数据流。 + + #### 注意如果未来重构 `process` 方法,可能需要该函数作为脚手架。 + """ + if self._read_stream is None or self._write_stream is None: + raise ValueError("read_stream or write_stream is None") + try: + async for data in self._read_stream.iter_read(): + result = await self.execute(data) + if isinstance(result, AsyncGenerator): + async for res in result: + await self._write_stream.write(res) + await asyncio.sleep(0) + else: + await self._write_stream.write(result) + await asyncio.sleep(0) except Exception as e: - raise e + logger.exception(f"An error occurred during processing: {e}") + # await self._write_stream.write(e) - def get_next(self) -> Optional[List[MiddlewareInterface]]: - return self._next_middleware - -class Engine(EngineInterface): - """ - 引擎基类 - - 注意在继承该类时需要注意使用的时特有的格式,方便上下文传输 - - 下面示例给出继承时的基本操作 - ```python - class MyEngine(Engine): - def __init__(is_stream: bool = False): - super().__init__(is_stream) - - def execute(date): - # you need to do something - pass - ``` - - 高级操作 - ```python - class MyEngine(Engine): - def __init__(is_stream: bool = False): - super().__init__(is_stream) - - def prepare_input(data): - return super().prepare_input(data) - - def process_output(data): - return super().process_output(data) - - def execute(date): - # you need to do something - pass - ``` - - """ - def __init__(self) -> None: - self._metadata = None - - def prepare_input(self, data: Any) -> Any: - self._metadata = data.get('__metadata__', None) - return data['__data__'] - - def process_output(self, data: Any) -> Any: - return {'__data__': data, '__metadata__': self._metadata} - - def execute(self, data: Any) -> Any: - raise NotImplementedError("execute method not implemented") - - def to_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: - return Middleware(lambda data: - self.process_output( - self.execute( - self.prepare_input(data))), next_middleware) - - def to_prepare_input_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: - return Middleware(self.prepare_input, next_middleware) - def to_process_output_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: - return Middleware(self.process_output, next_middleware) - def to_execute_middleware(self, next_middleware: Optional[List[MiddlewareInterface]] = []) -> MiddlewareInterface: - return Middleware(self.execute, next_middleware) - -class ExecutePipeline: - """ - 执行流 用于按顺序执行中间件 - - 示例: - ```python - pipeline = ExecutePipeline(middleware1, middleware2) - pipeline.add(middleware3) - result = pipeline.execute(data) - ``` - 执行顺序 data -> middleware1 -> middleware2 -> middleware3 -> result - - ```python - pipeline.reorder([0,2,1]) - result = pipeline.execute(data) - ``` - 执行顺序 data -> middleware1 -> middleware3 -> middleware2 -> result - - ```python - pipeline.reorder([0,1]) - result = pipeline.execute(data) - ``` - 执行顺序 data -> middleware1 -> middleware2 -> result - """ - def __init__(self, *components: Union[MiddlewareInterface, EngineInterface]): - self.middlewares: List[MiddlewareInterface] = [] - for component in components: - if isinstance(component, EngineInterface): - self.middlewares.append(component.to_middleware()) - elif isinstance(component, MiddlewareInterface): - self.middlewares.append(component) - else: - raise TypeError(f"Unsupported type {type(component)}") - - def add(self, middleware: MiddlewareInterface): - """添加中间件到执行流里""" - self.middlewares.append(middleware) - - def reorder(self, new_order: List[int]): - """重新排序中间件,按照new_order指定的索引顺序排列,注意可以实现删除功能""" - self.middlewares = [self.middlewares[i] for i in new_order] - - def execute(self, data): - """执行中间件 - - 注意中间件的next不参与主执行流 - - 注意只有中间件的next为None,返回值才不会被输入到执行流中,即停止执行后续内容 - - TODO: 1. 添加中间件执行失败时的处理机制 + @abstractmethod + async def execute(self, data) -> AsyncGenerator[Any, Any] | Coroutine[Any, Any, Any]: """ - try: - for middleware in self.middlewares: - if data is None: - break - data = middleware.process(data) - return data - except ValueError: - return None + Process the given data. - def execute_engines(self, data: Any, metadata = None) -> Any: + Parameters: + - data: The data to be processed. + + Returns: + - AsyncGenerator | Any: The processed data or an asynchronous generator of processed data. + + 处理给定的数据。 + + 参数: + - data: 要处理的数据。 + + 返回: + - AsyncGenerator | Any: 处理后的数据或处理后数据的异步生成器。 """ - 执行引擎 + pass + + async def process(self): """ - res = self.execute({'__data__': data, '__metadata__': metadata}) - if not res: - return None - else: - return res['__data__'] + Future entry point for starting the processor. + + 未来启动处理器的调用入口。 + """ + await self._iter_process() + + +class ExecutePipeline(ExecuteProcessor): + """ + A class representing a pipeline of processors. + + This class manages a sequence of `ExecuteProcessor` instances and coordinates their execution. + + Attributes: + - _processors: A list of `ExecuteProcessor` instances. + - _streams: A list of `ExecuteDataStream` instances, including the first and last streams which can be set externally. + - _tasks: A list of asyncio Tasks representing the execution of each processor. + + 一个表示处理器管道的类。 + + 该类管理一系列 `ExecuteProcessor` 实例,并协调它们的执行。 + + 属性: + - _processors: `ExecuteProcessor` 实例列表。 + - _streams: `ExecuteDataStream` 实例列表,包括可以外部设置的第一个和最后一个数据流。 + - _tasks: 表示每个处理器执行的 asyncio Task 列表。 + """ + + def __init__(self, processors: list[ExecuteProcessor]): + """ + Initialize the `ExecutePipeline` with a list of processors. + + Parameters: + - processors: A list of `ExecuteProcessor` instances. + + 使用处理器列表初始化 `ExecutePipeline`。 + + 参数: + - processors: `ExecuteProcessor` 实例列表。 + """ + self._processors = processors + self._streams = [None] + [ExecuteDataStream() for _ in range(len(self._processors) - 1)] + [None] + self._tasks:list[asyncio.Task] = [] + self.__is_executed = False + + def set_stream(self, + read_stream: ExecuteDataStream | None = None, + write_stream: ExecuteDataStream | None = None): + """ + Set the input and output streams for the pipeline. + + Parameters: + - read_stream: The input stream for the pipeline. + - write_stream: The output stream for the pipeline. + + 设置管道的输入和输出数据流。 + + 参数: + - read_stream: 管道的输入数据流。 + - write_stream: 管道的输出数据流。 + """ + self._streams[0] = write_stream or self._streams[0] + self._streams[-1] = read_stream or self._streams[-1] + + def _get_stream(self, index: int) -> ExecuteDataStream: + """ + Get the stream at the specified index. + + Parameters: + - index: The index of the stream to retrieve. + + Returns: + - ExecuteDataStream: The stream at the specified index. + + 获取指定索引处的数据流。 + + 参数: + - index: 要获取的数据流的索引。 + + 返回: + - ExecuteDataStream: 指定索引处的数据流。 + """ + stream = self._streams[index] + if stream is None: + raise ValueError("Stream not found") + return stream + + async def write(self, + data, + index = 0, + timeout = None, + life_time = -1): + """ + Write data to the pipeline's input data stream. + + Parameters: + - data: The data to write. + - index: The index of the stream to write to. Default is 0 (the first stream). + - timeout: The timeout for writing the data. + - life_time: The life time of the data. + + 将数据写入管道的输入数据流。 + + 参数: + - data: 要写入的数据。 + - index: 写入数据的数据流索引,默认为 0 (第一个数据流)。 + - timeout: 写入数据的超时时间。 + - life_time: 数据的生命周期。 + """ + await self._get_stream(index).write(data, timeout, life_time) + + async def read(self, index: int = -1, timeout: float | None = None): + """ + Read data from the last processor's data stream. + + Parameters: + - index: The index of the stream to read from. Default is -1 (the last stream). + - timeout: The timeout for reading the data. + + Returns: + - Any: The data read from the stream. + + 从最后一个处理器的数据流中读取数据。 + + 参数: + - index: 读取数据的数据流索引,默认为 -1 (最后一个数据流)。 + - timeout: 读取数据的超时时间。 + + 返回: + - Any: 从数据流中读取的数据。 + """ + return await self._get_stream(index).read(timeout) + + def iter_read(self, index: int = -1, timeout: float | None = None): + """ + Iterate over the data in the specified stream. + + Parameters: + - index: The index of the stream to iterate over. Default is -1 (the last stream). + - timeout: The timeout for reading the data. + + Returns: + - AsyncIterator: An asynchronous iterator over the data in the stream. + + 迭代指定数据流中的数据。 + + 参数: + - index: 迭代数据的数据流索引,默认为 -1 (最后一个数据流)。 + - timeout: 读取数据的超时时间。 + + 返回: + - AsyncIterator: 对数据流中的数据进行迭代的异步迭代器。 + """ + return self._get_stream(index).iter_read(timeout) + + def cancel(self): + """ + Cancel all running tasks in the pipeline. + + 取消管道中所有正在运行的任务。 + """ + if self._tasks is None: + return + for task in self._tasks: + task.cancel() + + def start(self): + """ + Start the pipeline by executing the `execute` method. + + 通过执行 `execute` 方法来启动管道。 + """ + self.set_stream(ExecuteDataStream(), ExecuteDataStream()) + self.execute() + return self + + def execute(self): + """ + Execute the pipeline by setting up the streams for each processor and creating tasks for their execution. + + 通过为每个处理器设置数据流并为其执行创建任务来执行管道。 + """ + if self.__is_executed: + return + self.__is_executed = True + for i, processor in enumerate(self._processors): + processor.set_stream(self._get_stream(i), self._get_stream(i + 1)) + self._tasks.append( + asyncio.create_task(processor.process()) + ) + + async def process(self): + """ + Execute the pipeline and wait for all tasks to complete. + + 执行管道并等待所有任务完成。 + """ + self.execute() + await asyncio.gather(*self._tasks) + +# 示例处理器 +class ExampleProcessor(ExecuteProcessor): + + async def execute(self, data): + return data + " processed" + +async def main(): + processors:list[ExecuteProcessor] = [ExampleProcessor()] + pipeline = ExecutePipeline(processors) + pipeline.start() + + input_data = "Hello, world!" + await pipeline.write(input_data) + print("Pipeline started.") + async def fun(): + async for data in pipeline.iter_read(): + print(data) + task = asyncio.create_task(fun()) + + while True: + data = input("input data: ") + await pipeline.write(data) + await asyncio.sleep(0) + # print(await pipeline.read()) + +# 使用示例 +if __name__ == "__main__": + asyncio.run(main()) \ No newline at end of file diff --git a/src/core/echo.py b/src/core/echo.py deleted file mode 100644 index e86ef68..0000000 --- a/src/core/echo.py +++ /dev/null @@ -1,33 +0,0 @@ -import sys -from typing import Generator, Iterator, List, Optional -from .. import Middleware, MiddlewareInterface, Engine - -class EchoMiddleware(Middleware): - def __init__(self, next_middleware: Optional[List[MiddlewareInterface]] = []): - super().__init__(self._process, next_middleware) - def _process(self, data): - if isinstance(data, Iterator): - ret = '' - for i in data: - ret += i - sys.stdout.write(i) - sys.stdout.flush() - sys.stdout.write('\n') - else: - ret = data - print(data) - return ret - -class EchoEngine(Engine): - def execute(self, data): - if isinstance(data, Generator): - ret = '' - for i in data: - ret += i - sys.stdout.write(i) - sys.stdout.flush() - sys.stdout.write('\n') - else: - ret = data - print(data) - return ret \ No newline at end of file diff --git a/src/core/interface.py b/src/core/interface.py index 3957e3f..e69de29 100644 --- a/src/core/interface.py +++ b/src/core/interface.py @@ -1,55 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, List, Optional - -class MiddlewareInterface(ABC): - """ - 中间件接口 - """ - @abstractmethod - def process(self, data: Any) -> Any: - """处理数据并返回处理后的结果""" - pass - - @abstractmethod - def get_next(self) -> Optional[List['MiddlewareInterface']]: - """获取下一个中间件""" - pass - -class EngineInterface(ABC): - """ - 引擎接口 - """ - @abstractmethod - def execute(self, data: Any) -> Any: - """执行引擎逻辑并返回结果""" - pass - - @abstractmethod - def prepare_input(self, data: Any) -> Any: - """将输入引擎数据剔除元数据""" - pass - - @abstractmethod - def process_output(self, data: Any) -> Any: - """将输出数据转换为引擎的格式""" - pass - - @abstractmethod - def to_execute_middleware(self) -> MiddlewareInterface: - """将引擎具体执行过程转换为遵循MiddlewareInterface的中间件.""" - pass - - @abstractmethod - def to_prepare_input_middleware(self) -> MiddlewareInterface: - """将原生引擎数据转换为传输的数据.""" - pass - - @abstractmethod - def to_process_output_middleware(self) -> MiddlewareInterface: - """将输出数据包装成原生引擎数据.""" - pass - - @abstractmethod - def to_middleware(self) -> MiddlewareInterface: - """将引擎执行流转换为遵循MiddlewareInterface的中间件.""" - pass \ No newline at end of file diff --git a/src/core/tee.py b/src/core/tee.py deleted file mode 100644 index 1437241..0000000 --- a/src/core/tee.py +++ /dev/null @@ -1,19 +0,0 @@ -import sys -from typing import Iterator, List, Optional -from .. import Middleware, MiddlewareInterface, Engine - -class TeeEngine(Engine): - def execute(self, data): - if isinstance(data, Iterator): - def _(): - for item in data: - sys.stdout.write(str(item)) - sys.stdout.flush() - yield item - sys.stdout.write('\n') - sys.stdout.flush() - ret = _() - else: - ret = data - print(data) - return ret \ No newline at end of file diff --git a/src/core/tools.py b/src/core/tools.py new file mode 100644 index 0000000..50c11aa --- /dev/null +++ b/src/core/tools.py @@ -0,0 +1,15 @@ +from sys import stdout +from typing import Any, AsyncGenerator +from .core import ExecuteProcessor + +class EchoProcessor(ExecuteProcessor): + async def execute(self, data): + stdout.write(data) + stdout.flush() + return None + +class TeeProcessor(ExecuteProcessor): + async def execute(self, data): + stdout.write(data) + stdout.flush() + return data \ No newline at end of file diff --git a/src/offical/requests/__init__.py b/src/offical/requests/__init__.py index ec76656..b732b29 100644 --- a/src/offical/requests/__init__.py +++ b/src/offical/requests/__init__.py @@ -1,5 +1,5 @@ -from src import dynamic_package_import -dynamic_package_import([ - ('requests', None), - ]) \ No newline at end of file +# from src import dynamic_package_import +# dynamic_package_import([ +# ('requests', None), +# ]) \ No newline at end of file diff --git a/src/offical/requests/remote_engine.py b/src/offical/requests/remote_engine.py index dec1a24..04b0785 100644 --- a/src/offical/requests/remote_engine.py +++ b/src/offical/requests/remote_engine.py @@ -2,6 +2,7 @@ from logging import getLogger from typing import Any, Dict logger = getLogger(__name__) import requests +import httpx class RemoteEngine(): """基础远程服务类, 提供远程调用的基础框架。""" @@ -14,14 +15,11 @@ class RemoteEngine(): 'Authorization': 'Bearer ' + self._api_key } - def post_prompt(self, messages : Dict, is_stream : bool = False) -> requests.Response: + async def post_prompt(self, messages : Dict) -> httpx.Response: try: - response = requests.post(self.url, headers=self.header, json=messages, stream=is_stream) - if (response.status_code != 200): - logger.error(f"post_prompt error: {response.text}") - raise Exception(f"post_prompt error: {response.text}") + async with httpx.AsyncClient() as client: + response = await client.post(self.url, headers=self.header, json=messages) + return response.raise_for_status() except Exception as e: logger.exception("post_prompt") - raise e - finally: - return response \ No newline at end of file + raise e \ No newline at end of file diff --git a/src/offical/requests/tongyiqianwen.py b/src/offical/requests/tongyiqianwen.py index 7b0df78..e6ec167 100644 --- a/src/offical/requests/tongyiqianwen.py +++ b/src/offical/requests/tongyiqianwen.py @@ -1,35 +1,40 @@ # https://dashscope.console.aliyun.com/model +import asyncio +from concurrent.futures import ThreadPoolExecutor import json -from typing import Any, Generator, Iterator +from typing import Any, AsyncGenerator, AsyncIterator, Generator, Iterator, override from .remote_engine import RemoteEngine -from src import ChatAIEngine as ChatAIEngine -from src import StreamEngine as StreamEngine +from src import ExecuteProcessor as ExecuteProcessor from logging import getLogger logger = getLogger(__name__) -class ChatAiTongYiQianWen(ChatAIEngine, RemoteEngine, StreamEngine): +class ChatAiTongYiQianWen(ExecuteProcessor): API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation" def __init__(self, api_key : str, _model : str = "qwen-1.8b-chat", *args, **kwargs) -> None: + self._modual_name = kwargs.get("model", _model) - StreamEngine.__init__(self, use_stream_api=kwargs.get("use_stream_api", False), return_as_stream=kwargs.get("return_as_stream", False)) - ChatAIEngine.__init__(self) - RemoteEngine.__init__(self, api_key=api_key, - base_url=self.API_URL) + self.use_stream_api=kwargs.get("use_stream_api", False) + self.remote_engine = RemoteEngine(api_key=api_key, + base_url=self.API_URL) def _transform_raw(self, response): return response["output"]["choices"][0]["message"]["content"] - def _transform_iterator(self, response): - for chunk in response: - input_string = chunk.decode('utf-8') - json_string = input_string[input_string.find("data:") + 5:].strip() - yield self._transform_raw(json.loads(json_string)) + async def _transform_iterator(self, response: AsyncIterator[str]): + current_msg = '' + async for line in response: + if line.startswith("data:"): + current_msg += line[5:].strip() + elif len(line) == 0: + if current_msg: + yield self._transform_raw(json.loads(current_msg)) + current_msg = '' - def execute_stream(self, data) -> Generator[Any, None, None] | Iterator: - self.header['X-DashScope-SSE'] = 'enable' - response = self.post_prompt({ + async def execute_stream(self, data): + self.remote_engine.header['X-DashScope-SSE'] = 'enable' + response = await self.remote_engine.post_prompt({ 'model': self._modual_name, "input": { "messages": data @@ -38,14 +43,11 @@ class ChatAiTongYiQianWen(ChatAIEngine, RemoteEngine, StreamEngine): "result_format": "message", "incremental_output": True } - }, - is_stream=True) - - ret = response.iter_content(chunk_size=None) - return self._transform_iterator(ret) + }) + return self._transform_iterator(response.aiter_lines()) - def execute_nonstream(self, data) -> Any: - response = self.post_prompt({ + async def execute_nonstream(self, data) -> Any: + response = await self.remote_engine.post_prompt({ 'model': self._modual_name, "input": { "messages": data @@ -54,11 +56,19 @@ class ChatAiTongYiQianWen(ChatAIEngine, RemoteEngine, StreamEngine): "result_format": "message", "incremental_output": False } - }, - is_stream=False) + }) ret = response.json() return self._transform_raw(ret) def prepare_input(self, data: Any) -> Any: - data['__data__'] = {'role': 'user', 'content': f'{data['__data__']}'} - return super().prepare_input(data) \ No newline at end of file + return [{'role': 'user', 'content': f'{data}'}] + + @override + async def execute(self, data): + data = self.prepare_input(data) + executor = ThreadPoolExecutor(max_workers=1) + if self.use_stream_api: + res = await self.execute_stream(data) + else: + res = await self.execute_nonstream(data) + return res \ No newline at end of file diff --git a/tests/core/test_echo.py b/tests/core/test_echo.py index 5981032..3978eb7 100644 --- a/tests/core/test_echo.py +++ b/tests/core/test_echo.py @@ -1,11 +1,11 @@ import unittest -from src.core import echo +from src.core import tools from io import StringIO import sys class TestEcho(unittest.TestCase): def test_base(self): - echoMid = echo.EchoMiddleWare() + echoMid = tools.EchoMiddleWare() captured_out = StringIO() sys.stdout = captured_out echoMid.process('hello')