484 lines
16 KiB
Python
484 lines
16 KiB
Python
import os
|
||
import re
|
||
import sys
|
||
import json
|
||
import time
|
||
import concurrent.futures
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from typing import Callable, Any
|
||
|
||
import pytest
|
||
import yaml
|
||
from glrocky.core.logger import logger
|
||
from glrocky.framework.marks import Marks as M
|
||
from glrocky.framework.schemas import Device
|
||
from glrocky.services.dify.dify import run_workflow
|
||
from loguru._logger import Logger
|
||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, create_model
|
||
import socket
|
||
|
||
PRODUCTION_ENV_NAME = "GLROCKY_PRODUCTION_MODE"
|
||
PRODUCTION_EXECUTOR = "192.168.0.139"
|
||
YAML_FILE = Path(__file__).parent / "test_cases.yaml"
|
||
|
||
def _get_local_ip() -> str:
|
||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||
try:
|
||
s.connect(("8.8.8.8", 80))
|
||
ip: str = s.getsockname()[0]
|
||
return str(ip)
|
||
finally:
|
||
s.close()
|
||
|
||
|
||
def _is_production_mode() -> bool:
|
||
env_val = os.getenv(PRODUCTION_ENV_NAME)
|
||
if env_val is not None:
|
||
return env_val.lower() in ("true", "yes", "1")
|
||
return _get_local_ip() == PRODUCTION_EXECUTOR
|
||
|
||
|
||
def pre_process_material(input: list[str]) -> str:
|
||
|
||
assert isinstance(input, list)
|
||
if not input:
|
||
return json.dumps([], ensure_ascii=False)
|
||
result = []
|
||
for text in input:
|
||
splited = _Q_PATTERN.split(text, maxsplit=50)
|
||
for part in splited:
|
||
if not part.strip():
|
||
continue
|
||
result.append(part.strip())
|
||
return json.dumps(result, ensure_ascii=False)
|
||
|
||
|
||
IS_PRODUCTION = _is_production_mode()
|
||
|
||
METRIC_FIELD_MAPPING: dict[str, str] = {
|
||
"result": "result|结果",
|
||
"_dify_result": "_dify_result|业务完成",
|
||
"_dify_message": "_dify_message|业务错误信息",
|
||
"resultText": "resultText|文本结果",
|
||
"recordVideo": "recordVideo|录像",
|
||
"recordAudio": "recordAudio|录音",
|
||
"screenshotList": "screenshotList|截图",
|
||
"firstToken": "firstToken|首字符时长",
|
||
"timeSeries": "timeSeries|时间序列",
|
||
"fileList": "fileList|文件",
|
||
}
|
||
|
||
IMAGE_EXTENSIONS = {".jpg", ".png", ".gif", ".bmp"}
|
||
VIDEO_EXTENSIONS = {".mp4", ".mkv"}
|
||
_Q_PATTERN = re.compile(r"Q\d+\s*[::\.]+\s*", flags=re.MULTILINE)
|
||
|
||
|
||
@dataclass
|
||
class _TimeoutCtx:
|
||
client: Any = None
|
||
task_id: str | None = None
|
||
|
||
def stop(self):
|
||
if self.client and self.task_id:
|
||
try:
|
||
logger.info(f"Stopping workflow: {self.task_id}")
|
||
self.client.stop_workflow(self.task_id)
|
||
logger.info(f"Stopped workflow: {self.task_id}")
|
||
except Exception as e:
|
||
logger.debug(f"Stop workflow failed: {e}")
|
||
|
||
|
||
class MetricManager:
|
||
def __init__(self, metric):
|
||
self.metric = metric
|
||
|
||
def new_span(self, m_type: str = "default", m_id: str = "default", m_iter: int = 1):
|
||
self.metric.span(m_type, m_id, m_iter)
|
||
|
||
def _smart_name_label(self, name: str, label: str | None) -> tuple[str, str]:
|
||
"""retruns: (name,label)"""
|
||
if not name or not name.strip():
|
||
raise ValueError("Empty name")
|
||
|
||
if label:
|
||
return name, label
|
||
sep = next((s for s in ("|", ":") if s in name), None)
|
||
if sep:
|
||
_n, _l = name.split(sep, maxsplit=1)
|
||
return (_n, _l)
|
||
return name, name
|
||
|
||
def _add_metric(self, name: str, value: Any, m_type: str, label: str | None = None):
|
||
_n, _l = self._smart_name_label(name, label)
|
||
self.metric.add(name=_n, label=_l, value=value, type=m_type)
|
||
|
||
def add_text_metric(self, name: str, value: str, label: str | None = None):
|
||
self._add_metric(name, value, "text", label)
|
||
|
||
def add_number_metric(self, name: str, value: float, label: str | None = None):
|
||
self._add_metric(name, value, "number", label)
|
||
|
||
def add_image_metric(self, name: str, value: Path | str, label: str | None = None):
|
||
self._add_metric(name, str(value), "image", label)
|
||
|
||
def add_video_metric(self, name: str, value: Path | str, label: str | None = None):
|
||
self._add_metric(name, value, "video", label)
|
||
|
||
def send(self):
|
||
self.metric.send_all()
|
||
|
||
def from_dict(self, the_dict: dict[str, Any]):
|
||
if not the_dict:
|
||
raise ValueError("dict is null")
|
||
for _k, v in the_dict.items():
|
||
k = METRIC_FIELD_MAPPING.get(_k, _k)
|
||
logger.info(k)
|
||
|
||
match v:
|
||
case None:
|
||
self.add_text_metric(name=k, value="")
|
||
case int() | float():
|
||
self.add_number_metric(name=k, value=v)
|
||
case str():
|
||
self.add_text_metric(name=k, value=v)
|
||
case Path():
|
||
match v.suffix.lower():
|
||
case s if s in IMAGE_EXTENSIONS:
|
||
self.add_image_metric(name=k, value=str(v.resolve()))
|
||
case s if s in VIDEO_EXTENSIONS:
|
||
self.add_video_metric(name=k, value=str(v.resolve()))
|
||
case _:
|
||
self.add_text_metric(name=k, value=str(v.resolve()))
|
||
case _:
|
||
raise TypeError(f"{type(v)} not supported")
|
||
|
||
|
||
class DifySettings(BaseModel):
|
||
difyUrl: str = Field(
|
||
title="Dify服务地址",
|
||
)
|
||
difyWorkflowId: str = Field(
|
||
title="Dify工作流ID",
|
||
)
|
||
difyApiKey: str = Field(
|
||
title="Dify API密钥",
|
||
)
|
||
|
||
|
||
class MaterialForDify(BaseModel):
|
||
paramGroupUuid: list[str] = Field(
|
||
default_factory=list, title="UUID", description="UUID"
|
||
)
|
||
inputTextList: list[str] = Field(
|
||
default_factory=list,
|
||
title="对话列表,支持多轮对话",
|
||
description="对话列表,支持多轮对话",
|
||
)
|
||
prompt: list[str] = Field(
|
||
default_factory=list, title="提示词", description="提示词"
|
||
)
|
||
model_config = ConfigDict(arbitrary_types_allowed=True) # allows uuid extra...
|
||
|
||
|
||
class DifyPayload(BaseModel):
|
||
inputTextList: list[str] | str
|
||
prompt: list[str] | str
|
||
deviceId: str
|
||
address: str
|
||
caseId: str
|
||
parameters: dict[str, Any] |str| None = None
|
||
|
||
@field_serializer("inputTextList")
|
||
def serialize_input_text(self, value: list[str] | str) -> str:
|
||
if isinstance(value, list):
|
||
return pre_process_material(value)
|
||
return value
|
||
|
||
@field_serializer("prompt")
|
||
def serialize_prompt_as_single_str(self, value: list[str] | str) -> str:
|
||
if isinstance(value, list):
|
||
return value[0] if len(value) > 0 else ""
|
||
return value
|
||
|
||
|
||
def create_dynamic_model(model_name: str, parameters: list[dict]) -> type[BaseModel]:
|
||
field_definitions = {}
|
||
for idx, param in enumerate(parameters):
|
||
name: str = (
|
||
param.get("name", "").strip() if isinstance(param.get("name"), str) else ""
|
||
)
|
||
if not name:
|
||
raise ValueError(f"参数定义错误: 第 {idx + 1} 个参数缺少 'name' 字段或为空")
|
||
|
||
label = (
|
||
param.get("label", "").strip()
|
||
if isinstance(param.get("label"), str)
|
||
else name
|
||
)
|
||
description = (
|
||
param.get("description", label).strip()
|
||
if isinstance(param.get("description"), str)
|
||
else label
|
||
)
|
||
param_type = (
|
||
param.get("type", "string").strip()
|
||
if isinstance(param.get("type"), str)
|
||
else "string"
|
||
)
|
||
|
||
if param_type not in ("string",):
|
||
raise ValueError(
|
||
f"参数 '{name}' 不支持的类型 '{param_type}',目前仅支持 'string'"
|
||
)
|
||
|
||
default = param.get("defaultValue", "")
|
||
if default and isinstance(default, str):
|
||
default = default.strip()
|
||
|
||
field_definitions[name] = (
|
||
str,
|
||
Field(default=default, title=label, description=description),
|
||
)
|
||
|
||
try:
|
||
model = create_model(model_name, __base__=BaseModel, **field_definitions)
|
||
model.model_config["extra"] = "ignore"
|
||
|
||
model()
|
||
return model
|
||
except Exception as e:
|
||
raise ValueError(f"创建参数模型 '{model_name}' 失败: {e}")
|
||
|
||
|
||
def call_dify(
|
||
case_meta_info: dict[str, str],
|
||
logger: Logger,
|
||
device_info: Device,
|
||
material: list[MaterialForDify],
|
||
dify_cfg: DifySettings,
|
||
metric,
|
||
material_reporter,
|
||
parameters: dict[str, Any] | None = None,
|
||
timeout_sec: int | None = None,
|
||
):
|
||
# logger.info(device_info.device_serial)
|
||
logger.info(dify_cfg)
|
||
|
||
event_callbacks: dict[str, Callable[..., None]] = {}
|
||
|
||
def on_node_started(_, __, d: dict[str, Any]):
|
||
logger.info(f"开始执行:{d.get('title', '')}")
|
||
logger.info(f"输入节点参数:{d.get('inputs')}")
|
||
|
||
def on_node_finished(_, __, d):
|
||
logger.info(f"结束执行:{d.get('title', '')}")
|
||
logger.info(f"节点输出:{d.get('outputs')}")
|
||
|
||
event_callbacks["on_node_started"] = on_node_started
|
||
event_callbacks["on_node_finished"] = on_node_finished
|
||
|
||
if timeout_sec:
|
||
ctx = _TimeoutCtx()
|
||
|
||
def on_workflow_started(client, event_name, event_data):
|
||
logger.info("开始执行dify任务")
|
||
ctx.client = client
|
||
ctx.task_id = event_data.get("task_id")
|
||
|
||
event_callbacks["on_workflow_started"] = on_workflow_started
|
||
if not material:
|
||
raise RuntimeError("缺少素材")
|
||
else:
|
||
logger.info(f"下发素材:{material}")
|
||
metric_manager = MetricManager(metric)
|
||
local_ip = _get_local_ip()
|
||
|
||
for material_index, item in enumerate(material, 1):
|
||
paramGroupUuid = item.paramGroupUuid[0]
|
||
assert paramGroupUuid, "The param Group Uuid Value must be set."
|
||
|
||
material_reporter.begin(paramGroupUuid)
|
||
dify_final_status = False
|
||
try:
|
||
# IMPORTANT: the group uuid for params
|
||
payload_data = {
|
||
**item.model_dump(),
|
||
"deviceId": device_info.device_serial,
|
||
"address": local_ip,
|
||
"caseId": case_meta_info.get("id", "Unknown"),
|
||
}
|
||
if parameters:
|
||
payload_data["parameters"] = json.dumps(parameters)
|
||
payload = DifyPayload(**payload_data)
|
||
logger.info(f"payload send to dify:\n{payload.model_dump_json(indent=2)}\n")
|
||
metric_manager.new_span(
|
||
"default", f"default-{material_index}", material_index
|
||
)
|
||
|
||
def _run_workflow():
|
||
return run_workflow(
|
||
api_key=dify_cfg.difyApiKey,
|
||
base_url=dify_cfg.difyUrl,
|
||
workflow_id=dify_cfg.difyWorkflowId,
|
||
inputs=payload.model_dump(exclude_none=True),
|
||
event_callbacks=event_callbacks,
|
||
)
|
||
|
||
if timeout_sec:
|
||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||
future = executor.submit(_run_workflow)
|
||
try:
|
||
result = future.result(timeout=timeout_sec)
|
||
except concurrent.futures.TimeoutError:
|
||
logger.error(f"Workflow timeout after {timeout_sec}s,(yaml config)")
|
||
time.sleep(0.1)
|
||
logger.info("尝试停止工作流...")
|
||
ctx.stop()
|
||
raise TimeoutError(f"Workflow timeout after {timeout_sec}s")
|
||
else:
|
||
result = _run_workflow()
|
||
logger.info(f"工作流返回结果:{result.outputs}")
|
||
logger.info(f"工作流最终状态:{result.success}")
|
||
|
||
if result.outputs:
|
||
logger.info("提交结果到执行器")
|
||
metric_manager.from_dict(result.outputs)
|
||
metric.send_all()
|
||
logger.info("提交结果到执行器完成")
|
||
else:
|
||
logger.warning("dify 工作流无返回")
|
||
if not result.success:
|
||
logger.error(f"工作流执行失败:{result.error}")
|
||
dify_final_status = result.success
|
||
except Exception as e:
|
||
logger.error(e)
|
||
dify_final_status = False
|
||
continue # do not break next round
|
||
finally:
|
||
material_reporter.end_with(paramGroupUuid, dify_final_status)
|
||
|
||
|
||
def make_dify_test(
|
||
meta: dict[str, str],
|
||
dify_settings: DifySettings | None = None,
|
||
param_model: type[BaseModel] | None = None,
|
||
timeout: int | None = None,
|
||
) -> Callable[..., None]:
|
||
if param_model is None:
|
||
param_model = create_model("EmptyParams", __base__=BaseModel)
|
||
|
||
@pytest.mark.usefixtures("cfg")
|
||
def _func_impl(
|
||
logger: Logger,
|
||
device_info: Device,
|
||
material: list[MaterialForDify],
|
||
metric,
|
||
material_reporter,
|
||
cfg: param_model,
|
||
) -> None:
|
||
param_dict = cfg.model_dump() if cfg else {}
|
||
call_dify(
|
||
meta,
|
||
logger,
|
||
device_info,
|
||
material,
|
||
dify_cfg=dify_settings,
|
||
metric=metric,
|
||
material_reporter=material_reporter,
|
||
parameters=param_dict,
|
||
timeout_sec=timeout,
|
||
)
|
||
|
||
# always use cfg fixture
|
||
# _func_impl = pytest.mark.usefixtures("cfg")(_func_impl)
|
||
return M.meta(**meta)(_func_impl)
|
||
|
||
|
||
def make_skip_test(meta: dict[str, str]) -> Callable[..., None]:
|
||
@M.skip(reason=f"{meta['id']} 未实现")
|
||
@M.meta(**meta)
|
||
def _func(
|
||
logger: Logger,
|
||
device_info: Device,
|
||
material: list[MaterialForDify],
|
||
metric,
|
||
) -> None:
|
||
logger.error(f"此用例在手机{device_info.device_serial}上暂未实现")
|
||
assert False, "not implemented"
|
||
|
||
return _func
|
||
|
||
|
||
def generate_cases_from_yaml(module_name: str, yaml_path: Path):
|
||
if not yaml_path.exists():
|
||
logger.warning(f"Test case definition file not found: {yaml_path}")
|
||
return
|
||
|
||
try:
|
||
with open(file=yaml_path, mode="r", encoding="utf-8") as f:
|
||
all_cases: list[dict[str, Any]] = yaml.safe_load( # pyright: ignore[reportAny]
|
||
f
|
||
)
|
||
except yaml.YAMLError as e:
|
||
logger.error(f"Failed to parse YAML file {yaml_path}: {e}")
|
||
pytest.fail(f"YAML解析失败: {yaml_path}")
|
||
return
|
||
if not all_cases:
|
||
return
|
||
for case_info in all_cases.get("cases", []):
|
||
if not isinstance(case_info, dict):
|
||
logger.debug(f"Skipping non-dictionary item in YAML file: {case_info}")
|
||
continue
|
||
case_id: str = case_info.get("id", "")
|
||
if not case_id:
|
||
logger.warning(f"用例 缺少 case_id 字段。{case_info=}")
|
||
pytest.fail(f"用例 缺少 case_id 字段。{case_info=}")
|
||
description: str = case_info.get("description", "")
|
||
if not description:
|
||
logger.warning(f"用例{case_id} 缺少 description 字段。")
|
||
pytest.fail(f"用例{case_id} 缺少 description 字段。")
|
||
action: str = case_info.get("action", "skipped")
|
||
|
||
meta = {"id": case_id, "description": description}
|
||
fn_name = f"test_{case_id.lower().replace('-', '_')}"
|
||
|
||
if action == "dify":
|
||
dify_settings: DifySettings | None = None
|
||
param_model: type[BaseModel] | None = None
|
||
timeout = case_info.get("timeout")
|
||
|
||
if parameters_def := case_info.get("parameters"):
|
||
model_name = f"{case_id.replace('-', '_')}_Params"
|
||
param_model = create_dynamic_model(model_name, parameters_def)
|
||
|
||
if dify_config_block := case_info.get("dify_config"):
|
||
env_key = "production" if IS_PRODUCTION else "testing"
|
||
if env_config := dify_config_block.get(env_key):
|
||
dify_settings = DifySettings(
|
||
difyUrl=env_config.get("url"),
|
||
difyWorkflowId=env_config.get("workflow_id"),
|
||
difyApiKey=env_config.get("api_key"),
|
||
)
|
||
else:
|
||
logger.warning(
|
||
f"No '{env_key}' config found for {case_id}, will use default."
|
||
)
|
||
fn = make_dify_test(meta, dify_settings, param_model, timeout)
|
||
elif action == "skipped":
|
||
fn = make_skip_test(meta)
|
||
elif action == "custom":
|
||
logger.info(f"Case {case_id} is a custom test.")
|
||
continue
|
||
else:
|
||
logger.warning(f"Unknown action '{action}' for case {case_id}. Skipping.")
|
||
continue
|
||
|
||
setattr(sys.modules[module_name], fn_name, fn)
|
||
logger.info(f"Generated {fn_name} for {case_id} with action '{action}'")
|
||
|
||
|
||
|
||
generate_cases_from_yaml(__name__, YAML_FILE)
|