739 lines
30 KiB
Python
739 lines
30 KiB
Python
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib
|
||
|
||
matplotlib.use('Agg')
|
||
import matplotlib.pyplot as plt
|
||
from scipy import signal
|
||
import os
|
||
import glob
|
||
from datetime import datetime
|
||
import time
|
||
from multiprocessing import Pool, cpu_count
|
||
from matplotlib.colors import Normalize
|
||
from matplotlib.ticker import MaxNLocator
|
||
import re
|
||
from colorama import Fore, Style, init
|
||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||
import warnings
|
||
import threading
|
||
|
||
# 初始化colorama
|
||
init(autoreset=True)
|
||
|
||
# 忽略特定的matplotlib警告
|
||
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
|
||
warnings.filterwarnings("ignore", category=FutureWarning, module="matplotlib")
|
||
|
||
# 创建线程锁,确保文件操作和日志输出的线程安全
|
||
file_lock = threading.Lock()
|
||
log_lock = threading.Lock()
|
||
|
||
|
||
class IMUDataAnalyzer:
|
||
def __init__(self, file_path):
|
||
self.file_path = file_path
|
||
self.data = None
|
||
self.sampling_rate = None
|
||
self.fig_size = (15, 10)
|
||
self.spectrogram_params = {} # 存储频谱图计算参数
|
||
|
||
# 从文件名推断数据类型和采样率
|
||
file_name = os.path.basename(file_path).lower()
|
||
if 'calib' in file_name:
|
||
self.data_type = 'calib'
|
||
self.default_sampling_rate = 5
|
||
elif 'raw' in file_name:
|
||
self.data_type = 'raw'
|
||
self.default_sampling_rate = 1000
|
||
else:
|
||
self.data_type = 'unknown'
|
||
self.default_sampling_rate = 5
|
||
|
||
# 解析文件路径和文件名
|
||
file_dir = os.path.dirname(os.path.abspath(file_path))
|
||
file_base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 创建文件名称+时间戳尾缀的输出目录
|
||
self.output_dir = os.path.join(file_dir, f"{file_base_name}_output_{self.timestamp}")
|
||
|
||
# 使用锁确保目录创建的线程安全
|
||
with file_lock:
|
||
if not os.path.exists(self.output_dir):
|
||
os.makedirs(self.output_dir)
|
||
self.log_progress(f"创建输出目录:{self.output_dir}", "INFO")
|
||
|
||
# 字体设置
|
||
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'Arial']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
# 设置matplotlib兼容性选项,避免布局引擎冲突
|
||
plt.rcParams['figure.constrained_layout.use'] = False
|
||
plt.rcParams['figure.constrained_layout.h_pad'] = 0.02
|
||
plt.rcParams['figure.constrained_layout.w_pad'] = 0.02
|
||
plt.rcParams['figure.constrained_layout.hspace'] = 0.02
|
||
plt.rcParams['figure.constrained_layout.wspace'] = 0.02
|
||
|
||
self.log_progress(f"处理文件:{self.file_path}", "INFO")
|
||
self.log_progress(f"数据类型:{self.data_type}", "INFO")
|
||
self.log_progress(f"输出路径:{self.output_dir}", "INFO")
|
||
|
||
def log_progress(self, message, level="INFO"):
|
||
"""带颜色和级别的日志输出(线程安全)"""
|
||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||
with log_lock:
|
||
if level == "INFO":
|
||
print(f"{Fore.CYAN}[{timestamp}] {Fore.GREEN}{message}")
|
||
elif level == "WARNING":
|
||
print(f"{Fore.CYAN}[{timestamp}] {Fore.YELLOW}警告: {message}")
|
||
elif level == "ERROR":
|
||
print(f"{Fore.CYAN}[{timestamp}] {Fore.RED}错误: {message}")
|
||
elif level == "SUCCESS":
|
||
print(f"{Fore.CYAN}[{timestamp}] {Fore.GREEN}✓ {message}")
|
||
else:
|
||
print(f"{Fore.CYAN}[{timestamp}] {message}")
|
||
|
||
def check_imu_columns_in_file(self):
|
||
"""检查文件是否包含IMU数据列(通过读取文件头)"""
|
||
try:
|
||
# 只读取第一行来检查列名
|
||
with open(self.file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||
first_line = f.readline().strip()
|
||
|
||
# 检查第一行是否包含imu关键词(不区分大小写)
|
||
if re.search(r'imu', first_line, re.IGNORECASE):
|
||
return True
|
||
else:
|
||
self.log_progress(f"文件头部不包含'imu'关键词,跳过处理,first_line {first_line}", "WARNING")
|
||
return False
|
||
|
||
except Exception as e:
|
||
self.log_progress(f"检查文件头部时出错: {str(e)}", "ERROR")
|
||
return False
|
||
|
||
def detect_imu_columns(self):
|
||
"""自动检测IMU数据列"""
|
||
all_columns = self.data.columns.tolist()
|
||
|
||
# 查找imu前缀(如imu1, imu2等)
|
||
imu_prefixes = set()
|
||
for col in all_columns:
|
||
match = re.match(r'^(imu\d+)_', col.lower())
|
||
if match:
|
||
imu_prefixes.add(match.group(1))
|
||
|
||
if not imu_prefixes:
|
||
self.log_progress("未检测到IMU数据列,尝试使用默认列名", "WARNING")
|
||
# 尝试使用常见列名
|
||
self.acc_columns = ['imu1_acc_x', 'imu1_acc_y', 'imu1_acc_z']
|
||
self.gyro_columns = ['imu1_gyro_x', 'imu1_gyro_y', 'imu1_gyro_z']
|
||
self.temp_columns = ['imu1_temp']
|
||
return
|
||
|
||
# 使用第一个检测到的IMU前缀
|
||
imu_prefix = list(imu_prefixes)[0]
|
||
self.log_progress(f"检测到IMU前缀: {imu_prefix}", "INFO")
|
||
|
||
# 查找加速度计列
|
||
self.acc_columns = [col for col in all_columns
|
||
if col.lower().startswith(f"{imu_prefix}_acc") and
|
||
any(axis in col.lower() for axis in ['_x', '_y', '_z'])]
|
||
|
||
# 查找陀螺仪列
|
||
self.gyro_columns = [col for col in all_columns
|
||
if col.lower().startswith(f"{imu_prefix}_gyro") and
|
||
any(axis in col.lower() for axis in ['_x', '_y', '_z'])]
|
||
|
||
# 查找温度列
|
||
self.temp_columns = [col for col in all_columns
|
||
if col.lower().startswith(f"{imu_prefix}_temp")]
|
||
|
||
# 如果没有找到温度列,尝试其他常见名称
|
||
if not self.temp_columns:
|
||
self.temp_columns = [col for col in all_columns
|
||
if any(name in col.lower() for name in ['temp', 'temperature'])]
|
||
|
||
self.log_progress(f"加速度计列: {self.acc_columns}", "INFO")
|
||
self.log_progress(f"陀螺仪列: {self.gyro_columns}", "INFO")
|
||
self.log_progress(f"温度列: {self.temp_columns}", "INFO")
|
||
|
||
def estimate_sampling_rate(self):
|
||
"""估计实际采样率"""
|
||
if 'time' in self.data.columns and len(self.data) > 10:
|
||
time_diff = np.diff(self.data['time'].values)
|
||
valid_diffs = time_diff[(time_diff > 0) & (time_diff < 10)] # 排除异常值
|
||
if len(valid_diffs) > 0:
|
||
estimated_rate = 1.0 / np.median(valid_diffs)
|
||
self.log_progress(f"根据时间戳估计的采样率: {estimated_rate:.2f} Hz")
|
||
return estimated_rate
|
||
|
||
# 如果没有时间列或无法估计,使用基于文件名的默认值
|
||
self.log_progress(f"使用基于文件名的默认采样率: {self.default_sampling_rate} Hz")
|
||
return self.default_sampling_rate
|
||
|
||
def load_data(self):
|
||
"""加载并预处理数据"""
|
||
self.log_progress("开始加载数据...")
|
||
start_time = time.time()
|
||
|
||
# 首先检查文件是否包含IMU数据
|
||
if not self.check_imu_columns_in_file():
|
||
raise ValueError("文件不包含IMU数据列,跳过处理")
|
||
|
||
# 使用锁确保文件读取的线程安全
|
||
with file_lock:
|
||
self.data = pd.read_csv(self.file_path)
|
||
|
||
self.log_progress(f"数据加载完成,共 {len(self.data)} 行,耗时 {time.time() - start_time:.2f}秒")
|
||
|
||
# 检测IMU数据列
|
||
self.detect_imu_columns()
|
||
|
||
# 估计采样率
|
||
self.sampling_rate = self.estimate_sampling_rate()
|
||
|
||
# 创建时间序列并处理异常时间值
|
||
if 'time' in self.data.columns:
|
||
valid_time_mask = (self.data['time'] > 0) & (self.data['time'] < 1e6)
|
||
self.data = self.data[valid_time_mask].copy()
|
||
self.data['time'] = np.arange(len(self.data)) / self.sampling_rate
|
||
else:
|
||
# 如果没有时间列,创建基于采样率的时间序列
|
||
self.data['time'] = np.arange(len(self.data)) / self.sampling_rate
|
||
|
||
def remove_dc(self, signal_data):
|
||
"""不移除直流分量(保留以在频谱中显示 DC)"""
|
||
return signal_data
|
||
|
||
def compute_spectrogram(self, signal_data):
|
||
"""计算频谱图(保留直流分量),优化频谱分辨率和减少颗粒感"""
|
||
# 保留直流分量
|
||
signal_data = self.remove_dc(signal_data)
|
||
|
||
# 数据长度
|
||
n_samples = len(signal_data)
|
||
|
||
# 根据采样率和数据长度自适应选择参数
|
||
if self.sampling_rate <= 10: # 低采样率(5Hz)
|
||
# 对于低采样率,使用较长的窗口以获得更好的频率分辨率
|
||
nperseg = min(256, max(64, n_samples // 2))
|
||
noverlap = int(nperseg * 0.75) # 增加重叠比例
|
||
|
||
else: # 高采样率(1000Hz)
|
||
# 对于高采样率,平衡时间分辨率和频率分辨率
|
||
if n_samples < 10000: # 较短的数据
|
||
nperseg = min(512, max(256, n_samples // 4))
|
||
else: # 较长的数据
|
||
nperseg = min(1024, max(512, n_samples // 8))
|
||
|
||
noverlap = int(nperseg * 0.66) # 适中的重叠比例
|
||
|
||
# 确保窗口大小合理
|
||
nperseg = max(16, min(nperseg, n_samples))
|
||
noverlap = min(noverlap, nperseg - 1)
|
||
|
||
# 记录频谱图计算参数
|
||
self.spectrogram_params = {
|
||
"nperseg": nperseg,
|
||
"noverlap": noverlap,
|
||
"window": "hamming",
|
||
"detrend": False,
|
||
"scaling": "density",
|
||
"mode": "psd"
|
||
}
|
||
|
||
# 使用更平滑的窗口函数
|
||
f, t, Sxx = signal.spectrogram(
|
||
signal_data,
|
||
fs=self.sampling_rate,
|
||
window='hamming', # 使用汉明窗,比汉宁窗更平滑
|
||
nperseg=nperseg,
|
||
noverlap=noverlap,
|
||
scaling='density',
|
||
detrend=False, # 保留直流
|
||
mode='psd'
|
||
)
|
||
|
||
# 应用平滑处理以减少颗粒感
|
||
if Sxx.size > 0:
|
||
# 使用小范围的高斯滤波平滑(可选)
|
||
from scipy.ndimage import gaussian_filter
|
||
Sxx_smoothed = gaussian_filter(Sxx, sigma=0.7)
|
||
return f, t, Sxx_smoothed
|
||
|
||
return f, t, Sxx
|
||
|
||
def process_signal(self, args):
|
||
"""并行处理单个信号"""
|
||
signal_data, axis = args
|
||
f, t, Sxx = self.compute_spectrogram(signal_data)
|
||
|
||
# 防止 log10(0)
|
||
eps = np.finfo(float).eps
|
||
Sxx_log = 10 * np.log10(Sxx + eps)
|
||
|
||
# 降采样以加速绘图
|
||
if len(t) > 1000: # 如果时间点太多,进行降采样
|
||
time_indices = np.linspace(0, len(t) - 1, 1000, dtype=int)
|
||
freq_indices = np.linspace(0, len(f) - 1, 500, dtype=int)
|
||
t = t[time_indices]
|
||
f = f[freq_indices]
|
||
Sxx_log = Sxx_log[freq_indices, :][:, time_indices]
|
||
dc_idx = int(np.argmin(np.abs(f - 0.0)))
|
||
dc_log = Sxx_log[dc_idx, :] # shape: (len(t),)
|
||
|
||
# 更健壮的 0 Hz 索引选择
|
||
zero_idx = np.where(np.isclose(f, 0.0))[0]
|
||
dc_idx = int(zero_idx[0]) if len(zero_idx) > 0 else int(np.argmin(np.abs(f - 0.0)))
|
||
dc_log = Sxx_log[dc_idx, :] # 每个时间窗的 0 Hz PSD(dB)
|
||
|
||
return {
|
||
'f': f,
|
||
't': t,
|
||
'Sxx_log': Sxx_log,
|
||
'dc_log': dc_log,
|
||
'axis': axis
|
||
}
|
||
|
||
@staticmethod
|
||
def robust_dc_ylim(results, p_low=5, p_high=95, pad_ratio=0.05, fallback=(0.0, 1.0)):
|
||
"""
|
||
计算统一 DC 纵轴范围(分位数 + 少许边距),并过滤 inf/NaN
|
||
"""
|
||
if not results:
|
||
return fallback
|
||
dc_all = np.concatenate([r['dc_log'].ravel() for r in results])
|
||
dc_all = dc_all[np.isfinite(dc_all)]
|
||
if dc_all.size == 0:
|
||
return fallback
|
||
lo, hi = np.percentile(dc_all, [p_low, p_high])
|
||
span = max(1e-9, hi - lo)
|
||
lo -= span * pad_ratio
|
||
hi += span * pad_ratio
|
||
return lo, hi
|
||
|
||
def get_time_domain_stats(self):
|
||
"""计算时域信号的统计信息"""
|
||
stats = {}
|
||
if self.acc_columns:
|
||
stats['加速度计'] = {col: {
|
||
'均值': self.data[col].mean(),
|
||
'标准差': self.data[col].std(),
|
||
'最大值': self.data[col].max(),
|
||
'最小值': self.data[col].min()
|
||
} for col in self.acc_columns}
|
||
if self.gyro_columns:
|
||
stats['陀螺仪'] = {col: {
|
||
'均值': self.data[col].mean(),
|
||
'标准差': self.data[col].std(),
|
||
'最大值': self.data[col].max(),
|
||
'最小值': self.data[col].min()
|
||
} for col in self.gyro_columns}
|
||
if self.temp_columns:
|
||
stats['温度'] = {col: {
|
||
'均值': self.data[col].mean(),
|
||
'标准差': self.data[col].std(),
|
||
'最大值': self.data[col].max(),
|
||
'最小值': self.data[col].min()
|
||
} for col in self.temp_columns}
|
||
return stats
|
||
|
||
def generate_html_report(self, time_domain_stats):
|
||
"""生成HTML报告"""
|
||
html_content = f"""
|
||
<!DOCTYPE html>
|
||
<html lang="zh-CN">
|
||
<head>
|
||
<meta charset="UTF-8">
|
||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||
<title>IMU数据分析报告 - {os.path.basename(self.file_path)}</title>
|
||
<style>
|
||
body {{ font-family: Arial, sans-serif; margin: 20px; }}
|
||
h1, h2, h3 {{ color: #333; }}
|
||
table {{ border-collapse: collapse; width: 100%; margin-bottom: 20px; }}
|
||
th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
|
||
th {{ background-color: #f2f2f2; }}
|
||
img {{ max-width: 100%; height: auto; display: block; margin: 10px 0; }}
|
||
</style>
|
||
</head>
|
||
<body>
|
||
<h1>IMU数据分析报告</h1>
|
||
<p><strong>文件路径:</strong> {self.file_path}</p>
|
||
<p><strong>分析时间:</strong> {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}</p>
|
||
<p><strong>采样率:</strong> {self.sampling_rate} Hz</p>
|
||
|
||
<h2>时域信号统计信息</h2>
|
||
"""
|
||
|
||
# 添加时域统计信息
|
||
for sensor_type, sensors in time_domain_stats.items():
|
||
html_content += f"<h3>{sensor_type}</h3>"
|
||
html_content += "<table>"
|
||
html_content += "<tr><th>传感器</th><th>均值</th><th>标准差</th><th>最大值</th><th>最小值</th></tr>"
|
||
for col, stats in sensors.items():
|
||
html_content += f"<tr><td>{col}</td><td>{stats['均值']:.4f}</td><td>{stats['标准差']:.4f}</td><td>{stats['最大值']:.4f}</td><td>{stats['最小值']:.4f}</td></tr>"
|
||
html_content += "</table>"
|
||
|
||
# 添加频域参数信息
|
||
html_content += """
|
||
<h2>频域信号计算参数</h2>
|
||
<table>
|
||
<tr><th>参数</th><th>值</th></tr>
|
||
"""
|
||
for key, value in self.spectrogram_params.items():
|
||
html_content += f"<tr><td>{key}</td><td>{value}</td></tr>"
|
||
html_content += "</table>"
|
||
|
||
# 添加图像链接
|
||
time_series_image = f'time_series_{self.timestamp}.png'
|
||
acc_spectrogram_image = f'acc_rainfall_spectrogram_{self.timestamp}.png'
|
||
gyro_spectrogram_image = f'gyro_rainfall_spectrogram_{self.timestamp}.png'
|
||
|
||
html_content += f"""
|
||
<h2>时域信号图</h2>
|
||
<img src="{time_series_image}" alt="时域信号图">
|
||
|
||
<h2>加速度计频谱雨点图</h2>
|
||
<img src="{acc_spectrogram_image}" alt="加速度计频谱雨点图">
|
||
|
||
<h2>陀螺仪频谱雨点图</h2>
|
||
<img src="{gyro_spectrogram_image}" alt="陀螺仪频谱雨点图">
|
||
"""
|
||
|
||
html_content += """
|
||
</body>
|
||
</html>
|
||
"""
|
||
|
||
# 保存HTML报告
|
||
report_path = os.path.join(self.output_dir, f'report_{self.timestamp}.html')
|
||
with open(report_path, 'w', encoding='utf-8') as f:
|
||
f.write(html_content)
|
||
|
||
self.log_progress(f"HTML报告已生成: {report_path}")
|
||
|
||
def plot_time_series(self):
|
||
"""绘制时间序列图"""
|
||
self.log_progress("开始绘制时间序列图...")
|
||
start_time = time.time()
|
||
|
||
# 确定子图数量
|
||
n_plots = 1 # 至少有一个加速度图
|
||
if self.gyro_columns: # 如果有陀螺仪数据
|
||
n_plots += 1
|
||
if self.temp_columns: # 如果有温度数据
|
||
n_plots += 1
|
||
|
||
fig, axes = plt.subplots(n_plots, 1, figsize=(12, 3 * n_plots), dpi=120)
|
||
if n_plots == 1:
|
||
axes = [axes] # 确保axes是列表
|
||
|
||
plot_idx = 0
|
||
|
||
# 加速度计数据
|
||
if self.acc_columns:
|
||
ax = axes[plot_idx]
|
||
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
|
||
labels = ['X', 'Y', 'Z']
|
||
for i, col in enumerate(self.acc_columns):
|
||
if i < 3: # 只绘制前三个轴
|
||
ax.plot(self.data['time'], self.data[col],
|
||
label=labels[i], color=colors[i], linewidth=1.0, alpha=0.8)
|
||
ax.set_title('加速度时间序列', fontsize=12)
|
||
ax.set_ylabel('加速度 (g)', fontsize=10)
|
||
ax.legend(loc='upper right', fontsize=8, framealpha=0.5)
|
||
ax.grid(True, linestyle=':', alpha=0.5)
|
||
ax.set_xlim(0, self.data['time'].max())
|
||
plot_idx += 1
|
||
|
||
# 陀螺仪数据(如果有)
|
||
if self.gyro_columns and plot_idx < n_plots:
|
||
ax = axes[plot_idx]
|
||
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
|
||
labels = ['X', 'Y', 'Z']
|
||
for i, col in enumerate(self.gyro_columns):
|
||
if i < 3: # 只绘制前三个轴
|
||
ax.plot(self.data['time'], self.data[col],
|
||
label=labels[i], color=colors[i], linewidth=1.0, alpha=0.8)
|
||
ax.set_title('陀螺仪时间序列', fontsize=12)
|
||
ax.set_ylabel('角速度 (deg/s)', fontsize=10)
|
||
ax.legend(loc='upper left', fontsize=8, framealpha=0.5)
|
||
ax.grid(True, linestyle=':', alpha=0.5)
|
||
ax.set_xlim(0, self.data['time'].max())
|
||
plot_idx += 1
|
||
|
||
# 温度数据(如果有)
|
||
if self.temp_columns and plot_idx < n_plots:
|
||
ax = axes[plot_idx]
|
||
ax.plot(self.data['time'], self.data[self.temp_columns[0]],
|
||
label='温度', color='#9467bd', linewidth=1.0, alpha=0.8)
|
||
ax.set_title('温度时间序列', fontsize=12)
|
||
ax.set_xlabel('时间 (s)', fontsize=10)
|
||
ax.set_ylabel('温度 (°C)', fontsize=10)
|
||
ax.legend(loc='upper right', fontsize=8, framealpha=0.5)
|
||
ax.grid(True, linestyle=':', alpha=0.5)
|
||
ax.set_xlim(0, self.data['time'].max())
|
||
|
||
plt.tight_layout()
|
||
output_path = os.path.join(self.output_dir, f'time_series_{self.timestamp}.png')
|
||
plt.savefig(output_path, bbox_inches='tight', dpi=150)
|
||
plt.close(fig)
|
||
self.log_progress(f"时间序列图已保存: {output_path}")
|
||
self.log_progress(f"时间序列图已保存为 {output_path},耗时 {time.time() - start_time:.2f}秒")
|
||
|
||
def plot_rainfall_spectrograms(self):
|
||
"""并行绘制所有频谱雨点图(修复colorbar布局问题)"""
|
||
self.log_progress("开始并行绘制频谱雨点图...")
|
||
start_time = time.time()
|
||
|
||
# 准备加速度计数据
|
||
self.log_progress("准备加速度计数据...")
|
||
acc_signals = [(self.data[col], f'Acc {["X", "Y", "Z"][i]}')
|
||
for i, col in enumerate(self.acc_columns) if i < 3] # 只处理前三个轴
|
||
|
||
# 准备陀螺仪数据(如果有)
|
||
gyro_signals = []
|
||
if self.gyro_columns:
|
||
self.log_progress("准备陀螺仪数据...")
|
||
gyro_signals = [(self.data[col], f'Gyro {["X", "Y", "Z"][i]}')
|
||
for i, col in enumerate(self.gyro_columns) if i < 3] # 只处理前三个轴
|
||
|
||
# 如果没有数据可处理,直接返回
|
||
if not acc_signals and not gyro_signals:
|
||
self.log_progress("警告: 没有有效的数据列可供处理", "WARNING")
|
||
return
|
||
|
||
# 使用多进程处理信号(避免线程冲突)
|
||
self.log_progress("使用多进程并行处理...")
|
||
all_signals = acc_signals + gyro_signals
|
||
with Pool(processes=min(len(all_signals), cpu_count())) as pool:
|
||
results = pool.map(self.process_signal, all_signals)
|
||
|
||
# 分离结果
|
||
self.log_progress("分离结果...")
|
||
acc_results = [r for r in results if r['axis'].startswith('Acc')]
|
||
gyro_results = [r for r in results if r['axis'].startswith('Gyro')]
|
||
|
||
# 统一颜色标尺(5%-95%分位)
|
||
if acc_results:
|
||
self.log_progress("计算加速度计全局最小和最大值...")
|
||
acc_all_Sxx = np.concatenate([r['Sxx_log'].ravel() for r in acc_results])
|
||
acc_vmin, acc_vmax = np.percentile(acc_all_Sxx, [5, 95])
|
||
|
||
# 统一 DC Y 轴范围
|
||
acc_dc_ymin, acc_dc_ymax = self.robust_dc_ylim(acc_results)
|
||
self.log_progress(f"加速度 DC (dB) 范围: {acc_dc_ymin:.1f} 到 {acc_dc_ymax:.1f}")
|
||
|
||
if gyro_results:
|
||
self.log_progress("计算陀螺仪全局最小和最大值...")
|
||
gyro_all_Sxx = np.concatenate([r['Sxx_log'].ravel() for r in gyro_results])
|
||
gyro_vmin, gyro_vmax = np.percentile(gyro_all_Sxx, [5, 95])
|
||
|
||
# 统一 DC Y 轴范围
|
||
gyro_dc_ymin, gyro_dc_ymax = self.robust_dc_ylim(gyro_results)
|
||
self.log_progress(f"陀螺仪 DC (dB) 范围: {gyro_dc_ymin:.1f} 到 {gyro_dc_ymax:.1f}")
|
||
|
||
# ========= 绘制加速度计频谱雨点图 =========
|
||
if acc_results:
|
||
self._plot_single_spectrogram(acc_results, acc_vmin, acc_vmax, acc_dc_ymin, acc_dc_ymax,
|
||
'加速度', 'acc_rainfall_spectrogram')
|
||
self.log_progress(f"加速度功率谱密度范围: {acc_vmin:.1f} dB 到 {acc_vmax:.1f} dB")
|
||
|
||
# ========= 绘制陀螺仪频谱雨点图 =========
|
||
if gyro_results:
|
||
self._plot_single_spectrogram(gyro_results, gyro_vmin, gyro_vmax, gyro_dc_ymin, gyro_dc_ymax,
|
||
'角速度', 'gyro_rainfall_spectrogram')
|
||
self.log_progress(f"陀螺仪功率谱密度范围: {gyro_vmin:.1f} dB 到 {gyro_vmax:.1f} dB")
|
||
|
||
total_time = time.time() - start_time
|
||
self.log_progress(f"频谱雨点图生成完成,总耗时 {total_time:.2f}秒")
|
||
|
||
def _plot_single_spectrogram(self, results, vmin, vmax, dc_ymin, dc_ymax, title_prefix, filename_prefix):
|
||
"""绘制单个频谱雨点图"""
|
||
rows = len(results)
|
||
fig = plt.figure(constrained_layout=True, figsize=(14, 4 * rows), dpi=150)
|
||
gs = fig.add_gridspec(nrows=rows, ncols=2, width_ratios=[22, 1], wspace=0.05, hspace=0.12)
|
||
|
||
axes_main = []
|
||
axes_cbar = []
|
||
for i in range(rows):
|
||
axes_main.append(fig.add_subplot(gs[i, 0]))
|
||
axes_cbar.append(fig.add_subplot(gs[i, 1]))
|
||
|
||
for i, result in enumerate(results):
|
||
ax = axes_main[i]
|
||
cax = axes_cbar[i]
|
||
|
||
sc = ax.scatter(
|
||
np.repeat(result['t'], len(result['f'])),
|
||
np.tile(result['f'], len(result['t'])),
|
||
c=result['Sxx_log'].T.ravel(),
|
||
cmap='jet',
|
||
s=3,
|
||
alpha=0.7,
|
||
vmin=vmin,
|
||
vmax=vmax,
|
||
rasterized=True
|
||
)
|
||
|
||
ax.set_title(f'{title_prefix}频谱雨点图 - {result["axis"][-1]}(右侧为DC分量 dB)', fontsize=10)
|
||
ax.set_xlabel('时间 (s)', fontsize=9)
|
||
ax.set_ylabel('频率 (Hz)', fontsize=9)
|
||
ax.set_ylim(0, self.sampling_rate / 2)
|
||
ax.grid(True, linestyle=':', alpha=0.4)
|
||
|
||
ax2 = ax.twinx()
|
||
ax2.plot(result['t'], result['dc_log'], color='black', linewidth=1.2, alpha=0.85, label='DC (dB)')
|
||
ax2.set_ylabel('直流分量 (dB)', fontsize=9, color='black')
|
||
ax2.set_ylim(dc_ymin, dc_ymax)
|
||
ax2.tick_params(axis='y', labelcolor='black')
|
||
ax2.yaxis.set_major_locator(MaxNLocator(nbins=6))
|
||
ax2.grid(False)
|
||
ax2.legend(loc='upper right', fontsize=8, framealpha=0.5)
|
||
|
||
cbar = fig.colorbar(sc, cax=cax)
|
||
cbar.set_label('功率谱密度 (dB)', fontsize=9)
|
||
cax.tick_params(labelsize=8)
|
||
|
||
output_path = os.path.join(self.output_dir, f'{filename_prefix}_{self.timestamp}.png')
|
||
plt.savefig(output_path, bbox_inches='tight', dpi=150)
|
||
plt.close(fig)
|
||
self.log_progress(f"{title_prefix}频谱雨点图已保存为 {output_path}")
|
||
|
||
def run_analysis(self):
|
||
"""运行完整分析流程"""
|
||
try:
|
||
self.log_progress("开始数据分析流程", "INFO")
|
||
start_time = time.time()
|
||
|
||
self.load_data()
|
||
self.plot_time_series()
|
||
self.plot_rainfall_spectrograms()
|
||
|
||
# 计算时域统计信息
|
||
time_domain_stats = self.get_time_domain_stats()
|
||
|
||
# 生成HTML报告
|
||
self.generate_html_report(time_domain_stats)
|
||
|
||
total_time = time.time() - start_time
|
||
self.log_progress(f"分析完成,总耗时 {total_time:.2f}秒", "SUCCESS")
|
||
self.log_progress(f"所有输出文件已保存到: {self.output_dir}", "INFO")
|
||
return True
|
||
|
||
except ValueError as e:
|
||
# 跳过不包含IMU数据的文件
|
||
self.log_progress(f"跳过文件: {str(e)}", "WARNING")
|
||
return False
|
||
except Exception as e:
|
||
self.log_progress(f"分析过程中出现错误: {str(e)}", "ERROR")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
|
||
def process_single_file(file_path):
|
||
"""处理单个文件的函数(使用进程隔离)"""
|
||
try:
|
||
print(f"{Fore.BLUE}开始处理文件: {os.path.basename(file_path)}")
|
||
analyzer = IMUDataAnalyzer(file_path)
|
||
success = analyzer.run_analysis()
|
||
if success:
|
||
return (file_path, True, "处理成功")
|
||
else:
|
||
return (file_path, False, "文件不包含IMU数据,已跳过")
|
||
except Exception as e:
|
||
return (file_path, False, str(e))
|
||
|
||
|
||
def main():
|
||
"""主函数,支持多文件处理和进度显示"""
|
||
print("=" * 60)
|
||
print(f"{Fore.CYAN}IMU数据频谱分析工具 - 多文件批量处理")
|
||
print("=" * 60)
|
||
|
||
# 获取输入路径
|
||
print(f"{Fore.WHITE}请输入包含CSV文件的目录路径: ")
|
||
input_path = input("> ").strip()
|
||
|
||
if not os.path.exists(input_path):
|
||
print(f"{Fore.RED}错误: 路径 '{input_path}' 不存在!")
|
||
return
|
||
|
||
# 查找所有包含imu的CSV文件(不区分大小写)
|
||
if os.path.isdir(input_path):
|
||
# 使用单个glob模式匹配所有文件,然后过滤包含imu的文件
|
||
all_csv_files = glob.glob(os.path.join(input_path, "**", "*.csv"), recursive=True)
|
||
csv_files = [f for f in all_csv_files if re.search(r'imu', f, re.IGNORECASE)]
|
||
csv_files = list(set(csv_files)) # 去重
|
||
csv_files.sort()
|
||
else:
|
||
# 对于单个文件,检查是否包含imu(不区分大小写)
|
||
if re.search(r'imu', input_path, re.IGNORECASE):
|
||
csv_files = [input_path]
|
||
else:
|
||
csv_files = []
|
||
|
||
if not csv_files:
|
||
print(f"{Fore.YELLOW}警告: 未找到包含'imu'的CSV文件")
|
||
return
|
||
|
||
print(f"{Fore.GREEN}找到 {len(csv_files)} 个IMU数据文件:")
|
||
for i, file in enumerate(csv_files, 1):
|
||
print(f" {i}. {os.path.basename(file)}")
|
||
|
||
# 使用多进程处理文件(避免matplotlib线程冲突)
|
||
print(f"\n{Fore.CYAN}开始多线程处理文件 (使用 {min(len(csv_files), cpu_count())} 个线程)...")
|
||
|
||
success_count = 0
|
||
skipped_count = 0
|
||
failed_files = []
|
||
|
||
# 使用ProcessPoolExecutor而不是ThreadPoolExecutor
|
||
with ProcessPoolExecutor(max_workers=min(len(csv_files), cpu_count())) as executor:
|
||
# 提交所有任务
|
||
future_to_file = {executor.submit(process_single_file, file): file for file in csv_files}
|
||
|
||
# 处理完成的任务
|
||
for future in as_completed(future_to_file):
|
||
file_path = future_to_file[future]
|
||
try:
|
||
result = future.result()
|
||
file_path, success, message = result
|
||
if success:
|
||
print(f"{Fore.GREEN}✓ 完成: {os.path.basename(file_path)}")
|
||
success_count += 1
|
||
else:
|
||
if "跳过" in message:
|
||
print(f"{Fore.YELLOW}↷ 跳过: {os.path.basename(file_path)} - {message}")
|
||
skipped_count += 1
|
||
else:
|
||
print(f"{Fore.RED}✗ 失败: {os.path.basename(file_path)} - {message}")
|
||
failed_files.append((file_path, message))
|
||
except Exception as e:
|
||
print(f"{Fore.RED}✗ 异常: {os.path.basename(file_path)} - {str(e)}")
|
||
failed_files.append((file_path, str(e)))
|
||
|
||
# 输出统计信息
|
||
print(f"\n{Fore.CYAN}处理完成统计:")
|
||
print(f"{Fore.GREEN}成功: {success_count} 个文件")
|
||
print(f"{Fore.YELLOW}跳过: {skipped_count} 个文件(不包含IMU数据)")
|
||
print(f"{Fore.RED}失败: {len(failed_files)} 个文件")
|
||
|
||
if failed_files:
|
||
print(f"\n{Fore.YELLOW}失败文件详情:")
|
||
for file, error in failed_files:
|
||
print(f" {os.path.basename(file)}: {error}")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
main()
|
||
except KeyboardInterrupt:
|
||
print(f"\n{Fore.YELLOW}用户中断程序执行")
|
||
except Exception as e:
|
||
print(f"{Fore.RED}程序运行出错: {str(e)}")
|
||
import traceback
|
||
|
||
traceback.print_exc() |