Files
ServerGuard/utils.py
2026-03-02 14:14:40 +08:00

420 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
ServerGuard - 通用工具库
提供命令执行、日志配置、输出解析等通用功能。
"""
import subprocess
import logging
import sys
import os
import re
import json
from typing import List, Dict, Any, Optional, Tuple, Union
from datetime import datetime
class ServerGuardError(Exception):
"""ServerGuard 基础异常类"""
pass
class CommandExecutionError(ServerGuardError):
"""命令执行异常"""
pass
class PermissionError(ServerGuardError):
"""权限异常"""
pass
def execute_command(
cmd_list: List[str],
timeout: int = 60,
check_returncode: bool = True,
capture_output: bool = True,
shell: bool = False,
input_data: Optional[str] = None
) -> Tuple[int, str, str]:
"""
安全地执行外部命令。
Args:
cmd_list: 命令及其参数的列表
timeout: 命令超时时间(秒)
check_returncode: 是否在非零返回码时抛出异常
capture_output: 是否捕获输出
shell: 是否使用 shell 执行
input_data: 输入到命令的字符串数据
Returns:
Tuple[returncode, stdout, stderr]
Raises:
CommandExecutionError: 命令执行失败
PermissionError: 权限不足
"""
logger = logging.getLogger(__name__)
# 安全:禁止使用 shell=True 时传递未经验证的命令字符串
if shell and isinstance(cmd_list, list):
cmd_str = ' '.join(cmd_list)
logger.warning(f"Using shell=True with command: {cmd_str}")
try:
logger.debug(f"Executing command: {' '.join(cmd_list)}")
kwargs = {
'timeout': timeout,
'shell': shell,
'universal_newlines': True # Python 3.6 compatible version of text=True
}
if capture_output:
kwargs['stdout'] = subprocess.PIPE
kwargs['stderr'] = subprocess.PIPE
if input_data:
kwargs['input'] = input_data
result = subprocess.run(cmd_list, **kwargs)
stdout = result.stdout if result.stdout else ""
stderr = result.stderr if result.stderr else ""
if check_returncode and result.returncode != 0:
error_msg = f"Command failed with code {result.returncode}: {' '.join(cmd_list)}\nstderr: {stderr}"
logger.error(error_msg)
raise CommandExecutionError(error_msg)
return result.returncode, stdout, stderr
except subprocess.TimeoutExpired:
error_msg = f"Command timed out after {timeout}s: {' '.join(cmd_list)}"
logger.error(error_msg)
raise CommandExecutionError(error_msg)
except FileNotFoundError:
error_msg = f"Command not found: {cmd_list[0]}"
logger.error(error_msg)
raise CommandExecutionError(error_msg)
except PermissionError as e:
error_msg = f"Permission denied executing: {' '.join(cmd_list)}"
logger.error(error_msg)
raise PermissionError(error_msg) from e
def check_root_privileges() -> bool:
"""
检查当前是否以 root 用户运行。
Returns:
bool: 是否为 root 用户
"""
return os.geteuid() == 0
def require_root(func):
"""
装饰器:要求函数必须以 root 权限运行。
"""
def wrapper(*args, **kwargs):
if not check_root_privileges():
logging.warning(f"Function {func.__name__} requires root privileges")
return {
"status": "error",
"error": "This function requires root privileges. Please run with sudo."
}
return func(*args, **kwargs)
return wrapper
def setup_logging(
log_file: Optional[str] = None,
level: int = logging.INFO,
console_output: bool = True
) -> logging.Logger:
"""
配置日志系统。
Args:
log_file: 日志文件路径None 则不写入文件
level: 日志级别
console_output: 是否输出到控制台
Returns:
logging.Logger: 配置好的 logger 实例
"""
logger = logging.getLogger()
logger.setLevel(level)
# 清除已有的 handlers
logger.handlers = []
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
if console_output:
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if log_file:
os.makedirs(os.path.dirname(log_file) or '.', exist_ok=True)
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
return logger
def parse_key_value_output(text: str, delimiter: str = ':') -> Dict[str, str]:
"""
解析 key: value 格式的文本输出。
Args:
text: 要解析的文本
delimiter: 键值分隔符
Returns:
Dict[str, str]: 解析后的字典
"""
result = {}
for line in text.strip().split('\n'):
line = line.strip()
if not line or line.startswith('#'):
continue
parts = line.split(delimiter, 1)
if len(parts) == 2:
key = parts[0].strip()
value = parts[1].strip()
result[key] = value
return result
def parse_table_output(text: str, headers: Optional[List[str]] = None) -> List[Dict[str, str]]:
"""
解析表格格式的文本输出。
Args:
text: 要解析的文本
headers: 表头列表None 则从第一行自动提取
Returns:
List[Dict[str, str]]: 解析后的列表
"""
lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
if not lines:
return []
if headers is None:
# 尝试自动检测表头
headers = [h.strip() for h in lines[0].split() if h.strip()]
data_lines = lines[1:]
else:
data_lines = lines
result = []
for line in data_lines:
values = line.split()
if len(values) >= len(headers):
row = {headers[i]: values[i] for i in range(len(headers))}
result.append(row)
return result
def extract_with_regex(text: str, pattern: str, group: int = 1, default: Any = None) -> Any:
"""
使用正则表达式从文本中提取内容。
Args:
text: 要搜索的文本
pattern: 正则表达式模式
group: 捕获组索引
default: 未匹配时的默认值
Returns:
匹配结果或默认值
"""
match = re.search(pattern, text)
if match:
try:
return match.group(group)
except IndexError:
return default
return default
def safe_int(value: Any, default: int = 0) -> int:
"""
安全地将值转换为整数。
Args:
value: 要转换的值
default: 转换失败时的默认值
Returns:
int: 转换后的整数
"""
try:
# 移除常见单位后缀
if isinstance(value, str):
value = value.strip().lower()
value = re.sub(r'[\s,]', '', value)
# 处理带单位的数值 (如 "32 GB", "2.5GHz")
value = re.sub(r'[^\d.-]', '', value)
return int(float(value))
except (ValueError, TypeError):
return default
def safe_float(value: Any, default: float = 0.0) -> float:
"""
安全地将值转换为浮点数。
Args:
value: 要转换的值
default: 转换失败时的默认值
Returns:
float: 转换后的浮点数
"""
try:
if isinstance(value, str):
value = value.strip().lower()
value = re.sub(r'[\s,]', '', value)
value = re.sub(r'[^\d.-]', '', value)
return float(value)
except (ValueError, TypeError):
return default
def get_timestamp() -> str:
"""
获取当前时间戳字符串。
Returns:
str: 格式化的时间戳
"""
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
def get_file_timestamp() -> str:
"""
获取适合文件名的当前时间戳字符串。
Returns:
str: 格式化的文件名时间戳
"""
return datetime.now().strftime('%Y%m%d_%H%M%S')
def read_file_lines(filepath: str, max_lines: int = 1000) -> List[str]:
"""
安全地读取文件内容。
Args:
filepath: 文件路径
max_lines: 最大读取行数
Returns:
List[str]: 文件行列表
"""
try:
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
lines = []
for i, line in enumerate(f):
if i >= max_lines:
break
lines.append(line.rstrip('\n'))
return lines
except (IOError, OSError) as e:
logging.getLogger(__name__).warning(f"Failed to read file {filepath}: {e}")
return []
def check_command_exists(command: str) -> bool:
"""
检查命令是否存在。
Args:
command: 命令名称
Returns:
bool: 命令是否存在
"""
try:
# Python 3.6 compatible version
subprocess.run(
['which', command],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
check=True
)
return True
except (subprocess.CalledProcessError, FileNotFoundError):
return False
def format_bytes(size_bytes: int) -> str:
"""
将字节数格式化为人类可读的字符串。
Args:
size_bytes: 字节数
Returns:
str: 格式化后的字符串
"""
if size_bytes == 0:
return "0 B"
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
unit_index = 0
size = float(size_bytes)
while size >= 1024 and unit_index < len(units) - 1:
size /= 1024
unit_index += 1
return f"{size:.2f} {units[unit_index]}"
def sanitize_filename(filename: str) -> str:
"""
清理文件名,移除不安全字符。
Args:
filename: 原始文件名
Returns:
str: 清理后的文件名
"""
# 移除或替换不安全字符
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
filename = filename.strip('. ')
return filename
def merge_dicts(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
"""
递归合并两个字典。
Args:
base: 基础字典
update: 更新字典
Returns:
Dict[str, Any]: 合并后的字典
"""
result = base.copy()
for key, value in update.items():
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
result[key] = merge_dicts(result[key], value)
else:
result[key] = value
return result