Files
xiaomi/scenarios/tc_generator.py
lmflash f4f51b5a1f init
2026-04-22 11:43:10 +08:00

484 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import re
import sys
import json
import time
import concurrent.futures
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Any
import pytest
import yaml
from glrocky.core.logger import logger
from glrocky.framework.marks import Marks as M
from glrocky.framework.schemas import Device
from glrocky.services.dify.dify import run_workflow
from loguru._logger import Logger
from pydantic import BaseModel, ConfigDict, Field, field_serializer, create_model
import socket
PRODUCTION_ENV_NAME = "GLROCKY_PRODUCTION_MODE"
PRODUCTION_EXECUTOR = "192.168.0.139"
YAML_FILE = Path(__file__).parent / "test_cases.yaml"
def _get_local_ip() -> str:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
try:
s.connect(("8.8.8.8", 80))
ip: str = s.getsockname()[0]
return str(ip)
finally:
s.close()
def _is_production_mode() -> bool:
env_val = os.getenv(PRODUCTION_ENV_NAME)
if env_val is not None:
return env_val.lower() in ("true", "yes", "1")
return _get_local_ip() == PRODUCTION_EXECUTOR
def pre_process_material(input: list[str]) -> str:
assert isinstance(input, list)
if not input:
return json.dumps([], ensure_ascii=False)
result = []
for text in input:
splited = _Q_PATTERN.split(text, maxsplit=50)
for part in splited:
if not part.strip():
continue
result.append(part.strip())
return json.dumps(result, ensure_ascii=False)
IS_PRODUCTION = _is_production_mode()
METRIC_FIELD_MAPPING: dict[str, str] = {
"result": "result|结果",
"_dify_result": "_dify_result|业务完成",
"_dify_message": "_dify_message|业务错误信息",
"resultText": "resultText|文本结果",
"recordVideo": "recordVideo|录像",
"recordAudio": "recordAudio|录音",
"screenshotList": "screenshotList|截图",
"firstToken": "firstToken|首字符时长",
"timeSeries": "timeSeries|时间序列",
"fileList": "fileList|文件",
}
IMAGE_EXTENSIONS = {".jpg", ".png", ".gif", ".bmp"}
VIDEO_EXTENSIONS = {".mp4", ".mkv"}
_Q_PATTERN = re.compile(r"Q\d+\s*[:\.]+\s*", flags=re.MULTILINE)
@dataclass
class _TimeoutCtx:
client: Any = None
task_id: str | None = None
def stop(self):
if self.client and self.task_id:
try:
logger.info(f"Stopping workflow: {self.task_id}")
self.client.stop_workflow(self.task_id)
logger.info(f"Stopped workflow: {self.task_id}")
except Exception as e:
logger.debug(f"Stop workflow failed: {e}")
class MetricManager:
def __init__(self, metric):
self.metric = metric
def new_span(self, m_type: str = "default", m_id: str = "default", m_iter: int = 1):
self.metric.span(m_type, m_id, m_iter)
def _smart_name_label(self, name: str, label: str | None) -> tuple[str, str]:
"""retruns: (name,label)"""
if not name or not name.strip():
raise ValueError("Empty name")
if label:
return name, label
sep = next((s for s in ("|", ":") if s in name), None)
if sep:
_n, _l = name.split(sep, maxsplit=1)
return (_n, _l)
return name, name
def _add_metric(self, name: str, value: Any, m_type: str, label: str | None = None):
_n, _l = self._smart_name_label(name, label)
self.metric.add(name=_n, label=_l, value=value, type=m_type)
def add_text_metric(self, name: str, value: str, label: str | None = None):
self._add_metric(name, value, "text", label)
def add_number_metric(self, name: str, value: float, label: str | None = None):
self._add_metric(name, value, "number", label)
def add_image_metric(self, name: str, value: Path | str, label: str | None = None):
self._add_metric(name, str(value), "image", label)
def add_video_metric(self, name: str, value: Path | str, label: str | None = None):
self._add_metric(name, value, "video", label)
def send(self):
self.metric.send_all()
def from_dict(self, the_dict: dict[str, Any]):
if not the_dict:
raise ValueError("dict is null")
for _k, v in the_dict.items():
k = METRIC_FIELD_MAPPING.get(_k, _k)
logger.info(k)
match v:
case None:
self.add_text_metric(name=k, value="")
case int() | float():
self.add_number_metric(name=k, value=v)
case str():
self.add_text_metric(name=k, value=v)
case Path():
match v.suffix.lower():
case s if s in IMAGE_EXTENSIONS:
self.add_image_metric(name=k, value=str(v.resolve()))
case s if s in VIDEO_EXTENSIONS:
self.add_video_metric(name=k, value=str(v.resolve()))
case _:
self.add_text_metric(name=k, value=str(v.resolve()))
case _:
raise TypeError(f"{type(v)} not supported")
class DifySettings(BaseModel):
difyUrl: str = Field(
title="Dify服务地址",
)
difyWorkflowId: str = Field(
title="Dify工作流ID",
)
difyApiKey: str = Field(
title="Dify API密钥",
)
class MaterialForDify(BaseModel):
paramGroupUuid: list[str] = Field(
default_factory=list, title="UUID", description="UUID"
)
inputTextList: list[str] = Field(
default_factory=list,
title="对话列表,支持多轮对话",
description="对话列表,支持多轮对话",
)
prompt: list[str] = Field(
default_factory=list, title="提示词", description="提示词"
)
model_config = ConfigDict(arbitrary_types_allowed=True) # allows uuid extra...
class DifyPayload(BaseModel):
inputTextList: list[str] | str
prompt: list[str] | str
deviceId: str
address: str
caseId: str
parameters: dict[str, Any] |str| None = None
@field_serializer("inputTextList")
def serialize_input_text(self, value: list[str] | str) -> str:
if isinstance(value, list):
return pre_process_material(value)
return value
@field_serializer("prompt")
def serialize_prompt_as_single_str(self, value: list[str] | str) -> str:
if isinstance(value, list):
return value[0] if len(value) > 0 else ""
return value
def create_dynamic_model(model_name: str, parameters: list[dict]) -> type[BaseModel]:
field_definitions = {}
for idx, param in enumerate(parameters):
name: str = (
param.get("name", "").strip() if isinstance(param.get("name"), str) else ""
)
if not name:
raise ValueError(f"参数定义错误: 第 {idx + 1} 个参数缺少 'name' 字段或为空")
label = (
param.get("label", "").strip()
if isinstance(param.get("label"), str)
else name
)
description = (
param.get("description", label).strip()
if isinstance(param.get("description"), str)
else label
)
param_type = (
param.get("type", "string").strip()
if isinstance(param.get("type"), str)
else "string"
)
if param_type not in ("string",):
raise ValueError(
f"参数 '{name}' 不支持的类型 '{param_type}',目前仅支持 'string'"
)
default = param.get("defaultValue", "")
if default and isinstance(default, str):
default = default.strip()
field_definitions[name] = (
str,
Field(default=default, title=label, description=description),
)
try:
model = create_model(model_name, __base__=BaseModel, **field_definitions)
model.model_config["extra"] = "ignore"
model()
return model
except Exception as e:
raise ValueError(f"创建参数模型 '{model_name}' 失败: {e}")
def call_dify(
case_meta_info: dict[str, str],
logger: Logger,
device_info: Device,
material: list[MaterialForDify],
dify_cfg: DifySettings,
metric,
material_reporter,
parameters: dict[str, Any] | None = None,
timeout_sec: int | None = None,
):
# logger.info(device_info.device_serial)
logger.info(dify_cfg)
event_callbacks: dict[str, Callable[..., None]] = {}
def on_node_started(_, __, d: dict[str, Any]):
logger.info(f"开始执行:{d.get('title', '')}")
logger.info(f"输入节点参数:{d.get('inputs')}")
def on_node_finished(_, __, d):
logger.info(f"结束执行:{d.get('title', '')}")
logger.info(f"节点输出:{d.get('outputs')}")
event_callbacks["on_node_started"] = on_node_started
event_callbacks["on_node_finished"] = on_node_finished
if timeout_sec:
ctx = _TimeoutCtx()
def on_workflow_started(client, event_name, event_data):
logger.info("开始执行dify任务")
ctx.client = client
ctx.task_id = event_data.get("task_id")
event_callbacks["on_workflow_started"] = on_workflow_started
if not material:
raise RuntimeError("缺少素材")
else:
logger.info(f"下发素材:{material}")
metric_manager = MetricManager(metric)
local_ip = _get_local_ip()
for material_index, item in enumerate(material, 1):
paramGroupUuid = item.paramGroupUuid[0]
assert paramGroupUuid, "The param Group Uuid Value must be set."
material_reporter.begin(paramGroupUuid)
dify_final_status = False
try:
# IMPORTANT: the group uuid for params
payload_data = {
**item.model_dump(),
"deviceId": device_info.device_serial,
"address": local_ip,
"caseId": case_meta_info.get("id", "Unknown"),
}
if parameters:
payload_data["parameters"] = json.dumps(parameters)
payload = DifyPayload(**payload_data)
logger.info(f"payload send to dify:\n{payload.model_dump_json(indent=2)}\n")
metric_manager.new_span(
"default", f"default-{material_index}", material_index
)
def _run_workflow():
return run_workflow(
api_key=dify_cfg.difyApiKey,
base_url=dify_cfg.difyUrl,
workflow_id=dify_cfg.difyWorkflowId,
inputs=payload.model_dump(exclude_none=True),
event_callbacks=event_callbacks,
)
if timeout_sec:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(_run_workflow)
try:
result = future.result(timeout=timeout_sec)
except concurrent.futures.TimeoutError:
logger.error(f"Workflow timeout after {timeout_sec}s,(yaml config)")
time.sleep(0.1)
logger.info("尝试停止工作流...")
ctx.stop()
raise TimeoutError(f"Workflow timeout after {timeout_sec}s")
else:
result = _run_workflow()
logger.info(f"工作流返回结果:{result.outputs}")
logger.info(f"工作流最终状态:{result.success}")
if result.outputs:
logger.info("提交结果到执行器")
metric_manager.from_dict(result.outputs)
metric.send_all()
logger.info("提交结果到执行器完成")
else:
logger.warning("dify 工作流无返回")
if not result.success:
logger.error(f"工作流执行失败:{result.error}")
dify_final_status = result.success
except Exception as e:
logger.error(e)
dify_final_status = False
continue # do not break next round
finally:
material_reporter.end_with(paramGroupUuid, dify_final_status)
def make_dify_test(
meta: dict[str, str],
dify_settings: DifySettings | None = None,
param_model: type[BaseModel] | None = None,
timeout: int | None = None,
) -> Callable[..., None]:
if param_model is None:
param_model = create_model("EmptyParams", __base__=BaseModel)
@pytest.mark.usefixtures("cfg")
def _func_impl(
logger: Logger,
device_info: Device,
material: list[MaterialForDify],
metric,
material_reporter,
cfg: param_model,
) -> None:
param_dict = cfg.model_dump() if cfg else {}
call_dify(
meta,
logger,
device_info,
material,
dify_cfg=dify_settings,
metric=metric,
material_reporter=material_reporter,
parameters=param_dict,
timeout_sec=timeout,
)
# always use cfg fixture
# _func_impl = pytest.mark.usefixtures("cfg")(_func_impl)
return M.meta(**meta)(_func_impl)
def make_skip_test(meta: dict[str, str]) -> Callable[..., None]:
@M.skip(reason=f"{meta['id']} 未实现")
@M.meta(**meta)
def _func(
logger: Logger,
device_info: Device,
material: list[MaterialForDify],
metric,
) -> None:
logger.error(f"此用例在手机{device_info.device_serial}上暂未实现")
assert False, "not implemented"
return _func
def generate_cases_from_yaml(module_name: str, yaml_path: Path):
if not yaml_path.exists():
logger.warning(f"Test case definition file not found: {yaml_path}")
return
try:
with open(file=yaml_path, mode="r", encoding="utf-8") as f:
all_cases: list[dict[str, Any]] = yaml.safe_load( # pyright: ignore[reportAny]
f
)
except yaml.YAMLError as e:
logger.error(f"Failed to parse YAML file {yaml_path}: {e}")
pytest.fail(f"YAML解析失败: {yaml_path}")
return
if not all_cases:
return
for case_info in all_cases.get("cases", []):
if not isinstance(case_info, dict):
logger.debug(f"Skipping non-dictionary item in YAML file: {case_info}")
continue
case_id: str = case_info.get("id", "")
if not case_id:
logger.warning(f"用例 缺少 case_id 字段。{case_info=}")
pytest.fail(f"用例 缺少 case_id 字段。{case_info=}")
description: str = case_info.get("description", "")
if not description:
logger.warning(f"用例{case_id} 缺少 description 字段。")
pytest.fail(f"用例{case_id} 缺少 description 字段。")
action: str = case_info.get("action", "skipped")
meta = {"id": case_id, "description": description}
fn_name = f"test_{case_id.lower().replace('-', '_')}"
if action == "dify":
dify_settings: DifySettings | None = None
param_model: type[BaseModel] | None = None
timeout = case_info.get("timeout")
if parameters_def := case_info.get("parameters"):
model_name = f"{case_id.replace('-', '_')}_Params"
param_model = create_dynamic_model(model_name, parameters_def)
if dify_config_block := case_info.get("dify_config"):
env_key = "production" if IS_PRODUCTION else "testing"
if env_config := dify_config_block.get(env_key):
dify_settings = DifySettings(
difyUrl=env_config.get("url"),
difyWorkflowId=env_config.get("workflow_id"),
difyApiKey=env_config.get("api_key"),
)
else:
logger.warning(
f"No '{env_key}' config found for {case_id}, will use default."
)
fn = make_dify_test(meta, dify_settings, param_model, timeout)
elif action == "skipped":
fn = make_skip_test(meta)
elif action == "custom":
logger.info(f"Case {case_id} is a custom test.")
continue
else:
logger.warning(f"Unknown action '{action}' for case {case_id}. Skipping.")
continue
setattr(sys.modules[module_name], fn_name, fn)
logger.info(f"Generated {fn_name} for {case_id} with action '{action}'")
generate_cases_from_yaml(__name__, YAML_FILE)