""" ServerGuard - 通用工具库 提供命令执行、日志配置、输出解析等通用功能。 """ import subprocess import logging import sys import os import re import json from typing import List, Dict, Any, Optional, Tuple, Union from datetime import datetime class ServerGuardError(Exception): """ServerGuard 基础异常类""" pass class CommandExecutionError(ServerGuardError): """命令执行异常""" pass class PermissionError(ServerGuardError): """权限异常""" pass def execute_command( cmd_list: List[str], timeout: int = 60, check_returncode: bool = True, capture_output: bool = True, shell: bool = False, input_data: Optional[str] = None ) -> Tuple[int, str, str]: """ 安全地执行外部命令。 Args: cmd_list: 命令及其参数的列表 timeout: 命令超时时间(秒) check_returncode: 是否在非零返回码时抛出异常 capture_output: 是否捕获输出 shell: 是否使用 shell 执行 input_data: 输入到命令的字符串数据 Returns: Tuple[returncode, stdout, stderr] Raises: CommandExecutionError: 命令执行失败 PermissionError: 权限不足 """ logger = logging.getLogger(__name__) # 安全:禁止使用 shell=True 时传递未经验证的命令字符串 if shell and isinstance(cmd_list, list): cmd_str = ' '.join(cmd_list) logger.warning(f"Using shell=True with command: {cmd_str}") try: logger.debug(f"Executing command: {' '.join(cmd_list)}") kwargs = { 'timeout': timeout, 'shell': shell, 'universal_newlines': True # Python 3.6 compatible version of text=True } if capture_output: kwargs['stdout'] = subprocess.PIPE kwargs['stderr'] = subprocess.PIPE if input_data: kwargs['input'] = input_data result = subprocess.run(cmd_list, **kwargs) stdout = result.stdout if result.stdout else "" stderr = result.stderr if result.stderr else "" if check_returncode and result.returncode != 0: error_msg = f"Command failed with code {result.returncode}: {' '.join(cmd_list)}\nstderr: {stderr}" logger.error(error_msg) raise CommandExecutionError(error_msg) return result.returncode, stdout, stderr except subprocess.TimeoutExpired: error_msg = f"Command timed out after {timeout}s: {' '.join(cmd_list)}" logger.error(error_msg) raise CommandExecutionError(error_msg) except FileNotFoundError: error_msg = f"Command not found: {cmd_list[0]}" logger.error(error_msg) raise CommandExecutionError(error_msg) except PermissionError as e: error_msg = f"Permission denied executing: {' '.join(cmd_list)}" logger.error(error_msg) raise PermissionError(error_msg) from e def check_root_privileges() -> bool: """ 检查当前是否以 root 用户运行。 Returns: bool: 是否为 root 用户 """ return os.geteuid() == 0 def require_root(func): """ 装饰器:要求函数必须以 root 权限运行。 """ def wrapper(*args, **kwargs): if not check_root_privileges(): logging.warning(f"Function {func.__name__} requires root privileges") return { "status": "error", "error": "This function requires root privileges. Please run with sudo." } return func(*args, **kwargs) return wrapper def setup_logging( log_file: Optional[str] = None, level: int = logging.INFO, console_output: bool = True ) -> logging.Logger: """ 配置日志系统。 Args: log_file: 日志文件路径,None 则不写入文件 level: 日志级别 console_output: 是否输出到控制台 Returns: logging.Logger: 配置好的 logger 实例 """ logger = logging.getLogger() logger.setLevel(level) # 清除已有的 handlers logger.handlers = [] formatter = logging.Formatter( '%(asctime)s - %(name)s - %(levelname)s - %(message)s', datefmt='%Y-%m-%d %H:%M:%S' ) if console_output: console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(formatter) logger.addHandler(console_handler) if log_file: os.makedirs(os.path.dirname(log_file) or '.', exist_ok=True) # 使用 FileHandler 并设置立即刷新 file_handler = logging.FileHandler(log_file, mode='a') file_handler.setFormatter(formatter) # 确保每次日志写入后立即刷新到磁盘 file_handler.flush = lambda: file_handler.stream.flush() logger.addHandler(file_handler) return logger class ProgressLogger: """ 进度日志记录器 - 用于记录测试进度,便于中断后排查问题。 """ def __init__(self, log_file: Optional[str] = None): self.logger = logging.getLogger(__name__) self.log_file = log_file self.steps = [] self.current_step = None self.start_time = None def start(self, operation: str): """开始一个操作步骤。""" from datetime import datetime self.current_step = { "operation": operation, "start_time": datetime.now().isoformat(), "status": "running" } self.start_time = datetime.now() msg = f"[START] {operation}" self.logger.info(msg) self._flush_log() def end(self, status: str = "success", message: str = ""): """结束当前操作步骤。""" from datetime import datetime if self.current_step: end_time = datetime.now() duration = (end_time - self.start_time).total_seconds() if self.start_time else 0 self.current_step["end_time"] = end_time.isoformat() self.current_step["status"] = status self.current_step["duration_seconds"] = duration self.current_step["message"] = message self.steps.append(self.current_step) msg = f"[END] {self.current_step['operation']} - Status: {status}" if message: msg += f" - {message}" msg += f" (Duration: {duration:.2f}s)" if status == "error": self.logger.error(msg) elif status == "warning": self.logger.warning(msg) else: self.logger.info(msg) self._flush_log() self.current_step = None def log(self, message: str, level: str = "info"): """记录中间日志。""" msg = f"[PROGRESS] {self.current_step['operation'] if self.current_step else 'UNKNOWN'} - {message}" if level == "error": self.logger.error(msg) elif level == "warning": self.logger.warning(msg) elif level == "debug": self.logger.debug(msg) else: self.logger.info(msg) self._flush_log() def _flush_log(self): """强制刷新日志到磁盘。""" for handler in self.logger.handlers: if hasattr(handler, 'flush'): handler.flush() def get_summary(self) -> Dict[str, Any]: """获取执行摘要。""" return { "total_steps": len(self.steps), "steps": self.steps, "current_running": self.current_step } def parse_key_value_output(text: str, delimiter: str = ':') -> Dict[str, str]: """ 解析 key: value 格式的文本输出。 Args: text: 要解析的文本 delimiter: 键值分隔符 Returns: Dict[str, str]: 解析后的字典 """ result = {} for line in text.strip().split('\n'): line = line.strip() if not line or line.startswith('#'): continue parts = line.split(delimiter, 1) if len(parts) == 2: key = parts[0].strip() value = parts[1].strip() result[key] = value return result def parse_table_output(text: str, headers: Optional[List[str]] = None) -> List[Dict[str, str]]: """ 解析表格格式的文本输出。 Args: text: 要解析的文本 headers: 表头列表,None 则从第一行自动提取 Returns: List[Dict[str, str]]: 解析后的列表 """ lines = [line.strip() for line in text.strip().split('\n') if line.strip()] if not lines: return [] if headers is None: # 尝试自动检测表头 headers = [h.strip() for h in lines[0].split() if h.strip()] data_lines = lines[1:] else: data_lines = lines result = [] for line in data_lines: values = line.split() if len(values) >= len(headers): row = {headers[i]: values[i] for i in range(len(headers))} result.append(row) return result def extract_with_regex(text: str, pattern: str, group: int = 1, default: Any = None) -> Any: """ 使用正则表达式从文本中提取内容。 Args: text: 要搜索的文本 pattern: 正则表达式模式 group: 捕获组索引 default: 未匹配时的默认值 Returns: 匹配结果或默认值 """ match = re.search(pattern, text) if match: try: return match.group(group) except IndexError: return default return default def safe_int(value: Any, default: int = 0) -> int: """ 安全地将值转换为整数。 Args: value: 要转换的值 default: 转换失败时的默认值 Returns: int: 转换后的整数 """ try: # 移除常见单位后缀 if isinstance(value, str): value = value.strip().lower() value = re.sub(r'[\s,]', '', value) # 处理带单位的数值 (如 "32 GB", "2.5GHz") value = re.sub(r'[^\d.-]', '', value) return int(float(value)) except (ValueError, TypeError): return default def safe_float(value: Any, default: float = 0.0) -> float: """ 安全地将值转换为浮点数。 Args: value: 要转换的值 default: 转换失败时的默认值 Returns: float: 转换后的浮点数 """ try: if isinstance(value, str): value = value.strip().lower() value = re.sub(r'[\s,]', '', value) value = re.sub(r'[^\d.-]', '', value) return float(value) except (ValueError, TypeError): return default def get_timestamp() -> str: """ 获取当前时间戳字符串。 Returns: str: 格式化的时间戳 """ return datetime.now().strftime('%Y-%m-%d %H:%M:%S') def get_file_timestamp() -> str: """ 获取适合文件名的当前时间戳字符串。 Returns: str: 格式化的文件名时间戳 """ return datetime.now().strftime('%Y%m%d_%H%M%S') def read_file_lines(filepath: str, max_lines: int = 1000) -> List[str]: """ 安全地读取文件内容。 Args: filepath: 文件路径 max_lines: 最大读取行数 Returns: List[str]: 文件行列表 """ try: with open(filepath, 'r', encoding='utf-8', errors='ignore') as f: lines = [] for i, line in enumerate(f): if i >= max_lines: break lines.append(line.rstrip('\n')) return lines except (IOError, OSError) as e: logging.getLogger(__name__).warning(f"Failed to read file {filepath}: {e}") return [] def check_command_exists(command: str) -> bool: """ 检查命令是否存在。 Args: command: 命令名称 Returns: bool: 命令是否存在 """ try: # Python 3.6 compatible version subprocess.run( ['which', command], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, check=True ) return True except (subprocess.CalledProcessError, FileNotFoundError): return False def format_bytes(size_bytes: int) -> str: """ 将字节数格式化为人类可读的字符串。 Args: size_bytes: 字节数 Returns: str: 格式化后的字符串 """ if size_bytes == 0: return "0 B" units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB'] unit_index = 0 size = float(size_bytes) while size >= 1024 and unit_index < len(units) - 1: size /= 1024 unit_index += 1 return f"{size:.2f} {units[unit_index]}" def sanitize_filename(filename: str) -> str: """ 清理文件名,移除不安全字符。 Args: filename: 原始文件名 Returns: str: 清理后的文件名 """ # 移除或替换不安全字符 filename = re.sub(r'[<>:"/\\|?*]', '_', filename) filename = filename.strip('. ') return filename def merge_dicts(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]: """ 递归合并两个字典。 Args: base: 基础字典 update: 更新字典 Returns: Dict[str, Any]: 合并后的字典 """ result = base.copy() for key, value in update.items(): if key in result and isinstance(result[key], dict) and isinstance(value, dict): result[key] = merge_dicts(result[key], value) else: result[key] = value return result