Python脚本开发文件初始化
This commit is contained in:
648
FFT_IMU/FFT_IMU_dc_v1.py
Normal file
648
FFT_IMU/FFT_IMU_dc_v1.py
Normal file
@@ -0,0 +1,648 @@
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
import matplotlib
|
||||
|
||||
matplotlib.use('Agg')
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy import signal
|
||||
import os
|
||||
import glob
|
||||
from datetime import datetime
|
||||
import time
|
||||
from multiprocessing import Pool, cpu_count
|
||||
from matplotlib.colors import Normalize
|
||||
from matplotlib.ticker import MaxNLocator
|
||||
import re
|
||||
from colorama import Fore, Style, init
|
||||
from concurrent.futures import ProcessPoolExecutor, as_completed
|
||||
import warnings
|
||||
import threading
|
||||
|
||||
# 初始化colorama
|
||||
init(autoreset=True)
|
||||
|
||||
# 忽略特定的matplotlib警告
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
|
||||
warnings.filterwarnings("ignore", category=FutureWarning, module="matplotlib")
|
||||
|
||||
# 创建线程锁,确保文件操作和日志输出的线程安全
|
||||
file_lock = threading.Lock()
|
||||
log_lock = threading.Lock()
|
||||
|
||||
|
||||
class IMUDataAnalyzer:
|
||||
def __init__(self, file_path):
|
||||
self.file_path = file_path
|
||||
self.data = None
|
||||
self.sampling_rate = None
|
||||
self.fig_size = (15, 10)
|
||||
|
||||
# 从文件名推断数据类型和采样率
|
||||
file_name = os.path.basename(file_path).lower()
|
||||
if 'calib' in file_name:
|
||||
self.data_type = 'calib'
|
||||
self.default_sampling_rate = 5
|
||||
elif 'raw' in file_name:
|
||||
self.data_type = 'raw'
|
||||
self.default_sampling_rate = 1000
|
||||
else:
|
||||
self.data_type = 'unknown'
|
||||
self.default_sampling_rate = 5
|
||||
|
||||
# 解析文件路径和文件名
|
||||
file_dir = os.path.dirname(os.path.abspath(file_path))
|
||||
file_base_name = os.path.splitext(os.path.basename(file_path))[0]
|
||||
self.timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
|
||||
# 创建文件名称+时间戳尾缀的输出目录
|
||||
self.output_dir = os.path.join(file_dir, f"{file_base_name}_output_{self.timestamp}")
|
||||
|
||||
# 使用锁确保目录创建的线程安全
|
||||
with file_lock:
|
||||
if not os.path.exists(self.output_dir):
|
||||
os.makedirs(self.output_dir)
|
||||
self.log_progress(f"创建输出目录:{self.output_dir}", "INFO")
|
||||
|
||||
# 字体设置
|
||||
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'Arial']
|
||||
plt.rcParams['axes.unicode_minus'] = False
|
||||
|
||||
# 设置matplotlib兼容性选项,避免布局引擎冲突
|
||||
plt.rcParams['figure.constrained_layout.use'] = False
|
||||
plt.rcParams['figure.constrained_layout.h_pad'] = 0.02
|
||||
plt.rcParams['figure.constrained_layout.w_pad'] = 0.02
|
||||
plt.rcParams['figure.constrained_layout.hspace'] = 0.02
|
||||
plt.rcParams['figure.constrained_layout.wspace'] = 0.02
|
||||
|
||||
self.log_progress(f"处理文件:{self.file_path}", "INFO")
|
||||
self.log_progress(f"数据类型:{self.data_type}", "INFO")
|
||||
self.log_progress(f"输出路径:{self.output_dir}", "INFO")
|
||||
|
||||
def log_progress(self, message, level="INFO"):
|
||||
"""带颜色和级别的日志输出(线程安全)"""
|
||||
timestamp = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
|
||||
with log_lock:
|
||||
if level == "INFO":
|
||||
print(f"{Fore.CYAN}[{timestamp}] {Fore.GREEN}{message}")
|
||||
elif level == "WARNING":
|
||||
print(f"{Fore.CYAN}[{timestamp}] {Fore.YELLOW}警告: {message}")
|
||||
elif level == "ERROR":
|
||||
print(f"{Fore.CYAN}[{timestamp}] {Fore.RED}错误: {message}")
|
||||
elif level == "SUCCESS":
|
||||
print(f"{Fore.CYAN}[{timestamp}] {Fore.GREEN}✓ {message}")
|
||||
else:
|
||||
print(f"{Fore.CYAN}[{timestamp}] {message}")
|
||||
|
||||
def check_imu_columns_in_file(self):
|
||||
"""检查文件是否包含IMU数据列(通过读取文件头)"""
|
||||
try:
|
||||
# 只读取第一行来检查列名
|
||||
with open(self.file_path, 'r', encoding='utf-8', errors='ignore') as f:
|
||||
first_line = f.readline().strip()
|
||||
|
||||
# 检查第一行是否包含imu关键词(不区分大小写)
|
||||
if re.search(r'imu', first_line, re.IGNORECASE):
|
||||
return True
|
||||
else:
|
||||
self.log_progress(f"文件头部不包含'imu'关键词,跳过处理,first_line {first_line}", "WARNING")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
self.log_progress(f"检查文件头部时出错: {str(e)}", "ERROR")
|
||||
return False
|
||||
|
||||
def detect_imu_columns(self):
|
||||
"""自动检测IMU数据列"""
|
||||
all_columns = self.data.columns.tolist()
|
||||
|
||||
# 查找imu前缀(如imu1, imu2等)
|
||||
imu_prefixes = set()
|
||||
for col in all_columns:
|
||||
match = re.match(r'^(imu\d+)_', col.lower())
|
||||
if match:
|
||||
imu_prefixes.add(match.group(1))
|
||||
|
||||
if not imu_prefixes:
|
||||
self.log_progress("未检测到IMU数据列,尝试使用默认列名", "WARNING")
|
||||
# 尝试使用常见列名
|
||||
self.acc_columns = ['imu1_acc_x', 'imu1_acc_y', 'imu1_acc_z']
|
||||
self.gyro_columns = ['imu1_gyro_x', 'imu1_gyro_y', 'imu1_gyro_z']
|
||||
self.temp_columns = ['imu1_temp']
|
||||
return
|
||||
|
||||
# 使用第一个检测到的IMU前缀
|
||||
imu_prefix = list(imu_prefixes)[0]
|
||||
self.log_progress(f"检测到IMU前缀: {imu_prefix}", "INFO")
|
||||
|
||||
# 查找加速度计列
|
||||
self.acc_columns = [col for col in all_columns
|
||||
if col.lower().startswith(f"{imu_prefix}_acc") and
|
||||
any(axis in col.lower() for axis in ['_x', '_y', '_z'])]
|
||||
|
||||
# 查找陀螺仪列
|
||||
self.gyro_columns = [col for col in all_columns
|
||||
if col.lower().startswith(f"{imu_prefix}_gyro") and
|
||||
any(axis in col.lower() for axis in ['_x', '_y', '_z'])]
|
||||
|
||||
# 查找温度列
|
||||
self.temp_columns = [col for col in all_columns
|
||||
if col.lower().startswith(f"{imu_prefix}_temp")]
|
||||
|
||||
# 如果没有找到温度列,尝试其他常见名称
|
||||
if not self.temp_columns:
|
||||
self.temp_columns = [col for col in all_columns
|
||||
if any(name in col.lower() for name in ['temp', 'temperature'])]
|
||||
|
||||
self.log_progress(f"加速度计列: {self.acc_columns}", "INFO")
|
||||
self.log_progress(f"陀螺仪列: {self.gyro_columns}", "INFO")
|
||||
self.log_progress(f"温度列: {self.temp_columns}", "INFO")
|
||||
|
||||
def estimate_sampling_rate(self):
|
||||
"""估计实际采样率"""
|
||||
if 'time' in self.data.columns and len(self.data) > 10:
|
||||
time_diff = np.diff(self.data['time'].values)
|
||||
valid_diffs = time_diff[(time_diff > 0) & (time_diff < 10)] # 排除异常值
|
||||
if len(valid_diffs) > 0:
|
||||
estimated_rate = 1.0 / np.median(valid_diffs)
|
||||
self.log_progress(f"根据时间戳估计的采样率: {estimated_rate:.2f} Hz")
|
||||
return estimated_rate
|
||||
|
||||
# 如果没有时间列或无法估计,使用基于文件名的默认值
|
||||
self.log_progress(f"使用基于文件名的默认采样率: {self.default_sampling_rate} Hz")
|
||||
return self.default_sampling_rate
|
||||
|
||||
def load_data(self):
|
||||
"""加载并预处理数据"""
|
||||
self.log_progress("开始加载数据...")
|
||||
start_time = time.time()
|
||||
|
||||
# 首先检查文件是否包含IMU数据
|
||||
if not self.check_imu_columns_in_file():
|
||||
raise ValueError("文件不包含IMU数据列,跳过处理")
|
||||
|
||||
# 使用锁确保文件读取的线程安全
|
||||
with file_lock:
|
||||
self.data = pd.read_csv(self.file_path)
|
||||
|
||||
self.log_progress(f"数据加载完成,共 {len(self.data)} 行,耗时 {time.time() - start_time:.2f}秒")
|
||||
|
||||
# 检测IMU数据列
|
||||
self.detect_imu_columns()
|
||||
|
||||
# 估计采样率
|
||||
self.sampling_rate = self.estimate_sampling_rate()
|
||||
|
||||
# 创建时间序列并处理异常时间值
|
||||
if 'time' in self.data.columns:
|
||||
valid_time_mask = (self.data['time'] > 0) & (self.data['time'] < 1e6)
|
||||
self.data = self.data[valid_time_mask].copy()
|
||||
self.data['time'] = np.arange(len(self.data)) / self.sampling_rate
|
||||
else:
|
||||
# 如果没有时间列,创建基于采样率的时间序列
|
||||
self.data['time'] = np.arange(len(self.data)) / self.sampling_rate
|
||||
|
||||
def remove_dc(self, signal_data):
|
||||
"""不移除直流分量(保留以在频谱中显示 DC)"""
|
||||
return signal_data
|
||||
|
||||
# def compute_spectrogram(self, signal_data):
|
||||
# """计算频谱图(保留直流分量)"""
|
||||
# # 保留直流分量
|
||||
# signal_data = self.remove_dc(signal_data)
|
||||
#
|
||||
# # 自适应窗口大小 - 根据采样率调整
|
||||
# if self.sampling_rate <= 10: # 低采样率
|
||||
# nperseg = min(64, max(16, len(signal_data) // 4))
|
||||
# else: # 高采样率
|
||||
# nperseg = min(1024, max(64, len(signal_data) // 8))
|
||||
#
|
||||
# noverlap = nperseg // 2
|
||||
#
|
||||
# f, t, Sxx = signal.spectrogram(
|
||||
# signal_data,
|
||||
# fs=self.sampling_rate,
|
||||
# window='hann',
|
||||
# nperseg=nperseg,
|
||||
# noverlap=noverlap,
|
||||
# scaling='density',
|
||||
# detrend=False, # 保留直流
|
||||
# mode='psd' # 更高效的模式
|
||||
# )
|
||||
# return f, t, Sxx
|
||||
|
||||
def compute_spectrogram(self, signal_data):
|
||||
"""计算频谱图(保留直流分量),优化频谱分辨率和减少颗粒感"""
|
||||
# 保留直流分量
|
||||
signal_data = self.remove_dc(signal_data)
|
||||
|
||||
# 数据长度
|
||||
n_samples = len(signal_data)
|
||||
|
||||
# 根据采样率和数据长度自适应选择参数
|
||||
if self.sampling_rate <= 10: # 低采样率(5Hz)
|
||||
# 对于低采样率,使用较长的窗口以获得更好的频率分辨率
|
||||
nperseg = min(256, max(64, n_samples // 2))
|
||||
noverlap = int(nperseg * 0.75) # 增加重叠比例
|
||||
|
||||
else: # 高采样率(1000Hz)
|
||||
# 对于高采样率,平衡时间分辨率和频率分辨率
|
||||
if n_samples < 10000: # 较短的数据
|
||||
nperseg = min(512, max(256, n_samples // 4))
|
||||
else: # 较长的数据
|
||||
nperseg = min(1024, max(512, n_samples // 8))
|
||||
|
||||
noverlap = int(nperseg * 0.66) # 适中的重叠比例
|
||||
|
||||
# 确保窗口大小合理
|
||||
nperseg = max(16, min(nperseg, n_samples))
|
||||
noverlap = min(noverlap, nperseg - 1)
|
||||
|
||||
# 使用更平滑的窗口函数
|
||||
f, t, Sxx = signal.spectrogram(
|
||||
signal_data,
|
||||
fs=self.sampling_rate,
|
||||
window='hamming', # 使用汉明窗,比汉宁窗更平滑
|
||||
nperseg=nperseg,
|
||||
noverlap=noverlap,
|
||||
scaling='density',
|
||||
# detrend='linear', # 使用线性去趋势,减少低频干扰
|
||||
detrend=False, # 保留直流
|
||||
mode='psd'
|
||||
)
|
||||
|
||||
# 应用平滑处理以减少颗粒感
|
||||
if Sxx.size > 0:
|
||||
# 使用小范围的高斯滤波平滑(可选)
|
||||
from scipy.ndimage import gaussian_filter
|
||||
Sxx_smoothed = gaussian_filter(Sxx, sigma=0.7)
|
||||
return f, t, Sxx_smoothed
|
||||
|
||||
return f, t, Sxx
|
||||
|
||||
def process_signal(self, args):
|
||||
"""并行处理单个信号"""
|
||||
signal_data, axis = args
|
||||
f, t, Sxx = self.compute_spectrogram(signal_data)
|
||||
|
||||
# 防止 log10(0)
|
||||
eps = np.finfo(float).eps
|
||||
Sxx_log = 10 * np.log10(Sxx + eps)
|
||||
|
||||
# 降采样以加速绘图
|
||||
if len(t) > 1000: # 如果时间点太多,进行降采样
|
||||
time_indices = np.linspace(0, len(t) - 1, 1000, dtype=int)
|
||||
freq_indices = np.linspace(0, len(f) - 1, 500, dtype=int)
|
||||
t = t[time_indices]
|
||||
f = f[freq_indices]
|
||||
Sxx_log = Sxx_log[freq_indices, :][:, time_indices]
|
||||
dc_idx = int(np.argmin(np.abs(f - 0.0)))
|
||||
dc_log = Sxx_log[dc_idx, :] # shape: (len(t),)
|
||||
|
||||
# 更健壮的 0 Hz 索引选择
|
||||
zero_idx = np.where(np.isclose(f, 0.0))[0]
|
||||
dc_idx = int(zero_idx[0]) if len(zero_idx) > 0 else int(np.argmin(np.abs(f - 0.0)))
|
||||
dc_log = Sxx_log[dc_idx, :] # 每个时间窗的 0 Hz PSD(dB)
|
||||
|
||||
return {
|
||||
'f': f,
|
||||
't': t,
|
||||
'Sxx_log': Sxx_log,
|
||||
'dc_log': dc_log,
|
||||
'axis': axis
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def robust_dc_ylim(results, p_low=5, p_high=95, pad_ratio=0.05, fallback=(0.0, 1.0)):
|
||||
"""
|
||||
计算统一 DC 纵轴范围(分位数 + 少许边距),并过滤 inf/NaN
|
||||
"""
|
||||
if not results:
|
||||
return fallback
|
||||
dc_all = np.concatenate([r['dc_log'].ravel() for r in results])
|
||||
dc_all = dc_all[np.isfinite(dc_all)]
|
||||
if dc_all.size == 0:
|
||||
return fallback
|
||||
lo, hi = np.percentile(dc_all, [p_low, p_high])
|
||||
span = max(1e-9, hi - lo)
|
||||
lo -= span * pad_ratio
|
||||
hi += span * pad_ratio
|
||||
return lo, hi
|
||||
|
||||
def plot_time_series(self):
|
||||
"""绘制时间序列图"""
|
||||
self.log_progress("开始绘制时间序列图...")
|
||||
start_time = time.time()
|
||||
|
||||
# 确定子图数量
|
||||
n_plots = 1 # 至少有一个加速度图
|
||||
if self.gyro_columns: # 如果有陀螺仪数据
|
||||
n_plots += 1
|
||||
if self.temp_columns: # 如果有温度数据
|
||||
n_plots += 1
|
||||
|
||||
fig, axes = plt.subplots(n_plots, 1, figsize=(12, 3 * n_plots), dpi=120)
|
||||
if n_plots == 1:
|
||||
axes = [axes] # 确保axes是列表
|
||||
|
||||
plot_idx = 0
|
||||
|
||||
# 加速度计数据
|
||||
if self.acc_columns:
|
||||
ax = axes[plot_idx]
|
||||
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
|
||||
labels = ['X', 'Y', 'Z']
|
||||
for i, col in enumerate(self.acc_columns):
|
||||
if i < 3: # 只绘制前三个轴
|
||||
ax.plot(self.data['time'], self.data[col],
|
||||
label=labels[i], color=colors[i], linewidth=1.0, alpha=0.8)
|
||||
ax.set_title('加速度时间序列', fontsize=12)
|
||||
ax.set_ylabel('加速度 (g)', fontsize=10)
|
||||
ax.legend(loc='upper right', fontsize=8, framealpha=0.5)
|
||||
ax.grid(True, linestyle=':', alpha=0.5)
|
||||
ax.set_xlim(0, self.data['time'].max())
|
||||
plot_idx += 1
|
||||
|
||||
# 陀螺仪数据(如果有)
|
||||
if self.gyro_columns and plot_idx < n_plots:
|
||||
ax = axes[plot_idx]
|
||||
colors = ['#1f77b4', '#ff7f0e', '#2ca02c']
|
||||
labels = ['X', 'Y', 'Z']
|
||||
for i, col in enumerate(self.gyro_columns):
|
||||
if i < 3: # 只绘制前三个轴
|
||||
ax.plot(self.data['time'], self.data[col],
|
||||
label=labels[i], color=colors[i], linewidth=1.0, alpha=0.8)
|
||||
ax.set_title('陀螺仪时间序列', fontsize=12)
|
||||
ax.set_ylabel('角速度 (deg/s)', fontsize=10)
|
||||
ax.legend(loc='upper left', fontsize=8, framealpha=0.5)
|
||||
ax.grid(True, linestyle=':', alpha=0.5)
|
||||
ax.set_xlim(0, self.data['time'].max())
|
||||
plot_idx += 1
|
||||
|
||||
# 温度数据(如果有)
|
||||
if self.temp_columns and plot_idx < n_plots:
|
||||
ax = axes[plot_idx]
|
||||
ax.plot(self.data['time'], self.data[self.temp_columns[0]],
|
||||
label='温度', color='#9467bd', linewidth=1.0, alpha=0.8)
|
||||
ax.set_title('温度时间序列', fontsize=12)
|
||||
ax.set_xlabel('时间 (s)', fontsize=10)
|
||||
ax.set_ylabel('温度 (°C)', fontsize=10)
|
||||
ax.legend(loc='upper right', fontsize=8, framealpha=0.5)
|
||||
ax.grid(True, linestyle=':', alpha=0.5)
|
||||
ax.set_xlim(0, self.data['time'].max())
|
||||
|
||||
plt.tight_layout()
|
||||
output_path = os.path.join(self.output_dir, f'time_series_{self.timestamp}.png')
|
||||
plt.savefig(output_path, bbox_inches='tight', dpi=150)
|
||||
plt.close(fig)
|
||||
self.log_progress(f"时间序列图已保存: {output_path}")
|
||||
self.log_progress(f"时间序列图已保存为 {output_path},耗时 {time.time() - start_time:.2f}秒")
|
||||
|
||||
def plot_rainfall_spectrograms(self):
|
||||
"""并行绘制所有频谱雨点图(修复colorbar布局问题)"""
|
||||
self.log_progress("开始并行绘制频谱雨点图...")
|
||||
start_time = time.time()
|
||||
|
||||
# 准备加速度计数据
|
||||
self.log_progress("准备加速度计数据...")
|
||||
acc_signals = [(self.data[col], f'Acc {["X", "Y", "Z"][i]}')
|
||||
for i, col in enumerate(self.acc_columns) if i < 3] # 只处理前三个轴
|
||||
|
||||
# 准备陀螺仪数据(如果有)
|
||||
gyro_signals = []
|
||||
if self.gyro_columns:
|
||||
self.log_progress("准备陀螺仪数据...")
|
||||
gyro_signals = [(self.data[col], f'Gyro {["X", "Y", "Z"][i]}')
|
||||
for i, col in enumerate(self.gyro_columns) if i < 3] # 只处理前三个轴
|
||||
|
||||
# 如果没有数据可处理,直接返回
|
||||
if not acc_signals and not gyro_signals:
|
||||
self.log_progress("警告: 没有有效的数据列可供处理", "WARNING")
|
||||
return
|
||||
|
||||
# 使用多进程处理信号(避免线程冲突)
|
||||
self.log_progress("使用多进程并行处理...")
|
||||
all_signals = acc_signals + gyro_signals
|
||||
with Pool(processes=min(len(all_signals), cpu_count())) as pool:
|
||||
results = pool.map(self.process_signal, all_signals)
|
||||
|
||||
# 分离结果
|
||||
self.log_progress("分离结果...")
|
||||
acc_results = [r for r in results if r['axis'].startswith('Acc')]
|
||||
gyro_results = [r for r in results if r['axis'].startswith('Gyro')]
|
||||
|
||||
# 统一颜色标尺(5%-95%分位)
|
||||
if acc_results:
|
||||
self.log_progress("计算加速度计全局最小和最大值...")
|
||||
acc_all_Sxx = np.concatenate([r['Sxx_log'].ravel() for r in acc_results])
|
||||
acc_vmin, acc_vmax = np.percentile(acc_all_Sxx, [5, 95])
|
||||
|
||||
# 统一 DC Y 轴范围
|
||||
acc_dc_ymin, acc_dc_ymax = self.robust_dc_ylim(acc_results)
|
||||
self.log_progress(f"加速度 DC (dB) 范围: {acc_dc_ymin:.1f} 到 {acc_dc_ymax:.1f}")
|
||||
|
||||
if gyro_results:
|
||||
self.log_progress("计算陀螺仪全局最小和最大值...")
|
||||
gyro_all_Sxx = np.concatenate([r['Sxx_log'].ravel() for r in gyro_results])
|
||||
gyro_vmin, gyro_vmax = np.percentile(gyro_all_Sxx, [5, 95])
|
||||
|
||||
# 统一 DC Y 轴范围
|
||||
gyro_dc_ymin, gyro_dc_ymax = self.robust_dc_ylim(gyro_results)
|
||||
self.log_progress(f"陀螺仪 DC (dB) 范围: {gyro_dc_ymin:.1f} 到 {gyro_dc_ymax:.1f}")
|
||||
|
||||
# ========= 绘制加速度计频谱雨点图 =========
|
||||
if acc_results:
|
||||
self._plot_single_spectrogram(acc_results, acc_vmin, acc_vmax, acc_dc_ymin, acc_dc_ymax,
|
||||
'加速度', 'acc_rainfall_spectrogram')
|
||||
self.log_progress(f"加速度功率谱密度范围: {acc_vmin:.1f} dB 到 {acc_vmax:.1f} dB")
|
||||
|
||||
# ========= 绘制陀螺仪频谱雨点图 =========
|
||||
if gyro_results:
|
||||
self._plot_single_spectrogram(gyro_results, gyro_vmin, gyro_vmax, gyro_dc_ymin, gyro_dc_ymax,
|
||||
'角速度', 'gyro_rainfall_spectrogram')
|
||||
self.log_progress(f"陀螺仪功率谱密度范围: {gyro_vmin:.1f} dB 到 {gyro_vmax:.1f} dB")
|
||||
|
||||
total_time = time.time() - start_time
|
||||
self.log_progress(f"频谱雨点图生成完成,总耗时 {total_time:.2f}秒")
|
||||
|
||||
def _plot_single_spectrogram(self, results, vmin, vmax, dc_ymin, dc_ymax, title_prefix, filename_prefix):
|
||||
"""绘制单个频谱雨点图"""
|
||||
rows = len(results)
|
||||
fig = plt.figure(constrained_layout=True, figsize=(14, 4 * rows), dpi=150)
|
||||
gs = fig.add_gridspec(nrows=rows, ncols=2, width_ratios=[22, 1], wspace=0.05, hspace=0.12)
|
||||
|
||||
axes_main = []
|
||||
axes_cbar = []
|
||||
for i in range(rows):
|
||||
axes_main.append(fig.add_subplot(gs[i, 0]))
|
||||
axes_cbar.append(fig.add_subplot(gs[i, 1]))
|
||||
|
||||
for i, result in enumerate(results):
|
||||
ax = axes_main[i]
|
||||
cax = axes_cbar[i]
|
||||
|
||||
sc = ax.scatter(
|
||||
np.repeat(result['t'], len(result['f'])),
|
||||
np.tile(result['f'], len(result['t'])),
|
||||
c=result['Sxx_log'].T.ravel(),
|
||||
cmap='jet',
|
||||
s=3,
|
||||
alpha=0.7,
|
||||
vmin=vmin,
|
||||
vmax=vmax,
|
||||
rasterized=True
|
||||
)
|
||||
|
||||
ax.set_title(f'{title_prefix}频谱雨点图 - {result["axis"][-1]}(右侧为DC分量 dB)', fontsize=10)
|
||||
ax.set_xlabel('时间 (s)', fontsize=9)
|
||||
ax.set_ylabel('频率 (Hz)', fontsize=9)
|
||||
ax.set_ylim(0, self.sampling_rate / 2)
|
||||
ax.grid(True, linestyle=':', alpha=0.4)
|
||||
|
||||
ax2 = ax.twinx()
|
||||
ax2.plot(result['t'], result['dc_log'], color='black', linewidth=1.2, alpha=0.85, label='DC (dB)')
|
||||
ax2.set_ylabel('直流分量 (dB)', fontsize=9, color='black')
|
||||
ax2.set_ylim(dc_ymin, dc_ymax)
|
||||
ax2.tick_params(axis='y', labelcolor='black')
|
||||
ax2.yaxis.set_major_locator(MaxNLocator(nbins=6))
|
||||
ax2.grid(False)
|
||||
ax2.legend(loc='upper right', fontsize=8, framealpha=0.5)
|
||||
|
||||
cbar = fig.colorbar(sc, cax=cax)
|
||||
cbar.set_label('功率谱密度 (dB)', fontsize=9)
|
||||
cax.tick_params(labelsize=8)
|
||||
|
||||
output_path = os.path.join(self.output_dir, f'{filename_prefix}_{self.timestamp}.png')
|
||||
plt.savefig(output_path, bbox_inches='tight', dpi=150)
|
||||
plt.close(fig)
|
||||
self.log_progress(f"{title_prefix}频谱雨点图已保存为 {output_path}")
|
||||
|
||||
def run_analysis(self):
|
||||
"""运行完整分析流程"""
|
||||
try:
|
||||
self.log_progress("开始数据分析流程", "INFO")
|
||||
start_time = time.time()
|
||||
|
||||
self.load_data()
|
||||
self.plot_time_series()
|
||||
self.plot_rainfall_spectrograms()
|
||||
|
||||
total_time = time.time() - start_time
|
||||
self.log_progress(f"分析完成,总耗时 {total_time:.2f}秒", "SUCCESS")
|
||||
self.log_progress(f"所有输出文件已保存到: {self.output_dir}", "INFO")
|
||||
return True
|
||||
|
||||
except ValueError as e:
|
||||
# 跳过不包含IMU数据的文件
|
||||
self.log_progress(f"跳过文件: {str(e)}", "WARNING")
|
||||
return False
|
||||
except Exception as e:
|
||||
self.log_progress(f"分析过程中出现错误: {str(e)}", "ERROR")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return False
|
||||
|
||||
|
||||
def process_single_file(file_path):
|
||||
"""处理单个文件的函数(使用进程隔离)"""
|
||||
try:
|
||||
print(f"{Fore.BLUE}开始处理文件: {os.path.basename(file_path)}")
|
||||
analyzer = IMUDataAnalyzer(file_path)
|
||||
success = analyzer.run_analysis()
|
||||
if success:
|
||||
return (file_path, True, "处理成功")
|
||||
else:
|
||||
return (file_path, False, "文件不包含IMU数据,已跳过")
|
||||
except Exception as e:
|
||||
return (file_path, False, str(e))
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,支持多文件处理和进度显示"""
|
||||
print("=" * 60)
|
||||
print(f"{Fore.CYAN}IMU数据频谱分析工具 - 多文件批量处理")
|
||||
print("=" * 60)
|
||||
|
||||
# 获取输入路径
|
||||
print(f"{Fore.WHITE}请输入包含CSV文件的目录路径: ")
|
||||
input_path = input("> ").strip()
|
||||
|
||||
if not os.path.exists(input_path):
|
||||
print(f"{Fore.RED}错误: 路径 '{input_path}' 不存在!")
|
||||
return
|
||||
|
||||
# 查找所有包含imu的CSV文件(不区分大小写)
|
||||
if os.path.isdir(input_path):
|
||||
# 使用单个glob模式匹配所有文件,然后过滤包含imu的文件
|
||||
all_csv_files = glob.glob(os.path.join(input_path, "**", "*.csv"), recursive=True)
|
||||
csv_files = [f for f in all_csv_files if re.search(r'imu', f, re.IGNORECASE)]
|
||||
csv_files = list(set(csv_files)) # 去重
|
||||
csv_files.sort()
|
||||
else:
|
||||
# 对于单个文件,检查是否包含imu(不区分大小写)
|
||||
if re.search(r'imu', input_path, re.IGNORECASE):
|
||||
csv_files = [input_path]
|
||||
else:
|
||||
csv_files = []
|
||||
|
||||
if not csv_files:
|
||||
print(f"{Fore.YELLOW}警告: 未找到包含'imu'的CSV文件")
|
||||
return
|
||||
|
||||
print(f"{Fore.GREEN}找到 {len(csv_files)} 个IMU数据文件:")
|
||||
for i, file in enumerate(csv_files, 1):
|
||||
print(f" {i}. {os.path.basename(file)}")
|
||||
|
||||
# 使用多进程处理文件(避免matplotlib线程冲突)
|
||||
print(f"\n{Fore.CYAN}开始多线程处理文件 (使用 {min(len(csv_files), cpu_count())} 个线程)...")
|
||||
|
||||
success_count = 0
|
||||
skipped_count = 0
|
||||
failed_files = []
|
||||
|
||||
# 使用ProcessPoolExecutor而不是ThreadPoolExecutor
|
||||
with ProcessPoolExecutor(max_workers=min(len(csv_files), cpu_count())) as executor:
|
||||
# 提交所有任务
|
||||
future_to_file = {executor.submit(process_single_file, file): file for file in csv_files}
|
||||
|
||||
# 处理完成的任务
|
||||
for future in as_completed(future_to_file):
|
||||
file_path = future_to_file[future]
|
||||
try:
|
||||
result = future.result()
|
||||
file_path, success, message = result
|
||||
if success:
|
||||
print(f"{Fore.GREEN}✓ 完成: {os.path.basename(file_path)}")
|
||||
success_count += 1
|
||||
else:
|
||||
if "跳过" in message:
|
||||
print(f"{Fore.YELLOW}↷ 跳过: {os.path.basename(file_path)} - {message}")
|
||||
skipped_count += 1
|
||||
else:
|
||||
print(f"{Fore.RED}✗ 失败: {os.path.basename(file_path)} - {message}")
|
||||
failed_files.append((file_path, message))
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}✗ 异常: {os.path.basename(file_path)} - {str(e)}")
|
||||
failed_files.append((file_path, str(e)))
|
||||
|
||||
# 输出统计信息
|
||||
print(f"\n{Fore.CYAN}处理完成统计:")
|
||||
print(f"{Fore.GREEN}成功: {success_count} 个文件")
|
||||
print(f"{Fore.YELLOW}跳过: {skipped_count} 个文件(不包含IMU数据)")
|
||||
print(f"{Fore.RED}失败: {len(failed_files)} 个文件")
|
||||
|
||||
if failed_files:
|
||||
print(f"\n{Fore.YELLOW}失败文件详情:")
|
||||
for file, error in failed_files:
|
||||
print(f" {os.path.basename(file)}: {error}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
main()
|
||||
except KeyboardInterrupt:
|
||||
print(f"\n{Fore.YELLOW}用户中断程序执行")
|
||||
except Exception as e:
|
||||
print(f"{Fore.RED}程序运行出错: {str(e)}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
Reference in New Issue
Block a user