refactor(main): 使用异步流处理替换原有的顺序执行
重构`main`函数,将原先的顺序执行模型改为基于异步流的处理方式,利用`aiostream`库实现更高效的并发数据处理。
This commit is contained in:
parent
6cc18ed15c
commit
6120e7ba57
64
main.py
64
main.py
@ -1,50 +1,34 @@
|
||||
# from src import Plugins
|
||||
from src import ExecutePipeline
|
||||
from src import EchoProcessor
|
||||
# from src import ExecutePipeline
|
||||
# from src import EchoProcessor
|
||||
import os
|
||||
from viztracer import VizTracer
|
||||
|
||||
async def main():
|
||||
from pathlib import Path
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
|
||||
# plug_conf = Plugins(Path(__file__).resolve().parent)
|
||||
from src.offical.requests.tongyiqianwen import ChatAiTongYiQianWen
|
||||
pipe = ExecutePipeline([
|
||||
ChatAiTongYiQianWen(api_key='sk-ab4a3e0e29d54ebaad560c1472933d41', use_stream_api=True),
|
||||
EchoProcessor()
|
||||
]).start()
|
||||
# # plug_conf.load_engine("asr_engine"),
|
||||
# # MyEngine(),
|
||||
# plug_conf.load_engine("chat_ai_engine"),
|
||||
# TeeProcessor(),
|
||||
# plug_conf.load_engine("tts_engine"),
|
||||
# plug_conf.load_engine("sounds_play_engine"),
|
||||
# EchoEngine()
|
||||
# exe_input = './tests/offical/sounds/asr_example.wav'
|
||||
# exe_input = input('input: ')
|
||||
from pathlib import Path
|
||||
from src.dashscope.tongyiqianwen import CAITongYiQianWen as CAI
|
||||
from src.dashscope.sambert import TTSSambert as TTS
|
||||
from src.pyaudio.sounds_play_engine import SoundsPlayEngine as sounds
|
||||
from aiostream import stream, pipe
|
||||
|
||||
async def main():
|
||||
|
||||
cai = CAI(api_key=os.environ.get("DASH_SCOPE_API_KEY", ""), use_stream_api=True)
|
||||
tts = TTS(api_key=os.environ.get("DASH_SCOPE_API_KEY", ""))
|
||||
exe_input = '给我一个100字的文章'
|
||||
await pipe.write(exe_input)
|
||||
await asyncio.sleep(0)
|
||||
# loop.run_in_executor(None, inputdata)
|
||||
loop = asyncio.get_running_loop()
|
||||
def inputdata(pipe, loop):
|
||||
while True:
|
||||
try:
|
||||
data = input('input :')
|
||||
asyncio.run_coroutine_threadsafe(pipe.write(data), loop)
|
||||
except KeyboardInterrupt:
|
||||
pipe.cancel()
|
||||
break
|
||||
|
||||
with VizTracer(log_async=True):
|
||||
thread = threading.Thread(target=inputdata, args=(pipe,loop,))
|
||||
thread.start()
|
||||
await pipe.process()
|
||||
thread.join()
|
||||
# asyncio.gather([i for i in pipe._tasks])
|
||||
# await asyncio.to_thread(inputdata, pipe)
|
||||
s = (
|
||||
stream.iterate(await cai.execute_stream(cai.prepare_input(exe_input)))
|
||||
| pipe.map(cai.transform_chunks)
|
||||
| pipe.print()
|
||||
| pipe.flatmap(tts.execute_stream)
|
||||
| pipe.map(sounds().execute)
|
||||
)
|
||||
await s
|
||||
|
||||
if __name__ == '__main__':
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv()
|
||||
import asyncio
|
||||
import threading
|
||||
asyncio.run(main())
|
||||
|
@ -1,8 +1,8 @@
|
||||
from src import Engine
|
||||
from aiostream import stream, pipe, Stream
|
||||
|
||||
class MyEngine(Engine):
|
||||
class MyExecute():
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
pass
|
||||
|
||||
def execute(self, data):
|
||||
print(data)
|
||||
pass
|
@ -5,5 +5,13 @@ python-dotenv
|
||||
# server optional
|
||||
# tornado
|
||||
|
||||
# async part
|
||||
aiohttp
|
||||
aiostream
|
||||
|
||||
# debug part
|
||||
viztracer
|
||||
|
||||
# optional
|
||||
tqdm
|
||||
rich # optional
|
@ -1,23 +1,3 @@
|
||||
# from .core.interface import MiddlewareInterface as MiddlewareInterface
|
||||
# from .core.interface import EngineInterface as EngineInterface
|
||||
|
||||
from .core.core import ExecuteDataStream as ExecuteDataStream
|
||||
from .core.core import ExecuteProcessor as ExecuteProcessor
|
||||
from .core.core import ExecutePipeline as ExecutePipeline
|
||||
from .core.tools import EchoProcessor as EchoProcessor
|
||||
from .core.tools import TeeProcessor as TeeProcessor
|
||||
|
||||
# from .engine.stream_engine import StreamEngine as StreamEngine
|
||||
# from .engine.asr_engine import ASREngine as ASREngine
|
||||
# from .engine.tts_engine import TTSEngine as TTSEngine
|
||||
# from .engine.nlu_engine import NLUEngine as NLUEngine
|
||||
# from .engine.chat_ai_engine import ChatAIEngine as ChatAIEngine
|
||||
|
||||
# from .plugins.dynamic_package_import import dynamic_package_import as dynamic_package_import
|
||||
# from .plugins.plugins_conf import PluginsConfig as PluginsConfig
|
||||
# from .plugins.plugins import Plugins as Plugins
|
||||
|
||||
# from .utils.logger import setup_logger as setup_logger
|
||||
import logging
|
||||
from sys import stdout
|
||||
logger = logging.getLogger(__name__)
|
||||
|
483
src/core/core.py
483
src/core/core.py
@ -1,483 +0,0 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from typing import Any, AsyncGenerator, Coroutine
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class ExecuteDataStream:
|
||||
"""
|
||||
Represents a stream of data with associated metadata.
|
||||
|
||||
This class provides methods to manage a stream of data, including reading and writing data
|
||||
asynchronously. It uses an asyncio Queue to store data items and manages metadata such as
|
||||
the name of the stream and its operational status.
|
||||
|
||||
Attributes:
|
||||
- __metadata: A dictionary containing metadata about the data stream.
|
||||
- __data_queue: An asyncio Queue for storing data items.
|
||||
|
||||
表示了含有元数据的数据流。
|
||||
|
||||
该类提供了用于管理数据流的方法,包括异步读写数据。它使用 asyncio 队列来存储数据项,
|
||||
并管理诸如数据流名称和操作状态等元数据。
|
||||
|
||||
属性:
|
||||
- __metadata: 包含关于数据流元数据的字典。
|
||||
- __data_queue: 用于存储数据项的 asyncio 队列。
|
||||
"""
|
||||
def __init__(self,
|
||||
name: str = "ExecuteDataStream",
|
||||
maxsize: int = 8):
|
||||
"""
|
||||
Initialize the ExecuteDataStream with an empty data buffer and metadata dictionary.
|
||||
|
||||
Parameters:
|
||||
- name: str, the name of the ExecuteDataStream, defaults to "ExecuteDataStream".
|
||||
- maxsize: int, the maximum size of the data queue, defaults to 8.
|
||||
|
||||
初始化 ExecuteDataStream 类的一个实例,创建一个空的数据缓冲区和元数据字典。
|
||||
|
||||
参数:
|
||||
- name: str, 数据流的名字,默认为 "ExecuteDataStream"。
|
||||
- maxsize: int, 数据队列的最大容量,默认为 8。
|
||||
"""
|
||||
self.__metadata = {
|
||||
"__name__": name,
|
||||
"__running__": True,
|
||||
"__open_life__": True,
|
||||
}
|
||||
self.__data_queue = asyncio.Queue(maxsize)
|
||||
|
||||
async def iter_read(self, timeout: float | None = None) -> AsyncGenerator:
|
||||
"""
|
||||
Asynchronously iterate over data in the stream, yielding each item as it becomes available.
|
||||
|
||||
Parameters:
|
||||
- timeout: float | None, the time to wait for data before timing out.
|
||||
|
||||
Returns:
|
||||
- AsyncGenerator: An asynchronous generator that yields data items as they become available.
|
||||
|
||||
异步迭代数据流中的数据,当数据可用时生成每个数据项(阻塞等待数据)。
|
||||
|
||||
参数:
|
||||
- timeout: float | None, 等待数据的时间,超时则停止等待。
|
||||
|
||||
返回:
|
||||
- AsyncGenerator: 当数据可用时生成数据项的异步生成器。
|
||||
"""
|
||||
while True:
|
||||
res = await self.read(timeout)
|
||||
if res is None:
|
||||
break
|
||||
yield res
|
||||
|
||||
async def read(self, timeout: float | None = None) -> Any | None:
|
||||
"""
|
||||
Asynchronously read data from the data stream.
|
||||
|
||||
Parameters:
|
||||
- timeout: float | None, the time to wait for data before timing out.
|
||||
|
||||
Returns:
|
||||
- Any | None: The data item or None if no data is available within the timeout period.
|
||||
|
||||
异步从数据流中读取数据。
|
||||
|
||||
参数:
|
||||
- timeout: float | None, 等待数据的时间,超时则停止等待。
|
||||
|
||||
返回:
|
||||
- Any | None: 数据项或 None ,如果在超时期间内没有数据可用。
|
||||
"""
|
||||
if not self.__metadata.get('__running__', False):
|
||||
return None
|
||||
|
||||
try:
|
||||
data = await asyncio.wait_for(self.__data_queue.get(), timeout)
|
||||
return data.get("data", None)
|
||||
except asyncio.TimeoutError:
|
||||
return None
|
||||
|
||||
async def write(self,
|
||||
data: Any,
|
||||
timeout: float | None = None,
|
||||
life_time: int = -1) -> bool:
|
||||
"""
|
||||
Asynchronously write data to the data stream.
|
||||
|
||||
#### Note: Please do not write None as it will be considered an error and thus discarded.
|
||||
|
||||
Parameters:
|
||||
- data: Any, the data to be written.
|
||||
- timeout: float | None, the time to wait for writing before timing out.
|
||||
- life_time: int, the number of iterations the data should remain in the stream.
|
||||
|
||||
Returns:
|
||||
- bool: True if the data was successfully written, False otherwise.
|
||||
|
||||
异步向数据流中写入数据。
|
||||
|
||||
#### 注意请不要写入 None 因为 None 会被认为是 Error ,所以会被抛弃 。
|
||||
|
||||
参数:
|
||||
- data: Any, 要写入的数据。
|
||||
- timeout: float | None, 写入等待时间,超时则停止等待。
|
||||
- life_time: int, 数据在数据流中的存活周期数。
|
||||
|
||||
返回:
|
||||
- bool: 如果数据成功写入则返回 True ,否则返回 False。
|
||||
"""
|
||||
if not self.__metadata.get('__running__', False):
|
||||
return False
|
||||
try:
|
||||
if life_time == 0 or data is None:
|
||||
return False
|
||||
if self.__metadata.get('__open_life__', False):
|
||||
life_time -= 1
|
||||
await asyncio.wait_for(self.__data_queue.put({
|
||||
'data': data,
|
||||
'life': life_time,
|
||||
}), timeout)
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
|
||||
|
||||
class ExecuteProcessor(ABC):
|
||||
"""
|
||||
An abstract base class for processing data streams.
|
||||
|
||||
This class represents a processor that can handle incoming data from one stream and send processed data to another stream.
|
||||
|
||||
Usage:
|
||||
- Inherit from this class and implement the `execute` method.
|
||||
- Call the `set_stream` method to set `read_stream` and `write_stream`.
|
||||
- Call the `process` method to start the processor (if there are custom requirements, you can refactor this function, but the refactoring must comply with the interface standard).
|
||||
|
||||
Attributes:
|
||||
- _read_stream: (Should not be used directly) An instance of `ExecuteDataStream` for reading data.
|
||||
- _write_stream: (Should not be used directly) An instance of `ExecuteDataStream` for writing data.
|
||||
|
||||
一个抽象基类,用于处理数据流。
|
||||
|
||||
该类表示一个处理器,可以从一个数据流读取数据,并将处理后的数据发送到另一个数据流。
|
||||
|
||||
使用方法:
|
||||
- 继承该类并实现 `execute` 方法。
|
||||
- 调用 `set_stream` 方法设置 `read_stream` 和 `write_stream`。
|
||||
- 调用 `process` 方法启动处理器(如果有定制化要求可重构该函数,但重构必须符合接口标准)。
|
||||
|
||||
属性:
|
||||
- _read_stream: (不应当直接使用)用于读取数据的 `ExecuteDataStream` 实例。
|
||||
- _write_stream: (不应当直接使用)用于写入数据的 `ExecuteDataStream` 实例。
|
||||
"""
|
||||
def __init__(self):
|
||||
"""
|
||||
Initialize the `ExecuteProcessor` with default read and write streams set to `None`.
|
||||
|
||||
初始化 `ExecuteProcessor` 类的一个实例,将默认的读取和写入数据流设置为 `None`。
|
||||
"""
|
||||
self._read_stream: ExecuteDataStream | None = None
|
||||
self._write_stream: ExecuteDataStream | None = None
|
||||
|
||||
def set_stream(self,
|
||||
read_stream: ExecuteDataStream | None = None,
|
||||
write_stream: ExecuteDataStream | None = None):
|
||||
"""
|
||||
Set the read and write streams for this processor.
|
||||
|
||||
#### Note: Once `read_stream` and `write_stream` are set, this instance should not accept new settings.
|
||||
|
||||
Parameters:
|
||||
- read_stream: An instance of `ExecuteDataStream` for reading data.
|
||||
- write_stream: An instance of `ExecuteDataStream` for writing data.
|
||||
|
||||
设置此处理器的读取和写入数据流。
|
||||
|
||||
#### 注意一旦设置了 `read_stream` 和 `write_stream` ,则该实例将不应该接受新的设置。
|
||||
|
||||
参数:
|
||||
- read_stream: 用于读取数据的 `ExecuteDataStream` 实例。
|
||||
- write_stream: 用于写入数据的 `ExecuteDataStream` 实例。
|
||||
"""
|
||||
self._read_stream = read_stream or self._read_stream
|
||||
self._write_stream = write_stream or self._write_stream
|
||||
|
||||
async def _iter_process(self):
|
||||
"""
|
||||
Iterate over the data in the read stream and process it, then write the results to the write stream.
|
||||
|
||||
#### Note: If the `process` method is refactored in the future, this function may serve as a scaffold.
|
||||
|
||||
迭代读取数据流中的数据并对其进行处理,然后将结果写入写入数据流。
|
||||
|
||||
#### 注意如果未来重构 `process` 方法,可能需要该函数作为脚手架。
|
||||
"""
|
||||
if self._read_stream is None or self._write_stream is None:
|
||||
raise ValueError("read_stream or write_stream is None")
|
||||
try:
|
||||
async for data in self._read_stream.iter_read():
|
||||
result = await self.execute(data)
|
||||
if isinstance(result, AsyncGenerator):
|
||||
async for res in result:
|
||||
await self._write_stream.write(res)
|
||||
await asyncio.sleep(0)
|
||||
else:
|
||||
await self._write_stream.write(result)
|
||||
await asyncio.sleep(0)
|
||||
except Exception as e:
|
||||
logger.exception(f"An error occurred during processing: {e}")
|
||||
# await self._write_stream.write(e)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, data) -> AsyncGenerator[Any, Any] | Coroutine[Any, Any, Any]:
|
||||
"""
|
||||
Process the given data.
|
||||
|
||||
Parameters:
|
||||
- data: The data to be processed.
|
||||
|
||||
Returns:
|
||||
- AsyncGenerator | Any: The processed data or an asynchronous generator of processed data.
|
||||
|
||||
处理给定的数据。
|
||||
|
||||
参数:
|
||||
- data: 要处理的数据。
|
||||
|
||||
返回:
|
||||
- AsyncGenerator | Any: 处理后的数据或处理后数据的异步生成器。
|
||||
"""
|
||||
pass
|
||||
|
||||
async def process(self):
|
||||
"""
|
||||
Future entry point for starting the processor.
|
||||
|
||||
未来启动处理器的调用入口。
|
||||
"""
|
||||
await self._iter_process()
|
||||
|
||||
|
||||
class ExecutePipeline(ExecuteProcessor):
|
||||
"""
|
||||
A class representing a pipeline of processors.
|
||||
|
||||
This class manages a sequence of `ExecuteProcessor` instances and coordinates their execution.
|
||||
|
||||
Attributes:
|
||||
- _processors: A list of `ExecuteProcessor` instances.
|
||||
- _streams: A list of `ExecuteDataStream` instances, including the first and last streams which can be set externally.
|
||||
- _tasks: A list of asyncio Tasks representing the execution of each processor.
|
||||
|
||||
一个表示处理器管道的类。
|
||||
|
||||
该类管理一系列 `ExecuteProcessor` 实例,并协调它们的执行。
|
||||
|
||||
属性:
|
||||
- _processors: `ExecuteProcessor` 实例列表。
|
||||
- _streams: `ExecuteDataStream` 实例列表,包括可以外部设置的第一个和最后一个数据流。
|
||||
- _tasks: 表示每个处理器执行的 asyncio Task 列表。
|
||||
"""
|
||||
|
||||
def __init__(self, processors: list[ExecuteProcessor]):
|
||||
"""
|
||||
Initialize the `ExecutePipeline` with a list of processors.
|
||||
|
||||
Parameters:
|
||||
- processors: A list of `ExecuteProcessor` instances.
|
||||
|
||||
使用处理器列表初始化 `ExecutePipeline`。
|
||||
|
||||
参数:
|
||||
- processors: `ExecuteProcessor` 实例列表。
|
||||
"""
|
||||
self._processors = processors
|
||||
self._streams = [None] + [ExecuteDataStream() for _ in range(len(self._processors) - 1)] + [None]
|
||||
self._tasks:list[asyncio.Task] = []
|
||||
self.__is_executed = False
|
||||
|
||||
def set_stream(self,
|
||||
read_stream: ExecuteDataStream | None = None,
|
||||
write_stream: ExecuteDataStream | None = None):
|
||||
"""
|
||||
Set the input and output streams for the pipeline.
|
||||
|
||||
Parameters:
|
||||
- read_stream: The input stream for the pipeline.
|
||||
- write_stream: The output stream for the pipeline.
|
||||
|
||||
设置管道的输入和输出数据流。
|
||||
|
||||
参数:
|
||||
- read_stream: 管道的输入数据流。
|
||||
- write_stream: 管道的输出数据流。
|
||||
"""
|
||||
self._streams[0] = write_stream or self._streams[0]
|
||||
self._streams[-1] = read_stream or self._streams[-1]
|
||||
|
||||
def _get_stream(self, index: int) -> ExecuteDataStream:
|
||||
"""
|
||||
Get the stream at the specified index.
|
||||
|
||||
Parameters:
|
||||
- index: The index of the stream to retrieve.
|
||||
|
||||
Returns:
|
||||
- ExecuteDataStream: The stream at the specified index.
|
||||
|
||||
获取指定索引处的数据流。
|
||||
|
||||
参数:
|
||||
- index: 要获取的数据流的索引。
|
||||
|
||||
返回:
|
||||
- ExecuteDataStream: 指定索引处的数据流。
|
||||
"""
|
||||
stream = self._streams[index]
|
||||
if stream is None:
|
||||
raise ValueError("Stream not found")
|
||||
return stream
|
||||
|
||||
async def write(self,
|
||||
data,
|
||||
index = 0,
|
||||
timeout = None,
|
||||
life_time = -1):
|
||||
"""
|
||||
Write data to the pipeline's input data stream.
|
||||
|
||||
Parameters:
|
||||
- data: The data to write.
|
||||
- index: The index of the stream to write to. Default is 0 (the first stream).
|
||||
- timeout: The timeout for writing the data.
|
||||
- life_time: The life time of the data.
|
||||
|
||||
将数据写入管道的输入数据流。
|
||||
|
||||
参数:
|
||||
- data: 要写入的数据。
|
||||
- index: 写入数据的数据流索引,默认为 0 (第一个数据流)。
|
||||
- timeout: 写入数据的超时时间。
|
||||
- life_time: 数据的生命周期。
|
||||
"""
|
||||
await self._get_stream(index).write(data, timeout, life_time)
|
||||
|
||||
async def read(self, index: int = -1, timeout: float | None = None):
|
||||
"""
|
||||
Read data from the last processor's data stream.
|
||||
|
||||
Parameters:
|
||||
- index: The index of the stream to read from. Default is -1 (the last stream).
|
||||
- timeout: The timeout for reading the data.
|
||||
|
||||
Returns:
|
||||
- Any: The data read from the stream.
|
||||
|
||||
从最后一个处理器的数据流中读取数据。
|
||||
|
||||
参数:
|
||||
- index: 读取数据的数据流索引,默认为 -1 (最后一个数据流)。
|
||||
- timeout: 读取数据的超时时间。
|
||||
|
||||
返回:
|
||||
- Any: 从数据流中读取的数据。
|
||||
"""
|
||||
return await self._get_stream(index).read(timeout)
|
||||
|
||||
def iter_read(self, index: int = -1, timeout: float | None = None):
|
||||
"""
|
||||
Iterate over the data in the specified stream.
|
||||
|
||||
Parameters:
|
||||
- index: The index of the stream to iterate over. Default is -1 (the last stream).
|
||||
- timeout: The timeout for reading the data.
|
||||
|
||||
Returns:
|
||||
- AsyncIterator: An asynchronous iterator over the data in the stream.
|
||||
|
||||
迭代指定数据流中的数据。
|
||||
|
||||
参数:
|
||||
- index: 迭代数据的数据流索引,默认为 -1 (最后一个数据流)。
|
||||
- timeout: 读取数据的超时时间。
|
||||
|
||||
返回:
|
||||
- AsyncIterator: 对数据流中的数据进行迭代的异步迭代器。
|
||||
"""
|
||||
return self._get_stream(index).iter_read(timeout)
|
||||
|
||||
def cancel(self):
|
||||
"""
|
||||
Cancel all running tasks in the pipeline.
|
||||
|
||||
取消管道中所有正在运行的任务。
|
||||
"""
|
||||
if self._tasks is None:
|
||||
return
|
||||
for task in self._tasks:
|
||||
task.cancel()
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the pipeline by executing the `execute` method.
|
||||
|
||||
通过执行 `execute` 方法来启动管道。
|
||||
"""
|
||||
self.set_stream(ExecuteDataStream(), ExecuteDataStream())
|
||||
self.execute()
|
||||
return self
|
||||
|
||||
def execute(self):
|
||||
"""
|
||||
Execute the pipeline by setting up the streams for each processor and creating tasks for their execution.
|
||||
|
||||
通过为每个处理器设置数据流并为其执行创建任务来执行管道。
|
||||
"""
|
||||
if self.__is_executed:
|
||||
return
|
||||
self.__is_executed = True
|
||||
for i, processor in enumerate(self._processors):
|
||||
processor.set_stream(self._get_stream(i), self._get_stream(i + 1))
|
||||
self._tasks.append(
|
||||
asyncio.create_task(processor.process())
|
||||
)
|
||||
|
||||
async def process(self):
|
||||
"""
|
||||
Execute the pipeline and wait for all tasks to complete.
|
||||
|
||||
执行管道并等待所有任务完成。
|
||||
"""
|
||||
self.execute()
|
||||
await asyncio.gather(*self._tasks)
|
||||
|
||||
# 示例处理器
|
||||
class ExampleProcessor(ExecuteProcessor):
|
||||
|
||||
async def execute(self, data):
|
||||
return data + " processed"
|
||||
|
||||
async def main():
|
||||
processors:list[ExecuteProcessor] = [ExampleProcessor()]
|
||||
pipeline = ExecutePipeline(processors)
|
||||
pipeline.start()
|
||||
|
||||
input_data = "Hello, world!"
|
||||
await pipeline.write(input_data)
|
||||
print("Pipeline started.")
|
||||
async def fun():
|
||||
async for data in pipeline.iter_read():
|
||||
print(data)
|
||||
task = asyncio.create_task(fun())
|
||||
|
||||
while True:
|
||||
data = input("input data: ")
|
||||
await pipeline.write(data)
|
||||
await asyncio.sleep(0)
|
||||
# print(await pipeline.read())
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
@ -1,15 +0,0 @@
|
||||
from sys import stdout
|
||||
from typing import Any, AsyncGenerator
|
||||
from .core import ExecuteProcessor
|
||||
|
||||
class EchoProcessor(ExecuteProcessor):
|
||||
async def execute(self, data):
|
||||
stdout.write(data)
|
||||
stdout.flush()
|
||||
return None
|
||||
|
||||
class TeeProcessor(ExecuteProcessor):
|
||||
async def execute(self, data):
|
||||
stdout.write(data)
|
||||
stdout.flush()
|
||||
return data
|
@ -2,7 +2,6 @@
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from src import ASREngine as ASREngine
|
||||
|
||||
from http import HTTPStatus
|
||||
import dashscope
|
||||
@ -19,12 +18,11 @@ logger = getLogger(__name__)
|
||||
# with open('asr_example.wav', 'wb') as f:
|
||||
# f.write(r.content)
|
||||
|
||||
class ASRParaformer(ASREngine):
|
||||
class ASRParaformer():
|
||||
def __init__(self, api_key : str,
|
||||
_model : str = "paraformer-realtime-v1",
|
||||
_is_stream : bool = False,
|
||||
*args, **kwargs) -> None:
|
||||
super().__init__()
|
||||
dashscope.api_key = api_key
|
||||
self.recognition = Recognition(kwargs.get("model", _model),
|
||||
format='wav',
|
||||
@ -34,7 +32,6 @@ class ASRParaformer(ASREngine):
|
||||
def process_output(self, data):
|
||||
corrected_data = data.replace("'", '"')
|
||||
res = json.loads(corrected_data)["text"]
|
||||
return super().process_output(res)
|
||||
|
||||
def execute(self, data: Any) -> Any:
|
||||
result = self.recognition.call(data)
|
@ -1,16 +1,16 @@
|
||||
# https://help.aliyun.com/zh/dashscope/developer-reference/quick-start-13?spm=a2c4g.11186623.0.0.26772e5cs8Vl59
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
from typing import Any, Generator, Iterable, Iterator
|
||||
from src import TTSEngine as TTSEngine
|
||||
|
||||
import dashscope
|
||||
from dashscope.api_entities.dashscope_response import SpeechSynthesisResponse
|
||||
from dashscope.audio.tts import ResultCallback, SpeechSynthesizer, SpeechSynthesisResult
|
||||
from aiostream import Stream, stream
|
||||
|
||||
from logging import getLogger
|
||||
|
||||
from src import StreamEngine as StreamEngine
|
||||
logger = getLogger(__name__)
|
||||
|
||||
# import requests
|
||||
@ -20,23 +20,20 @@ logger = getLogger(__name__)
|
||||
# with open('asr_example.wav', 'wb') as f:
|
||||
# f.write(r.content)
|
||||
|
||||
class TTSSambert(TTSEngine, StreamEngine):
|
||||
class TTSSambert():
|
||||
|
||||
def __init__(self, api_key : str,
|
||||
_model : str = "sambert-zhichu-v1",
|
||||
_is_stream : bool = False,
|
||||
*args, **kwargs) -> None:
|
||||
TTSEngine.__init__(self)
|
||||
StreamEngine.__init__(self,
|
||||
use_stream_api=kwargs.get('use_stream_api', False),
|
||||
return_as_stream=kwargs.get('return_as_stream', False))
|
||||
dashscope.api_key = api_key
|
||||
self.model = kwargs.get("model", _model)
|
||||
|
||||
class Callback(ResultCallback):
|
||||
def __init__(self, generator) -> None:
|
||||
def __init__(self)-> None:
|
||||
self.res = []
|
||||
super().__init__()
|
||||
self.generator = generator
|
||||
|
||||
def on_open(self):
|
||||
logger.debug('Speech synthesizer is opened.')
|
||||
|
||||
@ -46,38 +43,35 @@ class TTSSambert(TTSEngine, StreamEngine):
|
||||
def on_error(self, response: SpeechSynthesisResponse):
|
||||
logger.error('Speech synthesizer failed, response is %s' % (str(response)))
|
||||
def on_close(self):
|
||||
self.generator.send(None)
|
||||
logger.debug('Speech synthesizer is closed.')
|
||||
|
||||
def on_event(self, result: SpeechSynthesisResult):
|
||||
if result.get_audio_frame() is not None:
|
||||
logger.debug('audio result length:', sys.getsizeof(result.get_audio_frame()))
|
||||
res = result.get_audio_frame()
|
||||
self.generator.send(res)
|
||||
|
||||
self.res.append(result.get_audio_frame())
|
||||
|
||||
if result.get_timestamp() is not None:
|
||||
logger.debug('timestamp result:', str(result.get_timestamp()))
|
||||
async def generate_events(self, callback: Callback):
|
||||
for event in callback.res:
|
||||
yield event
|
||||
|
||||
def execute_nonstream(self, data) -> bytes:
|
||||
result = SpeechSynthesizer.call(model=self.model,
|
||||
text=data,
|
||||
sample_rate=48000)
|
||||
return result.get_audio_data()
|
||||
|
||||
# def execute_stream(self, data) -> Generator[Any, None, None] | Iterator:
|
||||
# pass
|
||||
# if self._is_stream:
|
||||
# def audio_generator():
|
||||
# while True:
|
||||
# data = yield
|
||||
# if data is None:
|
||||
# break
|
||||
# gen = audio_generator()
|
||||
# next(gen)
|
||||
# callback = self.Callback(gen)
|
||||
# SpeechSynthesizer.call(model=self.model,
|
||||
async def process_events(self, callback: Callback) -> Stream[bytes]:
|
||||
source = self.generate_events(callback)
|
||||
return stream.iterate(source)
|
||||
|
||||
# def execute_nonstream(self, data) -> bytes:
|
||||
# result = SpeechSynthesizer.call(model=self.model,
|
||||
# text=data,
|
||||
# sample_rate=48000,
|
||||
# callback=callback,
|
||||
# word_timestamp_enabled=True,
|
||||
# phoneme_timestamp_enabled=True)
|
||||
# sample_rate=48000)
|
||||
# return result.get_audio_data()
|
||||
|
||||
def execute_stream(self, data, *args):
|
||||
callback = self.Callback()
|
||||
SpeechSynthesizer.call(model=self.model,
|
||||
text=data,
|
||||
sample_rate=48000,
|
||||
callback=callback,
|
||||
word_timestamp_enabled=True,
|
||||
phoneme_timestamp_enabled=True)
|
||||
return self.generate_events(callback)
|
78
src/dashscope/tongyiqianwen.py
Normal file
78
src/dashscope/tongyiqianwen.py
Normal file
@ -0,0 +1,78 @@
|
||||
# https://dashscope.console.aliyun.com/model
|
||||
import json
|
||||
import aiohttp
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class CAITongYiQianWen():
|
||||
API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
def __init__(self, api_key : str,
|
||||
_model : str = "qwen-1.8b-chat",
|
||||
*args, **kwargs) -> None:
|
||||
self.api_key = api_key
|
||||
self._modual_name = kwargs.get("model", _model)
|
||||
self.use_stream_api=kwargs.get("use_stream_api", False)
|
||||
|
||||
def _transform_raw(self, response) -> str:
|
||||
return response["output"]["choices"][0]["message"]["content"]
|
||||
|
||||
def transform_chunks(self, response: str, *args):
|
||||
current_msg = ''
|
||||
for line in response.splitlines():
|
||||
if line.startswith("data:"):
|
||||
current_msg += line[5:].strip()
|
||||
elif len(line) == 0:
|
||||
if current_msg:
|
||||
return self._transform_raw(json.loads(current_msg))
|
||||
|
||||
async def post_prompt(self, messages: dict = {}, is_stream: bool = False, header: dict[str, str] = {}):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as client:
|
||||
async with client.post(self.API_URL,
|
||||
headers={
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + self.api_key,
|
||||
**header,
|
||||
None if not is_stream else 'X-DashScope-SSE': 'enable'
|
||||
},
|
||||
json=messages) as response:
|
||||
if is_stream:
|
||||
async for chunk in response.content.iter_any():
|
||||
yield chunk.decode('utf-8')
|
||||
else:
|
||||
data = await response.read()
|
||||
yield data.decode('utf-8')
|
||||
except Exception as e:
|
||||
logger.exception("post_prompt")
|
||||
raise e
|
||||
|
||||
async def execute_stream(self, data):
|
||||
response = self.post_prompt({
|
||||
'model': self._modual_name,
|
||||
"input": {
|
||||
"messages": data
|
||||
},
|
||||
"parameters": {
|
||||
"result_format": "message",
|
||||
"incremental_output": True
|
||||
}
|
||||
}, is_stream = True)
|
||||
return response
|
||||
|
||||
async def execute_nonstream(self, data):
|
||||
response = self.post_prompt({
|
||||
'model': self._modual_name,
|
||||
"input": {
|
||||
"messages": data
|
||||
},
|
||||
"parameters": {
|
||||
"result_format": "message",
|
||||
"incremental_output": False
|
||||
}
|
||||
})
|
||||
# ret = response.json()
|
||||
# return self._transform_raw(ret)
|
||||
|
||||
def prepare_input(self, data):
|
||||
return [{'role': 'user', 'content': f'{data}'}]
|
@ -1,5 +0,0 @@
|
||||
from .. import Engine
|
||||
|
||||
class ASREngine(Engine):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
@ -1,10 +0,0 @@
|
||||
from typing import Dict, Optional
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
from .. import Engine
|
||||
|
||||
class ChatAIEngine(Engine):
|
||||
"""基础AI引擎类, 提供历史记录管理的基础框架。"""
|
||||
def __init__(self, history: Optional[Dict] = None) -> None:
|
||||
self._history = history
|
||||
Engine.__init__(self)
|
@ -1,5 +0,0 @@
|
||||
from .. import Engine
|
||||
|
||||
class NLUEngine(Engine):
|
||||
def __init__(self, is_stream: bool = False) -> None:
|
||||
super().__init__(is_stream)
|
@ -1,56 +0,0 @@
|
||||
from typing import Any, Generator, Iterable, Iterator
|
||||
from .. import Engine
|
||||
|
||||
class StreamEngine(Engine):
|
||||
def __init__(self, use_stream_api: bool = False, return_as_stream: bool = False):
|
||||
self._use_stream_api = use_stream_api
|
||||
self._return_as_stream = return_as_stream
|
||||
|
||||
def execute_stream(self, data) -> Generator | Iterator:
|
||||
raise NotImplementedError
|
||||
|
||||
def execute_nonstream(self, data) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def _execute_stream(self, data: Any, return_as_stream: bool) -> Any:
|
||||
results = self.execute_stream(data)
|
||||
return results if return_as_stream else list(results)
|
||||
|
||||
def _execute_nonstream(self, data: Any, return_as_stream: bool) -> Any:
|
||||
result = self.execute_nonstream(data)
|
||||
if return_as_stream:
|
||||
def _():
|
||||
yield from result
|
||||
return _()
|
||||
else:
|
||||
return result
|
||||
|
||||
def execute(self, data: Any) -> Any:
|
||||
if not isinstance(data, Generator) or isinstance(data, bytes):
|
||||
if self._use_stream_api:
|
||||
return self._execute_stream([data], self._return_as_stream)
|
||||
else:
|
||||
return self._execute_nonstream(data, self._return_as_stream)
|
||||
else:
|
||||
if self._use_stream_api:
|
||||
if self._return_as_stream:
|
||||
def stream_results():
|
||||
for item in data:
|
||||
yield from self._execute_stream([item], True)
|
||||
return stream_results()
|
||||
else:
|
||||
return [self._execute_stream([item], False) for item in data]
|
||||
else:
|
||||
if self._return_as_stream:
|
||||
def non_stream_results():
|
||||
for item in data:
|
||||
yield self._execute_nonstream(item, False)
|
||||
return non_stream_results()
|
||||
else:
|
||||
res = self._execute_nonstream(next(data), False)
|
||||
for item in data:
|
||||
_ = self._execute_nonstream(item, False)
|
||||
if _ is None:
|
||||
continue
|
||||
res += _
|
||||
return res
|
@ -1,5 +0,0 @@
|
||||
from .. import Engine
|
||||
|
||||
class TTSEngine(Engine):
|
||||
def __init__(self, is_stream: bool = False) -> None:
|
||||
super().__init__(is_stream)
|
@ -1,7 +0,0 @@
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from src import dynamic_package_import
|
||||
dynamic_package_import([
|
||||
('dashscope', None),
|
||||
])
|
@ -1,5 +0,0 @@
|
||||
|
||||
# from src import dynamic_package_import
|
||||
# dynamic_package_import([
|
||||
# ('requests', None),
|
||||
# ])
|
@ -1,25 +0,0 @@
|
||||
from logging import getLogger
|
||||
from typing import Any, Dict
|
||||
logger = getLogger(__name__)
|
||||
import requests
|
||||
import httpx
|
||||
|
||||
class RemoteEngine():
|
||||
"""基础远程服务类, 提供远程调用的基础框架。"""
|
||||
def __init__(self, api_key: str, base_url: str):
|
||||
self._api_key = api_key
|
||||
self._base_url = base_url
|
||||
self.url = self._base_url
|
||||
self.header = {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': 'Bearer ' + self._api_key
|
||||
}
|
||||
|
||||
async def post_prompt(self, messages : Dict) -> httpx.Response:
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(self.url, headers=self.header, json=messages)
|
||||
return response.raise_for_status()
|
||||
except Exception as e:
|
||||
logger.exception("post_prompt")
|
||||
raise e
|
@ -1,74 +0,0 @@
|
||||
# https://dashscope.console.aliyun.com/model
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import json
|
||||
from typing import Any, AsyncGenerator, AsyncIterator, Generator, Iterator, override
|
||||
from .remote_engine import RemoteEngine
|
||||
from src import ExecuteProcessor as ExecuteProcessor
|
||||
|
||||
from logging import getLogger
|
||||
logger = getLogger(__name__)
|
||||
|
||||
class ChatAiTongYiQianWen(ExecuteProcessor):
|
||||
API_URL = "https://dashscope.aliyuncs.com/api/v1/services/aigc/text-generation/generation"
|
||||
def __init__(self, api_key : str,
|
||||
_model : str = "qwen-1.8b-chat",
|
||||
*args, **kwargs) -> None:
|
||||
|
||||
self._modual_name = kwargs.get("model", _model)
|
||||
self.use_stream_api=kwargs.get("use_stream_api", False)
|
||||
self.remote_engine = RemoteEngine(api_key=api_key,
|
||||
base_url=self.API_URL)
|
||||
def _transform_raw(self, response):
|
||||
return response["output"]["choices"][0]["message"]["content"]
|
||||
|
||||
async def _transform_iterator(self, response: AsyncIterator[str]):
|
||||
current_msg = ''
|
||||
async for line in response:
|
||||
if line.startswith("data:"):
|
||||
current_msg += line[5:].strip()
|
||||
elif len(line) == 0:
|
||||
if current_msg:
|
||||
yield self._transform_raw(json.loads(current_msg))
|
||||
current_msg = ''
|
||||
|
||||
async def execute_stream(self, data):
|
||||
self.remote_engine.header['X-DashScope-SSE'] = 'enable'
|
||||
response = await self.remote_engine.post_prompt({
|
||||
'model': self._modual_name,
|
||||
"input": {
|
||||
"messages": data
|
||||
},
|
||||
"parameters": {
|
||||
"result_format": "message",
|
||||
"incremental_output": True
|
||||
}
|
||||
})
|
||||
return self._transform_iterator(response.aiter_lines())
|
||||
|
||||
async def execute_nonstream(self, data) -> Any:
|
||||
response = await self.remote_engine.post_prompt({
|
||||
'model': self._modual_name,
|
||||
"input": {
|
||||
"messages": data
|
||||
},
|
||||
"parameters": {
|
||||
"result_format": "message",
|
||||
"incremental_output": False
|
||||
}
|
||||
})
|
||||
ret = response.json()
|
||||
return self._transform_raw(ret)
|
||||
|
||||
def prepare_input(self, data: Any) -> Any:
|
||||
return [{'role': 'user', 'content': f'{data}'}]
|
||||
|
||||
@override
|
||||
async def execute(self, data):
|
||||
data = self.prepare_input(data)
|
||||
executor = ThreadPoolExecutor(max_workers=1)
|
||||
if self.use_stream_api:
|
||||
res = await self.execute_stream(data)
|
||||
else:
|
||||
res = await self.execute_nonstream(data)
|
||||
return res
|
@ -1,4 +0,0 @@
|
||||
from src import dynamic_package_import
|
||||
dynamic_package_import([
|
||||
('pyaudio', None),
|
||||
])
|
@ -1,35 +0,0 @@
|
||||
from typing import Any, Generator
|
||||
from src import Engine
|
||||
import wave
|
||||
import io
|
||||
import pyaudio
|
||||
|
||||
from src import StreamEngine
|
||||
|
||||
class SoundsPlayEngine(StreamEngine):
|
||||
def __init__(self, **kwargs) -> None:
|
||||
# super().__init__(kwargs.get("is_stream", _is_stream))
|
||||
StreamEngine.__init__(self,
|
||||
use_stream_api=kwargs.get('use_stream_api', False),
|
||||
return_as_stream=kwargs.get('return_as_stream', False))
|
||||
|
||||
def execute_nonstream(self, data) -> Any:
|
||||
if data is None:
|
||||
return None
|
||||
wf = wave.open(io.BytesIO(data), 'rb')
|
||||
_audio = pyaudio.PyAudio()
|
||||
_stream = _audio.open(
|
||||
format=_audio.get_format_from_width(wf.getsampwidth()),
|
||||
channels=wf.getnchannels(),
|
||||
rate=wf.getframerate(),
|
||||
output=True)
|
||||
data = wf.readframes(1024) # 假定块大小为1024
|
||||
while data:
|
||||
_stream.write(data)
|
||||
data = wf.readframes(1024)
|
||||
|
||||
_stream.stop_stream()
|
||||
_stream.close()
|
||||
|
||||
_audio.terminate()
|
||||
wf.close()
|
@ -1,2 +0,0 @@
|
||||
class SoundsRecordEngine:
|
||||
pass
|
@ -1,9 +0,0 @@
|
||||
import pyaudio
|
||||
|
||||
class SoundsWrapper:
|
||||
def __init__(self, chunk=1024, format=pyaudio.paInt16, channels=1, rate=44100, input=True, output=True, input_device_index=None, output_device_index=None) -> None:
|
||||
self.chunk = chunk
|
||||
self.format = format
|
||||
self.channels = channels
|
||||
self.rate = rate
|
||||
self.input =input
|
@ -1,61 +0,0 @@
|
||||
r"""
|
||||
# Example usage
|
||||
```
|
||||
required_packages = [
|
||||
("numpy", "1.20.3"), # Specify version
|
||||
("pandas", None), # Latest version
|
||||
]
|
||||
|
||||
dynamic_package_import(required_packages)
|
||||
```
|
||||
"""
|
||||
|
||||
import importlib.metadata
|
||||
import subprocess
|
||||
import sys
|
||||
from typing import List, Optional, Tuple
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def install_or_upgrade_package(package: str, version: Optional[str] = None) -> None:
|
||||
"""
|
||||
Installs or upgrades a Python package using pip via Popen.
|
||||
|
||||
:param package: The name of the package to install or upgrade.
|
||||
:param version: Optional. The version of the package to install.
|
||||
"""
|
||||
# Check if the package is already installed
|
||||
try:
|
||||
dist = importlib.metadata.distribution(package)
|
||||
current_version = dist.version
|
||||
if version and current_version != version:
|
||||
logger.info(f"Upgrading {package} to version {version}.")
|
||||
_execute_pip_command(f"install {package}=={version}")
|
||||
else:
|
||||
logger.info(f"{package} is already installed at version {current_version}.")
|
||||
except importlib.metadata.PackageNotFoundError:
|
||||
logger.info(f"{package} is not installed. Installing...")
|
||||
_execute_pip_command(f"install {package}{'==' + version if version else ''}")
|
||||
|
||||
def _execute_pip_command(command: str) -> None:
|
||||
"""
|
||||
Executes a pip command using subprocess.Popen.
|
||||
|
||||
:param command: The pip command to execute (e.g., 'install numpy').
|
||||
"""
|
||||
pip_cmd = [sys.executable, "-m", "pip", *command.split()]
|
||||
with subprocess.Popen(pip_cmd, stderr=subprocess.PIPE, universal_newlines=True) as proc:
|
||||
_, stderr = proc.communicate()
|
||||
if proc.returncode != 0:
|
||||
logger.error(f"executing pip command: {command}\n\tdetails: {stderr}")
|
||||
else:
|
||||
logger.info(f"Pip command '{command}' executed successfully.")
|
||||
|
||||
def dynamic_package_import(required_packages: List[Tuple[str, Optional[str]]]) -> None:
|
||||
"""
|
||||
Checks for the presence of required packages and installs/upgrades them if necessary.
|
||||
|
||||
:param required_packages: A list of tuples, where each tuple contains a package name and optionally a version.
|
||||
"""
|
||||
for package, version in required_packages:
|
||||
install_or_upgrade_package(package, version)
|
@ -1,38 +0,0 @@
|
||||
from pathlib import Path
|
||||
import sys
|
||||
|
||||
|
||||
CURRENT_DIR_PATH = Path(__file__).resolve().parent
|
||||
sys.path.append(str(CURRENT_DIR_PATH.parent.parent.resolve()))
|
||||
from src import EngineInterface
|
||||
from .plugins_conf import PluginsConfig
|
||||
from .plugins_import import PluginsImport
|
||||
|
||||
class Plugins:
|
||||
def __init__(self,
|
||||
base_path : Path | str = CURRENT_DIR_PATH,
|
||||
conf_name : str = ".plugins.yml"):
|
||||
base_path = Path(base_path)
|
||||
self.base_path = base_path.resolve()
|
||||
self.conf_path = base_path / conf_name
|
||||
|
||||
self.config = PluginsConfig.load_config(self.conf_path)
|
||||
# global_conf: Dict = self.config.get('global')
|
||||
# self.core_path:Path = self.base_path / (global_conf.get('core_path', '') or '../')
|
||||
# self.plugin_path:Path = self.base_path / (global_conf.get('plugin_path', '') or './')
|
||||
# sys.path.append(str(self.base_path))
|
||||
|
||||
def _get_engine_import(self, engine_name: str):
|
||||
engine_conf, plg_root_path, plg_path = PluginsConfig.get_engine_conf(engine_name, self.base_path, self.conf_path)
|
||||
engine_conf.get('plugin_path', '')
|
||||
plg_import = PluginsImport(plg_root_path)
|
||||
plg_import.get_module(engine_conf['class_name'], plg_path)
|
||||
return plg_import, engine_conf
|
||||
|
||||
def load_engine_class(self, engine_name: str):
|
||||
plg_import, engine_conf = self._get_engine_import(engine_name)
|
||||
return plg_import.get_class(engine_conf['class_name'], engine_conf['class_name'])
|
||||
|
||||
def load_engine(self, engine_name: str) -> EngineInterface:
|
||||
plg_import, engine_conf = self._get_engine_import(engine_name)
|
||||
return plg_import.get_instance(engine_conf['class_name'], engine_conf['class_name'], **engine_conf)
|
@ -1,54 +0,0 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import re
|
||||
|
||||
from typing import Any, Callable, Dict, Optional, List
|
||||
|
||||
import yaml
|
||||
|
||||
class YamlEnvLoader(yaml.SafeLoader):
|
||||
def __init__(self, stream):
|
||||
super(YamlEnvLoader, self).__init__(stream)
|
||||
self.add_implicit_resolver('!env_var', re.compile(r'\$\{(.+?)\}'), None)
|
||||
|
||||
@staticmethod
|
||||
def env_var_constructor(loader, node):
|
||||
value = loader.construct_scalar(node)
|
||||
match = re.match(r'\$\{(.+?)\}', value)
|
||||
if match:
|
||||
env_var = match.group(1)
|
||||
return os.environ.get(env_var, value)
|
||||
return value
|
||||
YamlEnvLoader.add_constructor('!env_var', YamlEnvLoader.env_var_constructor)
|
||||
|
||||
class PluginsConfig:
|
||||
@staticmethod
|
||||
def load_config(path: Path) -> Dict[str, Any]:
|
||||
"""Load YAML configuration from the given path."""
|
||||
if not path.exists():
|
||||
raise FileExistsError(f"{path} not exists")
|
||||
if not path.is_file():
|
||||
raise FileNotFoundError(f"Config file at {path} does not exist.")
|
||||
with open(path, 'r') as f:
|
||||
return yaml.load(f, Loader=YamlEnvLoader)
|
||||
|
||||
@staticmethod
|
||||
def recursive_load(base_path: Path, config: Dict[str, Any] | Any):
|
||||
"""Recursively load configurations when 'plugin' and 'path' are present."""
|
||||
if isinstance(config, dict) and 'plugin' in config and 'path' in config:
|
||||
plugin_path:Path = base_path / config.get('path', '')
|
||||
nested_config = PluginsConfig.load_config(plugin_path)
|
||||
config.update(nested_config[config['plugin']])
|
||||
plg_root_path = plugin_path.parent / nested_config.get('root_path', None)
|
||||
if plg_root_path is not None:
|
||||
plugin_path = plg_root_path / nested_config[config['plugin']]['path']
|
||||
return config, plg_root_path, plugin_path
|
||||
|
||||
@staticmethod
|
||||
def get_engine_conf(engine_name, base_path: Path, config_file: Path):
|
||||
"""
|
||||
获取引擎配置
|
||||
"""
|
||||
conf = PluginsConfig.load_config(config_file)
|
||||
return PluginsConfig.recursive_load(base_path, conf.get(engine_name))
|
@ -1,48 +0,0 @@
|
||||
from pathlib import Path
|
||||
CURRENT_DIR_PATH = Path(__file__).resolve().parent
|
||||
import sys
|
||||
sys.path.append(str(CURRENT_DIR_PATH.parent.parent))
|
||||
|
||||
import importlib
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
class PluginsImport:
|
||||
def __init__(self, plugins_path: Path):
|
||||
plugins_path = plugins_path.parent
|
||||
if plugins_path not in sys.path:
|
||||
sys.path.append(str(plugins_path))
|
||||
self.plugins_path = plugins_path
|
||||
self.modules = {} # Store loaded modules here
|
||||
|
||||
def get_module(self, module_name: str, module_path: Path) -> ModuleType:
|
||||
# Check if the module has already been loaded
|
||||
if module_name in self.modules:
|
||||
return self.modules[module_name]
|
||||
|
||||
# Check if the module path is a subpath of the plugins path
|
||||
if not module_path.is_relative_to(self.plugins_path):
|
||||
raise ValueError("Module path must be a subpath of the plugins path.")
|
||||
|
||||
relative_path_parts = module_path.relative_to(self.plugins_path).with_suffix('').parts
|
||||
relative_path = '.'.join(relative_path_parts)
|
||||
|
||||
try:
|
||||
module = importlib.import_module(relative_path, package=self.plugins_path.name)
|
||||
self.modules[module_name] = module
|
||||
return module
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(f"Failed to import module '{module_name}' from path '{module_path}': {e}")
|
||||
|
||||
def get_class(self, class_name, module_name) -> Any:
|
||||
module = self.modules.get(module_name, None)
|
||||
if module is None:
|
||||
raise ValueError(f"Module '{module_name}' has not been loaded.")
|
||||
cls = getattr(module, class_name, None)
|
||||
if cls is None:
|
||||
raise AttributeError(f"Class '{class_name}' not found in module '{module_name}'.")
|
||||
return cls
|
||||
|
||||
def get_instance(self, _class_name, module_name, **kwargs) -> Any:
|
||||
cls = self.get_class(_class_name, module_name)
|
||||
return cls(**kwargs)
|
23
src/pyaudio/sounds_play_engine.py
Normal file
23
src/pyaudio/sounds_play_engine.py
Normal file
@ -0,0 +1,23 @@
|
||||
from typing import Any, Generator
|
||||
import wave
|
||||
import io
|
||||
import pyaudio
|
||||
|
||||
class SoundsPlayEngine():
|
||||
def __init__(self, **kwargs) -> None:
|
||||
self._player = pyaudio.PyAudio()
|
||||
self._stream = self._player.open(
|
||||
format=pyaudio.paInt16,
|
||||
channels=1,
|
||||
rate=48000,
|
||||
output=True)
|
||||
|
||||
|
||||
def execute(self, data: bytes, *args) -> Any:
|
||||
self._stream.write(data)
|
||||
|
||||
# _stream.stop_stream()
|
||||
# _stream.close()
|
||||
|
||||
# _audio.terminate()
|
||||
# wf.close()
|
@ -1,54 +0,0 @@
|
||||
import unittest
|
||||
from src.core import core
|
||||
|
||||
class TestBaseMiddleware(unittest.TestCase):
|
||||
def test_process_base_func1(self):
|
||||
middleware = core.Middleware(lambda x: x + 1)
|
||||
self.assertEqual(middleware.process(1), 2)
|
||||
self.assertEqual(middleware.get_next(), [])
|
||||
|
||||
def test_process_base_func2(self):
|
||||
mid = [core.Middleware(lambda x: x + 2)]
|
||||
middleware = core.Middleware(lambda x: x + 1, next_middleware=mid)
|
||||
self.assertEqual(middleware.process(1), 2)
|
||||
self.assertEqual(middleware._data, 4)
|
||||
self.assertIs(middleware.get_next(), mid)
|
||||
|
||||
class TestBaseEngine(unittest.TestCase):
|
||||
pass
|
||||
|
||||
class TestExecutePipeline(unittest.TestCase):
|
||||
def test_base(self):
|
||||
middleware = core.Middleware(lambda x: x + 1)
|
||||
pipeline = core.ExecutePipeline(middleware)
|
||||
self.assertEqual(pipeline.execute(1), 2)
|
||||
|
||||
def test_add(self):
|
||||
middleware = core.Middleware(lambda x: x + 1)
|
||||
pipeline = core.ExecutePipeline(middleware)
|
||||
pipeline.add(core.Middleware(lambda x: x + 2))
|
||||
self.assertEqual(pipeline.execute(1), 4)
|
||||
|
||||
def test_reorder(self):
|
||||
middleware = core.Middleware(lambda x: x + 1)
|
||||
pipeline = core.ExecutePipeline(middleware)
|
||||
pipeline.reorder([])
|
||||
pipeline.execute(1)
|
||||
|
||||
def test_reorder1(self):
|
||||
middleware_p1 = core.Middleware(lambda x: x + 1)
|
||||
middleware_p2 = core.Middleware(lambda x: x + 2)
|
||||
middleware_p3 = core.Middleware(lambda x: x + 3)
|
||||
pipeline = core.ExecutePipeline(middleware_p1, middleware_p2, middleware_p3)
|
||||
self.assertEqual(pipeline.execute(1), 7)
|
||||
pipeline.reorder([1, 0, 2])
|
||||
self.assertEqual(pipeline.execute(1), 7)
|
||||
pipeline.reorder([0, 2])
|
||||
self.assertEqual(pipeline.execute(1), 6)
|
||||
pipeline.reorder([1])
|
||||
self.assertEqual(pipeline.execute(1), 4)
|
||||
|
||||
def test_execute_none(self):
|
||||
middleware = core.Middleware(lambda x: x + 1)
|
||||
pipeline = core.ExecutePipeline(middleware)
|
||||
self.assertIsNone(pipeline.execute(None))
|
@ -1,13 +0,0 @@
|
||||
import unittest
|
||||
from src.core import tools
|
||||
from io import StringIO
|
||||
import sys
|
||||
|
||||
class TestEcho(unittest.TestCase):
|
||||
def test_base(self):
|
||||
echoMid = tools.EchoMiddleWare()
|
||||
captured_out = StringIO()
|
||||
sys.stdout = captured_out
|
||||
echoMid.process('hello')
|
||||
sys.stdout = sys.__stdout__
|
||||
self.assertEqual(captured_out.getvalue(), 'hello\n')
|
Binary file not shown.
@ -1,72 +0,0 @@
|
||||
import pyaudio
|
||||
import wave
|
||||
|
||||
class SoundsWapper:
|
||||
def __init__(self, chunk=1024, format=pyaudio.paInt16, channels=1, rate=44100):
|
||||
self.CHUNK = chunk
|
||||
self.FORMAT = format
|
||||
self.CHANNELS = channels
|
||||
self.RATE = rate
|
||||
self.p = pyaudio.PyAudio()
|
||||
|
||||
def record(self, seconds, filename='output.wav'):
|
||||
stream = self.p.open(format=self.FORMAT,
|
||||
channels=self.CHANNELS,
|
||||
rate=self.RATE,
|
||||
input=True,
|
||||
frames_per_buffer=self.CHUNK)
|
||||
|
||||
print("开始录音...")
|
||||
frames = []
|
||||
|
||||
for i in range(0, int(self.RATE / self.CHUNK * seconds)):
|
||||
data = stream.read(self.CHUNK)
|
||||
frames.append(data)
|
||||
|
||||
print("录音结束.")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
|
||||
wf = wave.open(filename, 'wb')
|
||||
wf.setnchannels(self.CHANNELS)
|
||||
wf.setsampwidth(self.p.get_sample_size(self.FORMAT))
|
||||
wf.setframerate(self.RATE)
|
||||
wf.writeframes(b''.join(frames))
|
||||
wf.close()
|
||||
|
||||
def play(self, filename='output.wav'):
|
||||
wf = wave.open(filename, 'rb')
|
||||
|
||||
stream = self.p.open(format=self.p.get_format_from_width(wf.getsampwidth()),
|
||||
channels=wf.getnchannels(),
|
||||
rate=wf.getframerate(),
|
||||
output=True)
|
||||
|
||||
print("开始播放...")
|
||||
data = wf.readframes(self.CHUNK)
|
||||
|
||||
while data:
|
||||
stream.write(data)
|
||||
data = wf.readframes(self.CHUNK)
|
||||
|
||||
print("播放结束.")
|
||||
|
||||
stream.stop_stream()
|
||||
stream.close()
|
||||
|
||||
def close(self):
|
||||
self.p.terminate()
|
||||
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
sw = SoundsWapper()
|
||||
|
||||
# 录音5秒钟
|
||||
sw.record(5)
|
||||
|
||||
# 播放录音
|
||||
sw.play()
|
||||
|
||||
# 清理资源
|
||||
sw.close()
|
Loading…
x
Reference in New Issue
Block a user