420 lines
11 KiB
Python
420 lines
11 KiB
Python
"""
|
||
ServerGuard - 通用工具库
|
||
|
||
提供命令执行、日志配置、输出解析等通用功能。
|
||
"""
|
||
|
||
import subprocess
|
||
import logging
|
||
import sys
|
||
import os
|
||
import re
|
||
import json
|
||
from typing import List, Dict, Any, Optional, Tuple, Union
|
||
from datetime import datetime
|
||
|
||
|
||
class ServerGuardError(Exception):
|
||
"""ServerGuard 基础异常类"""
|
||
pass
|
||
|
||
|
||
class CommandExecutionError(ServerGuardError):
|
||
"""命令执行异常"""
|
||
pass
|
||
|
||
|
||
class PermissionError(ServerGuardError):
|
||
"""权限异常"""
|
||
pass
|
||
|
||
|
||
def execute_command(
|
||
cmd_list: List[str],
|
||
timeout: int = 60,
|
||
check_returncode: bool = True,
|
||
capture_output: bool = True,
|
||
shell: bool = False,
|
||
input_data: Optional[str] = None
|
||
) -> Tuple[int, str, str]:
|
||
"""
|
||
安全地执行外部命令。
|
||
|
||
Args:
|
||
cmd_list: 命令及其参数的列表
|
||
timeout: 命令超时时间(秒)
|
||
check_returncode: 是否在非零返回码时抛出异常
|
||
capture_output: 是否捕获输出
|
||
shell: 是否使用 shell 执行
|
||
input_data: 输入到命令的字符串数据
|
||
|
||
Returns:
|
||
Tuple[returncode, stdout, stderr]
|
||
|
||
Raises:
|
||
CommandExecutionError: 命令执行失败
|
||
PermissionError: 权限不足
|
||
"""
|
||
logger = logging.getLogger(__name__)
|
||
|
||
# 安全:禁止使用 shell=True 时传递未经验证的命令字符串
|
||
if shell and isinstance(cmd_list, list):
|
||
cmd_str = ' '.join(cmd_list)
|
||
logger.warning(f"Using shell=True with command: {cmd_str}")
|
||
|
||
try:
|
||
logger.debug(f"Executing command: {' '.join(cmd_list)}")
|
||
|
||
kwargs = {
|
||
'timeout': timeout,
|
||
'shell': shell,
|
||
'universal_newlines': True # Python 3.6 compatible version of text=True
|
||
}
|
||
if capture_output:
|
||
kwargs['stdout'] = subprocess.PIPE
|
||
kwargs['stderr'] = subprocess.PIPE
|
||
if input_data:
|
||
kwargs['input'] = input_data
|
||
|
||
result = subprocess.run(cmd_list, **kwargs)
|
||
|
||
stdout = result.stdout if result.stdout else ""
|
||
stderr = result.stderr if result.stderr else ""
|
||
|
||
if check_returncode and result.returncode != 0:
|
||
error_msg = f"Command failed with code {result.returncode}: {' '.join(cmd_list)}\nstderr: {stderr}"
|
||
logger.error(error_msg)
|
||
raise CommandExecutionError(error_msg)
|
||
|
||
return result.returncode, stdout, stderr
|
||
|
||
except subprocess.TimeoutExpired:
|
||
error_msg = f"Command timed out after {timeout}s: {' '.join(cmd_list)}"
|
||
logger.error(error_msg)
|
||
raise CommandExecutionError(error_msg)
|
||
except FileNotFoundError:
|
||
error_msg = f"Command not found: {cmd_list[0]}"
|
||
logger.error(error_msg)
|
||
raise CommandExecutionError(error_msg)
|
||
except PermissionError as e:
|
||
error_msg = f"Permission denied executing: {' '.join(cmd_list)}"
|
||
logger.error(error_msg)
|
||
raise PermissionError(error_msg) from e
|
||
|
||
|
||
def check_root_privileges() -> bool:
|
||
"""
|
||
检查当前是否以 root 用户运行。
|
||
|
||
Returns:
|
||
bool: 是否为 root 用户
|
||
"""
|
||
return os.geteuid() == 0
|
||
|
||
|
||
def require_root(func):
|
||
"""
|
||
装饰器:要求函数必须以 root 权限运行。
|
||
"""
|
||
def wrapper(*args, **kwargs):
|
||
if not check_root_privileges():
|
||
logging.warning(f"Function {func.__name__} requires root privileges")
|
||
return {
|
||
"status": "error",
|
||
"error": "This function requires root privileges. Please run with sudo."
|
||
}
|
||
return func(*args, **kwargs)
|
||
return wrapper
|
||
|
||
|
||
def setup_logging(
|
||
log_file: Optional[str] = None,
|
||
level: int = logging.INFO,
|
||
console_output: bool = True
|
||
) -> logging.Logger:
|
||
"""
|
||
配置日志系统。
|
||
|
||
Args:
|
||
log_file: 日志文件路径,None 则不写入文件
|
||
level: 日志级别
|
||
console_output: 是否输出到控制台
|
||
|
||
Returns:
|
||
logging.Logger: 配置好的 logger 实例
|
||
"""
|
||
logger = logging.getLogger()
|
||
logger.setLevel(level)
|
||
|
||
# 清除已有的 handlers
|
||
logger.handlers = []
|
||
|
||
formatter = logging.Formatter(
|
||
'%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
||
datefmt='%Y-%m-%d %H:%M:%S'
|
||
)
|
||
|
||
if console_output:
|
||
console_handler = logging.StreamHandler(sys.stdout)
|
||
console_handler.setFormatter(formatter)
|
||
logger.addHandler(console_handler)
|
||
|
||
if log_file:
|
||
os.makedirs(os.path.dirname(log_file) or '.', exist_ok=True)
|
||
file_handler = logging.FileHandler(log_file)
|
||
file_handler.setFormatter(formatter)
|
||
logger.addHandler(file_handler)
|
||
|
||
return logger
|
||
|
||
|
||
def parse_key_value_output(text: str, delimiter: str = ':') -> Dict[str, str]:
|
||
"""
|
||
解析 key: value 格式的文本输出。
|
||
|
||
Args:
|
||
text: 要解析的文本
|
||
delimiter: 键值分隔符
|
||
|
||
Returns:
|
||
Dict[str, str]: 解析后的字典
|
||
"""
|
||
result = {}
|
||
for line in text.strip().split('\n'):
|
||
line = line.strip()
|
||
if not line or line.startswith('#'):
|
||
continue
|
||
|
||
parts = line.split(delimiter, 1)
|
||
if len(parts) == 2:
|
||
key = parts[0].strip()
|
||
value = parts[1].strip()
|
||
result[key] = value
|
||
|
||
return result
|
||
|
||
|
||
def parse_table_output(text: str, headers: Optional[List[str]] = None) -> List[Dict[str, str]]:
|
||
"""
|
||
解析表格格式的文本输出。
|
||
|
||
Args:
|
||
text: 要解析的文本
|
||
headers: 表头列表,None 则从第一行自动提取
|
||
|
||
Returns:
|
||
List[Dict[str, str]]: 解析后的列表
|
||
"""
|
||
lines = [line.strip() for line in text.strip().split('\n') if line.strip()]
|
||
if not lines:
|
||
return []
|
||
|
||
if headers is None:
|
||
# 尝试自动检测表头
|
||
headers = [h.strip() for h in lines[0].split() if h.strip()]
|
||
data_lines = lines[1:]
|
||
else:
|
||
data_lines = lines
|
||
|
||
result = []
|
||
for line in data_lines:
|
||
values = line.split()
|
||
if len(values) >= len(headers):
|
||
row = {headers[i]: values[i] for i in range(len(headers))}
|
||
result.append(row)
|
||
|
||
return result
|
||
|
||
|
||
def extract_with_regex(text: str, pattern: str, group: int = 1, default: Any = None) -> Any:
|
||
"""
|
||
使用正则表达式从文本中提取内容。
|
||
|
||
Args:
|
||
text: 要搜索的文本
|
||
pattern: 正则表达式模式
|
||
group: 捕获组索引
|
||
default: 未匹配时的默认值
|
||
|
||
Returns:
|
||
匹配结果或默认值
|
||
"""
|
||
match = re.search(pattern, text)
|
||
if match:
|
||
try:
|
||
return match.group(group)
|
||
except IndexError:
|
||
return default
|
||
return default
|
||
|
||
|
||
def safe_int(value: Any, default: int = 0) -> int:
|
||
"""
|
||
安全地将值转换为整数。
|
||
|
||
Args:
|
||
value: 要转换的值
|
||
default: 转换失败时的默认值
|
||
|
||
Returns:
|
||
int: 转换后的整数
|
||
"""
|
||
try:
|
||
# 移除常见单位后缀
|
||
if isinstance(value, str):
|
||
value = value.strip().lower()
|
||
value = re.sub(r'[\s,]', '', value)
|
||
# 处理带单位的数值 (如 "32 GB", "2.5GHz")
|
||
value = re.sub(r'[^\d.-]', '', value)
|
||
return int(float(value))
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
|
||
def safe_float(value: Any, default: float = 0.0) -> float:
|
||
"""
|
||
安全地将值转换为浮点数。
|
||
|
||
Args:
|
||
value: 要转换的值
|
||
default: 转换失败时的默认值
|
||
|
||
Returns:
|
||
float: 转换后的浮点数
|
||
"""
|
||
try:
|
||
if isinstance(value, str):
|
||
value = value.strip().lower()
|
||
value = re.sub(r'[\s,]', '', value)
|
||
value = re.sub(r'[^\d.-]', '', value)
|
||
return float(value)
|
||
except (ValueError, TypeError):
|
||
return default
|
||
|
||
|
||
def get_timestamp() -> str:
|
||
"""
|
||
获取当前时间戳字符串。
|
||
|
||
Returns:
|
||
str: 格式化的时间戳
|
||
"""
|
||
return datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
|
||
|
||
def get_file_timestamp() -> str:
|
||
"""
|
||
获取适合文件名的当前时间戳字符串。
|
||
|
||
Returns:
|
||
str: 格式化的文件名时间戳
|
||
"""
|
||
return datetime.now().strftime('%Y%m%d_%H%M%S')
|
||
|
||
|
||
def read_file_lines(filepath: str, max_lines: int = 1000) -> List[str]:
|
||
"""
|
||
安全地读取文件内容。
|
||
|
||
Args:
|
||
filepath: 文件路径
|
||
max_lines: 最大读取行数
|
||
|
||
Returns:
|
||
List[str]: 文件行列表
|
||
"""
|
||
try:
|
||
with open(filepath, 'r', encoding='utf-8', errors='ignore') as f:
|
||
lines = []
|
||
for i, line in enumerate(f):
|
||
if i >= max_lines:
|
||
break
|
||
lines.append(line.rstrip('\n'))
|
||
return lines
|
||
except (IOError, OSError) as e:
|
||
logging.getLogger(__name__).warning(f"Failed to read file {filepath}: {e}")
|
||
return []
|
||
|
||
|
||
def check_command_exists(command: str) -> bool:
|
||
"""
|
||
检查命令是否存在。
|
||
|
||
Args:
|
||
command: 命令名称
|
||
|
||
Returns:
|
||
bool: 命令是否存在
|
||
"""
|
||
try:
|
||
# Python 3.6 compatible version
|
||
subprocess.run(
|
||
['which', command],
|
||
stdout=subprocess.DEVNULL,
|
||
stderr=subprocess.DEVNULL,
|
||
check=True
|
||
)
|
||
return True
|
||
except (subprocess.CalledProcessError, FileNotFoundError):
|
||
return False
|
||
|
||
|
||
def format_bytes(size_bytes: int) -> str:
|
||
"""
|
||
将字节数格式化为人类可读的字符串。
|
||
|
||
Args:
|
||
size_bytes: 字节数
|
||
|
||
Returns:
|
||
str: 格式化后的字符串
|
||
"""
|
||
if size_bytes == 0:
|
||
return "0 B"
|
||
|
||
units = ['B', 'KB', 'MB', 'GB', 'TB', 'PB']
|
||
unit_index = 0
|
||
size = float(size_bytes)
|
||
|
||
while size >= 1024 and unit_index < len(units) - 1:
|
||
size /= 1024
|
||
unit_index += 1
|
||
|
||
return f"{size:.2f} {units[unit_index]}"
|
||
|
||
|
||
def sanitize_filename(filename: str) -> str:
|
||
"""
|
||
清理文件名,移除不安全字符。
|
||
|
||
Args:
|
||
filename: 原始文件名
|
||
|
||
Returns:
|
||
str: 清理后的文件名
|
||
"""
|
||
# 移除或替换不安全字符
|
||
filename = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||
filename = filename.strip('. ')
|
||
return filename
|
||
|
||
|
||
def merge_dicts(base: Dict[str, Any], update: Dict[str, Any]) -> Dict[str, Any]:
|
||
"""
|
||
递归合并两个字典。
|
||
|
||
Args:
|
||
base: 基础字典
|
||
update: 更新字典
|
||
|
||
Returns:
|
||
Dict[str, Any]: 合并后的字典
|
||
"""
|
||
result = base.copy()
|
||
for key, value in update.items():
|
||
if key in result and isinstance(result[key], dict) and isinstance(value, dict):
|
||
result[key] = merge_dicts(result[key], value)
|
||
else:
|
||
result[key] = value
|
||
return result
|