From c6e930ddf27d4a1c102b64423a2fa33ba49512eb Mon Sep 17 00:00:00 2001 From: nanako <469449812@qq.com> Date: Mon, 3 Nov 2025 21:45:01 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E7=A0=81=E6=A0=BC=E5=BC=8F=E6=95=B4?= =?UTF-8?q?=E7=90=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/engine/main.cpp | 13 +- src/engine/rpc/engine_rpc.cpp | 13 +- src/host_sandbox/main.cpp | 5 +- src/host_sandbox/rpc/host_rpc.cpp | 7 +- src/host_sandbox/sandbox_grpc_log_sink.cpp | 12 +- src/host_sandbox/sandbox_grpc_log_sink.h | 5 +- src/misc/lazy_singleton.h | 9 +- src/misc/linux/thread_tool.cpp | 14 +- src/misc/logger.cpp | 16 +- src/misc/logger.h | 146 +- src/misc/type_util.h | 150 +- src/misc/windows/thread_tool.cpp | 36 +- src/network/rpc/engine_rpc.h | 8 +- src/network/rpc/host_rpc.h | 8 +- .../shm/interprocess_synchronization.cpp | 24 +- .../shm/interprocess_synchronization.h | 102 +- src/network/shm/lock_free_ring_buffer.h | 63 +- src/network/shm/shared_memory_manager.cpp | 19 +- src/network/shm/shared_memory_manager.h | 26 +- src/network/shm/triple_buffer.h | 55 +- src/network/transport/zmq_client.cpp | 166 +- src/network/transport/zmq_client.h | 136 +- src/network/transport/zmq_client_processor.h | 76 +- src/network/transport/zmq_server.cpp | 106 +- src/network/transport/zmq_server.h | 120 +- src/network/transport/zmq_server_processor.h | 78 +- src/network/transport/zmq_util.h | 33 +- src/simd/aligned_allocator.h | 83 +- .../arm_simd_audio_processing_func.cpp | 398 ++-- .../arm_simd_audio_processing_func.h | 13 +- .../scalar_audio_processing_func.cpp | 96 +- .../scalar_audio_processing_func.h | 46 +- .../x86_simd_audio_processing_func.cpp | 1545 ++++++------ .../x86_simd_audio_processing_func.h | 27 +- src/simd/cpu_features.cpp | 276 +-- src/simd/cpu_features.h | 4 +- src/simd/simd_func_dispatcher.cpp | 4 +- src/simd/simd_func_dispatcher.h | 28 +- tests/network/test_zmq_rpc.cpp | 1091 ++++----- tests/shm/helpers/custom_assertions.h | 446 ++-- tests/shm/helpers/data_generator.h | 697 +++--- tests/shm/helpers/leak_detector.h | 727 +++--- tests/shm/helpers/multiprocess_harness.h | 625 +++-- tests/shm/helpers/performance_timer.h | 910 +++---- tests/shm/helpers/test_environment.h | 444 ++-- tests/shm/test_interprocess_sync.cpp | 592 ++--- tests/shm/test_lock_free_ring_buffer.cpp | 540 ++--- tests/shm/test_shared_memory_manager.cpp | 566 ++--- tests/shm/test_triple_buffer.cpp | 708 +++--- tests/test_audio_processing_comprehensive.cpp | 1127 ++++----- tests/test_simd.cpp | 2119 +++++++++-------- 51 files changed, 7435 insertions(+), 7123 deletions(-) diff --git a/src/engine/main.cpp b/src/engine/main.cpp index 012c4f7..16cd8d2 100644 --- a/src/engine/main.cpp +++ b/src/engine/main.cpp @@ -3,13 +3,13 @@ int main(int argc, char* argv[]) { shared_memory_config shm_config{}; - shm_config.segment_name = "AlichoBackendSharedMemory"; - shm_config.segment_size = 128 * 1024 * 1024; // 128 MB + shm_config.segment_name = "AlichoBackendSharedMemory"; + shm_config.segment_size = 128 * 1024 * 1024; // 128 MB shm_config.create_if_not_exists = true; - shm_config.remove_on_destroy = true; - shm_config.mutex_name = "AlichoBackendSharedMemoryMutex"; - shm_config.condition_name = "AlichoBackendSharedMemoryCondition"; - shm_config.semaphore_name = "AlichoBackendSharedMemorySemaphore"; + shm_config.remove_on_destroy = true; + shm_config.mutex_name = "AlichoBackendSharedMemoryMutex"; + shm_config.condition_name = "AlichoBackendSharedMemoryCondition"; + shm_config.semaphore_name = "AlichoBackendSharedMemorySemaphore"; shared_memory_manager::instance().init(shm_config); zmq_server::instance().init(); @@ -23,4 +23,3 @@ int main(int argc, char* argv[]) { shared_memory_manager::instance().shutdown(); } - diff --git a/src/engine/rpc/engine_rpc.cpp b/src/engine/rpc/engine_rpc.cpp index 0af21a9..1ac4192 100644 --- a/src/engine/rpc/engine_rpc.cpp +++ b/src/engine/rpc/engine_rpc.cpp @@ -1,4 +1,3 @@ - #include "rpc/engine_rpc.h" #include @@ -7,10 +6,10 @@ #include "transport/zmq_server_processor.h" namespace engine_rpc { - ZMQ_SERVER_REGISTER_PROCESSOR(log_message_t) { - // TODO 获取插件名称作为module_name - std::string module_name = "engine_rpc"; - const auto level = static_cast(data.level); - log_module_log(module_name, level, "{}", data.message.c_str()); - } + ZMQ_SERVER_REGISTER_PROCESSOR(log_message_t) { + // TODO 获取插件名称作为module_name + std::string module_name = "engine_rpc"; + const auto level = static_cast(data.level); + log_module_log(module_name, level, "{}", data.message.c_str()); + } } diff --git a/src/host_sandbox/main.cpp b/src/host_sandbox/main.cpp index fffcf08..445e019 100644 --- a/src/host_sandbox/main.cpp +++ b/src/host_sandbox/main.cpp @@ -6,13 +6,14 @@ int main(int argc, char* argv[]) { // 获取命令行参数中的gRPC服务地址 - uint32_t client_id; { + uint32_t client_id; + { if (argc < 2) { std::cerr << "用法: host_sandbox client_id" << std::endl; return -1; } const std::string grpc_address = argv[1]; // 例如 "localhost:50051" - client_id = static_cast(std::stoul(argv[1])); + client_id = static_cast(std::stoul(argv[1])); } // 沙盒进程不需要保存日志到文件, 只在控制台输出 diff --git a/src/host_sandbox/rpc/host_rpc.cpp b/src/host_sandbox/rpc/host_rpc.cpp index 317687f..9da4052 100644 --- a/src/host_sandbox/rpc/host_rpc.cpp +++ b/src/host_sandbox/rpc/host_rpc.cpp @@ -1,16 +1,15 @@ - #include "rpc/host_rpc.h" #include "shm/shared_memory_manager.h" #include "transport/zmq_client_processor.h" namespace host_rpc { - ZMQ_CLIENT_REGISTER_PROCESSOR(setup_t) { + ZMQ_CLIENT_REGISTER_PROCESSOR(setup_t) { shared_memory_config shm_config{}; shm_config.segment_name = data.shm_name; - const auto result = shared_memory_manager::instance().init(shm_config); + const auto result = shared_memory_manager::instance().init(shm_config); if (result != shared_memory_error::SUCCESS) { throw std::runtime_error("在 host_rpc::setup_t 处理器中初始化共享内存失败。"); } - } + } } diff --git a/src/host_sandbox/sandbox_grpc_log_sink.cpp b/src/host_sandbox/sandbox_grpc_log_sink.cpp index 6e5a251..8108807 100644 --- a/src/host_sandbox/sandbox_grpc_log_sink.cpp +++ b/src/host_sandbox/sandbox_grpc_log_sink.cpp @@ -10,7 +10,8 @@ // 构造函数:初始化gRPC客户端连接 grpc_log_sink::grpc_log_sink() { try { - } catch (const std::exception& e) { + } + catch (const std::exception& e) { std::cerr << "Error initializing gRPC log sink: " << e.what() << std::endl; } } @@ -19,10 +20,11 @@ grpc_log_sink::grpc_log_sink() { void grpc_log_sink::sink_it_(const spdlog::details::log_msg& msg) { try { engine_rpc::log_message_t pack; - pack.level = msg.level; + pack.level = msg.level; pack.message = fmt::to_string(msg.payload); zmq_client::instance().send(pack); - } catch (const std::exception& e) { + } + catch (const std::exception& e) { std::cerr << "Exception in grpc_log_sink::sink_it_: " << e.what() << std::endl; } } @@ -36,10 +38,10 @@ void grpc_log_sink::flush_() { // 格式化日志消息 std::string grpc_log_sink::format_message(const spdlog::details::log_msg& msg) { // 使用spdlog的默认格式化器 - spdlog::memory_buf_t formatted; + spdlog::memory_buf_t formatted; spdlog::pattern_formatter formatter; formatter.format(msg, formatted); - + // 转换为字符串 return fmt::to_string(formatted); } diff --git a/src/host_sandbox/sandbox_grpc_log_sink.h b/src/host_sandbox/sandbox_grpc_log_sink.h index 5474b45..344daec 100644 --- a/src/host_sandbox/sandbox_grpc_log_sink.h +++ b/src/host_sandbox/sandbox_grpc_log_sink.h @@ -12,15 +12,16 @@ class grpc_log_sink : public spdlog::sinks::base_sink { public: // 构造函数:创建gRPC客户端连接到指定的服务器地址 explicit grpc_log_sink(); - + ~grpc_log_sink() override = default; protected: // 实现日志输出逻辑 void sink_it_(const spdlog::details::log_msg& msg) override; - + // 实现刷新逻辑 void flush_() override; + private: // 格式化日志消息 std::string format_message(const spdlog::details::log_msg& msg); diff --git a/src/misc/lazy_singleton.h b/src/misc/lazy_singleton.h index 587d101..33e12a6 100644 --- a/src/misc/lazy_singleton.h +++ b/src/misc/lazy_singleton.h @@ -1,6 +1,6 @@ #pragma once -template +template class lazy_singleton { public: static T& instance() { @@ -9,12 +9,13 @@ public: } // 禁止拷贝和赋值 - lazy_singleton(const lazy_singleton&) = delete; + lazy_singleton(const lazy_singleton&) = delete; lazy_singleton& operator=(const lazy_singleton&) = delete; // 禁止移动构造和移动赋值 - lazy_singleton(lazy_singleton&&) = delete; + lazy_singleton(lazy_singleton&&) = delete; lazy_singleton& operator=(lazy_singleton&&) = delete; + protected: - lazy_singleton() = default; + lazy_singleton() = default; virtual ~lazy_singleton() = default; }; diff --git a/src/misc/linux/thread_tool.cpp b/src/misc/linux/thread_tool.cpp index 6b9023e..ad65e68 100644 --- a/src/misc/linux/thread_tool.cpp +++ b/src/misc/linux/thread_tool.cpp @@ -1,13 +1,11 @@ #incldue "thread_tool.h" -bool thread_set_affinity(boost::thread& thread, int core_id) -{ - // Linux implementation can be added here - return false; // Placeholder +bool thread_set_affinity(boost::thread& thread, int core_id) { + // Linux implementation can be added here + return false; // Placeholder } -bool thread_set_name(boost::thread& thread, const char* name) -{ - // Linux implementation can be added here - return false; // Placeholder +bool thread_set_name(boost::thread& thread, const char* name) { + // Linux implementation can be added here + return false; // Placeholder } diff --git a/src/misc/logger.cpp b/src/misc/logger.cpp index d4e9871..a01641f 100644 --- a/src/misc/logger.cpp +++ b/src/misc/logger.cpp @@ -4,11 +4,11 @@ #include void logger::init(const std::string& app_name, - log_level log_level, - bool file_logging, - const std::string& log_dir, - size_t max_file_size, - size_t max_files) { + log_level log_level, + bool file_logging, + const std::string& log_dir, + size_t max_file_size, + size_t max_files) { if (initialized_) { return; } @@ -94,11 +94,11 @@ void logger::add_sink(const spdlog::sink_ptr& sink) { } logger::logger() { -#if ALICHO_DEBUG + #if ALICHO_DEBUG init("alicho backend", log_level::TRACE, true); -#else + #else init("alicho backend", log_level::INFO, true); -#endif + #endif } logger::~logger() { diff --git a/src/misc/logger.h b/src/misc/logger.h index dcf3bf6..538b2bd 100644 --- a/src/misc/logger.h +++ b/src/misc/logger.h @@ -15,25 +15,25 @@ #endif enum class log_level { - TRACE = spdlog::level::trace, - DEBUG = spdlog::level::debug, - INFO = spdlog::level::info, - WARN = spdlog::level::warn, - ERROR = spdlog::level::err, + TRACE = spdlog::level::trace, + DEBUG = spdlog::level::debug, + INFO = spdlog::level::info, + WARN = spdlog::level::warn, + ERROR = spdlog::level::err, CRITICAL = spdlog::level::critical, - OFF = spdlog::level::off + OFF = spdlog::level::off }; class logger : public lazy_singleton { public: friend class lazy_singleton; - void init(const std::string& app_name = "alicho backend", - log_level log_level = log_level::INFO, - bool file_logging = true, - const std::string& log_dir = "./logs", - size_t max_file_size = 5 * 1024 * 1024, - size_t max_files = 3); + void init(const std::string& app_name = "alicho backend", + log_level log_level = log_level::INFO, + bool file_logging = true, + const std::string& log_dir = "./logs", + size_t max_file_size = 5 * 1024 * 1024, + size_t max_files = 3); void shutdown(); @@ -67,45 +67,51 @@ public: } } - template - void log(log_level level, fmt::format_string fmt, Args&&... args) {; + template + void log(log_level level, fmt::format_string fmt, Args&&... args) { + ; if (logger_) { logger_->log(static_cast(level), fmt, std::forward(args)...); } } // 编译时格式字符串版本(用于字符串字面量) - template + template void trace(fmt::format_string fmt, Args&&... args) { if (logger_) { logger_->trace(fmt, std::forward(args)...); } } - template + + template void debug(fmt::format_string fmt, Args&&... args) { if (logger_) { logger_->debug(fmt, std::forward(args)...); } } - template + + template void info(fmt::format_string fmt, Args&&... args) { if (logger_) { logger_->info(fmt, std::forward(args)...); } } - template + + template void warn(fmt::format_string fmt, Args&&... args) { if (logger_) { logger_->warn(fmt, std::forward(args)...); } } - template + + template void error(fmt::format_string fmt, Args&&... args) { if (logger_) { logger_->error(fmt, std::forward(args)...); } } - template + + template void critical(fmt::format_string fmt, Args&&... args) { if (logger_) { logger_->critical(fmt, std::forward(args)...); @@ -113,173 +119,201 @@ public: } // 运行时格式字符串版本(用于动态字符串) - template + template void trace_runtime(const char* fmt, const Args&... args) { if (logger_) { logger_->trace(fmt::runtime(fmt), args...); } } - template + + template void debug_runtime(const char* fmt, const Args&... args) { if (logger_) { logger_->debug(fmt::runtime(fmt), args...); } } - template + + template void info_runtime(const char* fmt, const Args&... args) { if (logger_) { logger_->info(fmt::runtime(fmt), args...); } } - template + + template void warn_runtime(const char* fmt, const Args&... args) { if (logger_) { logger_->warn(fmt::runtime(fmt), args...); } } - template + + template void error_runtime(const char* fmt, const Args&... args) { if (logger_) { logger_->error(fmt::runtime(fmt), args...); } } - template + + template void critical_runtime(const char* fmt, const Args&... args) { if (logger_) { logger_->critical(fmt::runtime(fmt), args...); } } + private: logger(); virtual ~logger(); - auto create_logger(const std::string& logger_name) -> std::shared_ptr; + auto create_logger(const std::string& logger_name) -> std::shared_ptr; std::shared_ptr logger_; - spdlog::sink_ptr console_sink_; - spdlog::sink_ptr file_sink_; - bool initialized_ = false; + spdlog::sink_ptr console_sink_; + spdlog::sink_ptr file_sink_; + bool initialized_ = false; }; // 编译时格式字符串版本 -template +template void log_log(log_level level, fmt::format_string fmt, Args&&... args) { logger::instance().log(level, fmt, std::forward(args)...); } -template + +template void log_trace(fmt::format_string fmt, Args&&... args) { logger::instance().trace(fmt, std::forward(args)...); } -template + +template void log_debug(fmt::format_string fmt, Args&&... args) { logger::instance().debug(fmt, std::forward(args)...); } -template + +template void log_info(fmt::format_string fmt, Args&&... args) { logger::instance().info(fmt, std::forward(args)...); } -template + +template void log_warn(fmt::format_string fmt, Args&&... args) { logger::instance().warn(fmt, std::forward(args)...); } -template + +template void log_error(fmt::format_string fmt, Args&&... args) { logger::instance().error(fmt, std::forward(args)...); } -template + +template void log_critical(fmt::format_string fmt, Args&&... args) { logger::instance().critical(fmt, std::forward(args)...); } // 运行时格式字符串版本 -template +template void log_trace_runtime(const char* fmt, const Args&... args) { logger::instance().trace_runtime(fmt, args...); } -template + +template void log_debug_runtime(const char* fmt, const Args&... args) { logger::instance().debug_runtime(fmt, args...); } -template + +template void log_info_runtime(const char* fmt, const Args&... args) { logger::instance().info_runtime(fmt, args...); } -template + +template void log_warn_runtime(const char* fmt, const Args&... args) { logger::instance().warn_runtime(fmt, args...); } -template + +template void log_error_runtime(const char* fmt, const Args&... args) { logger::instance().error_runtime(fmt, args...); } -template + +template void log_critical_runtime(const char* fmt, const Args&... args) { logger::instance().critical_runtime(fmt, args...); } // 模块日志 - 编译时格式字符串版本 -template +template void log_module_log(const std::string& module_name, log_level level, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->log(static_cast(level), fmt, std::forward(args)...); } -template + +template void log_module_trace(const std::string& module_name, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->trace(fmt, std::forward(args)...); } -template + +template void log_module_debug(const std::string& module_name, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->debug(fmt, std::forward(args)...); } -template + +template void log_module_info(const std::string& module_name, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->info(fmt, std::forward(args)...); } -template + +template void log_module_warn(const std::string& module_name, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->warn(fmt, std::forward(args)...); } -template + +template void log_module_error(const std::string& module_name, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->error(fmt, std::forward(args)...); } -template + +template void log_module_critical(const std::string& module_name, fmt::format_string fmt, Args&&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->critical(fmt, std::forward(args)...); } // 模块日志 - 运行时格式字符串版本 -template +template void log_module_trace_runtime(const std::string& module_name, const char* fmt, const Args&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->trace(fmt::runtime(fmt), args...); } -template + +template void log_module_debug_runtime(const std::string& module_name, const char* fmt, const Args&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->debug(fmt::runtime(fmt), args...); } -template + +template void log_module_info_runtime(const std::string& module_name, const char* fmt, const Args&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->info(fmt::runtime(fmt), args...); } -template + +template void log_module_warn_runtime(const std::string& module_name, const char* fmt, const Args&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->warn(fmt::runtime(fmt), args...); } -template + +template void log_module_error_runtime(const std::string& module_name, const char* fmt, const Args&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->error(fmt::runtime(fmt), args...); } -template + +template void log_module_critical_runtime(const std::string& module_name, const char* fmt, const Args&... args) { auto module_logger = logger::instance().get_module_logger(module_name); module_logger->critical(fmt::runtime(fmt), args...); diff --git a/src/misc/type_util.h b/src/misc/type_util.h index 56e8b10..f0ab7a4 100644 --- a/src/misc/type_util.h +++ b/src/misc/type_util.h @@ -4,93 +4,93 @@ #include namespace type_util { - template - struct fixed_string { - char data[N]{}; - size_t length = N - 1; + template + struct fixed_string { + char data[N]{}; + size_t length = N - 1; - constexpr fixed_string(const char (&str)[N]) { - for (size_t i = 0; i < N; ++i) { - data[i] = str[i]; - } - } + constexpr fixed_string(const char (&str)[N]) { + for (size_t i = 0; i < N; ++i) { + data[i] = str[i]; + } + } - constexpr operator std::string_view() const { - return {data, length}; - } - }; + constexpr operator std::string_view() const { + return {data, length}; + } + }; - template - constexpr auto get_type_name_raw() { -#if defined(__clang__) - return fixed_string{__PRETTY_FUNCTION__}; -#elif defined(__GNUC__) - return fixed_string{__PRETTY_FUNCTION__}; -#elif defined(_MSC_VER) - return fixed_string{__FUNCSIG__}; -#else -#error "Compiler not supported" -#endif - } + template + constexpr auto get_type_name_raw() { + #if defined(__clang__) + return fixed_string{__PRETTY_FUNCTION__}; + #elif defined(__GNUC__) + return fixed_string{__PRETTY_FUNCTION__}; + #elif defined(_MSC_VER) + return fixed_string{__FUNCSIG__}; + #else + #error "Compiler not supported" + #endif + } - // 编译期解析并提取纯净的类型名 - template - constexpr auto get_type_name() { - constexpr std::string_view full_signature = get_type_name_raw(); - size_t start = 0; - size_t end = full_signature.length(); + // 编译期解析并提取纯净的类型名 + template + constexpr auto get_type_name() { + constexpr std::string_view full_signature = get_type_name_raw(); + size_t start = 0; + size_t end = full_signature.length(); -#if defined(__clang__) || defined(__GNUC__) - // 示例: "constexpr auto detail::get_type_name_raw() [with T = MyStruct]" - start = full_signature.find("T = ") + 4; // 定位到类型名开始 - end = full_signature.find_last_of(']'); // 定位到类型名结束 -#elif defined(_MSC_VER) - // 示例: "auto __cdecl detail::get_type_name_raw(void)" - start = full_signature.find('<') + 1; - // 移除 "struct ", "class " 等前缀 - std::string_view prefix_struct = "struct "; - std::string_view prefix_class = "class "; - if (full_signature.substr(start, prefix_struct.size()) == prefix_struct) { - start += prefix_struct.size(); - } - else if (full_signature.substr(start, prefix_class.size()) == prefix_class) { - start += prefix_class.size(); - } - end = full_signature.find_last_of('>'); -#endif + #if defined(__clang__) || defined(__GNUC__) + // 示例: "constexpr auto detail::get_type_name_raw() [with T = MyStruct]" + start = full_signature.find("T = ") + 4; // 定位到类型名开始 + end = full_signature.find_last_of(']'); // 定位到类型名结束 + #elif defined(_MSC_VER) + // 示例: "auto __cdecl detail::get_type_name_raw(void)" + start = full_signature.find('<') + 1; + // 移除 "struct ", "class " 等前缀 + std::string_view prefix_struct = "struct "; + std::string_view prefix_class = "class "; + if (full_signature.substr(start, prefix_struct.size()) == prefix_struct) { + start += prefix_struct.size(); + } + else if (full_signature.substr(start, prefix_class.size()) == prefix_class) { + start += prefix_class.size(); + } + end = full_signature.find_last_of('>'); + #endif - return full_signature.substr(start, end - start); - } + return full_signature.substr(start, end - start); + } - // 编译期 FNV-1a 哈希算法 (32位) - 直接在fixed_string上操作以支持MSVC - template - constexpr uint32_t fnv1a_hash_32_impl(const fixed_string& str, size_t index, uint32_t hash) { - constexpr uint32_t prime = 16777619U; - - if (index >= str.length) { - return hash; - } - - uint32_t new_hash = hash ^ static_cast(static_cast(str.data[index])); - new_hash *= prime; - - return fnv1a_hash_32_impl(str, index + 1, new_hash); - } - - template - constexpr uint32_t fnv1a_hash_32(const fixed_string& str) { - constexpr uint32_t basis = 2166136261U; - return fnv1a_hash_32_impl(str, 0, basis); - } + // 编译期 FNV-1a 哈希算法 (32位) - 直接在fixed_string上操作以支持MSVC + template + constexpr uint32_t fnv1a_hash_32_impl(const fixed_string& str, size_t index, uint32_t hash) { + constexpr uint32_t prime = 16777619U; + + if (index >= str.length) { + return hash; + } + + uint32_t new_hash = hash ^ static_cast(static_cast(str.data[index])); + new_hash *= prime; + + return fnv1a_hash_32_impl(str, index + 1, new_hash); + } + + template + constexpr uint32_t fnv1a_hash_32(const fixed_string& str) { + constexpr uint32_t basis = 2166136261U; + return fnv1a_hash_32_impl(str, 0, basis); + } } // 最终对外暴露的接口 template struct alicho_type_id { - // 获取原始类型签名 - static constexpr auto raw_name = type_util::get_type_name_raw(); - // 获取编译期计算的唯一哈希ID (32位) - 直接在fixed_string上计算 - static constexpr uint32_t hash = type_util::fnv1a_hash_32(raw_name); + // 获取原始类型签名 + static constexpr auto raw_name = type_util::get_type_name_raw(); + // 获取编译期计算的唯一哈希ID (32位) - 直接在fixed_string上计算 + static constexpr uint32_t hash = type_util::fnv1a_hash_32(raw_name); }; // 一个更便捷的变量模板 diff --git a/src/misc/windows/thread_tool.cpp b/src/misc/windows/thread_tool.cpp index f674d02..96a0fb5 100644 --- a/src/misc/windows/thread_tool.cpp +++ b/src/misc/windows/thread_tool.cpp @@ -3,25 +3,23 @@ #include #include "logger.h" -bool thread_set_affinity(boost::thread& thread, int core_id) -{ - const auto mask = 1LL << core_id; // 1LL 确保是64位 - const auto result = SetThreadAffinityMask(thread.native_handle(), mask); - if (result == 0) { - // 错误处理 - log_module_error(THREAD_TOOL_LOG_MODULE, "无法将线程亲和性设置为核心{}, 调用 SetThreadAffinityMask 时出错: {}", core_id, - GetLastError()); - } - return result != 0; +bool thread_set_affinity(boost::thread& thread, int core_id) { + const auto mask = 1LL << core_id; // 1LL 确保是64位 + const auto result = SetThreadAffinityMask(thread.native_handle(), mask); + if (result == 0) { + // 错误处理 + log_module_error(THREAD_TOOL_LOG_MODULE, "无法将线程亲和性设置为核心{}, 调用 SetThreadAffinityMask 时出错: {}", core_id, + GetLastError()); + } + return result != 0; } -bool thread_set_name(boost::thread& thread, const char* name) -{ - // Windows 上设置线程名称的标准方法是使用 SetThreadDescription - const auto hr = SetThreadDescription(thread.native_handle(), std::wstring(name, name + strlen(name)).c_str()); - if (FAILED(hr)) { - log_module_error(THREAD_TOOL_LOG_MODULE, "无法将线程名称设置为 {}, 调用 SetThreadDescription 时出错: 0x{:X}", name, hr); - return false; - } - return true; +bool thread_set_name(boost::thread& thread, const char* name) { + // Windows 上设置线程名称的标准方法是使用 SetThreadDescription + const auto hr = SetThreadDescription(thread.native_handle(), std::wstring(name, name + strlen(name)).c_str()); + if (FAILED(hr)) { + log_module_error(THREAD_TOOL_LOG_MODULE, "无法将线程名称设置为 {}, 调用 SetThreadDescription 时出错: 0x{:X}", name, hr); + return false; + } + return true; } diff --git a/src/network/rpc/engine_rpc.h b/src/network/rpc/engine_rpc.h index a35aa42..0250a77 100644 --- a/src/network/rpc/engine_rpc.h +++ b/src/network/rpc/engine_rpc.h @@ -3,8 +3,8 @@ #include namespace engine_rpc { - struct log_message_t { - uint32_t level; - std::string message; - }; + struct log_message_t { + uint32_t level; + std::string message; + }; } diff --git a/src/network/rpc/host_rpc.h b/src/network/rpc/host_rpc.h index 8ff4dd0..32aa56b 100644 --- a/src/network/rpc/host_rpc.h +++ b/src/network/rpc/host_rpc.h @@ -2,7 +2,7 @@ #include namespace host_rpc { - struct setup_t { - std::string shm_name; - }; -} \ No newline at end of file + struct setup_t { + std::string shm_name; + }; +} diff --git a/src/network/shm/interprocess_synchronization.cpp b/src/network/shm/interprocess_synchronization.cpp index 5035864..56474c1 100644 --- a/src/network/shm/interprocess_synchronization.cpp +++ b/src/network/shm/interprocess_synchronization.cpp @@ -61,7 +61,7 @@ interprocess_synchronization::lock_mutex(const std::string& name, std::chrono::m try { if (timeout.count() > 0) { const auto abs_time = boost::posix_time::microsec_clock::universal_time() + - boost::posix_time::milliseconds(timeout.count()); + boost::posix_time::milliseconds(timeout.count()); if (it->second->timed_lock(abs_time)) { return shared_memory_error::SUCCESS; } return shared_memory_error::SYNCHRONIZATION_FAILED; // 超时 @@ -147,8 +147,8 @@ shared_memory_error interprocess_synchronization::open_condition(const std::stri } shared_memory_error interprocess_synchronization::wait_condition(const std::string& condition_name, - const std::string& mutex_name, - std::chrono::milliseconds timeout) { + const std::string& mutex_name, + std::chrono::milliseconds timeout) { std::lock_guard lock(local_mutex_); auto cond_it = conditions_.find(condition_name); @@ -160,7 +160,7 @@ shared_memory_error interprocess_synchronization::wait_condition(const std::stri try { if (timeout.count() > 0) { const auto abs_time = boost::posix_time::microsec_clock::universal_time() + - boost::posix_time::milliseconds(timeout.count()); + boost::posix_time::milliseconds(timeout.count()); scoped_lock named_lock(*mutex_it->second); if (cond_it->second->timed_wait(named_lock, abs_time)) { return shared_memory_error::SUCCESS; } @@ -172,9 +172,9 @@ shared_memory_error interprocess_synchronization::wait_condition(const std::stri } catch (const boost::interprocess::interprocess_exception& e) { log_module_error(INTERPROCESS_SYNCHRONIZATION_LOG_MODULE, - "等待条件变量失败 '%s': %s", - condition_name.c_str(), - e.what()); + "等待条件变量失败 '%s': %s", + condition_name.c_str(), + e.what()); return shared_memory_error::SYNCHRONIZATION_FAILED; } } @@ -229,8 +229,8 @@ interprocess_synchronization::create_semaphore(const std::string& name, unsigned boost::interprocess::named_semaphore::remove(name.c_str()); auto semaphore = std::make_unique(boost::interprocess::create_only, - name.c_str(), - initial_count); + name.c_str(), + initial_count); semaphores_[name] = std::move(semaphore); log_module_debug(INTERPROCESS_SYNCHRONIZATION_LOG_MODULE, "创建信号量: %s, 初始计数: %u", name.c_str(), initial_count); @@ -261,7 +261,7 @@ shared_memory_error interprocess_synchronization::open_semaphore(const std::stri } shared_memory_error interprocess_synchronization::wait_semaphore(const std::string& name, - std::chrono::milliseconds timeout) { + std::chrono::milliseconds timeout) { std::lock_guard lock(local_mutex_); auto it = semaphores_.find(name); @@ -270,7 +270,7 @@ shared_memory_error interprocess_synchronization::wait_semaphore(const std::stri try { if (timeout.count() > 0) { const auto abs_time = boost::posix_time::microsec_clock::universal_time() + - boost::posix_time::milliseconds(timeout.count()); + boost::posix_time::milliseconds(timeout.count()); if (it->second->timed_wait(abs_time)) { return shared_memory_error::SUCCESS; } return shared_memory_error::SYNCHRONIZATION_FAILED; // 超时 @@ -319,7 +319,7 @@ shared_memory_error interprocess_synchronization::remove_semaphore(const std::st } interprocess_synchronization::scoped_mutex_lock::scoped_mutex_lock(interprocess_synchronization& sync, - const std::string& mutex_name) : sync_(sync), + const std::string& mutex_name) : sync_(sync), mutex_name_(mutex_name), locked_(false) { auto result = sync_.lock_mutex(mutex_name_); diff --git a/src/network/shm/interprocess_synchronization.h b/src/network/shm/interprocess_synchronization.h index 0255161..acc3e61 100644 --- a/src/network/shm/interprocess_synchronization.h +++ b/src/network/shm/interprocess_synchronization.h @@ -9,55 +9,57 @@ class interprocess_synchronization { public: - using interprocess_mutex = boost::interprocess::named_mutex; - using interprocess_condition = boost::interprocess::named_condition; - using interprocess_semaphore = boost::interprocess::named_semaphore; - using scoped_lock = boost::interprocess::scoped_lock; - - explicit interprocess_synchronization(const shared_memory_config& config); - ~interprocess_synchronization(); - - // 互斥量操作 - shared_memory_error create_mutex(const std::string& name); - shared_memory_error open_mutex(const std::string& name); - shared_memory_error lock_mutex(const std::string& name, std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)); - shared_memory_error unlock_mutex(const std::string& name); - shared_memory_error remove_mutex(const std::string& name); - - // 条件变量操作 - shared_memory_error create_condition(const std::string& name); - shared_memory_error open_condition(const std::string& name); - shared_memory_error wait_condition(const std::string& condition_name, const std::string& mutex_name, - std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)); - shared_memory_error notify_condition(const std::string& name, bool notify_all = false); - shared_memory_error remove_condition(const std::string& name); - - // 信号量操作 - shared_memory_error create_semaphore(const std::string& name, unsigned int initial_count); - shared_memory_error open_semaphore(const std::string& name); - shared_memory_error wait_semaphore(const std::string& name, std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)); - shared_memory_error post_semaphore(const std::string& name); - shared_memory_error remove_semaphore(const std::string& name); - - // RAII锁 - class scoped_mutex_lock { - public: - scoped_mutex_lock(interprocess_synchronization& sync, const std::string& mutex_name); - ~scoped_mutex_lock(); - - [[nodiscard]] auto is_locked() const { return locked_; } - - private: - interprocess_synchronization& sync_; - std::string mutex_name_; - bool locked_; - }; - + using interprocess_mutex = boost::interprocess::named_mutex; + using interprocess_condition = boost::interprocess::named_condition; + using interprocess_semaphore = boost::interprocess::named_semaphore; + using scoped_lock = boost::interprocess::scoped_lock; + + explicit interprocess_synchronization(const shared_memory_config& config); + ~interprocess_synchronization(); + + // 互斥量操作 + shared_memory_error create_mutex(const std::string& name); + shared_memory_error open_mutex(const std::string& name); + shared_memory_error lock_mutex(const std::string& name, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)); + shared_memory_error unlock_mutex(const std::string& name); + shared_memory_error remove_mutex(const std::string& name); + + // 条件变量操作 + shared_memory_error create_condition(const std::string& name); + shared_memory_error open_condition(const std::string& name); + shared_memory_error wait_condition(const std::string& condition_name, const std::string& mutex_name, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)); + shared_memory_error notify_condition(const std::string& name, bool notify_all = false); + shared_memory_error remove_condition(const std::string& name); + + // 信号量操作 + shared_memory_error create_semaphore(const std::string& name, unsigned int initial_count); + shared_memory_error open_semaphore(const std::string& name); + shared_memory_error wait_semaphore(const std::string& name, + std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)); + shared_memory_error post_semaphore(const std::string& name); + shared_memory_error remove_semaphore(const std::string& name); + + // RAII锁 + class scoped_mutex_lock { + public: + scoped_mutex_lock(interprocess_synchronization& sync, const std::string& mutex_name); + ~scoped_mutex_lock(); + + [[nodiscard]] auto is_locked() const { return locked_; } + + private: + interprocess_synchronization& sync_; + std::string mutex_name_; + bool locked_; + }; + private: - shared_memory_config config_; - std::unordered_map> mutexes_; - std::unordered_map> conditions_; - std::unordered_map> semaphores_; - - std::mutex local_mutex_; // 保护本地数据结构 + shared_memory_config config_; + std::unordered_map> mutexes_; + std::unordered_map> conditions_; + std::unordered_map> semaphores_; + + std::mutex local_mutex_; // 保护本地数据结构 }; diff --git a/src/network/shm/lock_free_ring_buffer.h b/src/network/shm/lock_free_ring_buffer.h index ad86725..bca5568 100644 --- a/src/network/shm/lock_free_ring_buffer.h +++ b/src/network/shm/lock_free_ring_buffer.h @@ -7,7 +7,7 @@ #define LOCK_FREE_RING_BUFFER_LOG_MODULE "lock-free ring buffer" -template +template class lock_free_ring_buffer { public: static_assert(std::is_trivially_copyable_v, "T 必须是平凡可拷贝类型"); @@ -16,15 +16,15 @@ public: ~lock_free_ring_buffer(); // 禁止拷贝,允许移动 - lock_free_ring_buffer(const lock_free_ring_buffer&) = delete; + lock_free_ring_buffer(const lock_free_ring_buffer&) = delete; lock_free_ring_buffer& operator=(const lock_free_ring_buffer&) = delete; - lock_free_ring_buffer(lock_free_ring_buffer&&) = default; - lock_free_ring_buffer& operator=(lock_free_ring_buffer&&) = default; + lock_free_ring_buffer(lock_free_ring_buffer&&) = default; + lock_free_ring_buffer& operator=(lock_free_ring_buffer&&) = default; // 生产者接口 auto try_push(const T& item) -> bool; auto try_push(T&& item) -> bool; - template + template auto try_emplace(Args&&... args) -> bool; // 消费者接口 @@ -41,18 +41,19 @@ public: // 批量操作 auto try_push_batch(const T* items, size_t count) -> size_t; auto try_pop_batch(T* items, size_t count) -> size_t; + private: struct buffer_data { - std::atomic head{ 0 }; - std::atomic tail{ 0 }; + std::atomic head{0}; + std::atomic tail{0}; size_t capacity{}; alignas(64) T data[1]; }; buffer_data* buffer_; - size_t capacity_; - std::string name_; - bool creator_; + size_t capacity_; + std::string name_; + bool creator_; [[nodiscard]] auto next_index(size_t index) const -> size_t { return (index + 1) % (capacity_ + 1); // +1 是为了区分满和空的状态 @@ -66,7 +67,7 @@ private: } }; -template +template lock_free_ring_buffer::lock_free_ring_buffer(const std::string& name, size_t capacity) : capacity_(capacity), name_(name), creator_(false) { @@ -79,7 +80,7 @@ lock_free_ring_buffer::lock_free_ring_buffer(const std::string& name, size_t throw std::runtime_error("无法分配共享内存"); } - buffer_ = new (raw_memory) buffer_data(); + buffer_ = new(raw_memory) buffer_data(); buffer_->capacity = capacity_; creator_ = true; @@ -94,7 +95,7 @@ lock_free_ring_buffer::lock_free_ring_buffer(const std::string& name, size_t } } -template +template lock_free_ring_buffer::~lock_free_ring_buffer() { if (creator_) { shm_deallocate_raw(name_); @@ -102,9 +103,9 @@ lock_free_ring_buffer::~lock_free_ring_buffer() { } } -template +template auto lock_free_ring_buffer::try_push(const T& item) -> bool { - auto head = buffer_->head.load(std::memory_order_relaxed); + auto head = buffer_->head.load(std::memory_order_relaxed); auto next_head = next_index(head); if (next_head == buffer_->tail.load(std::memory_order_acquire)) { @@ -116,9 +117,9 @@ auto lock_free_ring_buffer::try_push(const T& item) -> bool { return true; } -template +template auto lock_free_ring_buffer::try_push(T&& item) -> bool { - auto head = buffer_->head.load(std::memory_order_relaxed); + auto head = buffer_->head.load(std::memory_order_relaxed); auto next_head = next_index(head); if (next_head == buffer_->tail.load(std::memory_order_acquire)) { @@ -131,23 +132,23 @@ auto lock_free_ring_buffer::try_push(T&& item) -> bool { return true; } -template -template +template +template auto lock_free_ring_buffer::try_emplace(Args&&... args) -> bool { - auto head = buffer_->head.load(std::memory_order_relaxed); + auto head = buffer_->head.load(std::memory_order_relaxed); auto next_head = next_index(head); if (next_head == buffer_->tail.load(std::memory_order_acquire)) { return false; // 缓冲区满 } - new (&buffer_->data[head]) T(std::forward(args)...); + new(&buffer_->data[head]) T(std::forward(args)...); buffer_->head.store(next_head, std::memory_order_release); return true; } -template +template auto lock_free_ring_buffer::try_pop(T& item) -> bool { auto tail = buffer_->tail.load(std::memory_order_relaxed); @@ -160,44 +161,44 @@ auto lock_free_ring_buffer::try_pop(T& item) -> bool { return true; } -template +template auto lock_free_ring_buffer::try_peek(T& item) const -> bool { auto tail = buffer_->tail.load(std::memory_order_relaxed); if (tail == buffer_->head.load(std::memory_order_acquire)) { - return false; // 缓冲区空 + return false; // 缓冲区空 } item = buffer_->data[tail]; return true; } -template +template auto lock_free_ring_buffer::empty() const -> bool { return buffer_->tail.load(std::memory_order_acquire) == - buffer_->head.load(std::memory_order_acquire); + buffer_->head.load(std::memory_order_acquire); } -template +template auto lock_free_ring_buffer::full() const -> bool { auto head = buffer_->head.load(std::memory_order_acquire); auto tail = buffer_->tail.load(std::memory_order_acquire); return next_index(head) == tail; } -template +template auto lock_free_ring_buffer::size() const -> size_t { auto head = buffer_->head.load(std::memory_order_acquire); auto tail = buffer_->tail.load(std::memory_order_acquire); return distance(tail, head); } -template +template auto lock_free_ring_buffer::available_space() const -> size_t { return capacity_ - size(); } -template +template auto lock_free_ring_buffer::try_push_batch(const T* items, size_t count) -> size_t { if (!items || count == 0) { return 0; @@ -214,7 +215,7 @@ auto lock_free_ring_buffer::try_push_batch(const T* items, size_t count) -> s return pushed; } -template +template auto lock_free_ring_buffer::try_pop_batch(T* items, size_t count) -> size_t { if (!items || count == 0) { return 0; diff --git a/src/network/shm/shared_memory_manager.cpp b/src/network/shm/shared_memory_manager.cpp index 7bfaeba..2bee1d3 100644 --- a/src/network/shm/shared_memory_manager.cpp +++ b/src/network/shm/shared_memory_manager.cpp @@ -8,7 +8,8 @@ auto shared_memory_manager::init(const shared_memory_config& config) -> shared_m return shared_memory_error::SUCCESS; config_ = config; - log_module_info(SHARED_MEMORY_MANAGER_LOG_MODULE, "初始化共享内存段: {}, 大小: {} 字节", config_.segment_name, config_.segment_size); + log_module_info(SHARED_MEMORY_MANAGER_LOG_MODULE, "初始化共享内存段: {}, 大小: {} 字节", config_.segment_name, + config_.segment_size); auto result = create_or_open_segment(); if (result != shared_memory_error::SUCCESS) { @@ -86,7 +87,7 @@ auto shared_memory_manager::create_or_open_segment() -> shared_memory_error { try { try { segment_ = std::make_unique(boost::interprocess::open_only, - config_.segment_name.c_str()); + config_.segment_name.c_str()); creator_ = false; log_module_debug(SHARED_MEMORY_MANAGER_LOG_MODULE, "成功打开现有的共享内存段: {}", config_.segment_name); @@ -107,8 +108,8 @@ auto shared_memory_manager::create_or_open_segment() -> shared_memory_error { // 尝试创建新的共享内存段 segment_ = std::make_unique(boost::interprocess::create_only, - config_.segment_name.c_str(), - config_.segment_size); + config_.segment_name.c_str(), + config_.segment_size); creator_ = true; log_module_debug(SHARED_MEMORY_MANAGER_LOG_MODULE, "成功创建新的共享内存段: {}", config_.segment_name); @@ -119,7 +120,8 @@ auto shared_memory_manager::create_or_open_segment() -> shared_memory_error { catch (const boost::interprocess::interprocess_exception& e) { log_module_error(SHARED_MEMORY_MANAGER_LOG_MODULE, "创建或打开共享 memory 段时出错: {}", e.what()); return shared_memory_error::CREATION_FAILED; - } catch (const std::exception& e) { + } + catch (const std::exception& e) { log_module_error(SHARED_MEMORY_MANAGER_LOG_MODULE, "创建或打开共享 memory 段时发生异常: {}", e.what()); return shared_memory_error::CREATION_FAILED; } @@ -130,7 +132,12 @@ auto shared_memory_manager::create_or_open_segment() -> shared_memory_error { void shared_memory_manager::cleanup() { segment_.reset(); - if (creator_ && config_.remove_on_destroy) { + if (creator_ && config_ + + + . + remove_on_destroy + ) { try { if (boost::interprocess::shared_memory_object::remove(config_.segment_name.c_str())) { log_module_info(SHARED_MEMORY_MANAGER_LOG_MODULE, "已移除共享内存段: {}", config_.segment_name); diff --git a/src/network/shm/shared_memory_manager.h b/src/network/shm/shared_memory_manager.h index 1c583c0..e9ae3bb 100644 --- a/src/network/shm/shared_memory_manager.h +++ b/src/network/shm/shared_memory_manager.h @@ -56,13 +56,13 @@ public: auto is_initialized() const { return initialized_; } - template + template auto allocate(const std::string& name) -> T*; - template + template auto find(const std::string& name) -> T*; - template + template auto deallocate(const std::string& name) -> bool; auto allocate_raw(size_t size, const std::string& name) -> void*; @@ -89,7 +89,7 @@ private: void cleanup(); }; -template +template auto shared_memory_manager::allocate(const std::string& name) -> T* { if (!segment_) { log_module_error(SHARED_MEMORY_MANAGER_LOG_MODULE, "尝试在未初始化的共享内存段中分配对象: {}", name); @@ -103,7 +103,7 @@ auto shared_memory_manager::allocate(const std::string& name) -> T* { } } -template +template auto shared_memory_manager::find(const std::string& name) -> T* { if (!segment_) { log_module_error(SHARED_MEMORY_MANAGER_LOG_MODULE, "尝试在未初始化的共享内存段中查找对象: {}", name); @@ -121,7 +121,7 @@ auto shared_memory_manager::find(const std::string& name) -> T* { } } -template +template auto shared_memory_manager::deallocate(const std::string& name) -> bool { if (!segment_) { log_module_error(SHARED_MEMORY_MANAGER_LOG_MODULE, "尝试在未初始化的共享内存段中释放对象: {}", name); @@ -141,25 +141,31 @@ auto shared_memory_manager::deallocate(const std::string& name) -> bool { inline auto shm_allocate_raw(const std::string& name, size_t size) { return shared_memory_manager::instance().allocate_raw(size, name); } + inline auto shm_find_raw(const std::string& name) { return shared_memory_manager::instance().find_raw(name); } + inline auto shm_deallocate_raw(const std::string& name) { return shared_memory_manager::instance().deallocate_raw(name); } -template + +template auto shm_find_raw(const std::string& name) { return static_cast(shared_memory_manager::instance().find_raw(name)); } -template + +template auto shm_allocate(const std::string& name) { return shared_memory_manager::instance().allocate(name); } -template + +template auto shm_find(const std::string& name) { return shared_memory_manager::instance().find(name); } -template + +template auto shm_deallocate(const std::string& name) { return shared_memory_manager::instance().deallocate(name); } diff --git a/src/network/shm/triple_buffer.h b/src/network/shm/triple_buffer.h index 09402ec..55d9787 100644 --- a/src/network/shm/triple_buffer.h +++ b/src/network/shm/triple_buffer.h @@ -8,7 +8,7 @@ #define TRIPLE_BUFFER_LOG_MODULE "triple buffer" // 无锁三缓冲区 -template +template class triple_buffer { public: static_assert(std::is_trivially_copyable_v, "T 必须是平凡可拷贝类型"); @@ -17,10 +17,10 @@ public: ~triple_buffer(); // 禁止拷贝,允许移动 - triple_buffer(const triple_buffer&) = delete; + triple_buffer(const triple_buffer&) = delete; triple_buffer& operator=(const triple_buffer&) = delete; - triple_buffer(triple_buffer&&) = default; - triple_buffer& operator=(triple_buffer&&) = default; + triple_buffer(triple_buffer&&) = default; + triple_buffer& operator=(triple_buffer&&) = default; // 生产者接口 auto get_write_buffer() -> T*; @@ -34,27 +34,27 @@ public: // 状态查询 [[nodiscard]] auto has_new_data() const -> bool; [[nodiscard]] auto pending_writes() const -> size_t; + private: struct buffer_data { - std::atomic write_index{ 0 }; // 写缓冲区索引 - std::atomic read_index{ 0 }; // 读缓冲区索引 - std::atomic available_index{ 1 }; // 可用缓冲区索引 - std::atomic new_data{ false }; // 是否有新数据可读 - alignas(64) T buffers[3]; // 三个缓冲区 + std::atomic write_index{0}; // 写缓冲区索引 + std::atomic read_index{0}; // 读缓冲区索引 + std::atomic available_index{1}; // 可用缓冲区索引 + std::atomic new_data{false}; // 是否有新数据可读 + alignas(64) T buffers[3]; // 三个缓冲区 }; buffer_data* buffer_; - std::string name_; - bool creator_; + std::string name_; + bool creator_; int32_t current_write_buffer_ = -1; - int32_t current_read_buffer_ = -1; + int32_t current_read_buffer_ = -1; }; -template +template triple_buffer::triple_buffer(const std::string& name) : name_(name), - creator_(false) { - + creator_(false) { buffer_ = shm_find(name_); if (!buffer_) { @@ -65,12 +65,13 @@ triple_buffer::triple_buffer(const std::string& name) : name_(name), creator_ = true; log_module_debug(TRIPLE_BUFFER_LOG_MODULE, "创建无锁三缓冲区: {}", name_); - } else { + } + else { log_module_debug(TRIPLE_BUFFER_LOG_MODULE, "打开现有无锁三缓冲区: {}", name_); } } -template +template triple_buffer::~triple_buffer() { if (creator_) { shm_deallocate(name_); @@ -79,7 +80,7 @@ triple_buffer::~triple_buffer() { } } -template +template auto triple_buffer::get_write_buffer() -> T* { if (current_write_buffer_ >= 0) { return nullptr; @@ -89,10 +90,10 @@ auto triple_buffer::get_write_buffer() -> T* { return &buffer_->buffers[current_write_buffer_]; } -template +template void triple_buffer::commit_write() { if (current_write_buffer_ < 0) { - return; // 没有活跃的写操作 + return; // 没有活跃的写操作 } // 交换写缓冲区和可用缓冲区 @@ -106,12 +107,12 @@ void triple_buffer::commit_write() { current_write_buffer_ = -1; } -template +template void triple_buffer::discard_write() { current_write_buffer_ = -1; } -template +template auto triple_buffer::get_read_buffer() -> const T* { if (current_read_buffer_ >= 0) { // 继续使用当前读缓冲区 @@ -119,17 +120,17 @@ auto triple_buffer::get_read_buffer() -> const T* { } if (!buffer_->new_data.load(std::memory_order_acquire)) { - return nullptr; // 没有新数据 + return nullptr; // 没有新数据 } current_read_buffer_ = buffer_->write_index.load(std::memory_order_acquire); return &buffer_->buffers[current_read_buffer_]; } -template +template void triple_buffer::commit_read() { if (current_read_buffer_ < 0) { - return; // 没有活跃的读操作 + return; // 没有活跃的读操作 } // 交换读缓冲区和可用缓冲区 @@ -141,12 +142,12 @@ void triple_buffer::commit_read() { current_read_buffer_ = -1; } -template +template auto triple_buffer::has_new_data() const -> bool { return buffer_->new_data.load(std::memory_order_acquire); } -template +template auto triple_buffer::pending_writes() const -> size_t { return current_write_buffer_ >= 0 ? 1 : 0; } diff --git a/src/network/transport/zmq_client.cpp b/src/network/transport/zmq_client.cpp index 1a984d1..619da8b 100644 --- a/src/network/transport/zmq_client.cpp +++ b/src/network/transport/zmq_client.cpp @@ -5,94 +5,98 @@ #define ZMQ_CLIENT_LOG_MODULE "zmq_client" void zmq_client::init(uint32_t client_id) { - if (state_ == zmq_client::state::CONNECTED) { - if (client_id_ != client_id) { - // 客户端ID不同,需要断开重连 - log_module_info(ZMQ_CLIENT_LOG_MODULE, "客户端ID变更({} -> {}),断开重连", client_id_, client_id); - disconnect(); - } else { - // ID相同,已连接,直接返回 - log_module_warn(ZMQ_CLIENT_LOG_MODULE, "客户端已连接,ID={}", client_id_); - return; - } - } - - state_ = zmq_client::state::CONNECTING; - client_id_ = client_id; - - try { - socket_ = zmq::socket_t(context_, zmq::socket_type::dealer); - socket_.set(zmq::sockopt::routing_id, zmq::const_buffer(&client_id_, sizeof(client_id_))); - socket_.connect(ZMQ_SERVER_ADDRESS); - - state_ = zmq_client::state::CONNECTED; - log_module_info(ZMQ_CLIENT_LOG_MODULE, "客户端 {} 已连接", client_id_); - } catch (const zmq::error_t& e) { - state_ = zmq_client::state::FAILED; - log_module_error(ZMQ_CLIENT_LOG_MODULE, "连接失败: {}", e.what()); - throw; - } + if (state_ == zmq_client::state::CONNECTED) { + if (client_id_ != client_id) { + // 客户端ID不同,需要断开重连 + log_module_info(ZMQ_CLIENT_LOG_MODULE, "客户端ID变更({} -> {}),断开重连", client_id_, client_id); + disconnect(); + } + else { + // ID相同,已连接,直接返回 + log_module_warn(ZMQ_CLIENT_LOG_MODULE, "客户端已连接,ID={}", client_id_); + return; + } + } + + state_ = zmq_client::state::CONNECTING; + client_id_ = client_id; + + try { + socket_ = zmq::socket_t(context_, zmq::socket_type::dealer); + socket_.set(zmq::sockopt::routing_id, zmq::const_buffer(&client_id_, sizeof(client_id_))); + socket_.connect(ZMQ_SERVER_ADDRESS); + + state_ = zmq_client::state::CONNECTED; + log_module_info(ZMQ_CLIENT_LOG_MODULE, "客户端 {} 已连接", client_id_); + } + catch (const zmq::error_t& e) { + state_ = zmq_client::state::FAILED; + log_module_error(ZMQ_CLIENT_LOG_MODULE, "连接失败: {}", e.what()); + throw; + } } void zmq_client::recv() { - zmq::message_t msg; - - try { - auto result = socket_.recv(msg, zmq::recv_flags::none); - - if (!result || *result == 0) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "接收消息失败"); - return; - } - - auto pack = zmq_message_pack::deserialize(msg); - log_module_debug(ZMQ_CLIENT_LOG_MODULE, "收到消息,func_id: {}, payload大小: {}", - pack.func_id, pack.payload.size()); - - zmq_client_processor::instance().process(pack.func_id, pack.payload.data(), pack.payload.size()); - - } catch (const std::exception& e) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "处理消息异常: {}", e.what()); - } + zmq::message_t msg; + + try { + auto result = socket_.recv(msg, zmq::recv_flags::none); + + if (!result || *result == 0) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "接收消息失败"); + return; + } + + auto pack = zmq_message_pack::deserialize(msg); + log_module_debug(ZMQ_CLIENT_LOG_MODULE, "收到消息,func_id: {}, payload大小: {}", + pack.func_id, pack.payload.size()); + + zmq_client_processor::instance().process(pack.func_id, pack.payload.data(), pack.payload.size()); + } + catch (const std::exception& e) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "处理消息异常: {}", e.what()); + } } void zmq_client::disconnect() { - if (state_ == zmq_client::state::DISCONNECTED) { - return; - } - - try { - socket_.close(); - state_ = zmq_client::state::DISCONNECTED; - log_module_info(ZMQ_CLIENT_LOG_MODULE, "客户端 {} 已断开连接", client_id_); - } catch (const zmq::error_t& e) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "断开连接异常: {}", e.what()); - } + if (state_ == zmq_client::state::DISCONNECTED) { + return; + } + + try { + socket_.close(); + state_ = zmq_client::state::DISCONNECTED; + log_module_info(ZMQ_CLIENT_LOG_MODULE, "客户端 {} 已断开连接", client_id_); + } + catch (const zmq::error_t& e) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "断开连接异常: {}", e.what()); + } } bool zmq_client::reconnect() { - log_module_info(ZMQ_CLIENT_LOG_MODULE, "尝试重连客户端 {}", client_id_); - - disconnect(); - - for (uint32_t attempt = 1; attempt <= MAX_RECONNECT_ATTEMPTS; ++attempt) { - state_ = zmq_client::state::RECONNECTING; - log_module_info(ZMQ_CLIENT_LOG_MODULE, "重连尝试 {}/{}", attempt, MAX_RECONNECT_ATTEMPTS); - - try { - init(client_id_); - log_module_info(ZMQ_CLIENT_LOG_MODULE, "重连成功"); - return true; - } catch (const std::exception& e) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "重连失败: {}", e.what()); - - if (attempt < MAX_RECONNECT_ATTEMPTS) { - std::this_thread::sleep_for(std::chrono::milliseconds(RECONNECT_DELAY_MS)); - } - } - } - - state_ = zmq_client::state::FAILED; - log_module_error(ZMQ_CLIENT_LOG_MODULE, "重连失败,已达最大尝试次数"); - return false; + log_module_info(ZMQ_CLIENT_LOG_MODULE, "尝试重连客户端 {}", client_id_); + + disconnect(); + + for (uint32_t attempt = 1; attempt <= MAX_RECONNECT_ATTEMPTS; ++attempt) { + state_ = zmq_client::state::RECONNECTING; + log_module_info(ZMQ_CLIENT_LOG_MODULE, "重连尝试 {}/{}", attempt, MAX_RECONNECT_ATTEMPTS); + + try { + init(client_id_); + log_module_info(ZMQ_CLIENT_LOG_MODULE, "重连成功"); + return true; + } + catch (const std::exception& e) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "重连失败: {}", e.what()); + + if (attempt < MAX_RECONNECT_ATTEMPTS) { + std::this_thread::sleep_for(std::chrono::milliseconds(RECONNECT_DELAY_MS)); + } + } + } + + state_ = zmq_client::state::FAILED; + log_module_error(ZMQ_CLIENT_LOG_MODULE, "重连失败,已达最大尝试次数"); + return false; } diff --git a/src/network/transport/zmq_client.h b/src/network/transport/zmq_client.h index 01658c8..e3196df 100644 --- a/src/network/transport/zmq_client.h +++ b/src/network/transport/zmq_client.h @@ -9,74 +9,78 @@ #define ZMQ_CLIENT_LOG_MODULE "zmq_client" class zmq_client : public lazy_singleton { - friend class lazy_singleton; + friend class lazy_singleton; + public: - enum class state { - DISCONNECTED, - CONNECTING, - CONNECTED, - RECONNECTING, - FAILED - }; - - ~zmq_client() { - try { - if (state_ == state::CONNECTED) { - socket_.set(zmq::sockopt::linger, 0); // 立即关闭,不等待 - socket_.close(); - state_ = state::DISCONNECTED; - } - } catch (...) { - // 析构函数中不抛出异常,也不记录日志(logger可能已析构) - } - } - - void init(uint32_t client_id); + enum class state { + DISCONNECTED, + CONNECTING, + CONNECTED, + RECONNECTING, + FAILED + }; - template - void send(const T& message) { - if (!is_connected()) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "客户端未连接,尝试重连"); - if (!reconnect()) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "重连失败,无法发送消息"); - return; - } - } - - try { - auto buf = zmq_message_pack::create(message); - zmq::message_t msg(buf.data(), buf.size()); - auto result = socket_.send(msg, zmq::send_flags::none); - - if (!result) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "发送消息失败"); - } else { - log_module_debug(ZMQ_CLIENT_LOG_MODULE, "成功发送消息"); - } - } catch (const zmq::error_t& e) { - log_module_error(ZMQ_CLIENT_LOG_MODULE, "发送消息异常: {}", e.what()); - // 发送失败可能是连接断开,标记状态 - state_ = state::FAILED; - } - } + ~zmq_client() { + try { + if (state_ == state::CONNECTED) { + socket_.set(zmq::sockopt::linger, 0); // 立即关闭,不等待 + socket_.close(); + state_ = state::DISCONNECTED; + } + } + catch (...) { + // 析构函数中不抛出异常,也不记录日志(logger可能已析构) + } + } + + void init(uint32_t client_id); + + template + void send(const T& message) { + if (!is_connected()) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "客户端未连接,尝试重连"); + if (!reconnect()) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "重连失败,无法发送消息"); + return; + } + } + + try { + auto buf = zmq_message_pack::create(message); + zmq::message_t msg(buf.data(), buf.size()); + auto result = socket_.send(msg, zmq::send_flags::none); + + if (!result) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "发送消息失败"); + } + else { + log_module_debug(ZMQ_CLIENT_LOG_MODULE, "成功发送消息"); + } + } + catch (const zmq::error_t& e) { + log_module_error(ZMQ_CLIENT_LOG_MODULE, "发送消息异常: {}", e.what()); + // 发送失败可能是连接断开,标记状态 + state_ = state::FAILED; + } + } + + void recv(); + + // 连接管理 + void disconnect(); + bool reconnect(); + + // 状态查询 + state get_state() const { return state_; } + bool is_connected() const { return state_ == state::CONNECTED; } - void recv(); - - // 连接管理 - void disconnect(); - bool reconnect(); - - // 状态查询 - state get_state() const { return state_; } - bool is_connected() const { return state_ == state::CONNECTED; } - private: - zmq::socket_t socket_; - zmq::context_t context_; - uint32_t client_id_; - state state_ = state::DISCONNECTED; - - // 重连配置 - static constexpr uint32_t MAX_RECONNECT_ATTEMPTS = 3; - static constexpr uint32_t RECONNECT_DELAY_MS = 1000; + zmq::socket_t socket_; + zmq::context_t context_; + uint32_t client_id_; + state state_ = state::DISCONNECTED; + + // 重连配置 + static constexpr uint32_t MAX_RECONNECT_ATTEMPTS = 3; + static constexpr uint32_t RECONNECT_DELAY_MS = 1000; }; diff --git a/src/network/transport/zmq_client_processor.h b/src/network/transport/zmq_client_processor.h index 1d26c99..49c33e9 100644 --- a/src/network/transport/zmq_client_processor.h +++ b/src/network/transport/zmq_client_processor.h @@ -8,48 +8,58 @@ #define ZMQ_CLIENT_PROCESSOR_LOG_MODULE "zmq_client_processor" class zmq_client_processor : public lazy_singleton { - friend class lazy_singleton; + friend class lazy_singleton; public: - void register_processor(uint32_t func_id, std::function processor) { - processors_[func_id] = std::move(processor); - } + void register_processor(uint32_t func_id, std::function processor) { + processors_[func_id] = std::move(processor); + } + + void process(uint32_t func_id, const void* payload, size_t size) { + const auto it = processors_.find(func_id); + if (it == processors_.end()) { + log_module_error(ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "未找到func_id {} 的处理器", func_id); + return; + } + + try { + it->second(payload, size); + } + catch (const std::exception& e) { + log_module_error(ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "处理器执行异常: {}", e.what()); + } + } - void process(uint32_t func_id, const void* payload, size_t size) { - const auto it = processors_.find(func_id); - if (it == processors_.end()) { - log_module_error(ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "未找到func_id {} 的处理器", func_id); - return; - } - - try { - it->second(payload, size); - } catch (const std::exception& e) { - log_module_error(ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "处理器执行异常: {}", e.what()); - } - } private: - std::unordered_map> processors_; + std::unordered_map> processors_; }; template class zmq_client_register { public: - zmq_client_register(auto processor) { - zmq_client_processor::instance().register_processor(alicho_type_id_v, [processor](const void* payload, size_t size) { - try { - auto result = struct_pack::deserialize(static_cast(payload), size); - if (result.has_value()) { - processor(result.value()); - } - else { - log_module_error(ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "反序列化失败: error_code={}", static_cast(result.error())); - } - } catch (const std::exception& e) { - log_module_error(ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "反序列化失败: {}", e.what()); - } - }); - } + zmq_client_register(auto processor) { + zmq_client_processor::instance().register_processor(alicho_type_id_v, + [processor](const void* payload, size_t size) { + try { + auto result = struct_pack::deserialize( + static_cast(payload), size); + if (result.has_value()) { + processor(result.value()); + } + else { + log_module_error( + ZMQ_CLIENT_PROCESSOR_LOG_MODULE, + "反序列化失败: error_code={}", + static_cast(result.error())); + } + } + catch (const std::exception& e) { + log_module_error( + ZMQ_CLIENT_PROCESSOR_LOG_MODULE, "反序列化失败: {}", + e.what()); + } + }); + } }; #define ZMQ_CLIENT_REGISTER_PROCESSOR(data_type) \ diff --git a/src/network/transport/zmq_server.cpp b/src/network/transport/zmq_server.cpp index 95ddc01..553a905 100644 --- a/src/network/transport/zmq_server.cpp +++ b/src/network/transport/zmq_server.cpp @@ -3,70 +3,72 @@ #include "zmq_server_processor.h" void zmq_server::init() { - if (state_ == zmq_server::state::RUNNING) { - log_module_warn(ZMQ_SERVER_LOG_MODULE, "服务器已在运行中"); - return; - } - - try { - socket_ = zmq::socket_t(context_, zmq::socket_type::router); - socket_.bind(ZMQ_SERVER_ADDRESS); - state_ = zmq_server::state::RUNNING; - log_module_info(ZMQ_SERVER_LOG_MODULE, "ZMQ服务器已启动"); - } catch (const zmq::error_t& e) { - state_ = zmq_server::state::FAILED; - log_module_error(ZMQ_SERVER_LOG_MODULE, "启动失败: {}", e.what()); - throw; - } + if (state_ == zmq_server::state::RUNNING) { + log_module_warn(ZMQ_SERVER_LOG_MODULE, "服务器已在运行中"); + return; + } + + try { + socket_ = zmq::socket_t(context_, zmq::socket_type::router); + socket_.bind(ZMQ_SERVER_ADDRESS); + state_ = zmq_server::state::RUNNING; + log_module_info(ZMQ_SERVER_LOG_MODULE, "ZMQ服务器已启动"); + } + catch (const zmq::error_t& e) { + state_ = zmq_server::state::FAILED; + log_module_error(ZMQ_SERVER_LOG_MODULE, "启动失败: {}", e.what()); + throw; + } } void zmq_server::recv() { - try { - // --- 接收并存储身份 --- - zmq::message_t id_msg; - auto result = socket_.recv(id_msg, zmq::recv_flags::none); + try { + // --- 接收并存储身份 --- + zmq::message_t id_msg; + auto result = socket_.recv(id_msg, zmq::recv_flags::none); - if (!result || *result == 0) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "接收客户端ID失败"); - return; - } + if (!result || *result == 0) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "接收客户端ID失败"); + return; + } - zmq::message_t content_msg; - result = socket_.recv(content_msg, zmq::recv_flags::none); + zmq::message_t content_msg; + result = socket_.recv(content_msg, zmq::recv_flags::none); - if (!result || *result == 0) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "接收消息内容失败"); - return; - } + if (!result || *result == 0) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "接收消息内容失败"); + return; + } - if (id_msg.size() != sizeof(uint32_t)) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "无效的客户端ID大小: {}", id_msg.size()); - return; - } - auto client_id = *static_cast(id_msg.data()); - log_module_debug(ZMQ_SERVER_LOG_MODULE, "接收到来自客户端 {} 的消息", client_id); - clients_[client_id] = std::move(id_msg); + if (id_msg.size() != sizeof(uint32_t)) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "无效的客户端ID大小: {}", id_msg.size()); + return; + } + auto client_id = *static_cast(id_msg.data()); + log_module_debug(ZMQ_SERVER_LOG_MODULE, "接收到来自客户端 {} 的消息", client_id); + clients_[client_id] = std::move(id_msg); - if (content_msg.empty()) - return; - // 处理 content_msg ... - auto pack = zmq_message_pack::deserialize(content_msg); - log_module_debug(ZMQ_SERVER_LOG_MODULE, "收到消息,func_id: {}, payload大小: {}", - pack.func_id, pack.payload.size()); - zmq_server_processor::instance().process(pack.func_id, client_id, pack.payload.data(), pack.payload.size()); - } catch (const std::exception& e) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "处理消息异常: {}", e.what()); - } + if (content_msg.empty()) + return; + // 处理 content_msg ... + auto pack = zmq_message_pack::deserialize(content_msg); + log_module_debug(ZMQ_SERVER_LOG_MODULE, "收到消息,func_id: {}, payload大小: {}", + pack.func_id, pack.payload.size()); + zmq_server_processor::instance().process(pack.func_id, client_id, pack.payload.data(), pack.payload.size()); + } + catch (const std::exception& e) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "处理消息异常: {}", e.what()); + } } void zmq_server::remove_client(uint32_t client_id) { - auto it = clients_.find(client_id); - if (it != clients_.end()) { - clients_.erase(it); - log_module_info(ZMQ_SERVER_LOG_MODULE, "移除客户端 {}", client_id); - } + auto it = clients_.find(client_id); + if (it != clients_.end()) { + clients_.erase(it); + log_module_info(ZMQ_SERVER_LOG_MODULE, "移除客户端 {}", client_id); + } } bool zmq_server::has_client(uint32_t client_id) const { - return clients_.contains(client_id); + return clients_.contains(client_id); } diff --git a/src/network/transport/zmq_server.h b/src/network/transport/zmq_server.h index 08d65fd..51979ce 100644 --- a/src/network/transport/zmq_server.h +++ b/src/network/transport/zmq_server.h @@ -10,66 +10,70 @@ #define ZMQ_SERVER_LOG_MODULE "zmq server" class zmq_server : public lazy_singleton { - friend class lazy_singleton; + friend class lazy_singleton; + public: - enum class state { - STOPPED, - RUNNING, - FAILED - }; - - ~zmq_server() { - if (state_ == state::RUNNING) { - try { - socket_.set(zmq::sockopt::linger, 0); // 立即关闭,不等待 - socket_.close(); - state_ = state::STOPPED; - } catch (...) { - // 析构函数中不抛出异常,也不记录日志(logger可能已析构) - } - } - } - - void init(); + enum class state { + STOPPED, + RUNNING, + FAILED + }; - template - void send(uint32_t client_id, const T& data) { - const auto& it = clients_.find(client_id); - if (it == clients_.end()) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "尝试向未知客户端 {} 发送消息", client_id); - return; - } - - try { - auto buf = zmq_message_pack::create(data); - zmq::message_t msg(buf.data(), buf.size()); - auto result1 = socket_.send(it->second, zmq::send_flags::sndmore); - auto result2 = socket_.send(msg, zmq::send_flags::none); - - if (!result1 || !result2) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "向客户端 {} 发送消息失败", client_id); - } else { - log_module_debug(ZMQ_SERVER_LOG_MODULE, "成功向客户端 {} 发送消息", client_id); - } - } catch (const zmq::error_t& e) { - log_module_error(ZMQ_SERVER_LOG_MODULE, "发送消息异常: {}", e.what()); - } - } + ~zmq_server() { + if (state_ == state::RUNNING) { + try { + socket_.set(zmq::sockopt::linger, 0); // 立即关闭,不等待 + socket_.close(); + state_ = state::STOPPED; + } + catch (...) { + // 析构函数中不抛出异常,也不记录日志(logger可能已析构) + } + } + } + + void init(); + + template + void send(uint32_t client_id, const T& data) { + const auto& it = clients_.find(client_id); + if (it == clients_.end()) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "尝试向未知客户端 {} 发送消息", client_id); + return; + } + + try { + auto buf = zmq_message_pack::create(data); + zmq::message_t msg(buf.data(), buf.size()); + auto result1 = socket_.send(it->second, zmq::send_flags::sndmore); + auto result2 = socket_.send(msg, zmq::send_flags::none); + + if (!result1 || !result2) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "向客户端 {} 发送消息失败", client_id); + } + else { + log_module_debug(ZMQ_SERVER_LOG_MODULE, "成功向客户端 {} 发送消息", client_id); + } + } + catch (const zmq::error_t& e) { + log_module_error(ZMQ_SERVER_LOG_MODULE, "发送消息异常: {}", e.what()); + } + } + + void recv(); + + // 状态管理 + state get_state() const { return state_; } + bool is_running() const { return state_ == state::RUNNING; } + + // 客户端管理 + void remove_client(uint32_t client_id); + bool has_client(uint32_t client_id) const; + size_t client_count() const { return clients_.size(); } - void recv(); - - // 状态管理 - state get_state() const { return state_; } - bool is_running() const { return state_ == state::RUNNING; } - - // 客户端管理 - void remove_client(uint32_t client_id); - bool has_client(uint32_t client_id) const; - size_t client_count() const { return clients_.size(); } - private: - std::unordered_map clients_; - zmq::socket_t socket_; - zmq::context_t context_; - state state_ = state::STOPPED; + std::unordered_map clients_; + zmq::socket_t socket_; + zmq::context_t context_; + state state_ = state::STOPPED; }; diff --git a/src/network/transport/zmq_server_processor.h b/src/network/transport/zmq_server_processor.h index c48ceec..98f1788 100644 --- a/src/network/transport/zmq_server_processor.h +++ b/src/network/transport/zmq_server_processor.h @@ -9,48 +9,60 @@ #define ZMQ_SERVER_PROCESSOR_LOG_MODULE "zmq_server_processor" class zmq_server_processor : public lazy_singleton { - friend class lazy_singleton; + friend class lazy_singleton; public: - void register_processor(uint32_t func_id, std::function processor) { - processors_[func_id] = std::move(processor); - } + void register_processor(uint32_t func_id, std::function processor) { + processors_[func_id] = std::move(processor); + } + + void process(uint32_t func_id, uint32_t client_id, const void* payload, size_t size) { + const auto it = processors_.find(func_id); + if (it == processors_.end()) { + log_module_error(ZMQ_SERVER_PROCESSOR_LOG_MODULE, "未找到func_id {} 的处理器", func_id); + return; + } + + try { + it->second(client_id, payload, size); + } + catch (const std::exception& e) { + log_module_error(ZMQ_SERVER_PROCESSOR_LOG_MODULE, "处理器执行异常: {}", e.what()); + } + } - void process(uint32_t func_id, uint32_t client_id, const void* payload, size_t size) { - const auto it = processors_.find(func_id); - if (it == processors_.end()) { - log_module_error(ZMQ_SERVER_PROCESSOR_LOG_MODULE, "未找到func_id {} 的处理器", func_id); - return; - } - - try { - it->second(client_id, payload, size); - } catch (const std::exception& e) { - log_module_error(ZMQ_SERVER_PROCESSOR_LOG_MODULE, "处理器执行异常: {}", e.what()); - } - } private: - std::unordered_map> processors_; + std::unordered_map> processors_; }; template class zmq_server_register { public: - zmq_server_register(auto processor) { - zmq_server_processor::instance().register_processor(alicho_type_id_v, [processor](uint32_t client_id, const void* payload, size_t size) { - try { - // struct_pack::deserialize返回std::expected,需要检查结果 - auto result = struct_pack::deserialize((const char*)payload, size); - if (result.has_value()) { - processor(client_id, result.value()); - } else { - log_module_error(ZMQ_SERVER_PROCESSOR_LOG_MODULE, "反序列化失败: error_code={}", static_cast(result.error())); - } - } catch (const std::exception& e) { - log_module_error(ZMQ_SERVER_PROCESSOR_LOG_MODULE, "反序列化异常: {}", e.what()); - } - }); - } + zmq_server_register(auto processor) { + zmq_server_processor::instance().register_processor(alicho_type_id_v, + [processor](uint32_t client_id, const void* payload, + size_t size) { + try { + // struct_pack::deserialize返回std::expected,需要检查结果 + auto result = struct_pack::deserialize( + (const char*)payload, size); + if (result.has_value()) { + processor(client_id, result.value()); + } + else { + log_module_error( + ZMQ_SERVER_PROCESSOR_LOG_MODULE, + "反序列化失败: error_code={}", + static_cast(result.error())); + } + } + catch (const std::exception& e) { + log_module_error( + ZMQ_SERVER_PROCESSOR_LOG_MODULE, "反序列化异常: {}", + e.what()); + } + }); + } }; #define ZMQ_SERVER_REGISTER_PROCESSOR(data_type) \ diff --git a/src/network/transport/zmq_util.h b/src/network/transport/zmq_util.h index 72dcf82..f0b3de9 100644 --- a/src/network/transport/zmq_util.h +++ b/src/network/transport/zmq_util.h @@ -8,22 +8,23 @@ #endif struct zmq_message_pack { - uint32_t func_id; // 用于标识处理函数 - std::vector payload; // 实际数据负载 + uint32_t func_id; // 用于标识处理函数 + std::vector payload; // 实际数据负载 - template - static auto create(const T& data) { - zmq_message_pack pack; - pack.func_id = alicho_type_id_v; - pack.payload = struct_pack::serialize(data); - return struct_pack::serialize(pack); - } + template + static auto create(const T& data) { + zmq_message_pack pack; + pack.func_id = alicho_type_id_v; + pack.payload = struct_pack::serialize(data); + return struct_pack::serialize(pack); + } - static auto deserialize(const zmq::message_t& pack) { - const auto& result = struct_pack::deserialize(static_cast(pack.data()), pack.size()); - if (!result.has_value()) { - throw std::runtime_error("无法反序列化 zmq_message_pack,错误代码=" + std::to_string(result.error())); - } - return result.value(); - } + static auto deserialize(const zmq::message_t& pack) { + const auto& result = struct_pack::deserialize(static_cast(pack.data()), + pack.size()); + if (!result.has_value()) { + throw std::runtime_error("无法反序列化 zmq_message_pack,错误代码=" + std::to_string(result.error())); + } + return result.value(); + } }; diff --git a/src/simd/aligned_allocator.h b/src/simd/aligned_allocator.h index e58f936..8096df2 100644 --- a/src/simd/aligned_allocator.h +++ b/src/simd/aligned_allocator.h @@ -16,77 +16,79 @@ inline auto aligned_malloc(size_t size, size_t alignment) -> void* { return nullptr; } -#if ALICHO_PLATFORM_WINDOWS + #if ALICHO_PLATFORM_WINDOWS return _aligned_malloc(size, alignment); -#elif ALICHO_PLATFORM_POSIX || ALICHO_PLATFORM_UNIX + #elif ALICHO_PLATFORM_POSIX || ALICHO_PLATFORM_UNIX void* ptr = nullptr; if (posix_memalign(&ptr, alignment, size) != 0) { return nullptr; } return ptr; -#else + #else // 回退实现:手动对齐 // 分配额外空间来存储原始指针和进行对齐 size_t total_size = size + alignment + sizeof(void*); - void* raw_ptr = std::malloc(total_size); + void* raw_ptr = std::malloc(total_size); if (!raw_ptr) { return nullptr; } // 计算对齐后的地址 - uintptr_t raw_addr = reinterpret_cast(raw_ptr); + uintptr_t raw_addr = reinterpret_cast(raw_ptr); uintptr_t aligned_addr = (raw_addr + sizeof(void*) + alignment - 1) & ~(alignment - 1); - void* aligned_ptr = reinterpret_cast(aligned_addr); + void* aligned_ptr = reinterpret_cast(aligned_addr); // 在对齐地址前存储原始指针 (reinterpret_cast(aligned_ptr))[-1] = raw_ptr; return aligned_ptr; -#endif + #endif } inline void aligned_free(void* ptr) { if (!ptr) { return; } -#if ALICHO_PLATFORM_WINDOWS + #if ALICHO_PLATFORM_WINDOWS _aligned_free(ptr); -#elif ALICHO_PLATFORM_POSIX || ALICHO_PLATFORM_UNIX + #elif ALICHO_PLATFORM_POSIX || ALICHO_PLATFORM_UNIX std::free(ptr); -#else + #else // 回退实现:获取原始指针并释放 void* raw_ptr = (reinterpret_cast(ptr))[-1]; std::free(raw_ptr); -#endif + #endif } // 对齐分配器模板类 -template +template class aligned_allocator { public: - using value_type = T; - using pointer = T*; - using const_pointer = const T*; - using reference = T&; + using value_type = T; + using pointer = T*; + using const_pointer = const T*; + using reference = T&; using const_reference = const T&; - using size_type = std::size_t; + using size_type = std::size_t; using difference_type = std::ptrdiff_t; - template + template struct rebind { using other = aligned_allocator; }; aligned_allocator() noexcept = default; - template - aligned_allocator(const aligned_allocator&) noexcept {} + template + aligned_allocator(const aligned_allocator&) noexcept { + } auto allocate(size_type n) -> pointer { - if (n == 0) return nullptr; + if (n == 0) + return nullptr; size_type size = n * sizeof(T); - void* ptr = aligned_malloc(size, Alignment); + void* ptr = aligned_malloc(size, Alignment); if (!ptr) { throw std::bad_alloc(); @@ -99,12 +101,12 @@ public: aligned_free(p); } - template + template void construct(U* p, Args&&... args) { ::new(static_cast(p)) U(std::forward(args)...); } - template + template void destroy(U* p) { p->~U(); } @@ -114,27 +116,27 @@ public: } }; -template +template bool operator==(const aligned_allocator&, const aligned_allocator&) noexcept { return A1 == A2; } -template +template bool operator!=(const aligned_allocator&, const aligned_allocator&) noexcept { return A1 != A2; } // 类型别名,方便使用不同对齐方式的分配器 -template +template using sse_aligned_allocator = aligned_allocator; -template +template using avx_aligned_allocator = aligned_allocator; -template +template using avx512_aligned_allocator = aligned_allocator; -template +template using cache_aligned_allocator = aligned_allocator; -template +template auto is_aligned(void* ptr) -> bool { return (reinterpret_cast(ptr) % Alignment) == 0; } @@ -147,14 +149,14 @@ inline auto align_size(size_t size, size_t alignment) -> size_t { return (size + alignment - 1) & ~(alignment - 1); } -template +template auto align_pointer(void* ptr) -> void* { - const auto addr = reinterpret_cast(ptr); + const auto addr = reinterpret_cast(ptr); const auto aligned_addr = (addr + Alignment - 1) & ~(Alignment - 1); return reinterpret_cast(aligned_addr); } -template +template class aligned_buffer { public: aligned_buffer() = default; @@ -167,20 +169,20 @@ public: if (data_) { deallocate(); } - + if (new_size == 0) { data_ = nullptr; size_ = 0; return; } - + data_ = static_cast(aligned_malloc(new_size * sizeof(T), Alignment)); if (!data_) { throw std::bad_alloc(); } - + size_ = new_size; - + // 对于非POD类型,需要构造对象 if constexpr (!std::is_trivially_constructible_v) { for (size_t i = 0; i < size_; ++i) { @@ -214,7 +216,7 @@ public: auto size() const noexcept { return size_; } auto empty() const noexcept { return size_ == 0; } - auto& operator[](size_t index) noexcept { return data_[index]; } + auto& operator[](size_t index) noexcept { return data_[index]; } const auto& operator[](size_t index) const noexcept { return data_[index]; } auto begin() noexcept { return data_; } @@ -225,8 +227,9 @@ public: [[nodiscard]] auto is_properly_aligned() const noexcept -> bool { return is_aligned(data_); } + private: - T* data_ = nullptr; + T* data_ = nullptr; size_t size_ = 0; }; diff --git a/src/simd/audio_processing/arm_simd_audio_processing_func.cpp b/src/simd/audio_processing/arm_simd_audio_processing_func.cpp index 4dbfc85..457b1ae 100644 --- a/src/simd/audio_processing/arm_simd_audio_processing_func.cpp +++ b/src/simd/audio_processing/arm_simd_audio_processing_func.cpp @@ -8,16 +8,15 @@ #include "aligned_allocator.h" namespace arm_simd_audio_processing_func { - // 基础音频混合函数实现 (NEON版本) void mix_audio_neon(const float* src1, const float* src2, float* dst, size_t num_samples) { ASSERT_ALIGNED(src1, ALIGNMENT_NEON); ASSERT_ALIGNED(src2, ALIGNMENT_NEON); ASSERT_ALIGNED(dst, ALIGNMENT_NEON); - constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + size_t i = 0; // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -26,18 +25,18 @@ namespace arm_simd_audio_processing_func { float32x4_t a1 = vld1q_f32(&src1[i + 4]); float32x4_t a2 = vld1q_f32(&src1[i + 8]); float32x4_t a3 = vld1q_f32(&src1[i + 12]); - + float32x4_t b0 = vld1q_f32(&src2[i]); float32x4_t b1 = vld1q_f32(&src2[i + 4]); float32x4_t b2 = vld1q_f32(&src2[i + 8]); float32x4_t b3 = vld1q_f32(&src2[i + 12]); - + // 并行计算 float32x4_t result0 = vaddq_f32(a0, b0); float32x4_t result1 = vaddq_f32(a1, b1); float32x4_t result2 = vaddq_f32(a2, b2); float32x4_t result3 = vaddq_f32(a3, b3); - + // 存储结果 vst1q_f32(&dst[i], result0); vst1q_f32(&dst[i + 4], result1); @@ -47,8 +46,8 @@ namespace arm_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - float32x4_t a = vld1q_f32(&src1[i]); - float32x4_t b = vld1q_f32(&src2[i]); + float32x4_t a = vld1q_f32(&src1[i]); + float32x4_t b = vld1q_f32(&src2[i]); float32x4_t result = vaddq_f32(a, b); vst1q_f32(&dst[i], result); } @@ -64,9 +63,9 @@ namespace arm_simd_audio_processing_func { ASSERT_ALIGNED(src, ALIGNMENT_NEON); ASSERT_ALIGNED(dst, ALIGNMENT_NEON); - constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + size_t i = 0; float32x4_t gain_vec = vdupq_n_f32(gain); @@ -77,13 +76,13 @@ namespace arm_simd_audio_processing_func { float32x4_t a1 = vld1q_f32(&src[i + 4]); float32x4_t a2 = vld1q_f32(&src[i + 8]); float32x4_t a3 = vld1q_f32(&src[i + 12]); - + // 并行计算增益应用 float32x4_t result0 = vmulq_f32(a0, gain_vec); float32x4_t result1 = vmulq_f32(a1, gain_vec); float32x4_t result2 = vmulq_f32(a2, gain_vec); float32x4_t result3 = vmulq_f32(a3, gain_vec); - + // 存储结果 vst1q_f32(&dst[i], result0); vst1q_f32(&dst[i + 4], result1); @@ -93,7 +92,7 @@ namespace arm_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - float32x4_t a = vld1q_f32(&src[i]); + float32x4_t a = vld1q_f32(&src[i]); float32x4_t result = vmulq_f32(a, gain_vec); vst1q_f32(&dst[i], result); } @@ -108,13 +107,13 @@ namespace arm_simd_audio_processing_func { float calculate_rms_neon(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_NEON); - constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - float32x4_t sum_squares0 = vdupq_n_f32(0.0f); - float32x4_t sum_squares1 = vdupq_n_f32(0.0f); - float32x4_t sum_squares2 = vdupq_n_f32(0.0f); - float32x4_t sum_squares3 = vdupq_n_f32(0.0f); + size_t i = 0; + float32x4_t sum_squares0 = vdupq_n_f32(0.0f); + float32x4_t sum_squares1 = vdupq_n_f32(0.0f); + float32x4_t sum_squares2 = vdupq_n_f32(0.0f); + float32x4_t sum_squares3 = vdupq_n_f32(0.0f); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -123,13 +122,13 @@ namespace arm_simd_audio_processing_func { const float32x4_t a1 = vld1q_f32(&src[i + 4]); const float32x4_t a2 = vld1q_f32(&src[i + 8]); const float32x4_t a3 = vld1q_f32(&src[i + 12]); - + // 并行计算平方 const float32x4_t squared0 = vmulq_f32(a0, a0); const float32x4_t squared1 = vmulq_f32(a1, a1); const float32x4_t squared2 = vmulq_f32(a2, a2); const float32x4_t squared3 = vmulq_f32(a3, a3); - + // 累加到各自的累加器 sum_squares0 = vaddq_f32(sum_squares0, squared0); sum_squares1 = vaddq_f32(sum_squares1, squared1); @@ -139,25 +138,25 @@ namespace arm_simd_audio_processing_func { // 合并4个累加器 float32x4_t sum_squares = vaddq_f32(vaddq_f32(sum_squares0, sum_squares1), - vaddq_f32(sum_squares2, sum_squares3)); + vaddq_f32(sum_squares2, sum_squares3)); // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const float32x4_t a = vld1q_f32(&src[i]); + const float32x4_t a = vld1q_f32(&src[i]); const float32x4_t squared = vmulq_f32(a, a); - sum_squares = vaddq_f32(sum_squares, squared); + sum_squares = vaddq_f32(sum_squares, squared); } // 水平归约操作 - float32x2_t sum_pair = vadd_f32(vget_low_f32(sum_squares), vget_high_f32(sum_squares)); + float32x2_t sum_pair = vadd_f32(vget_low_f32(sum_squares), vget_high_f32(sum_squares)); float32x2_t sum_final = vpadd_f32(sum_pair, sum_pair); - double total_sum = vget_lane_f32(sum_final, 0); + double total_sum = vget_lane_f32(sum_final, 0); // 处理剩余的标量样本 for (; i < num_samples; ++i) { total_sum += static_cast(src[i]) * static_cast(src[i]); } - + return static_cast(std::sqrt(total_sum / static_cast(num_samples))); } @@ -165,13 +164,13 @@ namespace arm_simd_audio_processing_func { float calculate_peak_neon(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_NEON); - constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - float32x4_t peak_vec0 = vdupq_n_f32(0.0f); - float32x4_t peak_vec1 = vdupq_n_f32(0.0f); - float32x4_t peak_vec2 = vdupq_n_f32(0.0f); - float32x4_t peak_vec3 = vdupq_n_f32(0.0f); + size_t i = 0; + float32x4_t peak_vec0 = vdupq_n_f32(0.0f); + float32x4_t peak_vec1 = vdupq_n_f32(0.0f); + float32x4_t peak_vec2 = vdupq_n_f32(0.0f); + float32x4_t peak_vec3 = vdupq_n_f32(0.0f); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -180,13 +179,13 @@ namespace arm_simd_audio_processing_func { const float32x4_t a1 = vld1q_f32(&src[i + 4]); const float32x4_t a2 = vld1q_f32(&src[i + 8]); const float32x4_t a3 = vld1q_f32(&src[i + 12]); - + // 并行计算绝对值 const float32x4_t abs_a0 = vabsq_f32(a0); const float32x4_t abs_a1 = vabsq_f32(a1); const float32x4_t abs_a2 = vabsq_f32(a2); const float32x4_t abs_a3 = vabsq_f32(a3); - + // 更新各自的峰值向量 peak_vec0 = vmaxq_f32(peak_vec0, abs_a0); peak_vec1 = vmaxq_f32(peak_vec1, abs_a1); @@ -200,15 +199,15 @@ namespace arm_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const float32x4_t a = vld1q_f32(&src[i]); + const float32x4_t a = vld1q_f32(&src[i]); const float32x4_t abs_a = vabsq_f32(a); - peak_vec = vmaxq_f32(peak_vec, abs_a); + peak_vec = vmaxq_f32(peak_vec, abs_a); } // 水平最大值归约操作 - float32x2_t max_pair = vmax_f32(vget_low_f32(peak_vec), vget_high_f32(peak_vec)); + float32x2_t max_pair = vmax_f32(vget_low_f32(peak_vec), vget_high_f32(peak_vec)); float32x2_t max_final = vpmax_f32(max_pair, max_pair); - float peak = vget_lane_f32(max_final, 0); + float peak = vget_lane_f32(max_final, 0); // 处理剩余的标量样本 for (; i < num_samples; ++i) { @@ -217,7 +216,7 @@ namespace arm_simd_audio_processing_func { peak = abs_sample; } } - + return peak; } @@ -225,23 +224,23 @@ namespace arm_simd_audio_processing_func { void normalize_audio_neon(const float* src, float* dst, float target_peak, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_NEON); ASSERT_ALIGNED(dst, ALIGNMENT_NEON); - + // 边界情况处理 if (num_samples == 0 || target_peak <= 0.0f) { return; } - + // 计算当前音频的峰值 const float current_peak = calculate_peak_neon(src, num_samples); - + // 如果当前峰值过小,设置为静音 if (current_peak < 1e-10f) { // 使用NEON优化的零填充 - constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; - float32x4_t zero_vec = vdupq_n_f32(0.0f); - size_t i = 0; - + float32x4_t zero_vec = vdupq_n_f32(0.0f); + size_t i = 0; + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { vst1q_f32(&dst[i], zero_vec); @@ -249,22 +248,22 @@ namespace arm_simd_audio_processing_func { vst1q_f32(&dst[i + 8], zero_vec); vst1q_f32(&dst[i + 12], zero_vec); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { vst1q_f32(&dst[i], zero_vec); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { dst[i] = 0.0f; } return; } - + // 计算归一化增益因子 const float gain_factor = target_peak / current_peak; - + // 使用现有的apply_gain_neon函数应用增益 apply_gain_neon(src, dst, gain_factor, num_samples); } @@ -273,73 +272,73 @@ namespace arm_simd_audio_processing_func { void stereo_to_mono_neon(const float* stereo_src, float* mono_dst, size_t num_stereo_samples) { ASSERT_ALIGNED(stereo_src, ALIGNMENT_NEON); ASSERT_ALIGNED(mono_dst, ALIGNMENT_NEON); - + // 边界情况处理 if (num_stereo_samples == 0) { return; } - - constexpr size_t simd_width = 4; // NEON每次处理4个float - constexpr size_t unroll_factor = 2; // 2路循环展开 - const float32x4_t half_vec = vdupq_n_f32(0.5f); // 用于取平均值 - size_t stereo_idx = 0; // 立体声索引 - size_t mono_idx = 0; // 单声道索引 - + + constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t unroll_factor = 2; // 2路循环展开 + const float32x4_t half_vec = vdupq_n_f32(0.5f); // 用于取平均值 + size_t stereo_idx = 0; // 立体声索引 + size_t mono_idx = 0; // 单声道索引 + // 向量化处理(2路循环展开) // 每次处理4个单声道样本,需要读取8个立体声样本 for (; stereo_idx + simd_width * 2 * unroll_factor <= num_stereo_samples * 2; - stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { - + stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { // 加载8个立体声样本对 (16个float) - float32x4_t stereo0 = vld1q_f32(&stereo_src[stereo_idx]); // [L0, R0, L1, R1] - float32x4_t stereo1 = vld1q_f32(&stereo_src[stereo_idx + 4]); // [L2, R2, L3, R3] - float32x4_t stereo2 = vld1q_f32(&stereo_src[stereo_idx + 8]); // [L4, R4, L5, R5] + float32x4_t stereo0 = vld1q_f32(&stereo_src[stereo_idx]); // [L0, R0, L1, R1] + float32x4_t stereo1 = vld1q_f32(&stereo_src[stereo_idx + 4]); // [L2, R2, L3, R3] + float32x4_t stereo2 = vld1q_f32(&stereo_src[stereo_idx + 8]); // [L4, R4, L5, R5] float32x4_t stereo3 = vld1q_f32(&stereo_src[stereo_idx + 12]); // [L6, R6, L7, R7] - + // 分离左右声道 float32x4x2_t deinterleave0 = vuzpq_f32(stereo0, stereo1); // left=[L0,L1,L2,L3], right=[R0,R1,R2,R3] float32x4x2_t deinterleave1 = vuzpq_f32(stereo2, stereo3); // left=[L4,L5,L6,L7], right=[R4,R5,R6,R7] - + // 计算单声道 = (左声道 + 右声道) / 2 float32x4_t mono0 = vmulq_f32(vaddq_f32(deinterleave0.val[0], deinterleave0.val[1]), half_vec); float32x4_t mono1 = vmulq_f32(vaddq_f32(deinterleave1.val[0], deinterleave1.val[1]), half_vec); - + // 存储结果 vst1q_f32(&mono_dst[mono_idx], mono0); vst1q_f32(&mono_dst[mono_idx + 4], mono1); } - + // 处理剩余的样本对(标量处理) for (size_t i = stereo_idx / 2; i < num_stereo_samples; ++i) { - const float left = stereo_src[i * 2]; + const float left = stereo_src[i * 2]; const float right = stereo_src[i * 2 + 1]; - mono_dst[i] = (left + right) * 0.5f; + mono_dst[i] = (left + right) * 0.5f; } } // 音频限幅函数实现 (NEON版本) - void limit_audio_neon(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples) { + void limit_audio_neon(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_NEON); ASSERT_ALIGNED(dst, ALIGNMENT_NEON); - + // 边界情况处理 if (num_samples == 0 || threshold <= 0.0f) { return; } - - constexpr size_t simd_width = 4; // NEON每次处理4个float - constexpr size_t unroll_factor = 4; // 4路循环展开 - constexpr float release_time = 0.05f; // 释放时间常数(秒) - const float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 - + + constexpr size_t simd_width = 4; // NEON每次处理4个float + constexpr size_t unroll_factor = 4; // 4路循环展开 + constexpr float release_time = 0.05f; // 释放时间常数(秒) + const float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 + // 初始化限幅器状态 float current_gain = limiter_state != nullptr ? *limiter_state : 1.0f; - + // 阈值向量 const float32x4_t threshold_vec = vdupq_n_f32(threshold); - + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -347,71 +346,73 @@ namespace arm_simd_audio_processing_func { float32x4_t a1 = vld1q_f32(&src[i + 4]); float32x4_t a2 = vld1q_f32(&src[i + 8]); float32x4_t a3 = vld1q_f32(&src[i + 12]); - + // 计算绝对值 float32x4_t abs_a0 = vabsq_f32(a0); float32x4_t abs_a1 = vabsq_f32(a1); float32x4_t abs_a2 = vabsq_f32(a2); float32x4_t abs_a3 = vabsq_f32(a3); - + // 找出最大值 - float32x4_t max_abs = vmaxq_f32(vmaxq_f32(abs_a0, abs_a1), - vmaxq_f32(abs_a2, abs_a3)); - + float32x4_t max_abs = vmaxq_f32(vmaxq_f32(abs_a0, abs_a1), + vmaxq_f32(abs_a2, abs_a3)); + // 水平最大值归约 - float32x2_t max_pair = vmax_f32(vget_low_f32(max_abs), vget_high_f32(max_abs)); - float32x2_t max_final = vpmax_f32(max_pair, max_pair); - float max_sample = vget_lane_f32(max_final, 0); - + float32x2_t max_pair = vmax_f32(vget_low_f32(max_abs), vget_high_f32(max_abs)); + float32x2_t max_final = vpmax_f32(max_pair, max_pair); + float max_sample = vget_lane_f32(max_final, 0); + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 将当前增益转换为向量 float32x4_t gain_vec = vdupq_n_f32(current_gain); - + // 应用增益 float32x4_t result0 = vmulq_f32(a0, gain_vec); float32x4_t result1 = vmulq_f32(a1, gain_vec); float32x4_t result2 = vmulq_f32(a2, gain_vec); float32x4_t result3 = vmulq_f32(a3, gain_vec); - + // 存储结果 vst1q_f32(&dst[i], result0); vst1q_f32(&dst[i + 4], result1); vst1q_f32(&dst[i + 8], result2); vst1q_f32(&dst[i + 12], result3); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { - float sample = src[i]; + float sample = src[i]; float abs_sample = std::fabs(sample); - + // 计算需要的增益以限制幅度 float target_gain = abs_sample > threshold ? threshold / abs_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 dst[i] = sample * current_gain; } - + // 更新限幅器状态 if (limiter_state != nullptr) { *limiter_state = current_gain; @@ -419,161 +420,184 @@ namespace arm_simd_audio_processing_func { } // 音频淡入淡出函数实现 (NEON版本) - void fade_audio_neon(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples) { + void fade_audio_neon(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_NEON); ASSERT_ALIGNED(dst, ALIGNMENT_NEON); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 4; // NEON每次处理4个float + + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开 - size_t i = 0; - + size_t i = 0; + // 处理淡入部分 if (fade_in_samples > 0) { const float fade_in_step = 1.0f / static_cast(fade_in_samples); - + // 向量化处理淡入(4路循环展开) - for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * unroll_factor) { + for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * + unroll_factor) { // 计算当前样本的淡入系数 - float32x4_t gain0 = {i * fade_in_step, (i + 1) * fade_in_step, (i + 2) * fade_in_step, (i + 3) * fade_in_step}; - float32x4_t gain1 = {(i + 4) * fade_in_step, (i + 5) * fade_in_step, (i + 6) * fade_in_step, (i + 7) * fade_in_step}; - float32x4_t gain2 = {(i + 8) * fade_in_step, (i + 9) * fade_in_step, (i + 10) * fade_in_step, (i + 11) * fade_in_step}; - float32x4_t gain3 = {(i + 12) * fade_in_step, (i + 13) * fade_in_step, (i + 14) * fade_in_step, (i + 15) * fade_in_step}; - + float32x4_t gain0 = { + i * fade_in_step, (i + 1) * fade_in_step, (i + 2) * fade_in_step, (i + 3) * fade_in_step + }; + float32x4_t gain1 = { + (i + 4) * fade_in_step, (i + 5) * fade_in_step, (i + 6) * fade_in_step, (i + 7) * fade_in_step + }; + float32x4_t gain2 = { + (i + 8) * fade_in_step, (i + 9) * fade_in_step, (i + 10) * fade_in_step, (i + 11) * fade_in_step + }; + float32x4_t gain3 = { + (i + 12) * fade_in_step, (i + 13) * fade_in_step, (i + 14) * fade_in_step, (i + 15) * fade_in_step + }; + // 加载音频样本 float32x4_t a0 = vld1q_f32(&src[i]); float32x4_t a1 = vld1q_f32(&src[i + 4]); float32x4_t a2 = vld1q_f32(&src[i + 8]); float32x4_t a3 = vld1q_f32(&src[i + 12]); - + // 应用淡入增益 float32x4_t result0 = vmulq_f32(a0, gain0); float32x4_t result1 = vmulq_f32(a1, gain1); float32x4_t result2 = vmulq_f32(a2, gain2); float32x4_t result3 = vmulq_f32(a3, gain3); - + // 存储结果 vst1q_f32(&dst[i], result0); vst1q_f32(&dst[i + 4], result1); vst1q_f32(&dst[i + 8], result2); vst1q_f32(&dst[i + 12], result3); } - + // 处理剩余的淡入样本(标量处理) for (; i < std::min(fade_in_samples, num_samples); ++i) { const float gain = static_cast(i) / static_cast(fade_in_samples); - dst[i] = src[i] * gain; + dst[i] = src[i] * gain; } } - + // 处理中间部分(无淡入淡出,直接复制) const size_t middle_start = fade_in_samples; - const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; - + const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; + if (middle_end > middle_start) { // 使用NEON优化的直接复制 - for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * unroll_factor) { + for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * + unroll_factor) { float32x4_t a0 = vld1q_f32(&src[j]); float32x4_t a1 = vld1q_f32(&src[j + 4]); float32x4_t a2 = vld1q_f32(&src[j + 8]); float32x4_t a3 = vld1q_f32(&src[j + 12]); - + vst1q_f32(&dst[j], a0); vst1q_f32(&dst[j + 4], a1); vst1q_f32(&dst[j + 8], a2); vst1q_f32(&dst[j + 12], a3); } - + // 处理剩余的中间样本(标量处理) - for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * unroll_factor); j < middle_end; ++j) { + for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * + unroll_factor); j < middle_end; ++j) { dst[j] = src[j]; } } - + // 处理淡出部分 if (fade_out_samples > 0 && num_samples > fade_out_samples) { const size_t fade_out_start = num_samples - fade_out_samples; - const float fade_out_step = 1.0f / static_cast(fade_out_samples); - + const float fade_out_step = 1.0f / static_cast(fade_out_samples); + // 向量化处理淡出(4路循环展开) - for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * unroll_factor) { + for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * + unroll_factor) { // 计算当前样本的淡出系数(从1递减到0) const size_t fade_out_offset = j - fade_out_start; - float32x4_t gain0 = {1.0f - fade_out_offset * fade_out_step, 1.0f - (fade_out_offset + 1) * fade_out_step, - 1.0f - (fade_out_offset + 2) * fade_out_step, 1.0f - (fade_out_offset + 3) * fade_out_step}; - float32x4_t gain1 = {1.0f - (fade_out_offset + 4) * fade_out_step, 1.0f - (fade_out_offset + 5) * fade_out_step, - 1.0f - (fade_out_offset + 6) * fade_out_step, 1.0f - (fade_out_offset + 7) * fade_out_step}; - float32x4_t gain2 = {1.0f - (fade_out_offset + 8) * fade_out_step, 1.0f - (fade_out_offset + 9) * fade_out_step, - 1.0f - (fade_out_offset + 10) * fade_out_step, 1.0f - (fade_out_offset + 11) * fade_out_step}; - float32x4_t gain3 = {1.0f - (fade_out_offset + 12) * fade_out_step, 1.0f - (fade_out_offset + 13) * fade_out_step, - 1.0f - (fade_out_offset + 14) * fade_out_step, 1.0f - (fade_out_offset + 15) * fade_out_step}; - + float32x4_t gain0 = { + 1.0f - fade_out_offset * fade_out_step, 1.0f - (fade_out_offset + 1) * fade_out_step, + 1.0f - (fade_out_offset + 2) * fade_out_step, 1.0f - (fade_out_offset + 3) * fade_out_step + }; + float32x4_t gain1 = { + 1.0f - (fade_out_offset + 4) * fade_out_step, 1.0f - (fade_out_offset + 5) * fade_out_step, + 1.0f - (fade_out_offset + 6) * fade_out_step, 1.0f - (fade_out_offset + 7) * fade_out_step + }; + float32x4_t gain2 = { + 1.0f - (fade_out_offset + 8) * fade_out_step, 1.0f - (fade_out_offset + 9) * fade_out_step, + 1.0f - (fade_out_offset + 10) * fade_out_step, 1.0f - (fade_out_offset + 11) * fade_out_step + }; + float32x4_t gain3 = { + 1.0f - (fade_out_offset + 12) * fade_out_step, 1.0f - (fade_out_offset + 13) * fade_out_step, + 1.0f - (fade_out_offset + 14) * fade_out_step, 1.0f - (fade_out_offset + 15) * fade_out_step + }; + // 加载音频样本 float32x4_t a0 = vld1q_f32(&src[j]); float32x4_t a1 = vld1q_f32(&src[j + 4]); float32x4_t a2 = vld1q_f32(&src[j + 8]); float32x4_t a3 = vld1q_f32(&src[j + 12]); - + // 应用淡出增益 float32x4_t result0 = vmulq_f32(a0, gain0); float32x4_t result1 = vmulq_f32(a1, gain1); float32x4_t result2 = vmulq_f32(a2, gain2); float32x4_t result3 = vmulq_f32(a3, gain3); - + // 存储结果 vst1q_f32(&dst[j], result0); vst1q_f32(&dst[j + 4], result1); vst1q_f32(&dst[j + 8], result2); vst1q_f32(&dst[j + 12], result3); } - + // 处理剩余的淡出样本(标量处理) - for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * unroll_factor)); j < num_samples; ++j) { + for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * + unroll_factor)); j < num_samples; ++j) { const size_t fade_out_offset = j - fade_out_start; - const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); + const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); dst[j] = src[j] * gain; } } } // 简单均衡器函数实现 (NEON版本) - void simple_eq_neon(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples) { + void simple_eq_neon(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_NEON); ASSERT_ALIGNED(dst, ALIGNMENT_NEON); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 4; // NEON每次处理4个float + + constexpr size_t simd_width = 4; // NEON每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开 - + // 简化的频率分割系数 - constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 - constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 - constexpr float mid_factor = 0.7f; // 中频保持系数 - + constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 + constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 + constexpr float mid_factor = 0.7f; // 中频保持系数 + // 初始化EQ状态 - float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; + float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; float high_state = eq_state != nullptr ? eq_state[1] : 0.0f; - + // 创建增益向量 - const float32x4_t low_gain_vec = vdupq_n_f32(low_gain); - const float32x4_t mid_gain_vec = vdupq_n_f32(mid_gain); - const float32x4_t high_gain_vec = vdupq_n_f32(high_gain); - const float32x4_t low_cutoff_vec = vdupq_n_f32(low_cutoff); - const float32x4_t high_cutoff_vec = vdupq_n_f32(high_cutoff); - const float32x4_t mid_factor_vec = vdupq_n_f32(mid_factor); - const float32x4_t one_minus_low_cutoff_vec = vdupq_n_f32(1.0f - low_cutoff); + const float32x4_t low_gain_vec = vdupq_n_f32(low_gain); + const float32x4_t mid_gain_vec = vdupq_n_f32(mid_gain); + const float32x4_t high_gain_vec = vdupq_n_f32(high_gain); + const float32x4_t low_cutoff_vec = vdupq_n_f32(low_cutoff); + const float32x4_t high_cutoff_vec = vdupq_n_f32(high_cutoff); + const float32x4_t mid_factor_vec = vdupq_n_f32(mid_factor); + const float32x4_t one_minus_low_cutoff_vec = vdupq_n_f32(1.0f - low_cutoff); const float32x4_t one_minus_high_cutoff_vec = vdupq_n_f32(1.0f - high_cutoff); - + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -581,70 +605,74 @@ namespace arm_simd_audio_processing_func { float32x4_t input1 = vld1q_f32(&src[i + 4]); float32x4_t input2 = vld1q_f32(&src[i + 8]); float32x4_t input3 = vld1q_f32(&src[i + 12]); - + // 简化的低通滤波器实现(一阶IIR) float32x4_t low_state_vec = vdupq_n_f32(low_state); float32x4_t low0 = vmlaq_f32(vmulq_f32(low_state_vec, one_minus_low_cutoff_vec), input0, low_cutoff_vec); float32x4_t low1 = vmlaq_f32(vmulq_f32(low0, one_minus_low_cutoff_vec), input1, low_cutoff_vec); float32x4_t low2 = vmlaq_f32(vmulq_f32(low1, one_minus_low_cutoff_vec), input2, low_cutoff_vec); float32x4_t low3 = vmlaq_f32(vmulq_f32(low2, one_minus_low_cutoff_vec), input3, low_cutoff_vec); - + // 简化的高通滤波器实现 float32x4_t high0 = vsubq_f32(input0, low0); float32x4_t high1 = vsubq_f32(input1, low1); float32x4_t high2 = vsubq_f32(input2, low2); float32x4_t high3 = vsubq_f32(input3, low3); - + // 进一步高频处理 float32x4_t high_state_vec = vdupq_n_f32(high_state); high0 = vmlaq_f32(vmulq_f32(high_state_vec, one_minus_high_cutoff_vec), high0, high_cutoff_vec); high1 = vmlaq_f32(vmulq_f32(high0, one_minus_high_cutoff_vec), high1, high_cutoff_vec); high2 = vmlaq_f32(vmulq_f32(high1, one_minus_high_cutoff_vec), high2, high_cutoff_vec); high3 = vmlaq_f32(vmulq_f32(high2, one_minus_high_cutoff_vec), high3, high_cutoff_vec); - + // 中频:原始信号减去低频和高频 float32x4_t mid0 = vmulq_f32(vsubq_f32(vsubq_f32(input0, low0), high0), mid_factor_vec); float32x4_t mid1 = vmulq_f32(vsubq_f32(vsubq_f32(input1, low1), high1), mid_factor_vec); float32x4_t mid2 = vmulq_f32(vsubq_f32(vsubq_f32(input2, low2), high2), mid_factor_vec); float32x4_t mid3 = vmulq_f32(vsubq_f32(vsubq_f32(input3, low3), high3), mid_factor_vec); - + // 应用增益并混合 - float32x4_t result0 = vmlaq_f32(vmlaq_f32(vmulq_f32(high0, high_gain_vec), mid0, mid_gain_vec), low0, low_gain_vec); - float32x4_t result1 = vmlaq_f32(vmlaq_f32(vmulq_f32(high1, high_gain_vec), mid1, mid_gain_vec), low1, low_gain_vec); - float32x4_t result2 = vmlaq_f32(vmlaq_f32(vmulq_f32(high2, high_gain_vec), mid2, mid_gain_vec), low2, low_gain_vec); - float32x4_t result3 = vmlaq_f32(vmlaq_f32(vmulq_f32(high3, high_gain_vec), mid3, mid_gain_vec), low3, low_gain_vec); - + float32x4_t result0 = vmlaq_f32(vmlaq_f32(vmulq_f32(high0, high_gain_vec), mid0, mid_gain_vec), low0, + low_gain_vec); + float32x4_t result1 = vmlaq_f32(vmlaq_f32(vmulq_f32(high1, high_gain_vec), mid1, mid_gain_vec), low1, + low_gain_vec); + float32x4_t result2 = vmlaq_f32(vmlaq_f32(vmulq_f32(high2, high_gain_vec), mid2, mid_gain_vec), low2, + low_gain_vec); + float32x4_t result3 = vmlaq_f32(vmlaq_f32(vmulq_f32(high3, high_gain_vec), mid3, mid_gain_vec), low3, + low_gain_vec); + // 存储结果 vst1q_f32(&dst[i], result0); vst1q_f32(&dst[i + 4], result1); vst1q_f32(&dst[i + 8], result2); vst1q_f32(&dst[i + 12], result3); - + // 更新状态(使用最后一个元素) - low_state = vgetq_lane_f32(low3, 3); + low_state = vgetq_lane_f32(low3, 3); high_state = vgetq_lane_f32(high3, 3); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { float input = src[i]; - + // 低通滤波器 float low_output = low_cutoff * input + (1.0f - low_cutoff) * low_state; - low_state = low_output; - + low_state = low_output; + // 高通滤波器 - float high_input = input - low_output; + float high_input = input - low_output; float high_output = high_cutoff * high_input + (1.0f - high_cutoff) * high_state; - high_state = high_output; - + high_state = high_output; + // 中频 float mid_output = (input - low_output - high_output) * mid_factor; - + // 混合并应用增益 dst[i] = low_output * low_gain + mid_output * mid_gain + high_output * high_gain; } - + // 更新EQ状态 if (eq_state != nullptr) { eq_state[0] = low_state; diff --git a/src/simd/audio_processing/arm_simd_audio_processing_func.h b/src/simd/audio_processing/arm_simd_audio_processing_func.h index 18a84e8..fae1f27 100644 --- a/src/simd/audio_processing/arm_simd_audio_processing_func.h +++ b/src/simd/audio_processing/arm_simd_audio_processing_func.h @@ -3,8 +3,8 @@ #if ALICHO_PLATFORM_ARM namespace arm_simd_audio_processing_func { // 原有的4个基础音处理函数 - void mix_audio_neon(const float* src1, const float* src2, float* dst, size_t num_samples); - void apply_gain_neon(const float* src, float* dst, float gain, size_t num_samples); + void mix_audio_neon(const float* src1, const float* src2, float* dst, size_t num_samples); + void apply_gain_neon(const float* src, float* dst, float gain, size_t num_samples); float calculate_rms_neon(const float* src, size_t num_samples); float calculate_peak_neon(const float* src, size_t num_samples); @@ -15,12 +15,15 @@ namespace arm_simd_audio_processing_func { void stereo_to_mono_neon(const float* stereo_src, float* mono_dst, size_t num_stereo_samples); // 音频限幅:将超过阈值的样本限制在指定范围内 - void limit_audio_neon(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples); + void limit_audio_neon(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples); // 音频淡入淡出:应用线性淡入淡出效果 - void fade_audio_neon(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples); + void fade_audio_neon(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples); // 简单均衡器:简单的三段均衡器(低频、中频、高频增益) - void simple_eq_neon(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples); + void simple_eq_neon(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples); } #endif diff --git a/src/simd/audio_processing/scalar_audio_processing_func.cpp b/src/simd/audio_processing/scalar_audio_processing_func.cpp index 419bf95..fb722dd 100644 --- a/src/simd/audio_processing/scalar_audio_processing_func.cpp +++ b/src/simd/audio_processing/scalar_audio_processing_func.cpp @@ -34,16 +34,17 @@ namespace scalar_audio_processing_func { } return peak; } + // 音频归一化函数实现 void normalize_audio(const float* src, float* dst, float target_peak, size_t num_samples) { // 边界情况处理 if (num_samples == 0 || target_peak <= 0.0f) { return; } - + // 计算当前音频的峰值 const float current_peak = calculate_peak(src, num_samples); - + // 如果当前峰值过小,设置为静音 if (current_peak < 1e-10f) { for (size_t i = 0; i < num_samples; ++i) { @@ -51,144 +52,147 @@ namespace scalar_audio_processing_func { } return; } - + // 计算归一化增益因子 const float gain_factor = target_peak / current_peak; - + // 应用增益 for (size_t i = 0; i < num_samples; ++i) { dst[i] = src[i] * gain_factor; } } - + // 立体声到单声道转换函数实现 void stereo_to_mono(const float* stereo_src, float* mono_dst, size_t num_stereo_samples) { // 边界情况处理 if (num_stereo_samples == 0) { return; } - + // 对每个立体声样本对(左声道,右声道),计算其平均值作为单声道样本 for (size_t i = 0; i < num_stereo_samples; ++i) { - const float left = stereo_src[i * 2]; + const float left = stereo_src[i * 2]; const float right = stereo_src[i * 2 + 1]; - mono_dst[i] = (left + right) * 0.5f; + mono_dst[i] = (left + right) * 0.5f; } } - + // 音频限幅函数实现 - void limit_audio(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples) { + void limit_audio(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples) { // 边界情况处理 if (num_samples == 0 || threshold <= 0.0f) { return; } - + // 释放时间参数(与SIMD实现保持一致) - constexpr float release_time = 0.05f; // 释放时间常数(秒) - const float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 - + constexpr float release_time = 0.05f; // 释放时间常数(秒) + const float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 + // 初始化限幅器状态 float current_gain = limiter_state != nullptr ? *limiter_state : 1.0f; - + // 处理每个样本 for (size_t i = 0; i < num_samples; ++i) { - float sample = src[i]; + float sample = src[i]; float abs_sample = std::fabs(sample); - + // 计算需要的增益以限制幅度 float target_gain = abs_sample > threshold ? threshold / abs_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 dst[i] = sample * current_gain; } - + // 更新限幅器状态 if (limiter_state != nullptr) { *limiter_state = current_gain; } } - + // 音频淡入淡出函数实现 void fade_audio(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples) { // 边界情况处理 if (num_samples == 0) { return; } - + // 处理淡入部分 if (fade_in_samples > 0) { for (size_t i = 0; i < std::min(fade_in_samples, num_samples); ++i) { const float gain = static_cast(i) / static_cast(fade_in_samples); - dst[i] = src[i] * gain; + dst[i] = src[i] * gain; } } - + // 处理中间部分(无淡入淡出,直接复制) const size_t middle_start = fade_in_samples; - const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; - + const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; + if (middle_end > middle_start) { for (size_t i = middle_start; i < middle_end; ++i) { dst[i] = src[i]; } } - + // 处理淡出部分 if (fade_out_samples > 0 && num_samples > fade_out_samples) { const size_t fade_out_start = num_samples - fade_out_samples - 1; for (size_t i = fade_out_start; i < num_samples; ++i) { const size_t fade_out_offset = i - fade_out_start; - const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); + const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); dst[i] = src[i] * gain; } } } - + // 简单均衡器函数实现 - void simple_eq(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples) { + void simple_eq(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples) { // 边界情况处理 if (num_samples == 0) { return; } - + // 简化的频率分割系数(与SIMD实现保持一致) - constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 - constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 - constexpr float mid_factor = 0.7f; // 中频保持系数 - + constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 + constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 + constexpr float mid_factor = 0.7f; // 中频保持系数 + // 初始化EQ状态 - float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; + float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; float high_state = eq_state != nullptr ? eq_state[1] : 0.0f; - + // 处理每个样本 for (size_t i = 0; i < num_samples; ++i) { float input = src[i]; - + // 低通滤波器 float low_output = low_cutoff * input + (1.0f - low_cutoff) * low_state; - low_state = low_output; - + low_state = low_output; + // 高通滤波器 - float high_input = input - low_output; + float high_input = input - low_output; float high_output = high_cutoff * high_input + (1.0f - high_cutoff) * high_state; - high_state = high_output; - + high_state = high_output; + // 中频 float mid_output = (input - low_output - high_output) * mid_factor; - + // 混合并应用增益 dst[i] = low_output * low_gain + mid_output * mid_gain + high_output * high_gain; } - + // 更新EQ状态 if (eq_state != nullptr) { eq_state[0] = low_state; diff --git a/src/simd/audio_processing/scalar_audio_processing_func.h b/src/simd/audio_processing/scalar_audio_processing_func.h index 99493bf..f9ad154 100644 --- a/src/simd/audio_processing/scalar_audio_processing_func.h +++ b/src/simd/audio_processing/scalar_audio_processing_func.h @@ -3,63 +3,73 @@ #include namespace scalar_audio_processing_func { - void mix_audio(const float* src1, const float* src2, float* dst, size_t num_samples); - void apply_gain(const float* src, float* dst, float gain, size_t num_samples); + void mix_audio(const float* src1, const float* src2, float* dst, size_t num_samples); + void apply_gain(const float* src, float* dst, float gain, size_t num_samples); float calculate_rms(const float* src, size_t num_samples); float calculate_peak(const float* src, size_t num_samples); - + // 音频归一化:根据峰值将音频归一化到指定范围 void normalize_audio(const float* src, float* dst, float target_peak, size_t num_samples); - + // 立体声到单声道转换:将立体声(双通道)音频转换为单声道 void stereo_to_mono(const float* stereo_src, float* mono_dst, size_t num_stereo_samples); - + // 音频限幅:将超过阈值的样本限制在指定范围内 - void limit_audio(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples); - + void limit_audio(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples); + // 音频淡入淡出:应用线性淡入淡出效果 void fade_audio(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples); - + // 简单均衡器:简单的三段均衡器(低频、中频、高频增益) - void simple_eq(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples); + void simple_eq(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples); // 从f32转到有符号类型 - template + template void convert_f32_to_signed(const float* src, T* dst, size_t num_samples) { for (size_t i = 0; i < num_samples; ++i) { float sample = src[i]; - if (sample > 1.0f) sample = 1.0f; - if (sample < -1.0f) sample = -1.0f; + if (sample > 1.0f) + sample = 1.0f; + if (sample < -1.0f) + sample = -1.0f; dst[i] = static_cast(sample * static_cast(std::numeric_limits::max())); } } + // 从f32转到无符号类型 - template + template void convert_f32_to_unsigned(const float* src, T* dst, size_t num_samples) { for (size_t i = 0; i < num_samples; ++i) { float sample = src[i]; - if (sample > 1.0f) sample = 1.0f; - if (sample < -1.0f) sample = -1.0f; + if (sample > 1.0f) + sample = 1.0f; + if (sample < -1.0f) + sample = -1.0f; // 将[-1.0, 1.0]映射到[0, max] dst[i] = static_cast((sample + 1.0f) * 0.5f * static_cast(std::numeric_limits::max())); } } + // 从有符号类型转到f32 - template + template void convert_signed_to_f32(const T* src, float* dst, size_t num_samples) { for (size_t i = 0; i < num_samples; ++i) { dst[i] = static_cast(src[i]) / static_cast(std::numeric_limits::max()); } } + // 从无符号类型转到f32 - template + template void convert_unsigned_to_f32(const T* src, float* dst, size_t num_samples) { for (size_t i = 0; i < num_samples; ++i) { dst[i] = (static_cast(src[i]) / static_cast(std::numeric_limits::max())) * 2.0f - 1.0f; } } + // 直接转换不进行缩放 - template + template void convert_direct(const From* src, To* dst, size_t num_samples) { for (size_t i = 0; i < num_samples; ++i) { dst[i] = static_cast(src[i]); diff --git a/src/simd/audio_processing/x86_simd_audio_processing_func.cpp b/src/simd/audio_processing/x86_simd_audio_processing_func.cpp index b7308ac..87f2a8e 100644 --- a/src/simd/audio_processing/x86_simd_audio_processing_func.cpp +++ b/src/simd/audio_processing/x86_simd_audio_processing_func.cpp @@ -12,9 +12,9 @@ namespace x86_simd_audio_processing_func { ASSERT_ALIGNED(src2, ALIGNMENT_SSE); ASSERT_ALIGNED(dst, ALIGNMENT_SSE); - constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + size_t i = 0; // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -23,18 +23,18 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm_load_ps(&src1[i + 4]); auto a2 = _mm_load_ps(&src1[i + 8]); auto a3 = _mm_load_ps(&src1[i + 12]); - + auto b0 = _mm_load_ps(&src2[i]); auto b1 = _mm_load_ps(&src2[i + 4]); auto b2 = _mm_load_ps(&src2[i + 8]); auto b3 = _mm_load_ps(&src2[i + 12]); - + // 并行计算 auto result0 = _mm_add_ps(a0, b0); auto result1 = _mm_add_ps(a1, b1); auto result2 = _mm_add_ps(a2, b2); auto result3 = _mm_add_ps(a3, b3); - + // 存储结果 _mm_store_ps(&dst[i], result0); _mm_store_ps(&dst[i + 4], result1); @@ -44,8 +44,8 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm_load_ps(&src1[i]); - auto b = _mm_load_ps(&src2[i]); + auto a = _mm_load_ps(&src1[i]); + auto b = _mm_load_ps(&src2[i]); auto result = _mm_add_ps(a, b); _mm_store_ps(&dst[i], result); } @@ -61,9 +61,9 @@ namespace x86_simd_audio_processing_func { ASSERT_ALIGNED(src2, ALIGNMENT_AVX); ASSERT_ALIGNED(dst, ALIGNMENT_AVX); - constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + size_t i = 0; // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -72,18 +72,18 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm256_load_ps(&src1[i + 8]); auto a2 = _mm256_load_ps(&src1[i + 16]); auto a3 = _mm256_load_ps(&src1[i + 24]); - + auto b0 = _mm256_load_ps(&src2[i]); auto b1 = _mm256_load_ps(&src2[i + 8]); auto b2 = _mm256_load_ps(&src2[i + 16]); auto b3 = _mm256_load_ps(&src2[i + 24]); - + // 并行计算 auto result0 = _mm256_add_ps(a0, b0); auto result1 = _mm256_add_ps(a1, b1); auto result2 = _mm256_add_ps(a2, b2); auto result3 = _mm256_add_ps(a3, b3); - + // 存储结果 _mm256_store_ps(&dst[i], result0); _mm256_store_ps(&dst[i + 8], result1); @@ -93,8 +93,8 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm256_load_ps(&src1[i]); - auto b = _mm256_load_ps(&src2[i]); + auto a = _mm256_load_ps(&src1[i]); + auto b = _mm256_load_ps(&src2[i]); auto result = _mm256_add_ps(a, b); _mm256_store_ps(&dst[i], result); } @@ -110,9 +110,9 @@ namespace x86_simd_audio_processing_func { ASSERT_ALIGNED(src2, ALIGNMENT_AVX512); ASSERT_ALIGNED(dst, ALIGNMENT_AVX512); - constexpr size_t simd_width = 16; // AVX-512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + constexpr size_t simd_width = 16; // AVX-512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + size_t i = 0; // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -121,18 +121,18 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm512_load_ps(&src1[i + 16]); const auto a2 = _mm512_load_ps(&src1[i + 32]); const auto a3 = _mm512_load_ps(&src1[i + 48]); - + const auto b0 = _mm512_load_ps(&src2[i]); const auto b1 = _mm512_load_ps(&src2[i + 16]); const auto b2 = _mm512_load_ps(&src2[i + 32]); const auto b3 = _mm512_load_ps(&src2[i + 48]); - + // 并行计算 const auto result0 = _mm512_add_ps(a0, b0); const auto result1 = _mm512_add_ps(a1, b1); const auto result2 = _mm512_add_ps(a2, b2); const auto result3 = _mm512_add_ps(a3, b3); - + // 存储结果 _mm512_store_ps(&dst[i], result0); _mm512_store_ps(&dst[i + 16], result1); @@ -142,8 +142,8 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm512_load_ps(&src1[i]); - auto b = _mm512_load_ps(&src2[i]); + auto a = _mm512_load_ps(&src1[i]); + auto b = _mm512_load_ps(&src2[i]); auto result = _mm512_add_ps(a, b); _mm512_store_ps(&dst[i], result); } @@ -158,9 +158,9 @@ namespace x86_simd_audio_processing_func { ASSERT_ALIGNED(src, ALIGNMENT_SSE); ASSERT_ALIGNED(dst, ALIGNMENT_SSE); - constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + size_t i = 0; auto gain_vec = _mm_set1_ps(gain); @@ -171,13 +171,13 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm_load_ps(&src[i + 4]); auto a2 = _mm_load_ps(&src[i + 8]); auto a3 = _mm_load_ps(&src[i + 12]); - + // 并行计算增益应用 auto result0 = _mm_mul_ps(a0, gain_vec); auto result1 = _mm_mul_ps(a1, gain_vec); auto result2 = _mm_mul_ps(a2, gain_vec); auto result3 = _mm_mul_ps(a3, gain_vec); - + // 存储结果 _mm_store_ps(&dst[i], result0); _mm_store_ps(&dst[i + 4], result1); @@ -187,7 +187,7 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm_load_ps(&src[i]); + auto a = _mm_load_ps(&src[i]); auto result = _mm_mul_ps(a, gain_vec); _mm_store_ps(&dst[i], result); } @@ -202,9 +202,9 @@ namespace x86_simd_audio_processing_func { ASSERT_ALIGNED(src, ALIGNMENT_AVX); ASSERT_ALIGNED(dst, ALIGNMENT_AVX); - constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + size_t i = 0; auto gain_vec = _mm256_set1_ps(gain); @@ -215,13 +215,13 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm256_load_ps(&src[i + 8]); auto a2 = _mm256_load_ps(&src[i + 16]); auto a3 = _mm256_load_ps(&src[i + 24]); - + // 并行计算增益应用 auto result0 = _mm256_mul_ps(a0, gain_vec); auto result1 = _mm256_mul_ps(a1, gain_vec); auto result2 = _mm256_mul_ps(a2, gain_vec); auto result3 = _mm256_mul_ps(a3, gain_vec); - + // 存储结果 _mm256_store_ps(&dst[i], result0); _mm256_store_ps(&dst[i + 8], result1); @@ -231,7 +231,7 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm256_load_ps(&src[i]); + auto a = _mm256_load_ps(&src[i]); auto result = _mm256_mul_ps(a, gain_vec); _mm256_store_ps(&dst[i], result); } @@ -246,9 +246,9 @@ namespace x86_simd_audio_processing_func { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); ASSERT_ALIGNED(dst, ALIGNMENT_AVX512); - constexpr size_t simd_width = 16; // AVX-512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; + constexpr size_t simd_width = 16; // AVX-512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + size_t i = 0; auto gain_vec = _mm512_set1_ps(gain); @@ -259,13 +259,13 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm512_load_ps(&src[i + 16]); auto a2 = _mm512_load_ps(&src[i + 32]); auto a3 = _mm512_load_ps(&src[i + 48]); - + // 并行计算增益应用 auto result0 = _mm512_mul_ps(a0, gain_vec); auto result1 = _mm512_mul_ps(a1, gain_vec); auto result2 = _mm512_mul_ps(a2, gain_vec); auto result3 = _mm512_mul_ps(a3, gain_vec); - + // 存储结果 _mm512_store_ps(&dst[i], result0); _mm512_store_ps(&dst[i + 16], result1); @@ -275,7 +275,7 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm512_load_ps(&src[i]); + auto a = _mm512_load_ps(&src[i]); auto result = _mm512_mul_ps(a, gain_vec); _mm512_store_ps(&dst[i], result); } @@ -289,13 +289,13 @@ namespace x86_simd_audio_processing_func { float calculate_rms_sse(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_SSE); - constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - auto sum_squares0 = _mm_setzero_ps(); - auto sum_squares1 = _mm_setzero_ps(); - auto sum_squares2 = _mm_setzero_ps(); - auto sum_squares3 = _mm_setzero_ps(); + size_t i = 0; + auto sum_squares0 = _mm_setzero_ps(); + auto sum_squares1 = _mm_setzero_ps(); + auto sum_squares2 = _mm_setzero_ps(); + auto sum_squares3 = _mm_setzero_ps(); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -304,13 +304,13 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm_load_ps(&src[i + 4]); const auto a2 = _mm_load_ps(&src[i + 8]); const auto a3 = _mm_load_ps(&src[i + 12]); - + // 并行计算平方 const auto squared0 = _mm_mul_ps(a0, a0); const auto squared1 = _mm_mul_ps(a1, a1); const auto squared2 = _mm_mul_ps(a2, a2); const auto squared3 = _mm_mul_ps(a3, a3); - + // 累加到各自的累加器 sum_squares0 = _mm_add_ps(sum_squares0, squared0); sum_squares1 = _mm_add_ps(sum_squares1, squared1); @@ -324,35 +324,35 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const auto a = _mm_load_ps(&src[i]); + const auto a = _mm_load_ps(&src[i]); const auto squared = _mm_mul_ps(a, a); - sum_squares = _mm_add_ps(sum_squares, squared); + sum_squares = _mm_add_ps(sum_squares, squared); } // **关键优化:高效的SSE水平归约操作** // 使用hadd指令进行水平加法,避免内存存储+循环 - auto hadd1 = _mm_hadd_ps(sum_squares, sum_squares); // [a+b, c+d, a+b, c+d] - auto hadd2 = _mm_hadd_ps(hadd1, hadd1); // [a+b+c+d, *, a+b+c+d, *] + auto hadd1 = _mm_hadd_ps(sum_squares, sum_squares); // [a+b, c+d, a+b, c+d] + auto hadd2 = _mm_hadd_ps(hadd1, hadd1); // [a+b+c+d, *, a+b+c+d, *] double total_sum = _mm_cvtss_f32(hadd2); // 处理剩余的标量样本 for (; i < num_samples; ++i) { total_sum += static_cast(src[i]) * static_cast(src[i]); } - + return static_cast(std::sqrt(total_sum / static_cast(num_samples))); } float calculate_rms_avx(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX); - constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - auto sum_squares0 = _mm256_setzero_ps(); - auto sum_squares1 = _mm256_setzero_ps(); - auto sum_squares2 = _mm256_setzero_ps(); - auto sum_squares3 = _mm256_setzero_ps(); + size_t i = 0; + auto sum_squares0 = _mm256_setzero_ps(); + auto sum_squares1 = _mm256_setzero_ps(); + auto sum_squares2 = _mm256_setzero_ps(); + auto sum_squares3 = _mm256_setzero_ps(); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -361,13 +361,13 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm256_load_ps(&src[i + 8]); const auto a2 = _mm256_load_ps(&src[i + 16]); const auto a3 = _mm256_load_ps(&src[i + 24]); - + // 并行计算平方 const auto squared0 = _mm256_mul_ps(a0, a0); const auto squared1 = _mm256_mul_ps(a1, a1); const auto squared2 = _mm256_mul_ps(a2, a2); const auto squared3 = _mm256_mul_ps(a3, a3); - + // 累加到各自的累加器 sum_squares0 = _mm256_add_ps(sum_squares0, squared0); sum_squares1 = _mm256_add_ps(sum_squares1, squared1); @@ -381,40 +381,40 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const auto a = _mm256_load_ps(&src[i]); + const auto a = _mm256_load_ps(&src[i]); const auto squared = _mm256_mul_ps(a, a); - sum_squares = _mm256_add_ps(sum_squares, squared); + sum_squares = _mm256_add_ps(sum_squares, squared); } // **关键优化:高效的AVX水平归约操作** // 使用hadd + extract指令避免内存存储+循环 auto hadd1 = _mm256_hadd_ps(sum_squares, sum_squares); auto hadd2 = _mm256_hadd_ps(hadd1, hadd1); - + // 提取高低128位并相加 - auto low = _mm256_extractf128_ps(hadd2, 0); - auto high = _mm256_extractf128_ps(hadd2, 1); - auto final_sum = _mm_add_ps(low, high); + auto low = _mm256_extractf128_ps(hadd2, 0); + auto high = _mm256_extractf128_ps(hadd2, 1); + auto final_sum = _mm_add_ps(low, high); double total_sum = _mm_cvtss_f32(final_sum); // 处理剩余的标量样本 for (; i < num_samples; ++i) { total_sum += static_cast(src[i]) * static_cast(src[i]); } - + return static_cast(std::sqrt(total_sum / static_cast(num_samples))); } float calculate_rms_avx512(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); - constexpr size_t simd_width = 16; // AVX-512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - auto sum_squares0 = _mm512_setzero_ps(); - auto sum_squares1 = _mm512_setzero_ps(); - auto sum_squares2 = _mm512_setzero_ps(); - auto sum_squares3 = _mm512_setzero_ps(); + constexpr size_t simd_width = 16; // AVX-512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + size_t i = 0; + auto sum_squares0 = _mm512_setzero_ps(); + auto sum_squares1 = _mm512_setzero_ps(); + auto sum_squares2 = _mm512_setzero_ps(); + auto sum_squares3 = _mm512_setzero_ps(); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -423,13 +423,13 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm512_load_ps(&src[i + 16]); const auto a2 = _mm512_load_ps(&src[i + 32]); const auto a3 = _mm512_load_ps(&src[i + 48]); - + // 并行计算平方 const auto squared0 = _mm512_mul_ps(a0, a0); const auto squared1 = _mm512_mul_ps(a1, a1); const auto squared2 = _mm512_mul_ps(a2, a2); const auto squared3 = _mm512_mul_ps(a3, a3); - + // 累加到各自的累加器 sum_squares0 = _mm512_add_ps(sum_squares0, squared0); sum_squares1 = _mm512_add_ps(sum_squares1, squared1); @@ -443,9 +443,9 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const auto a = _mm512_load_ps(&src[i]); + const auto a = _mm512_load_ps(&src[i]); const auto squared = _mm512_mul_ps(a, a); - sum_squares = _mm512_add_ps(sum_squares, squared); + sum_squares = _mm512_add_ps(sum_squares, squared); } // **关键优化:高效的AVX-512水平归约操作** @@ -456,20 +456,20 @@ namespace x86_simd_audio_processing_func { for (; i < num_samples; ++i) { total_sum += static_cast(src[i]) * static_cast(src[i]); } - + return static_cast(std::sqrt(total_sum / static_cast(num_samples))); } float calculate_peak_sse(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_SSE); - constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - auto peak_vec0 = _mm_setzero_ps(); - auto peak_vec1 = _mm_setzero_ps(); - auto peak_vec2 = _mm_setzero_ps(); - auto peak_vec3 = _mm_setzero_ps(); + size_t i = 0; + auto peak_vec0 = _mm_setzero_ps(); + auto peak_vec1 = _mm_setzero_ps(); + auto peak_vec2 = _mm_setzero_ps(); + auto peak_vec3 = _mm_setzero_ps(); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -478,13 +478,13 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm_load_ps(&src[i + 4]); const auto a2 = _mm_load_ps(&src[i + 8]); const auto a3 = _mm_load_ps(&src[i + 12]); - + // 并行计算绝对值 const auto abs_a0 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a0); const auto abs_a1 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a1); const auto abs_a2 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a2); const auto abs_a3 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a3); - + // 更新各自的峰值向量 peak_vec0 = _mm_max_ps(peak_vec0, abs_a0); peak_vec1 = _mm_max_ps(peak_vec1, abs_a1); @@ -498,17 +498,17 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const auto a = _mm_load_ps(&src[i]); + const auto a = _mm_load_ps(&src[i]); const auto abs_a = _mm_andnot_ps(_mm_set1_ps(-0.0f), a); - peak_vec = _mm_max_ps(peak_vec, abs_a); + peak_vec = _mm_max_ps(peak_vec, abs_a); } // **关键优化:高效的SSE水平最大值归约操作** // 使用shuffle指令序列避免内存存储+循环 - auto temp1 = _mm_shuffle_ps(peak_vec, peak_vec, _MM_SHUFFLE(2, 3, 0, 1)); // [y, x, w, z] - auto max1 = _mm_max_ps(peak_vec, temp1); // [max(x,y), max(x,y), max(z,w), max(z,w)] - auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); // [max(z,w), max(z,w), max(x,y), max(x,y)] - auto final_max = _mm_max_ps(max1, temp2); // [final_max, *, *, *] + auto temp1 = _mm_shuffle_ps(peak_vec, peak_vec, _MM_SHUFFLE(2, 3, 0, 1)); // [y, x, w, z] + auto max1 = _mm_max_ps(peak_vec, temp1); // [max(x,y), max(x,y), max(z,w), max(z,w)] + auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); // [max(z,w), max(z,w), max(x,y), max(x,y)] + auto final_max = _mm_max_ps(max1, temp2); // [final_max, *, *, *] float peak = _mm_cvtss_f32(final_max); // 处理剩余的标量样本 @@ -518,20 +518,20 @@ namespace x86_simd_audio_processing_func { peak = abs_sample; } } - + return peak; } float calculate_peak_avx(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX); - constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - auto peak_vec0 = _mm256_setzero_ps(); - auto peak_vec1 = _mm256_setzero_ps(); - auto peak_vec2 = _mm256_setzero_ps(); - auto peak_vec3 = _mm256_setzero_ps(); + size_t i = 0; + auto peak_vec0 = _mm256_setzero_ps(); + auto peak_vec1 = _mm256_setzero_ps(); + auto peak_vec2 = _mm256_setzero_ps(); + auto peak_vec3 = _mm256_setzero_ps(); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -540,13 +540,13 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm256_load_ps(&src[i + 8]); const auto a2 = _mm256_load_ps(&src[i + 16]); const auto a3 = _mm256_load_ps(&src[i + 24]); - + // 并行计算绝对值 const auto abs_a0 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a0); const auto abs_a1 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a1); const auto abs_a2 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a2); const auto abs_a3 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a3); - + // 更新各自的峰值向量 peak_vec0 = _mm256_max_ps(peak_vec0, abs_a0); peak_vec1 = _mm256_max_ps(peak_vec1, abs_a1); @@ -560,23 +560,23 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const auto a = _mm256_load_ps(&src[i]); + const auto a = _mm256_load_ps(&src[i]); const auto abs_a = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a); - peak_vec = _mm256_max_ps(peak_vec, abs_a); + peak_vec = _mm256_max_ps(peak_vec, abs_a); } // **关键优化:高效的AVX水平最大值归约操作** // 提取高低128位并求最大值,然后使用SSE水平最大值 - auto low = _mm256_extractf128_ps(peak_vec, 0); - auto high = _mm256_extractf128_ps(peak_vec, 1); + auto low = _mm256_extractf128_ps(peak_vec, 0); + auto high = _mm256_extractf128_ps(peak_vec, 1); auto max_lane = _mm_max_ps(low, high); - + // 在128位向量内进行水平最大值操作 - auto temp1 = _mm_shuffle_ps(max_lane, max_lane, _MM_SHUFFLE(2, 3, 0, 1)); - auto max1 = _mm_max_ps(max_lane, temp1); - auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); - auto final_max = _mm_max_ps(max1, temp2); - float peak = _mm_cvtss_f32(final_max); + auto temp1 = _mm_shuffle_ps(max_lane, max_lane, _MM_SHUFFLE(2, 3, 0, 1)); + auto max1 = _mm_max_ps(max_lane, temp1); + auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); + auto final_max = _mm_max_ps(max1, temp2); + float peak = _mm_cvtss_f32(final_max); // 处理剩余的标量样本 for (; i < num_samples; ++i) { @@ -585,20 +585,20 @@ namespace x86_simd_audio_processing_func { peak = abs_sample; } } - + return peak; } float calculate_peak_avx512(const float* src, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); - constexpr size_t simd_width = 16; // AVX-512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - auto peak_vec0 = _mm512_setzero_ps(); - auto peak_vec1 = _mm512_setzero_ps(); - auto peak_vec2 = _mm512_setzero_ps(); - auto peak_vec3 = _mm512_setzero_ps(); + constexpr size_t simd_width = 16; // AVX-512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + size_t i = 0; + auto peak_vec0 = _mm512_setzero_ps(); + auto peak_vec1 = _mm512_setzero_ps(); + auto peak_vec2 = _mm512_setzero_ps(); + auto peak_vec3 = _mm512_setzero_ps(); // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { @@ -607,13 +607,13 @@ namespace x86_simd_audio_processing_func { const auto a1 = _mm512_load_ps(&src[i + 16]); const auto a2 = _mm512_load_ps(&src[i + 32]); const auto a3 = _mm512_load_ps(&src[i + 48]); - + // 并行计算绝对值 const auto abs_a0 = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), a0); const auto abs_a1 = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), a1); const auto abs_a2 = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), a2); const auto abs_a3 = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), a3); - + // 更新各自的峰值向量 peak_vec0 = _mm512_max_ps(peak_vec0, abs_a0); peak_vec1 = _mm512_max_ps(peak_vec1, abs_a1); @@ -627,9 +627,9 @@ namespace x86_simd_audio_processing_func { // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - const auto a = _mm512_load_ps(&src[i]); + const auto a = _mm512_load_ps(&src[i]); const auto abs_a = _mm512_andnot_ps(_mm512_set1_ps(-0.0f), a); - peak_vec = _mm512_max_ps(peak_vec, abs_a); + peak_vec = _mm512_max_ps(peak_vec, abs_a); } // **关键优化:高效的AVX-512水平最大值归约操作** @@ -643,31 +643,31 @@ namespace x86_simd_audio_processing_func { peak = abs_sample; } } - + return peak; } - + // 音频归一化函数实现 (SSE版本) void normalize_audio_sse(const float* src, float* dst, float target_peak, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_SSE); ASSERT_ALIGNED(dst, ALIGNMENT_SSE); - + // 边界情况处理 if (num_samples == 0 || target_peak <= 0.0f) { return; } - + // 计算当前音频的峰值 const float current_peak = calculate_peak_sse(src, num_samples); - + // 如果当前峰值过小,设置为静音或均匀值 if (current_peak < 1e-10f) { // 使用SSE优化的零填充 - constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; - auto zero_vec = _mm_setzero_ps(); - size_t i = 0; - + auto zero_vec = _mm_setzero_ps(); + size_t i = 0; + // 向量化处理(4路环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { _mm_store_ps(&dst[i], zero_vec); @@ -675,47 +675,47 @@ namespace x86_simd_audio_processing_func { _mm_store_ps(&dst[i + 8], zero_vec); _mm_store_ps(&dst[i + 12], zero_vec); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { _mm_store_ps(&dst[i], zero_vec); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { dst[i] = 0.0f; } return; } - + // 计算归一化增益因子 const float gain_factor = target_peak / current_peak; - + // 使用现有的apply_gain_sse函数应用增益 apply_gain_sse(src, dst, gain_factor, num_samples); } - + // 音频归一化函数实现 (AVX版本) void normalize_audio_avx(const float* src, float* dst, float target_peak, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX); ASSERT_ALIGNED(dst, ALIGNMENT_AVX); - + // 边界情况处理 if (num_samples == 0 || target_peak <= 0.0f) { return; } - + // 计算当前音频的峰值 const float current_peak = calculate_peak_avx(src, num_samples); - + // 如果当前峰值过小,设置为静音或均匀值 if (current_peak < 1e-10f) { // 使用AVX优化的零填充 - constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; - auto zero_vec = _mm256_setzero_ps(); - size_t i = 0; - + auto zero_vec = _mm256_setzero_ps(); + size_t i = 0; + // 向量化处理(4路环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { _mm256_store_ps(&dst[i], zero_vec); @@ -723,47 +723,47 @@ namespace x86_simd_audio_processing_func { _mm256_store_ps(&dst[i + 16], zero_vec); _mm256_store_ps(&dst[i + 24], zero_vec); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { _mm256_store_ps(&dst[i], zero_vec); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { dst[i] = 0.0f; } return; } - + // 计算归一化增益因子 const float gain_factor = target_peak / current_peak; - + // 使用现有的apply_gain_avx函数应用增益 apply_gain_avx(src, dst, gain_factor, num_samples); } - + // 音频归一化函数实现 (AVX512版本) void normalize_audio_avx512(const float* src, float* dst, float target_peak, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); ASSERT_ALIGNED(dst, ALIGNMENT_AVX512); - + // 边界情况处理 if (num_samples == 0 || target_peak <= 0.0f) { return; } - + // 计算当前音频的峰值 const float current_peak = calculate_peak_avx512(src, num_samples); - + // 如果当前峰值过小,设置为静音或均匀值 if (current_peak < 1e-10f) { // 使用AVX512优化的零填充 - constexpr size_t simd_width = 16; // AVX512每次处理16个float + constexpr size_t simd_width = 16; // AVX512每次处理16个float constexpr size_t unroll_factor = 4; - auto zero_vec = _mm512_setzero_ps(); - size_t i = 0; - + auto zero_vec = _mm512_setzero_ps(); + size_t i = 0; + // 向量化处理(4路环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { _mm512_store_ps(&dst[i], zero_vec); @@ -771,249 +771,248 @@ namespace x86_simd_audio_processing_func { _mm512_store_ps(&dst[i + 32], zero_vec); _mm512_store_ps(&dst[i + 48], zero_vec); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { _mm512_store_ps(&dst[i], zero_vec); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { dst[i] = 0.0f; } return; } - + // 计算归一化增益因子 const float gain_factor = target_peak / current_peak; - + // 使用现有的apply_gain_avx512函数应用增益 apply_gain_avx512(src, dst, gain_factor, num_samples); } - + // 立体声到单声道转换函数实现 (SSE版本) void stereo_to_mono_sse(const float* stereo_src, float* mono_dst, size_t num_stereo_samples) { ASSERT_ALIGNED(stereo_src, ALIGNMENT_SSE); ASSERT_ALIGNED(mono_dst, ALIGNMENT_SSE); - + // 边界情况处理 if (num_stereo_samples == 0) { return; } - - constexpr size_t simd_width = 4; // SSE每次处理4个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - const auto half_vec = _mm_set1_ps(0.5f); // 用于取平均值 - size_t stereo_idx = 0; // 立体声索引(以样本对为单位) - size_t mono_idx = 0; // 单声道索引 - + + constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + const auto half_vec = _mm_set1_ps(0.5f); // 用于取平均值 + size_t stereo_idx = 0; // 立体声索引(以样本对为单位) + size_t mono_idx = 0; // 单声道索引 + // 向量化处理(4路循环展开) // 每次处理4个单声道样本,需要读取8个立体声样本 for (; stereo_idx + simd_width * 2 * unroll_factor <= num_stereo_samples * 2; - stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { - + stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { // 加载8个立体声样本对 (16个float) - auto stereo0 = _mm_load_ps(&stereo_src[stereo_idx]); // [L0, R0, L1, R1] - auto stereo1 = _mm_load_ps(&stereo_src[stereo_idx + 4]); // [L2, R2, L3, R3] - auto stereo2 = _mm_load_ps(&stereo_src[stereo_idx + 8]); // [L4, R4, L5, R5] + auto stereo0 = _mm_load_ps(&stereo_src[stereo_idx]); // [L0, R0, L1, R1] + auto stereo1 = _mm_load_ps(&stereo_src[stereo_idx + 4]); // [L2, R2, L3, R3] + auto stereo2 = _mm_load_ps(&stereo_src[stereo_idx + 8]); // [L4, R4, L5, R5] auto stereo3 = _mm_load_ps(&stereo_src[stereo_idx + 12]); // [L6, R6, L7, R7] auto stereo4 = _mm_load_ps(&stereo_src[stereo_idx + 16]); // [L8, R8, L9, R9] auto stereo5 = _mm_load_ps(&stereo_src[stereo_idx + 20]); // [L10, R10, L11, R11] auto stereo6 = _mm_load_ps(&stereo_src[stereo_idx + 24]); // [L12, R12, L13, R13] auto stereo7 = _mm_load_ps(&stereo_src[stereo_idx + 28]); // [L14, R14, L15, R15] - + // 分离左右声道 - auto left0 = _mm_shuffle_ps(stereo0, stereo1, _MM_SHUFFLE(2, 0, 2, 0)); // [L0, L1, L2, L3] + auto left0 = _mm_shuffle_ps(stereo0, stereo1, _MM_SHUFFLE(2, 0, 2, 0)); // [L0, L1, L2, L3] auto right0 = _mm_shuffle_ps(stereo0, stereo1, _MM_SHUFFLE(3, 1, 3, 1)); // [R0, R1, R2, R3] - auto left1 = _mm_shuffle_ps(stereo2, stereo3, _MM_SHUFFLE(2, 0, 2, 0)); // [L4, L5, L6, L7] + auto left1 = _mm_shuffle_ps(stereo2, stereo3, _MM_SHUFFLE(2, 0, 2, 0)); // [L4, L5, L6, L7] auto right1 = _mm_shuffle_ps(stereo2, stereo3, _MM_SHUFFLE(3, 1, 3, 1)); // [R4, R5, R6, R7] - auto left2 = _mm_shuffle_ps(stereo4, stereo5, _MM_SHUFFLE(2, 0, 2, 0)); // [L8, L9, L10, L11] + auto left2 = _mm_shuffle_ps(stereo4, stereo5, _MM_SHUFFLE(2, 0, 2, 0)); // [L8, L9, L10, L11] auto right2 = _mm_shuffle_ps(stereo4, stereo5, _MM_SHUFFLE(3, 1, 3, 1)); // [R8, R9, R10, R11] - auto left3 = _mm_shuffle_ps(stereo6, stereo7, _MM_SHUFFLE(2, 0, 2, 0)); // [L12, L13, L14, L15] + auto left3 = _mm_shuffle_ps(stereo6, stereo7, _MM_SHUFFLE(2, 0, 2, 0)); // [L12, L13, L14, L15] auto right3 = _mm_shuffle_ps(stereo6, stereo7, _MM_SHUFFLE(3, 1, 3, 1)); // [R12, R13, R14, R15] - + // 计算单声道 = (左声道 + 右声道) / 2 auto mono0 = _mm_mul_ps(_mm_add_ps(left0, right0), half_vec); auto mono1 = _mm_mul_ps(_mm_add_ps(left1, right1), half_vec); auto mono2 = _mm_mul_ps(_mm_add_ps(left2, right2), half_vec); auto mono3 = _mm_mul_ps(_mm_add_ps(left3, right3), half_vec); - + // 存储结果 _mm_store_ps(&mono_dst[mono_idx], mono0); _mm_store_ps(&mono_dst[mono_idx + 4], mono1); _mm_store_ps(&mono_dst[mono_idx + 8], mono2); _mm_store_ps(&mono_dst[mono_idx + 12], mono3); } - + // 处理剩余的样本对(标量处理) for (size_t i = stereo_idx / 2; i < num_stereo_samples; ++i) { - const float left = stereo_src[i * 2]; + const float left = stereo_src[i * 2]; const float right = stereo_src[i * 2 + 1]; - mono_dst[i] = (left + right) * 0.5f; + mono_dst[i] = (left + right) * 0.5f; } } - + // 立体声到单声道转换函数实现 (AVX版本) void stereo_to_mono_avx(const float* stereo_src, float* mono_dst, size_t num_stereo_samples) { ASSERT_ALIGNED(stereo_src, ALIGNMENT_AVX); ASSERT_ALIGNED(mono_dst, ALIGNMENT_AVX); - + // 边界情况处理 if (num_stereo_samples == 0) { return; } - - constexpr size_t simd_width = 8; // AVX每次处理8个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - const auto half_vec = _mm256_set1_ps(0.5f); // 用于取平均值 - size_t stereo_idx = 0; // 立体声索引(以样本对为单位) - size_t mono_idx = 0; // 单声道索引 - + + constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + const auto half_vec = _mm256_set1_ps(0.5f); // 用于取平均值 + size_t stereo_idx = 0; // 立体声索引(以样本对为单位) + size_t mono_idx = 0; // 单声道索引 + // 向量化处理(4路循环展开) // 每次处理8个单声道样本,需要读取16个立体声样本 for (; stereo_idx + simd_width * 2 * unroll_factor <= num_stereo_samples * 2; - stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { - + stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { // 加载16个立体声样本对 (32个float) - auto stereo0 = _mm256_load_ps(&stereo_src[stereo_idx]); // [L0,R0,L1,R1,L2,R2,L3,R3] - auto stereo1 = _mm256_load_ps(&stereo_src[stereo_idx + 8]); // [L4,R4,L5,R5,L6,R6,L7,R7] + auto stereo0 = _mm256_load_ps(&stereo_src[stereo_idx]); // [L0,R0,L1,R1,L2,R2,L3,R3] + auto stereo1 = _mm256_load_ps(&stereo_src[stereo_idx + 8]); // [L4,R4,L5,R5,L6,R6,L7,R7] auto stereo2 = _mm256_load_ps(&stereo_src[stereo_idx + 16]); // [L8,R8,L9,R9,L10,R10,L11,R11] auto stereo3 = _mm256_load_ps(&stereo_src[stereo_idx + 24]); // [L12,R12,L13,R13,L14,R14,L15,R15] auto stereo4 = _mm256_load_ps(&stereo_src[stereo_idx + 32]); // [L16,R16,L17,R17,L18,R18,L19,R19] auto stereo5 = _mm256_load_ps(&stereo_src[stereo_idx + 40]); // [L20,R20,L21,R21,L22,R22,L23,R23] auto stereo6 = _mm256_load_ps(&stereo_src[stereo_idx + 48]); // [L24,R24,L25,R25,L26,R26,L27,R27] auto stereo7 = _mm256_load_ps(&stereo_src[stereo_idx + 56]); // [L28,R28,L29,R29,L30,R30,L31,R31] - + // 分离左右声道(使用shuffle和blend) - auto left0 = _mm256_shuffle_ps(stereo0, stereo1, _MM_SHUFFLE(2, 0, 2, 0)); // [L0,L1,L4,L5,L2,L3,L6,L7] + auto left0 = _mm256_shuffle_ps(stereo0, stereo1, _MM_SHUFFLE(2, 0, 2, 0)); // [L0,L1,L4,L5,L2,L3,L6,L7] auto right0 = _mm256_shuffle_ps(stereo0, stereo1, _MM_SHUFFLE(3, 1, 3, 1)); // [R0,R1,R4,R5,R2,R3,R6,R7] - auto left1 = _mm256_shuffle_ps(stereo2, stereo3, _MM_SHUFFLE(2, 0, 2, 0)); + auto left1 = _mm256_shuffle_ps(stereo2, stereo3, _MM_SHUFFLE(2, 0, 2, 0)); auto right1 = _mm256_shuffle_ps(stereo2, stereo3, _MM_SHUFFLE(3, 1, 3, 1)); - auto left2 = _mm256_shuffle_ps(stereo4, stereo5, _MM_SHUFFLE(2, 0, 2, 0)); + auto left2 = _mm256_shuffle_ps(stereo4, stereo5, _MM_SHUFFLE(2, 0, 2, 0)); auto right2 = _mm256_shuffle_ps(stereo4, stereo5, _MM_SHUFFLE(3, 1, 3, 1)); - auto left3 = _mm256_shuffle_ps(stereo6, stereo7, _MM_SHUFFLE(2, 0, 2, 0)); + auto left3 = _mm256_shuffle_ps(stereo6, stereo7, _MM_SHUFFLE(2, 0, 2, 0)); auto right3 = _mm256_shuffle_ps(stereo6, stereo7, _MM_SHUFFLE(3, 1, 3, 1)); - + // 重新排列以获得正确的顺序 - left0 = _mm256_permute2f128_ps(left0, left0, 0x01); // [L2,L3,L6,L7,L0,L1,L4,L5] + left0 = _mm256_permute2f128_ps(left0, left0, 0x01); // [L2,L3,L6,L7,L0,L1,L4,L5] right0 = _mm256_permute2f128_ps(right0, right0, 0x01); // [R2,R3,R6,R7,R0,R1,R4,R5] - left1 = _mm256_permute2f128_ps(left1, left1, 0x01); + left1 = _mm256_permute2f128_ps(left1, left1, 0x01); right1 = _mm256_permute2f128_ps(right1, right1, 0x01); - left2 = _mm256_permute2f128_ps(left2, left2, 0x01); + left2 = _mm256_permute2f128_ps(left2, left2, 0x01); right2 = _mm256_permute2f128_ps(right2, right2, 0x01); - left3 = _mm256_permute2f128_ps(left3, left3, 0x01); + left3 = _mm256_permute2f128_ps(left3, left3, 0x01); right3 = _mm256_permute2f128_ps(right3, right3, 0x01); - + // 计算单声道 = (左声道 + 右声道) / 2 auto mono0 = _mm256_mul_ps(_mm256_add_ps(left0, right0), half_vec); auto mono1 = _mm256_mul_ps(_mm256_add_ps(left1, right1), half_vec); auto mono2 = _mm256_mul_ps(_mm256_add_ps(left2, right2), half_vec); auto mono3 = _mm256_mul_ps(_mm256_add_ps(left3, right3), half_vec); - + // 存储结果 _mm256_store_ps(&mono_dst[mono_idx], mono0); _mm256_store_ps(&mono_dst[mono_idx + 8], mono1); _mm256_store_ps(&mono_dst[mono_idx + 16], mono2); _mm256_store_ps(&mono_dst[mono_idx + 24], mono3); } - + // 处理剩余的样本对(标量处理) for (size_t i = stereo_idx / 2; i < num_stereo_samples; ++i) { - const float left = stereo_src[i * 2]; + const float left = stereo_src[i * 2]; const float right = stereo_src[i * 2 + 1]; - mono_dst[i] = (left + right) * 0.5f; + mono_dst[i] = (left + right) * 0.5f; } } - + // 立体声到单声道转换函数实现 (AVX512版本) void stereo_to_mono_avx512(const float* stereo_src, float* mono_dst, size_t num_stereo_samples) { ASSERT_ALIGNED(stereo_src, ALIGNMENT_AVX512); ASSERT_ALIGNED(mono_dst, ALIGNMENT_AVX512); - + // 边界情况处理 if (num_stereo_samples == 0) { return; } - - constexpr size_t simd_width = 16; // AVX512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - const auto half_vec = _mm512_set1_ps(0.5f); // 用于取平均值 - size_t stereo_idx = 0; // 立体声索引(以样本对为单位) - size_t mono_idx = 0; // 单声道索引 - + + constexpr size_t simd_width = 16; // AVX512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + const auto half_vec = _mm512_set1_ps(0.5f); // 用于取平均值 + size_t stereo_idx = 0; // 立体声索引(以样本对为单位) + size_t mono_idx = 0; // 单声道索引 + // 向量化处理(4路循环展开) // 每次处理16个单声道样本,需要读取32个立体声样本 for (; stereo_idx + simd_width * 2 * unroll_factor <= num_stereo_samples * 2; - stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { - + stereo_idx += simd_width * 2 * unroll_factor, mono_idx += simd_width * unroll_factor) { // 加载32个立体声样本对 (64个float) - auto stereo0 = _mm512_load_ps(&stereo_src[stereo_idx]); // 16个交错样本 - auto stereo1 = _mm512_load_ps(&stereo_src[stereo_idx + 16]); // 16个交错样本 - auto stereo2 = _mm512_load_ps(&stereo_src[stereo_idx + 32]); // 16个交错样本 - auto stereo3 = _mm512_load_ps(&stereo_src[stereo_idx + 48]); // 16个交错样本 - auto stereo4 = _mm512_load_ps(&stereo_src[stereo_idx + 64]); // 16个交错样本 - auto stereo5 = _mm512_load_ps(&stereo_src[stereo_idx + 80]); // 16个交错样本 - auto stereo6 = _mm512_load_ps(&stereo_src[stereo_idx + 96]); // 16个交错样本 + auto stereo0 = _mm512_load_ps(&stereo_src[stereo_idx]); // 16个交错样本 + auto stereo1 = _mm512_load_ps(&stereo_src[stereo_idx + 16]); // 16个交错样本 + auto stereo2 = _mm512_load_ps(&stereo_src[stereo_idx + 32]); // 16个交错样本 + auto stereo3 = _mm512_load_ps(&stereo_src[stereo_idx + 48]); // 16个交错样本 + auto stereo4 = _mm512_load_ps(&stereo_src[stereo_idx + 64]); // 16个交错样本 + auto stereo5 = _mm512_load_ps(&stereo_src[stereo_idx + 80]); // 16个交错样本 + auto stereo6 = _mm512_load_ps(&stereo_src[stereo_idx + 96]); // 16个交错样本 auto stereo7 = _mm512_load_ps(&stereo_src[stereo_idx + 112]); // 16个交错样本 - + // 使用AVX512的交替shuffle来分离左右声道 const auto even_mask = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0); - const auto odd_mask = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); - - auto left0 = _mm512_permutex2var_ps(stereo0, even_mask, stereo1); + const auto odd_mask = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1); + + auto left0 = _mm512_permutex2var_ps(stereo0, even_mask, stereo1); auto right0 = _mm512_permutex2var_ps(stereo0, odd_mask, stereo1); - auto left1 = _mm512_permutex2var_ps(stereo2, even_mask, stereo3); + auto left1 = _mm512_permutex2var_ps(stereo2, even_mask, stereo3); auto right1 = _mm512_permutex2var_ps(stereo2, odd_mask, stereo3); - auto left2 = _mm512_permutex2var_ps(stereo4, even_mask, stereo5); + auto left2 = _mm512_permutex2var_ps(stereo4, even_mask, stereo5); auto right2 = _mm512_permutex2var_ps(stereo4, odd_mask, stereo5); - auto left3 = _mm512_permutex2var_ps(stereo6, even_mask, stereo7); + auto left3 = _mm512_permutex2var_ps(stereo6, even_mask, stereo7); auto right3 = _mm512_permutex2var_ps(stereo6, odd_mask, stereo7); - + // 计算单声道 = (左声道 + 右声道) / 2 auto mono0 = _mm512_mul_ps(_mm512_add_ps(left0, right0), half_vec); auto mono1 = _mm512_mul_ps(_mm512_add_ps(left1, right1), half_vec); auto mono2 = _mm512_mul_ps(_mm512_add_ps(left2, right2), half_vec); auto mono3 = _mm512_mul_ps(_mm512_add_ps(left3, right3), half_vec); - + // 存储结果 _mm512_store_ps(&mono_dst[mono_idx], mono0); _mm512_store_ps(&mono_dst[mono_idx + 16], mono1); _mm512_store_ps(&mono_dst[mono_idx + 32], mono2); _mm512_store_ps(&mono_dst[mono_idx + 48], mono3); } - + // 处理剩余的样本对(标量处理) for (size_t i = stereo_idx / 2; i < num_stereo_samples; ++i) { - const float left = stereo_src[i * 2]; + const float left = stereo_src[i * 2]; const float right = stereo_src[i * 2 + 1]; - mono_dst[i] = (left + right) * 0.5f; + mono_dst[i] = (left + right) * 0.5f; } } + // 音频限幅函数实现 (SSE版本) - void limit_audio_sse(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples) { + void limit_audio_sse(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_SSE); ASSERT_ALIGNED(dst, ALIGNMENT_SSE); - + // 边界情况处理 if (num_samples == 0 || threshold <= 0.0f) { return; } - - constexpr size_t simd_width = 4; // SSE每次处理4个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - constexpr float release_time = 0.05f; // 释放时间常数(秒),可根据实际需求调整 - float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 - + + constexpr size_t simd_width = 4; // SSE每次处理4个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + constexpr float release_time = 0.05f; // 释放时间常数(秒),可根据实际需求调整 + float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 + // 初始化限幅器状态,如果是首次调用,则从1.0开始 float current_gain = limiter_state != nullptr ? *limiter_state : 1.0f; - + // 阈值和释放系数向量 - const auto threshold_vec = _mm_set1_ps(threshold); + const auto threshold_vec = _mm_set1_ps(threshold); const auto release_coeff_vec = _mm_set1_ps(release_coeff); - const auto one_vec = _mm_set1_ps(1.0f); - + const auto one_vec = _mm_set1_ps(1.0f); + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -1021,141 +1020,146 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm_load_ps(&src[i + 4]); auto a2 = _mm_load_ps(&src[i + 8]); auto a3 = _mm_load_ps(&src[i + 12]); - + // 计算绝对值 auto abs_a0 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a0); auto abs_a1 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a1); auto abs_a2 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a2); auto abs_a3 = _mm_andnot_ps(_mm_set1_ps(-0.0f), a3); - + // 找出最大值 - auto max_abs = _mm_max_ps(_mm_max_ps(abs_a0, abs_a1), + auto max_abs = _mm_max_ps(_mm_max_ps(abs_a0, abs_a1), _mm_max_ps(abs_a2, abs_a3)); - + // 水平最大值归约 auto temp1 = _mm_shuffle_ps(max_abs, max_abs, _MM_SHUFFLE(2, 3, 0, 1)); // [y, x, w, z] - auto max1 = _mm_max_ps(max_abs, temp1); // [max(x,y), max(x,y), max(z,w), max(z,w)] - auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); // [max(z,w), max(z,w), max(x,y), max(x,y)] + auto max1 = _mm_max_ps(max_abs, temp1); // [max(x,y), max(x,y), max(z,w), max(z,w)] + auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); + // [max(z,w), max(z,w), max(x,y), max(x,y)] auto final_max = _mm_max_ps(max1, temp2); // [final_max, *, *, *] - + // 提取水平最大值 float max_sample = _mm_cvtss_f32(final_max); - + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 将当前增益转换为向量 auto gain_vec = _mm_set1_ps(current_gain); - + // 应用增益 auto result0 = _mm_mul_ps(a0, gain_vec); auto result1 = _mm_mul_ps(a1, gain_vec); auto result2 = _mm_mul_ps(a2, gain_vec); auto result3 = _mm_mul_ps(a3, gain_vec); - + // 存储结果 _mm_store_ps(&dst[i], result0); _mm_store_ps(&dst[i + 4], result1); _mm_store_ps(&dst[i + 8], result2); _mm_store_ps(&dst[i + 12], result3); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm_load_ps(&src[i]); + auto a = _mm_load_ps(&src[i]); auto abs_a = _mm_andnot_ps(_mm_set1_ps(-0.0f), a); - + // 找出最大值 auto max_abs = abs_a; - + // 水平最大值归约 - auto temp1 = _mm_shuffle_ps(max_abs, max_abs, _MM_SHUFFLE(2, 3, 0, 1)); - auto max1 = _mm_max_ps(max_abs, temp1); - auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); + auto temp1 = _mm_shuffle_ps(max_abs, max_abs, _MM_SHUFFLE(2, 3, 0, 1)); + auto max1 = _mm_max_ps(max_abs, temp1); + auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); auto final_max = _mm_max_ps(max1, temp2); - + // 提取水平最大值 float max_sample = _mm_cvtss_f32(final_max); - + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 auto gain_vec = _mm_set1_ps(current_gain); - auto result = _mm_mul_ps(a, gain_vec); + auto result = _mm_mul_ps(a, gain_vec); _mm_store_ps(&dst[i], result); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { - float sample = src[i]; + float sample = src[i]; float abs_sample = std::fabs(sample); - + // 计算需要的增益以限制幅度 float target_gain = abs_sample > threshold ? threshold / abs_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 dst[i] = sample * current_gain; } - + // 更新限幅器状态 if (limiter_state != nullptr) { *limiter_state = current_gain; } } - + // 音频限幅函数实现 (AVX版本) - void limit_audio_avx(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples) { + void limit_audio_avx(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX); ASSERT_ALIGNED(dst, ALIGNMENT_AVX); - + // 边界情况处理 if (num_samples == 0 || threshold <= 0.0f) { return; } - - constexpr size_t simd_width = 8; // AVX每次处理8个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - constexpr float release_time = 0.05f; // 释放时间常数(秒),可根据实际需求调整 - float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 - + + constexpr size_t simd_width = 8; // AVX每次处理8个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + constexpr float release_time = 0.05f; // 释放时间常数(秒),可根据实际需求调整 + float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 + // 初始化限幅器状态,如果是首次调用,则从1.0开始 float current_gain = limiter_state != nullptr ? *limiter_state : 1.0f; - + // 阈值和释放系数向量 - const auto threshold_vec = _mm256_set1_ps(threshold); + const auto threshold_vec = _mm256_set1_ps(threshold); const auto release_coeff_vec = _mm256_set1_ps(release_coeff); - const auto one_vec = _mm256_set1_ps(1.0f); - + const auto one_vec = _mm256_set1_ps(1.0f); + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -1163,149 +1167,153 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm256_load_ps(&src[i + 8]); auto a2 = _mm256_load_ps(&src[i + 16]); auto a3 = _mm256_load_ps(&src[i + 24]); - + // 计算绝对值 auto abs_a0 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a0); auto abs_a1 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a1); auto abs_a2 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a2); auto abs_a3 = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a3); - + // 找出最大值 - auto max_abs = _mm256_max_ps(_mm256_max_ps(abs_a0, abs_a1), + auto max_abs = _mm256_max_ps(_mm256_max_ps(abs_a0, abs_a1), _mm256_max_ps(abs_a2, abs_a3)); - + // AVX水平最大值归约 // 提取高低128位并求最大值 - auto high = _mm256_extractf128_ps(max_abs, 1); - auto low = _mm256_extractf128_ps(max_abs, 0); + auto high = _mm256_extractf128_ps(max_abs, 1); + auto low = _mm256_extractf128_ps(max_abs, 0); auto max_lane = _mm_max_ps(high, low); - + // 在128位向量内进行水平最大值操作 - auto temp1 = _mm_shuffle_ps(max_lane, max_lane, _MM_SHUFFLE(2, 3, 0, 1)); - auto max1 = _mm_max_ps(max_lane, temp1); - auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); + auto temp1 = _mm_shuffle_ps(max_lane, max_lane, _MM_SHUFFLE(2, 3, 0, 1)); + auto max1 = _mm_max_ps(max_lane, temp1); + auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); auto final_max = _mm_max_ps(max1, temp2); - + // 提取水平最大值 float max_sample = _mm_cvtss_f32(final_max); - + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 将当前增益转换为向量 auto gain_vec = _mm256_set1_ps(current_gain); - + // 应用增益 auto result0 = _mm256_mul_ps(a0, gain_vec); auto result1 = _mm256_mul_ps(a1, gain_vec); auto result2 = _mm256_mul_ps(a2, gain_vec); auto result3 = _mm256_mul_ps(a3, gain_vec); - + // 存储结果 _mm256_store_ps(&dst[i], result0); _mm256_store_ps(&dst[i + 8], result1); _mm256_store_ps(&dst[i + 16], result2); _mm256_store_ps(&dst[i + 24], result3); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm256_load_ps(&src[i]); + auto a = _mm256_load_ps(&src[i]); auto abs_a = _mm256_andnot_ps(_mm256_set1_ps(-0.0f), a); - + // 提取高低128位并求最大值 - auto high = _mm256_extractf128_ps(abs_a, 1); - auto low = _mm256_extractf128_ps(abs_a, 0); + auto high = _mm256_extractf128_ps(abs_a, 1); + auto low = _mm256_extractf128_ps(abs_a, 0); auto max_lane = _mm_max_ps(high, low); - + // 在128位向量内进行水平最大值操作 - auto temp1 = _mm_shuffle_ps(max_lane, max_lane, _MM_SHUFFLE(2, 3, 0, 1)); - auto max1 = _mm_max_ps(max_lane, temp1); - auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); + auto temp1 = _mm_shuffle_ps(max_lane, max_lane, _MM_SHUFFLE(2, 3, 0, 1)); + auto max1 = _mm_max_ps(max_lane, temp1); + auto temp2 = _mm_shuffle_ps(max1, max1, _MM_SHUFFLE(1, 0, 3, 2)); auto final_max = _mm_max_ps(max1, temp2); - + // 提取水平最大值 float max_sample = _mm_cvtss_f32(final_max); - + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 auto gain_vec = _mm256_set1_ps(current_gain); - auto result = _mm256_mul_ps(a, gain_vec); + auto result = _mm256_mul_ps(a, gain_vec); _mm256_store_ps(&dst[i], result); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { - float sample = src[i]; + float sample = src[i]; float abs_sample = std::fabs(sample); - + // 计算需要的增益以限制幅度 float target_gain = abs_sample > threshold ? threshold / abs_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 dst[i] = sample * current_gain; } - + // 更新限幅器状态 if (limiter_state != nullptr) { *limiter_state = current_gain; } } - + // 音频限幅函数实现 (AVX512版本) - void limit_audio_avx512(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples) { + void limit_audio_avx512(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); ASSERT_ALIGNED(dst, ALIGNMENT_AVX512); - + // 边界情况处理 if (num_samples == 0 || threshold <= 0.0f) { return; } - - constexpr size_t simd_width = 16; // AVX512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - constexpr float release_time = 0.05f; // 释放时间常数(秒),可根据实际需求调整 - float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 - + + constexpr size_t simd_width = 16; // AVX512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + constexpr float release_time = 0.05f; // 释放时间常数(秒),可根据实际需求调整 + float release_coeff = std::exp(-1.0f / (release_time * sample_rate)); // 释放系数 + // 初始化限幅器状态,如果是首次调用,则从1.0开始 float current_gain = limiter_state != nullptr ? *limiter_state : 1.0f; - + // 阈值和释放系数向量 - const auto threshold_vec = _mm512_set1_ps(threshold); + const auto threshold_vec = _mm512_set1_ps(threshold); const auto release_coeff_vec = _mm512_set1_ps(release_coeff); - const auto one_vec = _mm512_set1_ps(1.0f); - + const auto one_vec = _mm512_set1_ps(1.0f); + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -1313,551 +1321,658 @@ namespace x86_simd_audio_processing_func { auto a1 = _mm512_load_ps(&src[i + 16]); auto a2 = _mm512_load_ps(&src[i + 32]); auto a3 = _mm512_load_ps(&src[i + 48]); - + // 计算绝对值 auto abs_a0 = _mm512_abs_ps(a0); // AVX512提供直接的绝对值指令 auto abs_a1 = _mm512_abs_ps(a1); auto abs_a2 = _mm512_abs_ps(a2); auto abs_a3 = _mm512_abs_ps(a3); - + // 找出最大值 - auto max_abs = _mm512_max_ps(_mm512_max_ps(abs_a0, abs_a1), - _mm512_max_ps(abs_a2, abs_a3)); - + auto max_abs = _mm512_max_ps(_mm512_max_ps(abs_a0, abs_a1), + _mm512_max_ps(abs_a2, abs_a3)); + // AVX512水平最大值归约 float max_sample = _mm512_reduce_max_ps(max_abs); // 使用专用reduce指令 - + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 将当前增益转换为向量 auto gain_vec = _mm512_set1_ps(current_gain); - + // 应用增益 auto result0 = _mm512_mul_ps(a0, gain_vec); auto result1 = _mm512_mul_ps(a1, gain_vec); auto result2 = _mm512_mul_ps(a2, gain_vec); auto result3 = _mm512_mul_ps(a3, gain_vec); - + // 存储结果 _mm512_store_ps(&dst[i], result0); _mm512_store_ps(&dst[i + 16], result1); _mm512_store_ps(&dst[i + 32], result2); _mm512_store_ps(&dst[i + 48], result3); } - + // 处理剩余的向量(单次处理) for (; i + simd_width <= num_samples; i += simd_width) { - auto a = _mm512_load_ps(&src[i]); + auto a = _mm512_load_ps(&src[i]); auto abs_a = _mm512_abs_ps(a); - + // 水平最大值归约 float max_sample = _mm512_reduce_max_ps(abs_a); - + // 计算需要的增益以限制幅度 float target_gain = max_sample > threshold ? threshold / max_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 auto gain_vec = _mm512_set1_ps(current_gain); - auto result = _mm512_mul_ps(a, gain_vec); + auto result = _mm512_mul_ps(a, gain_vec); _mm512_store_ps(&dst[i], result); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { - float sample = src[i]; + float sample = src[i]; float abs_sample = std::fabs(sample); - + // 计算需要的增益以限制幅度 float target_gain = abs_sample > threshold ? threshold / abs_sample : 1.0f; - + // 平滑增益变化(使用包络跟随器) if (target_gain < current_gain) { // 立即攻击 current_gain = target_gain; - } else { + } + else { // 缓慢释放 current_gain = target_gain + (current_gain - target_gain) * release_coeff; } - + // 应用增益 dst[i] = sample * current_gain; } - + // 更新限幅器状态 if (limiter_state != nullptr) { *limiter_state = current_gain; } } + // 音频淡入淡出函数实现 (SSE版本) - void fade_audio_sse(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples) { + void fade_audio_sse(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_SSE); ASSERT_ALIGNED(dst, ALIGNMENT_SSE); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 4; // SSE每次处理4个float + + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - + size_t i = 0; + // 处理淡入部分 if (fade_in_samples > 0) { - const float fade_in_step = 1.0f / static_cast(fade_in_samples); - const auto fade_in_step_vec = _mm_set1_ps(fade_in_step); - + const float fade_in_step = 1.0f / static_cast(fade_in_samples); + const auto fade_in_step_vec = _mm_set1_ps(fade_in_step); + // 向量化处理淡入(4路循环展开) - for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * unroll_factor) { + for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * + unroll_factor) { // 计算当前样本的淡入系数 - auto gain0 = _mm_set_ps((i + 3) * fade_in_step, (i + 2) * fade_in_step, (i + 1) * fade_in_step, i * fade_in_step); - auto gain1 = _mm_set_ps((i + 7) * fade_in_step, (i + 6) * fade_in_step, (i + 5) * fade_in_step, (i + 4) * fade_in_step); - auto gain2 = _mm_set_ps((i + 11) * fade_in_step, (i + 10) * fade_in_step, (i + 9) * fade_in_step, (i + 8) * fade_in_step); - auto gain3 = _mm_set_ps((i + 15) * fade_in_step, (i + 14) * fade_in_step, (i + 13) * fade_in_step, (i + 12) * fade_in_step); - + auto gain0 = _mm_set_ps((i + 3) * fade_in_step, (i + 2) * fade_in_step, (i + 1) * fade_in_step, + i * fade_in_step); + auto gain1 = _mm_set_ps((i + 7) * fade_in_step, (i + 6) * fade_in_step, (i + 5) * fade_in_step, + (i + 4) * fade_in_step); + auto gain2 = _mm_set_ps((i + 11) * fade_in_step, (i + 10) * fade_in_step, (i + 9) * fade_in_step, + (i + 8) * fade_in_step); + auto gain3 = _mm_set_ps((i + 15) * fade_in_step, (i + 14) * fade_in_step, (i + 13) * fade_in_step, + (i + 12) * fade_in_step); + // 加载音频样本 auto a0 = _mm_load_ps(&src[i]); auto a1 = _mm_load_ps(&src[i + 4]); auto a2 = _mm_load_ps(&src[i + 8]); auto a3 = _mm_load_ps(&src[i + 12]); - + // 应用淡入增益 auto result0 = _mm_mul_ps(a0, gain0); auto result1 = _mm_mul_ps(a1, gain1); auto result2 = _mm_mul_ps(a2, gain2); auto result3 = _mm_mul_ps(a3, gain3); - + // 存储结果 _mm_store_ps(&dst[i], result0); _mm_store_ps(&dst[i + 4], result1); _mm_store_ps(&dst[i + 8], result2); _mm_store_ps(&dst[i + 12], result3); } - + // 处理剩余的淡入样本(标量处理) for (; i < std::min(fade_in_samples, num_samples); ++i) { const float gain = static_cast(i) / static_cast(fade_in_samples); - dst[i] = src[i] * gain; + dst[i] = src[i] * gain; } } - + // 处理中间部分(无淡入淡出,直接复制) const size_t middle_start = fade_in_samples; - const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; - + const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; + if (middle_end > middle_start) { // 使用SSE优化的直接复制 - for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * unroll_factor) { + for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * + unroll_factor) { auto a0 = _mm_load_ps(&src[j]); auto a1 = _mm_load_ps(&src[j + 4]); auto a2 = _mm_load_ps(&src[j + 8]); auto a3 = _mm_load_ps(&src[j + 12]); - + _mm_store_ps(&dst[j], a0); _mm_store_ps(&dst[j + 4], a1); _mm_store_ps(&dst[j + 8], a2); _mm_store_ps(&dst[j + 12], a3); } - + // 处理剩余的中间样本(标量处理) - for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * unroll_factor); j < middle_end; ++j) { + for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * + unroll_factor); j < middle_end; ++j) { dst[j] = src[j]; } } - + // 处理淡出部分 if (fade_out_samples > 0 && num_samples > fade_out_samples) { const size_t fade_out_start = num_samples - fade_out_samples; - const float fade_out_step = 1.0f / static_cast(fade_out_samples); - + const float fade_out_step = 1.0f / static_cast(fade_out_samples); + // 向量化处理淡出(4路循环展开) - for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * unroll_factor) { + for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * + unroll_factor) { // 计算当前样本的淡出系数(从1递减到0) const size_t fade_out_offset = j - fade_out_start; - auto gain0 = _mm_set_ps(1.0f - (fade_out_offset + 3) * fade_out_step, 1.0f - (fade_out_offset + 2) * fade_out_step, - 1.0f - (fade_out_offset + 1) * fade_out_step, 1.0f - fade_out_offset * fade_out_step); - auto gain1 = _mm_set_ps(1.0f - (fade_out_offset + 7) * fade_out_step, 1.0f - (fade_out_offset + 6) * fade_out_step, - 1.0f - (fade_out_offset + 5) * fade_out_step, 1.0f - (fade_out_offset + 4) * fade_out_step); - auto gain2 = _mm_set_ps(1.0f - (fade_out_offset + 11) * fade_out_step, 1.0f - (fade_out_offset + 10) * fade_out_step, - 1.0f - (fade_out_offset + 9) * fade_out_step, 1.0f - (fade_out_offset + 8) * fade_out_step); - auto gain3 = _mm_set_ps(1.0f - (fade_out_offset + 15) * fade_out_step, 1.0f - (fade_out_offset + 14) * fade_out_step, - 1.0f - (fade_out_offset + 13) * fade_out_step, 1.0f - (fade_out_offset + 12) * fade_out_step); - + auto gain0 = _mm_set_ps(1.0f - (fade_out_offset + 3) * fade_out_step, + 1.0f - (fade_out_offset + 2) * fade_out_step, + 1.0f - (fade_out_offset + 1) * fade_out_step, + 1.0f - fade_out_offset * fade_out_step); + auto gain1 = _mm_set_ps(1.0f - (fade_out_offset + 7) * fade_out_step, + 1.0f - (fade_out_offset + 6) * fade_out_step, + 1.0f - (fade_out_offset + 5) * fade_out_step, + 1.0f - (fade_out_offset + 4) * fade_out_step); + auto gain2 = _mm_set_ps(1.0f - (fade_out_offset + 11) * fade_out_step, + 1.0f - (fade_out_offset + 10) * fade_out_step, + 1.0f - (fade_out_offset + 9) * fade_out_step, + 1.0f - (fade_out_offset + 8) * fade_out_step); + auto gain3 = _mm_set_ps(1.0f - (fade_out_offset + 15) * fade_out_step, + 1.0f - (fade_out_offset + 14) * fade_out_step, + 1.0f - (fade_out_offset + 13) * fade_out_step, + 1.0f - (fade_out_offset + 12) * fade_out_step); + // 加载音频样本 auto a0 = _mm_load_ps(&src[j]); auto a1 = _mm_load_ps(&src[j + 4]); auto a2 = _mm_load_ps(&src[j + 8]); auto a3 = _mm_load_ps(&src[j + 12]); - + // 应用淡出增益 auto result0 = _mm_mul_ps(a0, gain0); auto result1 = _mm_mul_ps(a1, gain1); auto result2 = _mm_mul_ps(a2, gain2); auto result3 = _mm_mul_ps(a3, gain3); - + // 存储结果 _mm_store_ps(&dst[j], result0); _mm_store_ps(&dst[j + 4], result1); _mm_store_ps(&dst[j + 8], result2); _mm_store_ps(&dst[j + 12], result3); } - + // 处理剩余的淡出样本(标量处理) - for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * unroll_factor)); j < num_samples; ++j) { + for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * + unroll_factor)); j < num_samples; ++j) { const size_t fade_out_offset = j - fade_out_start; - const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); + const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); dst[j] = src[j] * gain; } } } - + // 音频淡入淡出函数实现 (AVX版本) - void fade_audio_avx(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples) { + void fade_audio_avx(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX); ASSERT_ALIGNED(dst, ALIGNMENT_AVX); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 8; // AVX每次处理8个float + + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - + size_t i = 0; + // 处理淡入部分 if (fade_in_samples > 0) { const float fade_in_step = 1.0f / static_cast(fade_in_samples); - + // 向量化处理淡入(4路循环展开) - for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * unroll_factor) { + for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * + unroll_factor) { // 计算当前样本的淡入系数 - auto gain0 = _mm256_set_ps((i + 7) * fade_in_step, (i + 6) * fade_in_step, (i + 5) * fade_in_step, (i + 4) * fade_in_step, - (i + 3) * fade_in_step, (i + 2) * fade_in_step, (i + 1) * fade_in_step, i * fade_in_step); - auto gain1 = _mm256_set_ps((i + 15) * fade_in_step, (i + 14) * fade_in_step, (i + 13) * fade_in_step, (i + 12) * fade_in_step, - (i + 11) * fade_in_step, (i + 10) * fade_in_step, (i + 9) * fade_in_step, (i + 8) * fade_in_step); - auto gain2 = _mm256_set_ps((i + 23) * fade_in_step, (i + 22) * fade_in_step, (i + 21) * fade_in_step, (i + 20) * fade_in_step, - (i + 19) * fade_in_step, (i + 18) * fade_in_step, (i + 17) * fade_in_step, (i + 16) * fade_in_step); - auto gain3 = _mm256_set_ps((i + 31) * fade_in_step, (i + 30) * fade_in_step, (i + 29) * fade_in_step, (i + 28) * fade_in_step, - (i + 27) * fade_in_step, (i + 26) * fade_in_step, (i + 25) * fade_in_step, (i + 24) * fade_in_step); - + auto gain0 = _mm256_set_ps((i + 7) * fade_in_step, (i + 6) * fade_in_step, (i + 5) * fade_in_step, + (i + 4) * fade_in_step, + (i + 3) * fade_in_step, (i + 2) * fade_in_step, (i + 1) * fade_in_step, + i * fade_in_step); + auto gain1 = _mm256_set_ps((i + 15) * fade_in_step, (i + 14) * fade_in_step, (i + 13) * fade_in_step, + (i + 12) * fade_in_step, + (i + 11) * fade_in_step, (i + 10) * fade_in_step, (i + 9) * fade_in_step, + (i + 8) * fade_in_step); + auto gain2 = _mm256_set_ps((i + 23) * fade_in_step, (i + 22) * fade_in_step, (i + 21) * fade_in_step, + (i + 20) * fade_in_step, + (i + 19) * fade_in_step, (i + 18) * fade_in_step, (i + 17) * fade_in_step, + (i + 16) * fade_in_step); + auto gain3 = _mm256_set_ps((i + 31) * fade_in_step, (i + 30) * fade_in_step, (i + 29) * fade_in_step, + (i + 28) * fade_in_step, + (i + 27) * fade_in_step, (i + 26) * fade_in_step, (i + 25) * fade_in_step, + (i + 24) * fade_in_step); + // 加载音频样本 auto a0 = _mm256_load_ps(&src[i]); auto a1 = _mm256_load_ps(&src[i + 8]); auto a2 = _mm256_load_ps(&src[i + 16]); auto a3 = _mm256_load_ps(&src[i + 24]); - + // 应用淡入增益 auto result0 = _mm256_mul_ps(a0, gain0); auto result1 = _mm256_mul_ps(a1, gain1); auto result2 = _mm256_mul_ps(a2, gain2); auto result3 = _mm256_mul_ps(a3, gain3); - + // 存储结果 _mm256_store_ps(&dst[i], result0); _mm256_store_ps(&dst[i + 8], result1); _mm256_store_ps(&dst[i + 16], result2); _mm256_store_ps(&dst[i + 24], result3); } - + // 处理剩余的淡入样本(标量处理) for (; i < std::min(fade_in_samples, num_samples); ++i) { const float gain = static_cast(i) / static_cast(fade_in_samples); - dst[i] = src[i] * gain; + dst[i] = src[i] * gain; } } - + // 处理中间部分(无淡入淡出,直接复制) const size_t middle_start = fade_in_samples; - const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; - + const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; + if (middle_end > middle_start) { // 使用AVX优化的直接复制 - for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * unroll_factor) { + for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * + unroll_factor) { auto a0 = _mm256_load_ps(&src[j]); auto a1 = _mm256_load_ps(&src[j + 8]); auto a2 = _mm256_load_ps(&src[j + 16]); auto a3 = _mm256_load_ps(&src[j + 24]); - + _mm256_store_ps(&dst[j], a0); _mm256_store_ps(&dst[j + 8], a1); _mm256_store_ps(&dst[j + 16], a2); _mm256_store_ps(&dst[j + 24], a3); } - + // 处理剩余的中间样本(标量处理) - for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * unroll_factor); j < middle_end; ++j) { + for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * + unroll_factor); j < middle_end; ++j) { dst[j] = src[j]; } } - + // 处理淡出部分 if (fade_out_samples > 0 && num_samples > fade_out_samples) { const size_t fade_out_start = num_samples - fade_out_samples; - const float fade_out_step = 1.0f / static_cast(fade_out_samples); - + const float fade_out_step = 1.0f / static_cast(fade_out_samples); + // 向量化处理淡出(4路循环展开) - for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * unroll_factor) { + for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * + unroll_factor) { // 计算当前样本的淡出系数(从1递减到0) const size_t fade_out_offset = j - fade_out_start; - auto gain0 = _mm256_set_ps(1.0f - (fade_out_offset + 7) * fade_out_step, 1.0f - (fade_out_offset + 6) * fade_out_step, - 1.0f - (fade_out_offset + 5) * fade_out_step, 1.0f - (fade_out_offset + 4) * fade_out_step, - 1.0f - (fade_out_offset + 3) * fade_out_step, 1.0f - (fade_out_offset + 2) * fade_out_step, - 1.0f - (fade_out_offset + 1) * fade_out_step, 1.0f - fade_out_offset * fade_out_step); - auto gain1 = _mm256_set_ps(1.0f - (fade_out_offset + 15) * fade_out_step, 1.0f - (fade_out_offset + 14) * fade_out_step, - 1.0f - (fade_out_offset + 13) * fade_out_step, 1.0f - (fade_out_offset + 12) * fade_out_step, - 1.0f - (fade_out_offset + 11) * fade_out_step, 1.0f - (fade_out_offset + 10) * fade_out_step, - 1.0f - (fade_out_offset + 9) * fade_out_step, 1.0f - (fade_out_offset + 8) * fade_out_step); - auto gain2 = _mm256_set_ps(1.0f - (fade_out_offset + 23) * fade_out_step, 1.0f - (fade_out_offset + 22) * fade_out_step, - 1.0f - (fade_out_offset + 21) * fade_out_step, 1.0f - (fade_out_offset + 20) * fade_out_step, - 1.0f - (fade_out_offset + 19) * fade_out_step, 1.0f - (fade_out_offset + 18) * fade_out_step, - 1.0f - (fade_out_offset + 17) * fade_out_step, 1.0f - (fade_out_offset + 16) * fade_out_step); - auto gain3 = _mm256_set_ps(1.0f - (fade_out_offset + 31) * fade_out_step, 1.0f - (fade_out_offset + 30) * fade_out_step, - 1.0f - (fade_out_offset + 29) * fade_out_step, 1.0f - (fade_out_offset + 28) * fade_out_step, - 1.0f - (fade_out_offset + 27) * fade_out_step, 1.0f - (fade_out_offset + 26) * fade_out_step, - 1.0f - (fade_out_offset + 25) * fade_out_step, 1.0f - (fade_out_offset + 24) * fade_out_step); - + auto gain0 = _mm256_set_ps(1.0f - (fade_out_offset + 7) * fade_out_step, + 1.0f - (fade_out_offset + 6) * fade_out_step, + 1.0f - (fade_out_offset + 5) * fade_out_step, + 1.0f - (fade_out_offset + 4) * fade_out_step, + 1.0f - (fade_out_offset + 3) * fade_out_step, + 1.0f - (fade_out_offset + 2) * fade_out_step, + 1.0f - (fade_out_offset + 1) * fade_out_step, + 1.0f - fade_out_offset * fade_out_step); + auto gain1 = _mm256_set_ps(1.0f - (fade_out_offset + 15) * fade_out_step, + 1.0f - (fade_out_offset + 14) * fade_out_step, + 1.0f - (fade_out_offset + 13) * fade_out_step, + 1.0f - (fade_out_offset + 12) * fade_out_step, + 1.0f - (fade_out_offset + 11) * fade_out_step, + 1.0f - (fade_out_offset + 10) * fade_out_step, + 1.0f - (fade_out_offset + 9) * fade_out_step, + 1.0f - (fade_out_offset + 8) * fade_out_step); + auto gain2 = _mm256_set_ps(1.0f - (fade_out_offset + 23) * fade_out_step, + 1.0f - (fade_out_offset + 22) * fade_out_step, + 1.0f - (fade_out_offset + 21) * fade_out_step, + 1.0f - (fade_out_offset + 20) * fade_out_step, + 1.0f - (fade_out_offset + 19) * fade_out_step, + 1.0f - (fade_out_offset + 18) * fade_out_step, + 1.0f - (fade_out_offset + 17) * fade_out_step, + 1.0f - (fade_out_offset + 16) * fade_out_step); + auto gain3 = _mm256_set_ps(1.0f - (fade_out_offset + 31) * fade_out_step, + 1.0f - (fade_out_offset + 30) * fade_out_step, + 1.0f - (fade_out_offset + 29) * fade_out_step, + 1.0f - (fade_out_offset + 28) * fade_out_step, + 1.0f - (fade_out_offset + 27) * fade_out_step, + 1.0f - (fade_out_offset + 26) * fade_out_step, + 1.0f - (fade_out_offset + 25) * fade_out_step, + 1.0f - (fade_out_offset + 24) * fade_out_step); + // 加载音频样本 auto a0 = _mm256_load_ps(&src[j]); auto a1 = _mm256_load_ps(&src[j + 8]); auto a2 = _mm256_load_ps(&src[j + 16]); auto a3 = _mm256_load_ps(&src[j + 24]); - + // 应用淡出增益 auto result0 = _mm256_mul_ps(a0, gain0); auto result1 = _mm256_mul_ps(a1, gain1); auto result2 = _mm256_mul_ps(a2, gain2); auto result3 = _mm256_mul_ps(a3, gain3); - + // 存储结果 _mm256_store_ps(&dst[j], result0); _mm256_store_ps(&dst[j + 8], result1); _mm256_store_ps(&dst[j + 16], result2); _mm256_store_ps(&dst[j + 24], result3); } - + // 处理剩余的淡出样本(标量处理) - for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * unroll_factor)); j < num_samples; ++j) { + for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * + unroll_factor)); j < num_samples; ++j) { const size_t fade_out_offset = j - fade_out_start; - const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); + const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); dst[j] = src[j] * gain; } } } - + // 音频淡入淡出函数实现 (AVX512版本) - void fade_audio_avx512(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples) { + void fade_audio_avx512(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); ASSERT_ALIGNED(dst, ALIGNMENT_AVX512); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 16; // AVX512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - size_t i = 0; - + + constexpr size_t simd_width = 16; // AVX512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + size_t i = 0; + // 处理淡入部分 if (fade_in_samples > 0) { const float fade_in_step = 1.0f / static_cast(fade_in_samples); - + // 向量化处理淡入(4路循环展开) - for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * unroll_factor) { + for (; i + simd_width * unroll_factor <= std::min(fade_in_samples, num_samples); i += simd_width * + unroll_factor) { // 计算当前样本的淡入系数 - auto gain0 = _mm512_set_ps((i + 15) * fade_in_step, (i + 14) * fade_in_step, (i + 13) * fade_in_step, (i + 12) * fade_in_step, - (i + 11) * fade_in_step, (i + 10) * fade_in_step, (i + 9) * fade_in_step, (i + 8) * fade_in_step, - (i + 7) * fade_in_step, (i + 6) * fade_in_step, (i + 5) * fade_in_step, (i + 4) * fade_in_step, - (i + 3) * fade_in_step, (i + 2) * fade_in_step, (i + 1) * fade_in_step, i * fade_in_step); - auto gain1 = _mm512_set_ps((i + 31) * fade_in_step, (i + 30) * fade_in_step, (i + 29) * fade_in_step, (i + 28) * fade_in_step, - (i + 27) * fade_in_step, (i + 26) * fade_in_step, (i + 25) * fade_in_step, (i + 24) * fade_in_step, - (i + 23) * fade_in_step, (i + 22) * fade_in_step, (i + 21) * fade_in_step, (i + 20) * fade_in_step, - (i + 19) * fade_in_step, (i + 18) * fade_in_step, (i + 17) * fade_in_step, (i + 16) * fade_in_step); - auto gain2 = _mm512_set_ps((i + 47) * fade_in_step, (i + 46) * fade_in_step, (i + 45) * fade_in_step, (i + 44) * fade_in_step, - (i + 43) * fade_in_step, (i + 42) * fade_in_step, (i + 41) * fade_in_step, (i + 40) * fade_in_step, - (i + 39) * fade_in_step, (i + 38) * fade_in_step, (i + 37) * fade_in_step, (i + 36) * fade_in_step, - (i + 35) * fade_in_step, (i + 34) * fade_in_step, (i + 33) * fade_in_step, (i + 32) * fade_in_step); - auto gain3 = _mm512_set_ps((i + 63) * fade_in_step, (i + 62) * fade_in_step, (i + 61) * fade_in_step, (i + 60) * fade_in_step, - (i + 59) * fade_in_step, (i + 58) * fade_in_step, (i + 57) * fade_in_step, (i + 56) * fade_in_step, - (i + 55) * fade_in_step, (i + 54) * fade_in_step, (i + 53) * fade_in_step, (i + 52) * fade_in_step, - (i + 51) * fade_in_step, (i + 50) * fade_in_step, (i + 49) * fade_in_step, (i + 48) * fade_in_step); - + auto gain0 = _mm512_set_ps((i + 15) * fade_in_step, (i + 14) * fade_in_step, (i + 13) * fade_in_step, + (i + 12) * fade_in_step, + (i + 11) * fade_in_step, (i + 10) * fade_in_step, (i + 9) * fade_in_step, + (i + 8) * fade_in_step, + (i + 7) * fade_in_step, (i + 6) * fade_in_step, (i + 5) * fade_in_step, + (i + 4) * fade_in_step, + (i + 3) * fade_in_step, (i + 2) * fade_in_step, (i + 1) * fade_in_step, + i * fade_in_step); + auto gain1 = _mm512_set_ps((i + 31) * fade_in_step, (i + 30) * fade_in_step, (i + 29) * fade_in_step, + (i + 28) * fade_in_step, + (i + 27) * fade_in_step, (i + 26) * fade_in_step, (i + 25) * fade_in_step, + (i + 24) * fade_in_step, + (i + 23) * fade_in_step, (i + 22) * fade_in_step, (i + 21) * fade_in_step, + (i + 20) * fade_in_step, + (i + 19) * fade_in_step, (i + 18) * fade_in_step, (i + 17) * fade_in_step, + (i + 16) * fade_in_step); + auto gain2 = _mm512_set_ps((i + 47) * fade_in_step, (i + 46) * fade_in_step, (i + 45) * fade_in_step, + (i + 44) * fade_in_step, + (i + 43) * fade_in_step, (i + 42) * fade_in_step, (i + 41) * fade_in_step, + (i + 40) * fade_in_step, + (i + 39) * fade_in_step, (i + 38) * fade_in_step, (i + 37) * fade_in_step, + (i + 36) * fade_in_step, + (i + 35) * fade_in_step, (i + 34) * fade_in_step, (i + 33) * fade_in_step, + (i + 32) * fade_in_step); + auto gain3 = _mm512_set_ps((i + 63) * fade_in_step, (i + 62) * fade_in_step, (i + 61) * fade_in_step, + (i + 60) * fade_in_step, + (i + 59) * fade_in_step, (i + 58) * fade_in_step, (i + 57) * fade_in_step, + (i + 56) * fade_in_step, + (i + 55) * fade_in_step, (i + 54) * fade_in_step, (i + 53) * fade_in_step, + (i + 52) * fade_in_step, + (i + 51) * fade_in_step, (i + 50) * fade_in_step, (i + 49) * fade_in_step, + (i + 48) * fade_in_step); + // 加载音频样本 auto a0 = _mm512_load_ps(&src[i]); auto a1 = _mm512_load_ps(&src[i + 16]); auto a2 = _mm512_load_ps(&src[i + 32]); auto a3 = _mm512_load_ps(&src[i + 48]); - + // 应用淡入增益 auto result0 = _mm512_mul_ps(a0, gain0); auto result1 = _mm512_mul_ps(a1, gain1); auto result2 = _mm512_mul_ps(a2, gain2); auto result3 = _mm512_mul_ps(a3, gain3); - + // 存储结果 _mm512_store_ps(&dst[i], result0); _mm512_store_ps(&dst[i + 16], result1); _mm512_store_ps(&dst[i + 32], result2); _mm512_store_ps(&dst[i + 48], result3); } - + // 处理剩余的淡入样本(标量处理) for (; i < std::min(fade_in_samples, num_samples); ++i) { const float gain = static_cast(i) / static_cast(fade_in_samples); - dst[i] = src[i] * gain; + dst[i] = src[i] * gain; } } - + // 处理中间部分(无淡入淡出,直接复制) const size_t middle_start = fade_in_samples; - const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; - + const size_t middle_end = num_samples > fade_out_samples ? num_samples - fade_out_samples : 0; + if (middle_end > middle_start) { // 使用AVX512优化的直接复制 - for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * unroll_factor) { + for (size_t j = middle_start; j + simd_width * unroll_factor <= middle_end; j += simd_width * + unroll_factor) { auto a0 = _mm512_load_ps(&src[j]); auto a1 = _mm512_load_ps(&src[j + 16]); auto a2 = _mm512_load_ps(&src[j + 32]); auto a3 = _mm512_load_ps(&src[j + 48]); - + _mm512_store_ps(&dst[j], a0); _mm512_store_ps(&dst[j + 16], a1); _mm512_store_ps(&dst[j + 32], a2); _mm512_store_ps(&dst[j + 48], a3); } - + // 处理剩余的中间样本(标量处理) - for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * unroll_factor); j < middle_end; ++j) { + for (size_t j = middle_start + ((middle_end - middle_start) / (simd_width * unroll_factor)) * (simd_width * + unroll_factor); j < middle_end; ++j) { dst[j] = src[j]; } } - + // 处理淡出部分 if (fade_out_samples > 0 && num_samples > fade_out_samples) { const size_t fade_out_start = num_samples - fade_out_samples; - const float fade_out_step = 1.0f / static_cast(fade_out_samples); - + const float fade_out_step = 1.0f / static_cast(fade_out_samples); + // 向量化处理淡出(4路循环展开) - for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * unroll_factor) { + for (size_t j = fade_out_start; j + simd_width * unroll_factor <= num_samples; j += simd_width * + unroll_factor) { // 计算当前样本的淡出系数(从1递减到0) const size_t fade_out_offset = j - fade_out_start; - auto gain0 = _mm512_set_ps(1.0f - (fade_out_offset + 15) * fade_out_step, 1.0f - (fade_out_offset + 14) * fade_out_step, - 1.0f - (fade_out_offset + 13) * fade_out_step, 1.0f - (fade_out_offset + 12) * fade_out_step, - 1.0f - (fade_out_offset + 11) * fade_out_step, 1.0f - (fade_out_offset + 10) * fade_out_step, - 1.0f - (fade_out_offset + 9) * fade_out_step, 1.0f - (fade_out_offset + 8) * fade_out_step, - 1.0f - (fade_out_offset + 7) * fade_out_step, 1.0f - (fade_out_offset + 6) * fade_out_step, - 1.0f - (fade_out_offset + 5) * fade_out_step, 1.0f - (fade_out_offset + 4) * fade_out_step, - 1.0f - (fade_out_offset + 3) * fade_out_step, 1.0f - (fade_out_offset + 2) * fade_out_step, - 1.0f - (fade_out_offset + 1) * fade_out_step, 1.0f - fade_out_offset * fade_out_step); - auto gain1 = _mm512_set_ps(1.0f - (fade_out_offset + 31) * fade_out_step, 1.0f - (fade_out_offset + 30) * fade_out_step, - 1.0f - (fade_out_offset + 29) * fade_out_step, 1.0f - (fade_out_offset + 28) * fade_out_step, - 1.0f - (fade_out_offset + 27) * fade_out_step, 1.0f - (fade_out_offset + 26) * fade_out_step, - 1.0f - (fade_out_offset + 25) * fade_out_step, 1.0f - (fade_out_offset + 24) * fade_out_step, - 1.0f - (fade_out_offset + 23) * fade_out_step, 1.0f - (fade_out_offset + 22) * fade_out_step, - 1.0f - (fade_out_offset + 21) * fade_out_step, 1.0f - (fade_out_offset + 20) * fade_out_step, - 1.0f - (fade_out_offset + 19) * fade_out_step, 1.0f - (fade_out_offset + 18) * fade_out_step, - 1.0f - (fade_out_offset + 17) * fade_out_step, 1.0f - (fade_out_offset + 16) * fade_out_step); - auto gain2 = _mm512_set_ps(1.0f - (fade_out_offset + 47) * fade_out_step, 1.0f - (fade_out_offset + 46) * fade_out_step, - 1.0f - (fade_out_offset + 45) * fade_out_step, 1.0f - (fade_out_offset + 44) * fade_out_step, - 1.0f - (fade_out_offset + 43) * fade_out_step, 1.0f - (fade_out_offset + 42) * fade_out_step, - 1.0f - (fade_out_offset + 41) * fade_out_step, 1.0f - (fade_out_offset + 40) * fade_out_step, - 1.0f - (fade_out_offset + 39) * fade_out_step, 1.0f - (fade_out_offset + 38) * fade_out_step, - 1.0f - (fade_out_offset + 37) * fade_out_step, 1.0f - (fade_out_offset + 36) * fade_out_step, - 1.0f - (fade_out_offset + 35) * fade_out_step, 1.0f - (fade_out_offset + 34) * fade_out_step, - 1.0f - (fade_out_offset + 33) * fade_out_step, 1.0f - (fade_out_offset + 32) * fade_out_step); - auto gain3 = _mm512_set_ps(1.0f - (fade_out_offset + 63) * fade_out_step, 1.0f - (fade_out_offset + 62) * fade_out_step, - 1.0f - (fade_out_offset + 61) * fade_out_step, 1.0f - (fade_out_offset + 60) * fade_out_step, - 1.0f - (fade_out_offset + 59) * fade_out_step, 1.0f - (fade_out_offset + 58) * fade_out_step, - 1.0f - (fade_out_offset + 57) * fade_out_step, 1.0f - (fade_out_offset + 56) * fade_out_step, - 1.0f - (fade_out_offset + 55) * fade_out_step, 1.0f - (fade_out_offset + 54) * fade_out_step, - 1.0f - (fade_out_offset + 53) * fade_out_step, 1.0f - (fade_out_offset + 52) * fade_out_step, - 1.0f - (fade_out_offset + 51) * fade_out_step, 1.0f - (fade_out_offset + 50) * fade_out_step, - 1.0f - (fade_out_offset + 49) * fade_out_step, 1.0f - (fade_out_offset + 48) * fade_out_step); - + auto gain0 = _mm512_set_ps(1.0f - (fade_out_offset + 15) * fade_out_step, + 1.0f - (fade_out_offset + 14) * fade_out_step, + 1.0f - (fade_out_offset + 13) * fade_out_step, + 1.0f - (fade_out_offset + 12) * fade_out_step, + 1.0f - (fade_out_offset + 11) * fade_out_step, + 1.0f - (fade_out_offset + 10) * fade_out_step, + 1.0f - (fade_out_offset + 9) * fade_out_step, + 1.0f - (fade_out_offset + 8) * fade_out_step, + 1.0f - (fade_out_offset + 7) * fade_out_step, + 1.0f - (fade_out_offset + 6) * fade_out_step, + 1.0f - (fade_out_offset + 5) * fade_out_step, + 1.0f - (fade_out_offset + 4) * fade_out_step, + 1.0f - (fade_out_offset + 3) * fade_out_step, + 1.0f - (fade_out_offset + 2) * fade_out_step, + 1.0f - (fade_out_offset + 1) * fade_out_step, + 1.0f - fade_out_offset * fade_out_step); + auto gain1 = _mm512_set_ps(1.0f - (fade_out_offset + 31) * fade_out_step, + 1.0f - (fade_out_offset + 30) * fade_out_step, + 1.0f - (fade_out_offset + 29) * fade_out_step, + 1.0f - (fade_out_offset + 28) * fade_out_step, + 1.0f - (fade_out_offset + 27) * fade_out_step, + 1.0f - (fade_out_offset + 26) * fade_out_step, + 1.0f - (fade_out_offset + 25) * fade_out_step, + 1.0f - (fade_out_offset + 24) * fade_out_step, + 1.0f - (fade_out_offset + 23) * fade_out_step, + 1.0f - (fade_out_offset + 22) * fade_out_step, + 1.0f - (fade_out_offset + 21) * fade_out_step, + 1.0f - (fade_out_offset + 20) * fade_out_step, + 1.0f - (fade_out_offset + 19) * fade_out_step, + 1.0f - (fade_out_offset + 18) * fade_out_step, + 1.0f - (fade_out_offset + 17) * fade_out_step, + 1.0f - (fade_out_offset + 16) * fade_out_step); + auto gain2 = _mm512_set_ps(1.0f - (fade_out_offset + 47) * fade_out_step, + 1.0f - (fade_out_offset + 46) * fade_out_step, + 1.0f - (fade_out_offset + 45) * fade_out_step, + 1.0f - (fade_out_offset + 44) * fade_out_step, + 1.0f - (fade_out_offset + 43) * fade_out_step, + 1.0f - (fade_out_offset + 42) * fade_out_step, + 1.0f - (fade_out_offset + 41) * fade_out_step, + 1.0f - (fade_out_offset + 40) * fade_out_step, + 1.0f - (fade_out_offset + 39) * fade_out_step, + 1.0f - (fade_out_offset + 38) * fade_out_step, + 1.0f - (fade_out_offset + 37) * fade_out_step, + 1.0f - (fade_out_offset + 36) * fade_out_step, + 1.0f - (fade_out_offset + 35) * fade_out_step, + 1.0f - (fade_out_offset + 34) * fade_out_step, + 1.0f - (fade_out_offset + 33) * fade_out_step, + 1.0f - (fade_out_offset + 32) * fade_out_step); + auto gain3 = _mm512_set_ps(1.0f - (fade_out_offset + 63) * fade_out_step, + 1.0f - (fade_out_offset + 62) * fade_out_step, + 1.0f - (fade_out_offset + 61) * fade_out_step, + 1.0f - (fade_out_offset + 60) * fade_out_step, + 1.0f - (fade_out_offset + 59) * fade_out_step, + 1.0f - (fade_out_offset + 58) * fade_out_step, + 1.0f - (fade_out_offset + 57) * fade_out_step, + 1.0f - (fade_out_offset + 56) * fade_out_step, + 1.0f - (fade_out_offset + 55) * fade_out_step, + 1.0f - (fade_out_offset + 54) * fade_out_step, + 1.0f - (fade_out_offset + 53) * fade_out_step, + 1.0f - (fade_out_offset + 52) * fade_out_step, + 1.0f - (fade_out_offset + 51) * fade_out_step, + 1.0f - (fade_out_offset + 50) * fade_out_step, + 1.0f - (fade_out_offset + 49) * fade_out_step, + 1.0f - (fade_out_offset + 48) * fade_out_step); + // 加载音频样本 auto a0 = _mm512_load_ps(&src[j]); auto a1 = _mm512_load_ps(&src[j + 16]); auto a2 = _mm512_load_ps(&src[j + 32]); auto a3 = _mm512_load_ps(&src[j + 48]); - + // 应用淡出增益 auto result0 = _mm512_mul_ps(a0, gain0); auto result1 = _mm512_mul_ps(a1, gain1); auto result2 = _mm512_mul_ps(a2, gain2); auto result3 = _mm512_mul_ps(a3, gain3); - + // 存储结果 _mm512_store_ps(&dst[j], result0); _mm512_store_ps(&dst[j + 16], result1); _mm512_store_ps(&dst[j + 32], result2); _mm512_store_ps(&dst[j + 48], result3); } - + // 处理剩余的淡出样本(标量处理) - for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * unroll_factor)); j < num_samples; ++j) { + for (size_t j = fade_out_start + ((fade_out_samples / (simd_width * unroll_factor)) * (simd_width * + unroll_factor)); j < num_samples; ++j) { const size_t fade_out_offset = j - fade_out_start; - const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); + const float gain = 1.0f - static_cast(fade_out_offset) / static_cast(fade_out_samples); dst[j] = src[j] * gain; } } } - + // 简单均衡器函数实现 (SSE版本) - void simple_eq_sse(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples) { + void simple_eq_sse(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_SSE); ASSERT_ALIGNED(dst, ALIGNMENT_SSE); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 4; // SSE每次处理4个float + + constexpr size_t simd_width = 4; // SSE每次处理4个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - + // 简化的频率分割系数 // 低频cutoff约为500Hz,高频cutoff约为5000Hz (假设44.1kHz采样率) - constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 - constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 - constexpr float mid_factor = 0.7f; // 中频保持系数 - + constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 + constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 + constexpr float mid_factor = 0.7f; // 中频保持系数 + // 初始化EQ状态 - float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; + float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; float high_state = eq_state != nullptr ? eq_state[1] : 0.0f; - + // 创建增益向量 - const auto low_gain_vec = _mm_set1_ps(low_gain); - const auto mid_gain_vec = _mm_set1_ps(mid_gain); - const auto high_gain_vec = _mm_set1_ps(high_gain); - const auto low_cutoff_vec = _mm_set1_ps(low_cutoff); - const auto high_cutoff_vec = _mm_set1_ps(high_cutoff); - const auto mid_factor_vec = _mm_set1_ps(mid_factor); - const auto one_minus_low_cutoff_vec = _mm_set1_ps(1.0f - low_cutoff); + const auto low_gain_vec = _mm_set1_ps(low_gain); + const auto mid_gain_vec = _mm_set1_ps(mid_gain); + const auto high_gain_vec = _mm_set1_ps(high_gain); + const auto low_cutoff_vec = _mm_set1_ps(low_cutoff); + const auto high_cutoff_vec = _mm_set1_ps(high_cutoff); + const auto mid_factor_vec = _mm_set1_ps(mid_factor); + const auto one_minus_low_cutoff_vec = _mm_set1_ps(1.0f - low_cutoff); const auto one_minus_high_cutoff_vec = _mm_set1_ps(1.0f - high_cutoff); - + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -1865,113 +1980,120 @@ namespace x86_simd_audio_processing_func { auto input1 = _mm_load_ps(&src[i + 4]); auto input2 = _mm_load_ps(&src[i + 8]); auto input3 = _mm_load_ps(&src[i + 12]); - + // 简化的低通滤波器实现(一阶IIR) // low_output = low_cutoff * input + (1 - low_cutoff) * low_state auto low_state_vec = _mm_set1_ps(low_state); - auto low0 = _mm_add_ps(_mm_mul_ps(input0, low_cutoff_vec), _mm_mul_ps(low_state_vec, one_minus_low_cutoff_vec)); + auto low0 = _mm_add_ps(_mm_mul_ps(input0, low_cutoff_vec), + _mm_mul_ps(low_state_vec, one_minus_low_cutoff_vec)); auto low1 = _mm_add_ps(_mm_mul_ps(input1, low_cutoff_vec), _mm_mul_ps(low0, one_minus_low_cutoff_vec)); auto low2 = _mm_add_ps(_mm_mul_ps(input2, low_cutoff_vec), _mm_mul_ps(low1, one_minus_low_cutoff_vec)); auto low3 = _mm_add_ps(_mm_mul_ps(input3, low_cutoff_vec), _mm_mul_ps(low2, one_minus_low_cutoff_vec)); - + // 简化的高通滤波器实现(输入减去低通) // high_output = input - low_output auto high0 = _mm_sub_ps(input0, low0); auto high1 = _mm_sub_ps(input1, low1); auto high2 = _mm_sub_ps(input2, low2); auto high3 = _mm_sub_ps(input3, low3); - + // 进一步高频处理 auto high_state_vec = _mm_set1_ps(high_state); - high0 = _mm_add_ps(_mm_mul_ps(high0, high_cutoff_vec), _mm_mul_ps(high_state_vec, one_minus_high_cutoff_vec)); + high0 = _mm_add_ps(_mm_mul_ps(high0, high_cutoff_vec), + _mm_mul_ps(high_state_vec, one_minus_high_cutoff_vec)); high1 = _mm_add_ps(_mm_mul_ps(high1, high_cutoff_vec), _mm_mul_ps(high0, one_minus_high_cutoff_vec)); high2 = _mm_add_ps(_mm_mul_ps(high2, high_cutoff_vec), _mm_mul_ps(high1, one_minus_high_cutoff_vec)); high3 = _mm_add_ps(_mm_mul_ps(high3, high_cutoff_vec), _mm_mul_ps(high2, one_minus_high_cutoff_vec)); - + // 中频:原始信号减去低频和高频 auto mid0 = _mm_mul_ps(_mm_sub_ps(_mm_sub_ps(input0, low0), high0), mid_factor_vec); auto mid1 = _mm_mul_ps(_mm_sub_ps(_mm_sub_ps(input1, low1), high1), mid_factor_vec); auto mid2 = _mm_mul_ps(_mm_sub_ps(_mm_sub_ps(input2, low2), high2), mid_factor_vec); auto mid3 = _mm_mul_ps(_mm_sub_ps(_mm_sub_ps(input3, low3), high3), mid_factor_vec); - + // 应用增益并混合 - auto result0 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low0, low_gain_vec), _mm_mul_ps(mid0, mid_gain_vec)), _mm_mul_ps(high0, high_gain_vec)); - auto result1 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low1, low_gain_vec), _mm_mul_ps(mid1, mid_gain_vec)), _mm_mul_ps(high1, high_gain_vec)); - auto result2 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low2, low_gain_vec), _mm_mul_ps(mid2, mid_gain_vec)), _mm_mul_ps(high2, high_gain_vec)); - auto result3 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low3, low_gain_vec), _mm_mul_ps(mid3, mid_gain_vec)), _mm_mul_ps(high3, high_gain_vec)); - + auto result0 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low0, low_gain_vec), _mm_mul_ps(mid0, mid_gain_vec)), + _mm_mul_ps(high0, high_gain_vec)); + auto result1 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low1, low_gain_vec), _mm_mul_ps(mid1, mid_gain_vec)), + _mm_mul_ps(high1, high_gain_vec)); + auto result2 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low2, low_gain_vec), _mm_mul_ps(mid2, mid_gain_vec)), + _mm_mul_ps(high2, high_gain_vec)); + auto result3 = _mm_add_ps(_mm_add_ps(_mm_mul_ps(low3, low_gain_vec), _mm_mul_ps(mid3, mid_gain_vec)), + _mm_mul_ps(high3, high_gain_vec)); + // 存储结果 _mm_store_ps(&dst[i], result0); _mm_store_ps(&dst[i + 4], result1); _mm_store_ps(&dst[i + 8], result2); _mm_store_ps(&dst[i + 12], result3); - + // 更新状态(使用最后一个元素) - low_state = _mm_cvtss_f32(_mm_shuffle_ps(low3, low3, _MM_SHUFFLE(3, 3, 3, 3))); + low_state = _mm_cvtss_f32(_mm_shuffle_ps(low3, low3, _MM_SHUFFLE(3, 3, 3, 3))); high_state = _mm_cvtss_f32(_mm_shuffle_ps(high3, high3, _MM_SHUFFLE(3, 3, 3, 3))); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { float input = src[i]; - + // 低通滤波器 float low_output = low_cutoff * input + (1.0f - low_cutoff) * low_state; - low_state = low_output; - + low_state = low_output; + // 高通滤波器 - float high_input = input - low_output; + float high_input = input - low_output; float high_output = high_cutoff * high_input + (1.0f - high_cutoff) * high_state; - high_state = high_output; - + high_state = high_output; + // 中频 float mid_output = (input - low_output - high_output) * mid_factor; - + // 混合并应用增益 dst[i] = low_output * low_gain + mid_output * mid_gain + high_output * high_gain; } - + // 更新EQ状态 if (eq_state != nullptr) { eq_state[0] = low_state; eq_state[1] = high_state; } } - + // 简单均衡器函数实现 (AVX版本) - void simple_eq_avx(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples) { + void simple_eq_avx(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX); ASSERT_ALIGNED(dst, ALIGNMENT_AVX); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 8; // AVX每次处理8个float + + constexpr size_t simd_width = 8; // AVX每次处理8个float constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - + // 简化的频率分割系数 - constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 - constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 - constexpr float mid_factor = 0.7f; // 中频保持系数 - + constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 + constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 + constexpr float mid_factor = 0.7f; // 中频保持系数 + // 初始化EQ状态 - float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; + float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; float high_state = eq_state != nullptr ? eq_state[1] : 0.0f; - + // 创建增益向量 - const auto low_gain_vec = _mm256_set1_ps(low_gain); - const auto mid_gain_vec = _mm256_set1_ps(mid_gain); - const auto high_gain_vec = _mm256_set1_ps(high_gain); - const auto low_cutoff_vec = _mm256_set1_ps(low_cutoff); - const auto high_cutoff_vec = _mm256_set1_ps(high_cutoff); - const auto mid_factor_vec = _mm256_set1_ps(mid_factor); - const auto one_minus_low_cutoff_vec = _mm256_set1_ps(1.0f - low_cutoff); + const auto low_gain_vec = _mm256_set1_ps(low_gain); + const auto mid_gain_vec = _mm256_set1_ps(mid_gain); + const auto high_gain_vec = _mm256_set1_ps(high_gain); + const auto low_cutoff_vec = _mm256_set1_ps(low_cutoff); + const auto high_cutoff_vec = _mm256_set1_ps(high_cutoff); + const auto mid_factor_vec = _mm256_set1_ps(mid_factor); + const auto one_minus_low_cutoff_vec = _mm256_set1_ps(1.0f - low_cutoff); const auto one_minus_high_cutoff_vec = _mm256_set1_ps(1.0f - high_cutoff); - + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -1979,113 +2101,130 @@ namespace x86_simd_audio_processing_func { auto input1 = _mm256_load_ps(&src[i + 8]); auto input2 = _mm256_load_ps(&src[i + 16]); auto input3 = _mm256_load_ps(&src[i + 24]); - + // 简化的低通滤波器实现(一阶IIR) auto low_state_vec = _mm256_set1_ps(low_state); - auto low0 = _mm256_add_ps(_mm256_mul_ps(input0, low_cutoff_vec), _mm256_mul_ps(low_state_vec, one_minus_low_cutoff_vec)); - auto low1 = _mm256_add_ps(_mm256_mul_ps(input1, low_cutoff_vec), _mm256_mul_ps(low0, one_minus_low_cutoff_vec)); - auto low2 = _mm256_add_ps(_mm256_mul_ps(input2, low_cutoff_vec), _mm256_mul_ps(low1, one_minus_low_cutoff_vec)); - auto low3 = _mm256_add_ps(_mm256_mul_ps(input3, low_cutoff_vec), _mm256_mul_ps(low2, one_minus_low_cutoff_vec)); - + auto low0 = _mm256_add_ps(_mm256_mul_ps(input0, low_cutoff_vec), + _mm256_mul_ps(low_state_vec, one_minus_low_cutoff_vec)); + auto low1 = _mm256_add_ps(_mm256_mul_ps(input1, low_cutoff_vec), + _mm256_mul_ps(low0, one_minus_low_cutoff_vec)); + auto low2 = _mm256_add_ps(_mm256_mul_ps(input2, low_cutoff_vec), + _mm256_mul_ps(low1, one_minus_low_cutoff_vec)); + auto low3 = _mm256_add_ps(_mm256_mul_ps(input3, low_cutoff_vec), + _mm256_mul_ps(low2, one_minus_low_cutoff_vec)); + // 简化的高通滤波器实现 auto high0 = _mm256_sub_ps(input0, low0); auto high1 = _mm256_sub_ps(input1, low1); auto high2 = _mm256_sub_ps(input2, low2); auto high3 = _mm256_sub_ps(input3, low3); - + // 进一步高频处理 auto high_state_vec = _mm256_set1_ps(high_state); - high0 = _mm256_add_ps(_mm256_mul_ps(high0, high_cutoff_vec), _mm256_mul_ps(high_state_vec, one_minus_high_cutoff_vec)); - high1 = _mm256_add_ps(_mm256_mul_ps(high1, high_cutoff_vec), _mm256_mul_ps(high0, one_minus_high_cutoff_vec)); - high2 = _mm256_add_ps(_mm256_mul_ps(high2, high_cutoff_vec), _mm256_mul_ps(high1, one_minus_high_cutoff_vec)); - high3 = _mm256_add_ps(_mm256_mul_ps(high3, high_cutoff_vec), _mm256_mul_ps(high2, one_minus_high_cutoff_vec)); - + high0 = _mm256_add_ps(_mm256_mul_ps(high0, high_cutoff_vec), + _mm256_mul_ps(high_state_vec, one_minus_high_cutoff_vec)); + high1 = _mm256_add_ps(_mm256_mul_ps(high1, high_cutoff_vec), + _mm256_mul_ps(high0, one_minus_high_cutoff_vec)); + high2 = _mm256_add_ps(_mm256_mul_ps(high2, high_cutoff_vec), + _mm256_mul_ps(high1, one_minus_high_cutoff_vec)); + high3 = _mm256_add_ps(_mm256_mul_ps(high3, high_cutoff_vec), + _mm256_mul_ps(high2, one_minus_high_cutoff_vec)); + // 中频:原始信号减去低频和高频 auto mid0 = _mm256_mul_ps(_mm256_sub_ps(_mm256_sub_ps(input0, low0), high0), mid_factor_vec); auto mid1 = _mm256_mul_ps(_mm256_sub_ps(_mm256_sub_ps(input1, low1), high1), mid_factor_vec); auto mid2 = _mm256_mul_ps(_mm256_sub_ps(_mm256_sub_ps(input2, low2), high2), mid_factor_vec); auto mid3 = _mm256_mul_ps(_mm256_sub_ps(_mm256_sub_ps(input3, low3), high3), mid_factor_vec); - + // 应用增益并混合 - auto result0 = _mm256_add_ps(_mm256_add_ps(_mm256_mul_ps(low0, low_gain_vec), _mm256_mul_ps(mid0, mid_gain_vec)), _mm256_mul_ps(high0, high_gain_vec)); - auto result1 = _mm256_add_ps(_mm256_add_ps(_mm256_mul_ps(low1, low_gain_vec), _mm256_mul_ps(mid1, mid_gain_vec)), _mm256_mul_ps(high1, high_gain_vec)); - auto result2 = _mm256_add_ps(_mm256_add_ps(_mm256_mul_ps(low2, low_gain_vec), _mm256_mul_ps(mid2, mid_gain_vec)), _mm256_mul_ps(high2, high_gain_vec)); - auto result3 = _mm256_add_ps(_mm256_add_ps(_mm256_mul_ps(low3, low_gain_vec), _mm256_mul_ps(mid3, mid_gain_vec)), _mm256_mul_ps(high3, high_gain_vec)); - + auto result0 = _mm256_add_ps( + _mm256_add_ps(_mm256_mul_ps(low0, low_gain_vec), _mm256_mul_ps(mid0, mid_gain_vec)), + _mm256_mul_ps(high0, high_gain_vec)); + auto result1 = _mm256_add_ps( + _mm256_add_ps(_mm256_mul_ps(low1, low_gain_vec), _mm256_mul_ps(mid1, mid_gain_vec)), + _mm256_mul_ps(high1, high_gain_vec)); + auto result2 = _mm256_add_ps( + _mm256_add_ps(_mm256_mul_ps(low2, low_gain_vec), _mm256_mul_ps(mid2, mid_gain_vec)), + _mm256_mul_ps(high2, high_gain_vec)); + auto result3 = _mm256_add_ps( + _mm256_add_ps(_mm256_mul_ps(low3, low_gain_vec), _mm256_mul_ps(mid3, mid_gain_vec)), + _mm256_mul_ps(high3, high_gain_vec)); + // 存储结果 _mm256_store_ps(&dst[i], result0); _mm256_store_ps(&dst[i + 8], result1); _mm256_store_ps(&dst[i + 16], result2); _mm256_store_ps(&dst[i + 24], result3); - + // 更新状态(使用最后一个元素) - auto low_temp = _mm256_extractf128_ps(low3, 1); - low_state = _mm_cvtss_f32(_mm_shuffle_ps(low_temp, low_temp, _MM_SHUFFLE(3, 3, 3, 3))); + auto low_temp = _mm256_extractf128_ps(low3, 1); + low_state = _mm_cvtss_f32(_mm_shuffle_ps(low_temp, low_temp, _MM_SHUFFLE(3, 3, 3, 3))); auto high_temp = _mm256_extractf128_ps(high3, 1); - high_state = _mm_cvtss_f32(_mm_shuffle_ps(high_temp, high_temp, _MM_SHUFFLE(3, 3, 3, 3))); + high_state = _mm_cvtss_f32(_mm_shuffle_ps(high_temp, high_temp, _MM_SHUFFLE(3, 3, 3, 3))); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { float input = src[i]; - + // 低通滤波器 float low_output = low_cutoff * input + (1.0f - low_cutoff) * low_state; - low_state = low_output; - + low_state = low_output; + // 高通滤波器 - float high_input = input - low_output; + float high_input = input - low_output; float high_output = high_cutoff * high_input + (1.0f - high_cutoff) * high_state; - high_state = high_output; - + high_state = high_output; + // 中频 float mid_output = (input - low_output - high_output) * mid_factor; - + // 混合并应用增益 dst[i] = low_output * low_gain + mid_output * mid_gain + high_output * high_gain; } - + // 更新EQ状态 if (eq_state != nullptr) { eq_state[0] = low_state; eq_state[1] = high_state; } } - + // 简单均衡器函数实现 (AVX512版本) - void simple_eq_avx512(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples) { + void simple_eq_avx512(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, + float* eq_state, size_t num_samples) { ASSERT_ALIGNED(src, ALIGNMENT_AVX512); ASSERT_ALIGNED(dst, ALIGNMENT_AVX512); - + // 边界情况处理 if (num_samples == 0) { return; } - - constexpr size_t simd_width = 16; // AVX512每次处理16个float - constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 - + + constexpr size_t simd_width = 16; // AVX512每次处理16个float + constexpr size_t unroll_factor = 4; // 4路循环展开提高指令级并行性 + // 简化的频率分割系数 - constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 - constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 - constexpr float mid_factor = 0.7f; // 中频保持系数 - + constexpr float low_cutoff = 0.02f; // 简化的低通滤波器系数 + constexpr float high_cutoff = 0.1f; // 简化的高通滤波器系数 + constexpr float mid_factor = 0.7f; // 中频保持系数 + // 初始化EQ状态 - float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; + float low_state = eq_state != nullptr ? eq_state[0] : 0.0f; float high_state = eq_state != nullptr ? eq_state[1] : 0.0f; - + // 创建增益向量 - const auto low_gain_vec = _mm512_set1_ps(low_gain); - const auto mid_gain_vec = _mm512_set1_ps(mid_gain); - const auto high_gain_vec = _mm512_set1_ps(high_gain); - const auto low_cutoff_vec = _mm512_set1_ps(low_cutoff); - const auto high_cutoff_vec = _mm512_set1_ps(high_cutoff); - const auto mid_factor_vec = _mm512_set1_ps(mid_factor); - const auto one_minus_low_cutoff_vec = _mm512_set1_ps(1.0f - low_cutoff); + const auto low_gain_vec = _mm512_set1_ps(low_gain); + const auto mid_gain_vec = _mm512_set1_ps(mid_gain); + const auto high_gain_vec = _mm512_set1_ps(high_gain); + const auto low_cutoff_vec = _mm512_set1_ps(low_cutoff); + const auto high_cutoff_vec = _mm512_set1_ps(high_cutoff); + const auto mid_factor_vec = _mm512_set1_ps(mid_factor); + const auto one_minus_low_cutoff_vec = _mm512_set1_ps(1.0f - low_cutoff); const auto one_minus_high_cutoff_vec = _mm512_set1_ps(1.0f - high_cutoff); - + size_t i = 0; - + // 向量化处理(4路循环展开) for (; i + simd_width * unroll_factor <= num_samples; i += simd_width * unroll_factor) { // 加载4个向量 @@ -2093,72 +2232,76 @@ namespace x86_simd_audio_processing_func { auto input1 = _mm512_load_ps(&src[i + 16]); auto input2 = _mm512_load_ps(&src[i + 32]); auto input3 = _mm512_load_ps(&src[i + 48]); - + // 简化的低通滤波器实现(一阶IIR) auto low_state_vec = _mm512_set1_ps(low_state); auto low0 = _mm512_fmadd_ps(input0, low_cutoff_vec, _mm512_mul_ps(low_state_vec, one_minus_low_cutoff_vec)); auto low1 = _mm512_fmadd_ps(input1, low_cutoff_vec, _mm512_mul_ps(low0, one_minus_low_cutoff_vec)); auto low2 = _mm512_fmadd_ps(input2, low_cutoff_vec, _mm512_mul_ps(low1, one_minus_low_cutoff_vec)); auto low3 = _mm512_fmadd_ps(input3, low_cutoff_vec, _mm512_mul_ps(low2, one_minus_low_cutoff_vec)); - + // 简化的高通滤波器实现 auto high0 = _mm512_sub_ps(input0, low0); auto high1 = _mm512_sub_ps(input1, low1); auto high2 = _mm512_sub_ps(input2, low2); auto high3 = _mm512_sub_ps(input3, low3); - + // 进一步高频处理 auto high_state_vec = _mm512_set1_ps(high_state); high0 = _mm512_fmadd_ps(high0, high_cutoff_vec, _mm512_mul_ps(high_state_vec, one_minus_high_cutoff_vec)); high1 = _mm512_fmadd_ps(high1, high_cutoff_vec, _mm512_mul_ps(high0, one_minus_high_cutoff_vec)); high2 = _mm512_fmadd_ps(high2, high_cutoff_vec, _mm512_mul_ps(high1, one_minus_high_cutoff_vec)); high3 = _mm512_fmadd_ps(high3, high_cutoff_vec, _mm512_mul_ps(high2, one_minus_high_cutoff_vec)); - + // 中频:原始信号减去低频和高频 auto mid0 = _mm512_mul_ps(_mm512_sub_ps(_mm512_sub_ps(input0, low0), high0), mid_factor_vec); auto mid1 = _mm512_mul_ps(_mm512_sub_ps(_mm512_sub_ps(input1, low1), high1), mid_factor_vec); auto mid2 = _mm512_mul_ps(_mm512_sub_ps(_mm512_sub_ps(input2, low2), high2), mid_factor_vec); auto mid3 = _mm512_mul_ps(_mm512_sub_ps(_mm512_sub_ps(input3, low3), high3), mid_factor_vec); - + // 应用增益并混合 (使用FMA指令优化) - auto result0 = _mm512_fmadd_ps(low0, low_gain_vec, _mm512_fmadd_ps(mid0, mid_gain_vec, _mm512_mul_ps(high0, high_gain_vec))); - auto result1 = _mm512_fmadd_ps(low1, low_gain_vec, _mm512_fmadd_ps(mid1, mid_gain_vec, _mm512_mul_ps(high1, high_gain_vec))); - auto result2 = _mm512_fmadd_ps(low2, low_gain_vec, _mm512_fmadd_ps(mid2, mid_gain_vec, _mm512_mul_ps(high2, high_gain_vec))); - auto result3 = _mm512_fmadd_ps(low3, low_gain_vec, _mm512_fmadd_ps(mid3, mid_gain_vec, _mm512_mul_ps(high3, high_gain_vec))); - + auto result0 = _mm512_fmadd_ps(low0, low_gain_vec, + _mm512_fmadd_ps(mid0, mid_gain_vec, _mm512_mul_ps(high0, high_gain_vec))); + auto result1 = _mm512_fmadd_ps(low1, low_gain_vec, + _mm512_fmadd_ps(mid1, mid_gain_vec, _mm512_mul_ps(high1, high_gain_vec))); + auto result2 = _mm512_fmadd_ps(low2, low_gain_vec, + _mm512_fmadd_ps(mid2, mid_gain_vec, _mm512_mul_ps(high2, high_gain_vec))); + auto result3 = _mm512_fmadd_ps(low3, low_gain_vec, + _mm512_fmadd_ps(mid3, mid_gain_vec, _mm512_mul_ps(high3, high_gain_vec))); + // 存储结果 _mm512_store_ps(&dst[i], result0); _mm512_store_ps(&dst[i + 16], result1); _mm512_store_ps(&dst[i + 32], result2); _mm512_store_ps(&dst[i + 48], result3); - + // 更新状态(使用最后一个元素) - __m128 low_temp = _mm512_extractf32x4_ps(low3, 3); - low_state = _mm_cvtss_f32(low_temp); + __m128 low_temp = _mm512_extractf32x4_ps(low3, 3); + low_state = _mm_cvtss_f32(low_temp); __m128 high_temp = _mm512_extractf32x4_ps(high3, 3); - high_state = _mm_cvtss_f32(high_temp); + high_state = _mm_cvtss_f32(high_temp); } - + // 处理剩余的标量样本 for (; i < num_samples; ++i) { float input = src[i]; - + // 低通滤波器 float low_output = low_cutoff * input + (1.0f - low_cutoff) * low_state; - low_state = low_output; - + low_state = low_output; + // 高通滤波器 - float high_input = input - low_output; + float high_input = input - low_output; float high_output = high_cutoff * high_input + (1.0f - high_cutoff) * high_state; - high_state = high_output; - + high_state = high_output; + // 中频 float mid_output = (input - low_output - high_output) * mid_factor; - + // 混合并应用增益 dst[i] = low_output * low_gain + mid_output * mid_gain + high_output * high_gain; } - + // 更新EQ状态 if (eq_state != nullptr) { eq_state[0] = low_state; diff --git a/src/simd/audio_processing/x86_simd_audio_processing_func.h b/src/simd/audio_processing/x86_simd_audio_processing_func.h index 4acfa70..afbb840 100644 --- a/src/simd/audio_processing/x86_simd_audio_processing_func.h +++ b/src/simd/audio_processing/x86_simd_audio_processing_func.h @@ -29,18 +29,27 @@ namespace x86_simd_audio_processing_func { void stereo_to_mono_avx512(const float* stereo_src, float* mono_dst, size_t num_stereo_samples); // 音频限幅:将超过阈值的样本限制在指定范围内 - void limit_audio_sse(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples); - void limit_audio_avx(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples); - void limit_audio_avx512(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, size_t num_samples); + void limit_audio_sse(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples); + void limit_audio_avx(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples); + void limit_audio_avx512(const float* src, float* dst, float threshold, float* limiter_state, float sample_rate, + size_t num_samples); // 音频淡入淡出:应用线性淡入淡出效果 - void fade_audio_sse(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples); - void fade_audio_avx(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples); - void fade_audio_avx512(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, size_t num_samples); + void fade_audio_sse(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples); + void fade_audio_avx(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples); + void fade_audio_avx512(const float* src, float* dst, size_t fade_in_samples, size_t fade_out_samples, + size_t num_samples); // 简单均衡器:简单的三段均衡器(低频、中频、高频增益) - void simple_eq_sse(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples); - void simple_eq_avx(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples); - void simple_eq_avx512(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, size_t num_samples); + void simple_eq_sse(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples); + void simple_eq_avx(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, float* eq_state, + size_t num_samples); + void simple_eq_avx512(const float* src, float* dst, float low_gain, float mid_gain, float high_gain, + float* eq_state, size_t num_samples); } #endif diff --git a/src/simd/cpu_features.cpp b/src/simd/cpu_features.cpp index 50e4aa7..4ddb9f5 100644 --- a/src/simd/cpu_features.cpp +++ b/src/simd/cpu_features.cpp @@ -13,17 +13,17 @@ #include #elif ALICHO_PLATFORM_UNIX || ALICHO_PLATFORM_POSIX #if ALICHO_PLATFORM_X86 - #include +#include #endif - #include - #include +#include +#include #if ALICHO_PLATFORM_APPLE - #include +#include #endif #endif #if ALICHO_PLATFORM_ARM && defined(__ARM_NEON) - #include +#include #endif auto cpu_info::features_string() const -> std::string { @@ -35,29 +35,29 @@ auto cpu_info::features_string() const -> std::string { }; static const FeatureNamePair feature_names[] = { - { cpu_feature::SSE, "SSE" }, - { cpu_feature::SSE2, "SSE2" }, - { cpu_feature::SSE3, "SSE3" }, - { cpu_feature::SSSE3, "SSSE3" }, - { cpu_feature::SSE4_1, "SSE4.1" }, - { cpu_feature::SSE4_2, "SSE4.2" }, - { cpu_feature::AVX, "AVX" }, - { cpu_feature::AVX2, "AVX2" }, - { cpu_feature::FMA, "FMA" }, - { cpu_feature::AVX512F, "AVX512F" }, - { cpu_feature::AVX512VL, "AVX512VL" }, - { cpu_feature::AVX512BW, "AVX512BW" }, - { cpu_feature::AVX512DQ, "AVX512DQ" }, - { cpu_feature::AVX512IFMA, "AVX512IFMA" }, - { cpu_feature::AVX512VBMI, "AVX512VBMI" }, - { cpu_feature::NEON, "NEON" }, - { cpu_feature::NEON_FP16, "NEON_FP16" }, - { cpu_feature::VECTOR, "VECTOR" }, - { cpu_feature::POPCNT, "POPCNT" }, - { cpu_feature::BMI1, "BMI1" } + {cpu_feature::SSE, "SSE"}, + {cpu_feature::SSE2, "SSE2"}, + {cpu_feature::SSE3, "SSE3"}, + {cpu_feature::SSSE3, "SSSE3"}, + {cpu_feature::SSE4_1, "SSE4.1"}, + {cpu_feature::SSE4_2, "SSE4.2"}, + {cpu_feature::AVX, "AVX"}, + {cpu_feature::AVX2, "AVX2"}, + {cpu_feature::FMA, "FMA"}, + {cpu_feature::AVX512F, "AVX512F"}, + {cpu_feature::AVX512VL, "AVX512VL"}, + {cpu_feature::AVX512BW, "AVX512BW"}, + {cpu_feature::AVX512DQ, "AVX512DQ"}, + {cpu_feature::AVX512IFMA, "AVX512IFMA"}, + {cpu_feature::AVX512VBMI, "AVX512VBMI"}, + {cpu_feature::NEON, "NEON"}, + {cpu_feature::NEON_FP16, "NEON_FP16"}, + {cpu_feature::VECTOR, "VECTOR"}, + {cpu_feature::POPCNT, "POPCNT"}, + {cpu_feature::BMI1, "BMI1"} }; - for (const auto& pair: feature_names) { + for (const auto& pair : feature_names) { if (supports(pair.feature)) { if (!result.empty()) { result += ", "; } result += pair.name; @@ -74,12 +74,12 @@ auto cpu_feature_detector::recommended_simd_level() const noexcept -> simd_level // 基本策略: 选择最高支持的SIMD级别 switch (info_.max_simd_level) { case simd_level::AVX512: -#if ALICHO_PLATFORM_WINDOWS + #if ALICHO_PLATFORM_WINDOWS if (info_.vendor.find("Intel") != std::string::npos && ( - info_.brand.find("Xeon") != std::string::npos || info_.brand.find("i9") != std::string::npos)) { + info_.brand.find("Xeon") != std::string::npos || info_.brand.find("i9") != std::string::npos)) { return simd_level::AVX512; // 高端Intel CPU上使用AVX-512 } -#endif + #endif // 所有的AMD CPU上都开启AVX-512支持 if (info_.vendor.find("AMD") != std::string::npos) { return simd_level::AVX512; } return simd_level::AVX2; // 其他情况回退到AVX2以确保兼容性 @@ -106,13 +106,13 @@ void cpu_feature_detector::print_info() const { } cpu_feature_detector::cpu_feature_detector() { -#if ALICHO_PLATFORM_WINDOWS + #if ALICHO_PLATFORM_WINDOWS SYSTEM_INFO sys_info; GetSystemInfo(&sys_info); info_.logical_cores = sys_info.dwNumberOfProcessors; -#else + #else info_.logical_cores = static_cast(std::thread::hardware_concurrency()); -#endif + #endif info_.physical_cores = info_.logical_cores; // 先填写默认值,实际检测在detect_features中进行 @@ -120,19 +120,19 @@ cpu_feature_detector::cpu_feature_detector() { } void cpu_feature_detector::detect_features() { -#if ALICHO_PLATFORM_X86 + #if ALICHO_PLATFORM_X86 detect_x86_features(); -#elif ALICHO_PLATFORM_ARM + #elif ALICHO_PLATFORM_ARM detect_arm_features(); -#else + #else info_.max_simd_level = simd_level::NONE; -#endif + #endif } void cpu_feature_detector::detect_x86_features() { -#if ALICHO_PLATFORM_X86 - char vendor[13] = { 0 }; - char brand[49] = { 0 }; + #if ALICHO_PLATFORM_X86 + char vendor[13] = {0}; + char brand[49] = {0}; // 获取制造商和品牌字符串 get_vendor_string(vendor); @@ -184,7 +184,7 @@ void cpu_feature_detector::detect_x86_features() { else { info_.max_simd_level = simd_level::NONE; } // 获取物理核心数 - 平台特定实现 -#if ALICHO_PLATFORM_WINDOWS + #if ALICHO_PLATFORM_WINDOWS // Windows特定的物理心数检测 DWORD length = 0; PSYSTEM_LOGICAL_PROCESSOR_INFORMATION buffer = nullptr; @@ -192,7 +192,7 @@ void cpu_feature_detector::detect_x86_features() { // 获取所需缓冲区大小 GetLogicalProcessorInformation(buffer, &length); if (GetLastError() == ERROR_INSUFFICIENT_BUFFER) { - buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION) malloc(length); + buffer = (PSYSTEM_LOGICAL_PROCESSOR_INFORMATION)malloc(length); if (buffer) { if (GetLogicalProcessorInformation(buffer, &length)) { DWORD processorCoreCount = 0; @@ -210,119 +210,123 @@ void cpu_feature_detector::detect_x86_features() { free(buffer); } } -#elif ALICHO_PLATFORM_MACOS - // macOS特定的物理核数检测 - int nm[2]; - size_t len = 4; - uint32_t count; + #elif ALICHO_PLATFORM_MACOS + // macOS特定的物理核数检测 + int nm[2]; + size_t len = 4; + uint32_t count; - nm[0] = CTL_HW; - nm[1] = HW_AVAILCPU; - sysctlbyname("hw.physicalcpu", &count, &len, NULL, 0); - if (count > 0) { - info_.physical_cores = static_cast(count); - } else { - // 回退到逻辑核心数 - info_.physical_cores = info_.logical_cores; - } -#elif ALICHO_PLATFORM_LINUX - // Linux特定的物理核数检测 - // 尝试读取/proc/cpuinfo - FILE* fp = fopen("/proc/cpuinfo", "r"); - if (fp) { - char buffer[1024]; - std::unordered_set physical_ids; + nm[0] = CTL_HW; + nm[1] = HW_AVAILCPU; + sysctlbyname("hw.physicalcpu", &count, &len, NULL, 0); + if (count > 0) { + info_.physical_cores = static_cast(count); + } + else { + // 回退到逻辑核心数 + info_.physical_cores = info_.logical_cores; + } + #elif ALICHO_PLATFORM_LINUX + // Linux特定的物理核数检测 + // 尝试读取/proc/cpuinfo + FILE* fp = fopen("/proc/cpuinfo", "r"); + if (fp) { + char buffer[1024]; + std::unordered_set physical_ids; - while (fgets(buffer, sizeof(buffer), fp)) { - if (strncmp(buffer, "physical id", 11) == 0) { - char* id_pos = strchr(buffer, ':'); - if (id_pos) { - physical_ids.insert(std::string(id_pos + 1)); - } - } - } - fclose(fp); + while (fgets(buffer, sizeof(buffer), fp)) { + if (strncmp(buffer, "physical id", 11) == 0) { + char* id_pos = strchr(buffer, ':'); + if (id_pos) { + physical_ids.insert(std::string(id_pos + 1)); + } + } + } + fclose(fp); - if (!physical_ids.empty()) { - info_.physical_cores = static_cast(physical_ids.size()); - } - } -#endif -#endif + if (!physical_ids.empty()) { + info_.physical_cores = static_cast(physical_ids.size()); + } + } + #endif + #endif } void cpu_feature_detector::detect_arm_features() { -#if ALICHO_PLATFORM_ARM + #if ALICHO_PLATFORM_ARM // 检测NEON支持 -#if defined(__ARM_NEON) || defined(__ARM_NEON__) - info_.features |= static_cast(cpu_feature::NEON); - info_.max_simd_level = simd_level::NEON; -#endif + #if defined(__ARM_NEON) || defined(__ARM_NEON__) + info_.features |= static_cast(cpu_feature::NEON); + info_.max_simd_level = simd_level::NEON; + #endif // 检测FP16支持(如果可用) -#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) - info_.features |= static_cast(cpu_feature::NEON_FP16); - info_.max_simd_level = simd_level::NEON_FP16; -#endif + #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) + info_.features |= static_cast(cpu_feature::NEON_FP16); + info_.max_simd_level = simd_level::NEON_FP16; + #endif -#if ALICHO_PLATFORM_ANDROID - info_.vendor = "ARM"; - if (auto* fp = fopen("/proc/cpuinfo", "r")) { - char line[1024]; - char line[1024]; - while (fgets(line, sizeof(line), fp)) { - if (strncmp(line, "Hardware", 8) == 0 || - strncmp(line, "model name", 10) == 0) { - char* sep = strchr(line, ':'); - if (sep) { - // 去除前导空格 - char* value = sep + 1; - while (*value && isspace(*value)) value++; + #if ALICHO_PLATFORM_ANDROID + info_.vendor = "ARM"; + if (auto* fp = fopen("/proc/cpuinfo", "r")) { + char line[1024]; + char line[1024]; + while (fgets(line, sizeof(line), fp)) { + if (strncmp(line, "Hardware", 8) == 0 || + strncmp(line, "model name", 10) == 0) { + char* sep = strchr(line, ':'); + if (sep) { + // 去除前导空格 + char* value = sep + 1; + while (*value && isspace(*value)) + value++; - // 去除末尾换行符 - char* end = value; - while (*end && *end != '\n') end++; - *end = '\0'; + // 去除末尾换行符 + char* end = value; + while (*end && *end != '\n') + end++; + *end = '\0'; - info_.brand = value; - break; - } + info_.brand = value; + break; } } - fclose(fp); - } -#elif ALICHO_PLATFORM_APPLE - // iOS/macOS ARM (Apple Silicon) - info_.vendor = "Apple"; - char brand[128]; - size_t size = sizeof(brand); - if (sysctlbyname("machdep.cpu.brand_string", &brand, &size, NULL, 0) == 0) { - info_.brand = brand; } + fclose(fp); + } + #elif ALICHO_PLATFORM_APPLE + // iOS/macOS ARM (Apple Silicon) + info_.vendor = "Apple"; + char brand[128]; + size_t size = sizeof(brand); + if (sysctlbyname("machdep.cpu.brand_string", &brand, &size, NULL, 0) == 0) { + info_.brand = brand; + } - // 获取物理核心数 (通过sysctl) - int nm[2]; - size_t len = 4; - uint32_t count; + // 获取物理核心数 (通过sysctl) + int nm[2]; + size_t len = 4; + uint32_t count; - sysctlbyname("hw.physicalcpu", &count, &len, NULL, 0); - if (count > 0) { - info_.physical_cores = static_cast(count); - } else { - // 回退到逻辑核心数 - info_.physical_cores = info_.logical_cores; - } -#else - // 其他ARM平台 - info_.vendor = "ARM"; - info_.brand = "Unknown ARM Processor"; -#endif + sysctlbyname("hw.physicalcpu", &count, &len, NULL, 0); + if (count > 0) { + info_.physical_cores = static_cast(count); + } + else { + // 回退到逻辑核心数 + info_.physical_cores = info_.logical_cores; + } + #else + // 其他ARM平台 + info_.vendor = "ARM"; + info_.brand = "Unknown ARM Processor"; + #endif // 获取物理核心数 (Linux/Unix) -#if ALICHO_PLATFORM_LINUX && !ALICHO_PLATFORM_ANDROID - info_.physical_cores = static_cast(sysconf(_SC_NPROCESSORS_CONF)); -#endif -#endif + #if ALICHO_PLATFORM_LINUX && !ALICHO_PLATFORM_ANDROID + info_.physical_cores = static_cast(sysconf(_SC_NPROCESSORS_CONF)); + #endif + #endif } #if ALICHO_PLATFORM_X86 @@ -344,7 +348,7 @@ auto cpu_feature_detector::cpuid(uint32_t function_id, uint32_t subfunction_id) #else auto cpu_feature_detector::cpuid(uint32_t function_id, uint32_t subfunction_id) -> cpu_id_result { cpu_id_result result; - uint32_t regs[4]; + uint32_t regs[4]; __cpuid_count(function_id, subfunction_id, regs[0], regs[1], regs[2], regs[3]); diff --git a/src/simd/cpu_features.h b/src/simd/cpu_features.h index 6522183..e64fa03 100644 --- a/src/simd/cpu_features.h +++ b/src/simd/cpu_features.h @@ -91,7 +91,7 @@ protected: void detect_arm_features(); -#if ALICHO_PLATFORM_X86 + #if ALICHO_PLATFORM_X86 struct cpu_id_result { uint32_t eax, ebx, ecx, edx; }; @@ -101,7 +101,7 @@ protected: static void get_vendor_string(char* vendor_string); static void get_brand_string(char* brand_string); -#endif + #endif cpu_info info_; }; diff --git a/src/simd/simd_func_dispatcher.cpp b/src/simd/simd_func_dispatcher.cpp index d4b1b89..5326ad6 100644 --- a/src/simd/simd_func_dispatcher.cpp +++ b/src/simd/simd_func_dispatcher.cpp @@ -3,7 +3,7 @@ void simd_func_dispatcher::print_registry_status() const { printf("Registered SIMD Functions:\n"); - for (const auto& pair: func_registry_) { + for (const auto& pair : func_registry_) { const auto& func_name = pair.first; const auto& holder = pair.second; @@ -11,7 +11,7 @@ void simd_func_dispatcher::print_registry_status() const { if (holder->has_implementation()) { auto versions = holder->get_available_versions(); printf(" Available Versions: "); - for (const auto& version: versions) { + for (const auto& version : versions) { switch (version) { case simd_func_version::SCALAR: printf("SCALAR "); diff --git a/src/simd/simd_func_dispatcher.h b/src/simd/simd_func_dispatcher.h index 22fc477..13b452e 100644 --- a/src/simd/simd_func_dispatcher.h +++ b/src/simd/simd_func_dispatcher.h @@ -48,10 +48,10 @@ constexpr auto simd_level_to_version(simd_level level) { return simd_func_version::SCALAR; } -template +template class multi_version_func; -template +template class multi_version_func { public: using func_type = std::function; @@ -61,7 +61,7 @@ public: void register_version(simd_func_version version, func_type func) { functions_[static_cast(version)] = std::move(func); - best_func_ = get_best_func(); // 更新最佳函数 + best_func_ = get_best_func(); // 更新最佳函数 } const auto& get_best_func() const { @@ -101,7 +101,7 @@ public: private: func_arr functions_{}; - func_type best_func_{ nullptr }; + func_type best_func_{nullptr}; }; class simd_func_dispatcher : public lazy_singleton { @@ -109,28 +109,28 @@ public: friend class lazy_singleton; // 注册函数(通过函数名) - template + template void register_function(const std::string& func_name, - simd_func_version version, - std::function func) { + simd_func_version version, + std::function func) { auto& holder = get_or_create_func(func_name); holder.register_version(version, std::move(func)); } // 获取函数 - template + template const auto& get_function(const std::string& func_name) const { const auto& it = func_registry_.find(func_name); if (it == func_registry_.end()) { throw std::runtime_error("函数 '" + func_name + "' 未注册"); } - + auto* holder = static_cast*>(it->second.get()); return holder->func; } // 调用函数 - template + template auto call_function(const std::string& func_name, args&&... in_args) const { const auto& func = get_function(func_name); return func(std::forward(in_args)...); @@ -139,7 +139,7 @@ public: // 列出所有已经注册的函数 [[nodiscard]] auto list_functions() const -> std::vector { std::vector func_names; - for (const auto& pair: func_registry_) { func_names.push_back(pair.first); } + for (const auto& pair : func_registry_) { func_names.push_back(pair.first); } return func_names; } @@ -157,7 +157,7 @@ private: }; // 具体的函数持有者模板 - template + template struct func_holder : func_holder_base { multi_version_func func; @@ -169,7 +169,7 @@ private: }; // 获取或创建函数持有者(仅用于注册) - template + template auto& get_or_create_func(const std::string& func_name) { const auto& it = func_registry_.find(func_name); if (it != func_registry_.end()) { @@ -194,7 +194,7 @@ private: #define CALL_SIMD_FUNCTION(func_signature, func_name, ...) \ simd_func_dispatcher::instance().call_function(func_name, __VA_ARGS__); -template +template class simd_auto_register { public: simd_auto_register(const std::string& func_name, simd_func_version version, std::function func) { diff --git a/tests/network/test_zmq_rpc.cpp b/tests/network/test_zmq_rpc.cpp index 588a26d..ecb6019 100644 --- a/tests/network/test_zmq_rpc.cpp +++ b/tests/network/test_zmq_rpc.cpp @@ -21,60 +21,60 @@ // 用于测试的简单计数器,跟踪处理器调用 struct test_counter { - std::atomic count{0}; - std::mutex mutex; - std::condition_variable cv; - - void increment() { - count++; - cv.notify_all(); - } - - bool wait_for(int expected, std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { - std::unique_lock lock(mutex); - return cv.wait_for(lock, timeout, [this, expected]() { - return count >= expected; - }); - } - - void reset() { - count = 0; - } + std::atomic count{0}; + std::mutex mutex; + std::condition_variable cv; + + void increment() { + count++; + cv.notify_all(); + } + + bool wait_for(int expected, std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) { + std::unique_lock lock(mutex); + return cv.wait_for(lock, timeout, [this, expected]() { + return count >= expected; + }); + } + + void reset() { + count = 0; + } }; // 测试用的数据结构 struct test_message_t { - uint32_t id; - std::string content; + uint32_t id; + std::string content; }; struct test_response_t { - uint32_t request_id; - bool success; - std::string result; + uint32_t request_id; + bool success; + std::string result; }; // 全局计数器用于验证处理器调用 inline test_counter g_server_counter; inline test_counter g_client_counter; -inline std::string g_last_received_content; -inline std::mutex g_content_mutex; +inline std::string g_last_received_content; +inline std::mutex g_content_mutex; // 注册测试消息处理器 ZMQ_SERVER_REGISTER_PROCESSOR(test_message_t) { - g_server_counter.increment(); - { - std::lock_guard lock(g_content_mutex); - g_last_received_content = data.content; - } + g_server_counter.increment(); + { + std::lock_guard lock(g_content_mutex); + g_last_received_content = data.content; + } } ZMQ_CLIENT_REGISTER_PROCESSOR(test_response_t) { - g_client_counter.increment(); - { - std::lock_guard lock(g_content_mutex); - g_last_received_content = data.result; - } + g_client_counter.increment(); + { + std::lock_guard lock(g_content_mutex); + g_last_received_content = data.result; + } } // ============================================================================ @@ -83,229 +83,229 @@ ZMQ_CLIENT_REGISTER_PROCESSOR(test_response_t) { class ZmqRpcTest : public ::testing::Test { protected: - void SetUp() override { - // 重置计数器 - g_server_counter.reset(); - g_client_counter.reset(); - g_last_received_content.clear(); - - // 清理客户端状态,确保每个测试独立 - auto& client = zmq_client::instance(); - if (client.is_connected()) { - client.disconnect(); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - } - } - - void TearDown() override { - // 清理可能未关闭的连接 - auto& client = zmq_client::instance(); - if (client.is_connected()) { - client.disconnect(); - } - } + void SetUp() override { + // 重置计数器 + g_server_counter.reset(); + g_client_counter.reset(); + g_last_received_content.clear(); + + // 清理客户端状态,确保每个测试独立 + auto& client = zmq_client::instance(); + if (client.is_connected()) { + client.disconnect(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + } + } + + void TearDown() override { + // 清理可能未关闭的连接 + auto& client = zmq_client::instance(); + if (client.is_connected()) { + client.disconnect(); + } + } }; /** * @brief 测试服务器初始化成功 */ TEST_F(ZmqRpcTest, ServerInitialization) { - auto& server = zmq_server::instance(); - - ASSERT_NO_THROW({ - server.init(); - }); - - EXPECT_TRUE(server.is_running()); - EXPECT_EQ(server.get_state(), zmq_server::state::RUNNING); - EXPECT_EQ(server.client_count(), 0); + auto& server = zmq_server::instance(); + + ASSERT_NO_THROW({ + server.init(); + }); + + EXPECT_TRUE(server.is_running()); + EXPECT_EQ(server.get_state(), zmq_server::state::RUNNING); + EXPECT_EQ(server.client_count(), 0); } /** * @brief 测试客户端连接成功 */ TEST_F(ZmqRpcTest, ClientConnection) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 12345; - - ASSERT_NO_THROW({ - client.init(client_id); - }); - - EXPECT_TRUE(client.is_connected()); - EXPECT_EQ(client.get_state(), zmq_client::state::CONNECTED); - - // 给一点时间让连接建立 - std::this_thread::sleep_for(std::chrono::milliseconds(100)); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 12345; + + ASSERT_NO_THROW({ + client.init(client_id); + }); + + EXPECT_TRUE(client.is_connected()); + EXPECT_EQ(client.get_state(), zmq_client::state::CONNECTED); + + // 给一点时间让连接建立 + std::this_thread::sleep_for(std::chrono::milliseconds(100)); } /** * @brief 测试客户端向服务器发送消息 */ TEST_F(ZmqRpcTest, ClientToServerMessage) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 10001; - client.init(client_id); - - // 等待连接建立 - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // 发送测试消息 - test_message_t msg; - msg.id = 1; - msg.content = "Hello from client"; - - client.send(msg); - - // 在服务器端接收消息(需要在单独的线程中运行) - std::thread server_thread([&server]() { - server.recv(); - }); + auto& server = zmq_server::instance(); + server.init(); - // 等待处理器被调用 - ASSERT_TRUE(g_server_counter.wait_for(1)); - - // 验证消息内容 - { - std::lock_guard lock(g_content_mutex); - EXPECT_EQ(g_last_received_content, "Hello from client"); - } - - server_thread.join(); - - // 验证客户端已在服务器注册 - EXPECT_TRUE(server.has_client(client_id)); + auto& client = zmq_client::instance(); + const uint32_t client_id = 10001; + client.init(client_id); + + // 等待连接建立 + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // 发送测试消息 + test_message_t msg; + msg.id = 1; + msg.content = "Hello from client"; + + client.send(msg); + + // 在服务器端接收消息(需要在单独的线程中运行) + std::thread server_thread([&server]() { + server.recv(); + }); + + // 等待处理器被调用 + ASSERT_TRUE(g_server_counter.wait_for(1)); + + // 验证消息内容 + { + std::lock_guard lock(g_content_mutex); + EXPECT_EQ(g_last_received_content, "Hello from client"); + } + + server_thread.join(); + + // 验证客户端已在服务器注册 + EXPECT_TRUE(server.has_client(client_id)); } /** * @brief 测试服务器向客户端发送消息 */ TEST_F(ZmqRpcTest, ServerToClientMessage) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 10002; - client.init(client_id); - - // 先让客户端向服务器发送消息以建立连接 - test_message_t hello_msg; - hello_msg.id = 1; - hello_msg.content = "Hello"; - client.send(hello_msg); - - std::thread server_thread1([&server]() { - server.recv(); - }); - server_thread1.join(); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // 现在服务器向客户端发送消息 - test_response_t response; - response.request_id = 1; - response.success = true; - response.result = "Response from server"; - - server.send(client_id, response); - - // 客户端接收消息 - std::thread client_thread([&client]() { - client.recv(); - }); - - // 等待客户端处理器被调用 - ASSERT_TRUE(g_client_counter.wait_for(1)); - - // 验证消息内容 - { - std::lock_guard lock(g_content_mutex); - EXPECT_EQ(g_last_received_content, "Response from server"); - } - - client_thread.join(); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 10002; + client.init(client_id); + + // 先让客户端向服务器发送消息以建立连接 + test_message_t hello_msg; + hello_msg.id = 1; + hello_msg.content = "Hello"; + client.send(hello_msg); + + std::thread server_thread1([&server]() { + server.recv(); + }); + server_thread1.join(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // 现在服务器向客户端发送消息 + test_response_t response; + response.request_id = 1; + response.success = true; + response.result = "Response from server"; + + server.send(client_id, response); + + // 客户端接收消息 + std::thread client_thread([&client]() { + client.recv(); + }); + + // 等待客户端处理器被调用 + ASSERT_TRUE(g_client_counter.wait_for(1)); + + // 验证消息内容 + { + std::lock_guard lock(g_content_mutex); + EXPECT_EQ(g_last_received_content, "Response from server"); + } + + client_thread.join(); } /** * @brief 测试多个客户端连接 */ TEST_F(ZmqRpcTest, MultipleClients) { - auto& server = zmq_server::instance(); - server.init(); - - // 注意:由于 zmq_client 是单例,这里我们模拟多客户端的场景 - // 实际项目中,每个客户端应该是独立的实例 - - const uint32_t client_id_1 = 20001; - const uint32_t client_id_2 = 20002; - - // 第一个客户端 - { - auto& client1 = zmq_client::instance(); - client1.init(client_id_1); - - test_message_t msg1; - msg1.id = 1; - msg1.content = "Client 1"; - client1.send(msg1); - - std::thread([&server]() { server.recv(); }).join(); - } - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // 第二个客户端(需要断开并重新连接,因为是单例) - { - auto& client2 = zmq_client::instance(); - client2.disconnect(); - std::this_thread::sleep_for(std::chrono::milliseconds(50)); - - client2.init(client_id_2); - - test_message_t msg2; - msg2.id = 2; - msg2.content = "Client 2"; - client2.send(msg2); - - std::thread([&server]() { server.recv(); }).join(); - } - - // 验证两个客户端都已注册 - EXPECT_TRUE(server.has_client(client_id_1)); - EXPECT_TRUE(server.has_client(client_id_2)); - EXPECT_GE(server.client_count(), 1); // 至少有一个客户端 + auto& server = zmq_server::instance(); + server.init(); + + // 注意:由于 zmq_client 是单例,这里我们模拟多客户端的场景 + // 实际项目中,每个客户端应该是独立的实例 + + const uint32_t client_id_1 = 20001; + const uint32_t client_id_2 = 20002; + + // 第一个客户端 + { + auto& client1 = zmq_client::instance(); + client1.init(client_id_1); + + test_message_t msg1; + msg1.id = 1; + msg1.content = "Client 1"; + client1.send(msg1); + + std::thread([&server]() { server.recv(); }).join(); + } + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // 第二个客户端(需要断开并重新连接,因为是单例) + { + auto& client2 = zmq_client::instance(); + client2.disconnect(); + std::this_thread::sleep_for(std::chrono::milliseconds(50)); + + client2.init(client_id_2); + + test_message_t msg2; + msg2.id = 2; + msg2.content = "Client 2"; + client2.send(msg2); + + std::thread([&server]() { server.recv(); }).join(); + } + + // 验证两个客户端都已注册 + EXPECT_TRUE(server.has_client(client_id_1)); + EXPECT_TRUE(server.has_client(client_id_2)); + EXPECT_GE(server.client_count(), 1); // 至少有一个客户端 } /** * @brief 测试客户端移除 */ TEST_F(ZmqRpcTest, ClientRemoval) { - auto& server = zmq_server::instance(); - server.init(); - - const uint32_t client_id = 30001; - auto& client = zmq_client::instance(); - client.init(client_id); - - // 建立连接 - test_message_t msg; - msg.id = 1; - msg.content = "Test"; - client.send(msg); - - std::thread([&server]() { server.recv(); }).join(); - - EXPECT_TRUE(server.has_client(client_id)); - - // 移除客户端 - server.remove_client(client_id); - EXPECT_FALSE(server.has_client(client_id)); + auto& server = zmq_server::instance(); + server.init(); + + const uint32_t client_id = 30001; + auto& client = zmq_client::instance(); + client.init(client_id); + + // 建立连接 + test_message_t msg; + msg.id = 1; + msg.content = "Test"; + client.send(msg); + + std::thread([&server]() { server.recv(); }).join(); + + EXPECT_TRUE(server.has_client(client_id)); + + // 移除客户端 + server.remove_client(client_id); + EXPECT_FALSE(server.has_client(client_id)); } // ============================================================================ @@ -316,68 +316,69 @@ TEST_F(ZmqRpcTest, ClientRemoval) { * @brief 测试 engine_rpc::log_message_t 处理 */ TEST_F(ZmqRpcTest, EngineRpcLogMessage) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 40001; - client.init(client_id); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // 发送日志消息 - engine_rpc::log_message_t log_msg; - log_msg.level = 2; // INFO 级别 - log_msg.message = "Test log message from sandbox_host"; - - client.send(log_msg); - - // 服务器接收并处理 - std::thread([&server]() { - server.recv(); - }).join(); - - // 验证消息被处理(通过日志输出,这里我们只验证不抛出异常) - SUCCEED(); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 40001; + client.init(client_id); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // 发送日志消息 + engine_rpc::log_message_t log_msg; + log_msg.level = 2; // INFO 级别 + log_msg.message = "Test log message from sandbox_host"; + + client.send(log_msg); + + // 服务器接收并处理 + std::thread([&server]() { + server.recv(); + }).join(); + + // 验证消息被处理(通过日志输出,这里我们只验证不抛出异常) + SUCCEED(); } /** * @brief 测试 host_rpc::setup_t 处理 */ TEST_F(ZmqRpcTest, HostRpcSetup) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 40002; - client.init(client_id); - - // 建立连接 - test_message_t hello; - hello.id = 1; - hello.content = "Hello"; - client.send(hello); - std::thread([&server]() { server.recv(); }).join(); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // 服务器向客户端发送 setup 消息 - host_rpc::setup_t setup; - setup.shm_name = "test_shared_memory_segment"; - - server.send(client_id, setup); - - // 客户端接收并处理(注意:setup_t 处理器会尝试初始化共享内存) - // 这可能会失败,因为共享内存可能不存在,但我们主要测试消息传输 - std::thread([&client]() { - try { - client.recv(); - } catch (const std::exception&) { - // 预期可能失败,因为共享内存不存在 - } - }).join(); - - SUCCEED(); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 40002; + client.init(client_id); + + // 建立连接 + test_message_t hello; + hello.id = 1; + hello.content = "Hello"; + client.send(hello); + std::thread([&server]() { server.recv(); }).join(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // 服务器向客户端发送 setup 消息 + host_rpc::setup_t setup; + setup.shm_name = "test_shared_memory_segment"; + + server.send(client_id, setup); + + // 客户端接收并处理(注意:setup_t 处理器会尝试初始化共享内存) + // 这可能会失败,因为共享内存可能不存在,但我们主要测试消息传输 + std::thread([&client]() { + try { + client.recv(); + } + catch (const std::exception&) { + // 预期可能失败,因为共享内存不存在 + } + }).join(); + + SUCCEED(); } // ============================================================================ @@ -388,102 +389,102 @@ TEST_F(ZmqRpcTest, HostRpcSetup) { * @brief 测试向未连接的客户端发送消息 */ TEST_F(ZmqRpcTest, SendToNonExistentClient) { - auto& server = zmq_server::instance(); - server.init(); - - const uint32_t fake_client_id = 99999; - - // 尝试向不存在的客户端发送消息(应该记录错误但不崩溃) - test_response_t response; - response.request_id = 1; - response.success = false; - response.result = "Error"; - - ASSERT_NO_THROW({ - server.send(fake_client_id, response); - }); - - EXPECT_FALSE(server.has_client(fake_client_id)); + auto& server = zmq_server::instance(); + server.init(); + + const uint32_t fake_client_id = 99999; + + // 尝试向不存在的客户端发送消息(应该记录错误但不崩溃) + test_response_t response; + response.request_id = 1; + response.success = false; + response.result = "Error"; + + ASSERT_NO_THROW({ + server.send(fake_client_id, response); + }); + + EXPECT_FALSE(server.has_client(fake_client_id)); } /** * @brief 测试未注册的 func_id 处理 */ TEST_F(ZmqRpcTest, UnregisteredFuncId) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 50001; - client.init(client_id); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - // 手动构造一个具有未注册 func_id 的消息 - struct unknown_message_t { - int data; - }; - - unknown_message_t unknown_msg; - unknown_msg.data = 12345; - - // 发送消息 - client.send(unknown_msg); - - // 服务器接收(应该记录错误但不崩溃) - ASSERT_NO_THROW({ - std::thread([&server]() { - server.recv(); - }).join(); - }); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 50001; + client.init(client_id); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + // 手动构造一个具有未注册 func_id 的消息 + struct unknown_message_t { + int data; + }; + + unknown_message_t unknown_msg; + unknown_msg.data = 12345; + + // 发送消息 + client.send(unknown_msg); + + // 服务器接收(应该记录错误但不崩溃) + ASSERT_NO_THROW({ + std::thread([&server]() { + server.recv(); + }).join(); + }); } /** * @brief 测试客户端断开连接 */ TEST_F(ZmqRpcTest, ClientDisconnect) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 60001; - client.init(client_id); - - EXPECT_TRUE(client.is_connected()); - - // 断开连接 - client.disconnect(); - - EXPECT_FALSE(client.is_connected()); - EXPECT_EQ(client.get_state(), zmq_client::state::DISCONNECTED); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 60001; + client.init(client_id); + + EXPECT_TRUE(client.is_connected()); + + // 断开连接 + client.disconnect(); + + EXPECT_FALSE(client.is_connected()); + EXPECT_EQ(client.get_state(), zmq_client::state::DISCONNECTED); } /** * @brief 测试客户端重连 */ TEST_F(ZmqRpcTest, ClientReconnect) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 60002; - - // 初始连接 - client.init(client_id); - EXPECT_TRUE(client.is_connected()); - - // 断开 - client.disconnect(); - EXPECT_FALSE(client.is_connected()); - - // 重连 - bool reconnect_success = client.reconnect(); - - if (reconnect_success) { - EXPECT_TRUE(client.is_connected()); - EXPECT_EQ(client.get_state(), zmq_client::state::CONNECTED); - } - // 如果重连失败也可以接受,取决于具体实现 + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 60002; + + // 初始连接 + client.init(client_id); + EXPECT_TRUE(client.is_connected()); + + // 断开 + client.disconnect(); + EXPECT_FALSE(client.is_connected()); + + // 重连 + bool reconnect_success = client.reconnect(); + + if (reconnect_success) { + EXPECT_TRUE(client.is_connected()); + EXPECT_EQ(client.get_state(), zmq_client::state::CONNECTED); + } + // 如果重连失败也可以接受,取决于具体实现 } // ============================================================================ @@ -494,48 +495,48 @@ TEST_F(ZmqRpcTest, ClientReconnect) { * @brief 测试消息序列化和反序列化 */ TEST_F(ZmqRpcTest, MessageSerialization) { - // 测试 zmq_message_pack 的创建 - test_message_t original_msg; - original_msg.id = 12345; - original_msg.content = "Test serialization"; - - // 序列化 - auto serialized = zmq_message_pack::create(original_msg); - - EXPECT_GT(serialized.size(), 0); - - // 反序列化外层包装 - zmq::message_t zmq_msg(serialized.data(), serialized.size()); - auto pack = zmq_message_pack::deserialize(zmq_msg); - - EXPECT_GT(pack.func_id, 0); - EXPECT_GT(pack.payload.size(), 0); - - // 反序列化内部数据 - auto result = struct_pack::deserialize( - pack.payload.data(), pack.payload.size()); - - ASSERT_TRUE(result.has_value()); - EXPECT_EQ(result.value().id, original_msg.id); - EXPECT_EQ(result.value().content, original_msg.content); + // 测试 zmq_message_pack 的创建 + test_message_t original_msg; + original_msg.id = 12345; + original_msg.content = "Test serialization"; + + // 序列化 + auto serialized = zmq_message_pack::create(original_msg); + + EXPECT_GT(serialized.size(), 0); + + // 反序列化外层包装 + zmq::message_t zmq_msg(serialized.data(), serialized.size()); + auto pack = zmq_message_pack::deserialize(zmq_msg); + + EXPECT_GT(pack.func_id, 0); + EXPECT_GT(pack.payload.size(), 0); + + // 反序列化内部数据 + auto result = struct_pack::deserialize( + pack.payload.data(), pack.payload.size()); + + ASSERT_TRUE(result.has_value()); + EXPECT_EQ(result.value().id, original_msg.id); + EXPECT_EQ(result.value().content, original_msg.content); } /** * @brief 测试不同类型消息的 type_id */ TEST_F(ZmqRpcTest, TypeIdUniqueness) { - // 验证不同类型有不同的 type_id - uint32_t log_msg_id = alicho_type_id_v; - uint32_t setup_id = alicho_type_id_v; - uint32_t test_msg_id = alicho_type_id_v; - - EXPECT_NE(log_msg_id, setup_id); - EXPECT_NE(log_msg_id, test_msg_id); - EXPECT_NE(setup_id, test_msg_id); - - // 同一类型应该有相同的 type_id - uint32_t log_msg_id_2 = alicho_type_id_v; - EXPECT_EQ(log_msg_id, log_msg_id_2); + // 验证不同类型有不同的 type_id + uint32_t log_msg_id = alicho_type_id_v; + uint32_t setup_id = alicho_type_id_v; + uint32_t test_msg_id = alicho_type_id_v; + + EXPECT_NE(log_msg_id, setup_id); + EXPECT_NE(log_msg_id, test_msg_id); + EXPECT_NE(setup_id, test_msg_id); + + // 同一类型应该有相同的 type_id + uint32_t log_msg_id_2 = alicho_type_id_v; + EXPECT_EQ(log_msg_id, log_msg_id_2); } // ============================================================================ @@ -546,162 +547,162 @@ TEST_F(ZmqRpcTest, TypeIdUniqueness) { * @brief 测试消息吞吐量 */ TEST_F(ZmqRpcTest, MessageThroughput) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 70001; - client.init(client_id); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - const int num_messages = 100; - - // [诊断日志] 记录开始时间戳(纳秒精度) - auto start_time = std::chrono::high_resolution_clock::now(); - auto start_ns = std::chrono::time_point_cast(start_time).time_since_epoch().count(); - std::cout << "[诊断] 测试开始时间戳(纳秒): " << start_ns << std::endl; - - // 发送多条消息 - for (int i = 0; i < num_messages; ++i) { - test_message_t msg; - msg.id = i; - msg.content = "Message " + std::to_string(i); - client.send(msg); - } - - // [诊断日志] 发送完成时间 - auto send_complete_time = std::chrono::high_resolution_clock::now(); - auto send_duration_us = std::chrono::duration_cast( - send_complete_time - start_time).count(); - std::cout << "[诊断] 发送 " << num_messages << " 条消息耗时: " - << send_duration_us << " 微秒" << std::endl; - - // 在服务器端接收所有消息 - std::thread server_thread([&server, num_messages]() { - auto thread_start = std::chrono::high_resolution_clock::now(); - for (int i = 0; i < num_messages; ++i) { - server.recv(); - } - auto thread_end = std::chrono::high_resolution_clock::now(); - auto thread_duration_us = std::chrono::duration_cast( - thread_end - thread_start).count(); - std::cout << "[诊断] 服务器线程接收耗时: " << thread_duration_us << " 微秒" << std::endl; - }); - - // 等待所有消息被处理 - auto wait_start = std::chrono::high_resolution_clock::now(); - ASSERT_TRUE(g_server_counter.wait_for(num_messages, - std::chrono::seconds(10))); - auto wait_end = std::chrono::high_resolution_clock::now(); - auto wait_duration_us = std::chrono::duration_cast( - wait_end - wait_start).count(); - std::cout << "[诊断] 等待处理完成耗时: " << wait_duration_us << " 微秒" << std::endl; - - auto end_time = std::chrono::high_resolution_clock::now(); - auto end_ns = std::chrono::time_point_cast(end_time).time_since_epoch().count(); - std::cout << "[诊断] 测试结束时间戳(纳秒): " << end_ns << std::endl; - std::cout << "[诊断] 总时间差(纳秒): " << (end_ns - start_ns) << std::endl; - - // 使用微秒精度计算 - auto duration_us = std::chrono::duration_cast( - end_time - start_time).count(); - auto duration_ms = std::chrono::duration_cast( - end_time - start_time).count(); - - std::cout << "[诊断] 微秒精度耗时: " << duration_us << " 微秒" << std::endl; - std::cout << "[诊断] 毫秒精度耗时: " << duration_ms << " 毫秒(截断)" << std::endl; - - server_thread.join(); - - std::cout << "处理 " << num_messages << " 条消息耗时: " - << duration_ms << " ms" << std::endl; - - // 使用微秒精度计算吞吐量 - if (duration_us > 0) { - std::cout << "平均吞吐量: " - << (num_messages * 1000000.0 / duration_us) << " msg/s" << std::endl; - } else { - std::cout << "平均吞吐量: 耗时太短无法测量" << std::endl; - } - - // 修改断言:使用微秒精度,期望至少大于0微秒 - EXPECT_GT(duration_us, 0); + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 70001; + client.init(client_id); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + const int num_messages = 100; + + // [诊断日志] 记录开始时间戳(纳秒精度) + auto start_time = std::chrono::high_resolution_clock::now(); + auto start_ns = std::chrono::time_point_cast(start_time).time_since_epoch().count(); + std::cout << "[诊断] 测试开始时间戳(纳秒): " << start_ns << std::endl; + + // 发送多条消息 + for (int i = 0; i < num_messages; ++i) { + test_message_t msg; + msg.id = i; + msg.content = "Message " + std::to_string(i); + client.send(msg); + } + + // [诊断日志] 发送完成时间 + auto send_complete_time = std::chrono::high_resolution_clock::now(); + auto send_duration_us = std::chrono::duration_cast( + send_complete_time - start_time).count(); + std::cout << "[诊断] 发送 " << num_messages << " 条消息耗时: " + << send_duration_us << " 微秒" << std::endl; + + // 在服务器端接收所有消息 + std::thread server_thread([&server, num_messages]() { + auto thread_start = std::chrono::high_resolution_clock::now(); + for (int i = 0; i < num_messages; ++i) { + server.recv(); + } + auto thread_end = std::chrono::high_resolution_clock::now(); + auto thread_duration_us = std::chrono::duration_cast( + thread_end - thread_start).count(); + std::cout << "[诊断] 服务器线程接收耗时: " << thread_duration_us << " 微秒" << std::endl; + }); + + // 等待所有消息被处理 + auto wait_start = std::chrono::high_resolution_clock::now(); + ASSERT_TRUE(g_server_counter.wait_for(num_messages, + std::chrono::seconds(10))); + auto wait_end = std::chrono::high_resolution_clock::now(); + auto wait_duration_us = std::chrono::duration_cast( + wait_end - wait_start).count(); + std::cout << "[诊断] 等待处理完成耗时: " << wait_duration_us << " 微秒" << std::endl; + + auto end_time = std::chrono::high_resolution_clock::now(); + auto end_ns = std::chrono::time_point_cast(end_time).time_since_epoch().count(); + std::cout << "[诊断] 测试结束时间戳(纳秒): " << end_ns << std::endl; + std::cout << "[诊断] 总时间差(纳秒): " << (end_ns - start_ns) << std::endl; + + // 使用微秒精度计算 + auto duration_us = std::chrono::duration_cast( + end_time - start_time).count(); + auto duration_ms = std::chrono::duration_cast( + end_time - start_time).count(); + + std::cout << "[诊断] 微秒精度耗时: " << duration_us << " 微秒" << std::endl; + std::cout << "[诊断] 毫秒精度耗时: " << duration_ms << " 毫秒(截断)" << std::endl; + + server_thread.join(); + + std::cout << "处理 " << num_messages << " 条消息耗时: " << duration_ms << " ms" << std::endl; + + // 使用微秒精度计算吞吐量 + if (duration_us > 0) { + std::cout << "平均吞吐量: " + << (num_messages * 1000000.0 / duration_us) << " msg/s" << std::endl; + } + else { + std::cout << "平均吞吐量: 耗时太短无法测量" << std::endl; + } + + // 修改断言:使用微秒精度,期望至少大于0微秒 + EXPECT_GT(duration_us, 0); } /** * @brief 测试往返延迟 (Round-Trip Time) */ TEST_F(ZmqRpcTest, RoundTripLatency) { - auto& server = zmq_server::instance(); - server.init(); - - auto& client = zmq_client::instance(); - const uint32_t client_id = 70002; - client.init(client_id); - - // 建立连接 - test_message_t hello; - hello.id = 0; - hello.content = "Init"; - client.send(hello); - std::thread([&server]() { server.recv(); }).join(); - - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - - const int num_iterations = 10; - std::vector latencies; - - for (int i = 0; i < num_iterations; ++i) { - g_client_counter.reset(); - - auto start = std::chrono::high_resolution_clock::now(); - - // 客户端发送消息 - test_message_t request; - request.id = i; - request.content = "Ping"; - client.send(request); - - // 服务器接收并回复 - std::thread server_thread([&server, client_id, i]() { - server.recv(); - - test_response_t response; - response.request_id = i; - response.success = true; - response.result = "Pong"; - server.send(client_id, response); - }); - - // 客户端接收回复 - std::thread client_thread([&client]() { - client.recv(); - }); - - // 等待响应 - ASSERT_TRUE(g_client_counter.wait_for(1, - std::chrono::seconds(5))); - - auto end = std::chrono::high_resolution_clock::now(); - double latency = std::chrono::duration(end - start).count(); - latencies.push_back(latency); - - server_thread.join(); - client_thread.join(); - } - - // 计算平均延迟 - double total_latency = 0; - for (double lat : latencies) { - total_latency += lat; - } - double avg_latency = total_latency / num_iterations; - - std::cout << "平均往返延迟: " << avg_latency << " ms" << std::endl; - - EXPECT_GT(avg_latency, 0); - EXPECT_LT(avg_latency, 1000); // 应该小于1秒 + auto& server = zmq_server::instance(); + server.init(); + + auto& client = zmq_client::instance(); + const uint32_t client_id = 70002; + client.init(client_id); + + // 建立连接 + test_message_t hello; + hello.id = 0; + hello.content = "Init"; + client.send(hello); + std::thread([&server]() { server.recv(); }).join(); + + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + + const int num_iterations = 10; + std::vector latencies; + + for (int i = 0; i < num_iterations; ++i) { + g_client_counter.reset(); + + auto start = std::chrono::high_resolution_clock::now(); + + // 客户端发送消息 + test_message_t request; + request.id = i; + request.content = "Ping"; + client.send(request); + + // 服务器接收并回复 + std::thread server_thread([&server, client_id, i]() { + server.recv(); + + test_response_t response; + response.request_id = i; + response.success = true; + response.result = "Pong"; + server.send(client_id, response); + }); + + // 客户端接收回复 + std::thread client_thread([&client]() { + client.recv(); + }); + + // 等待响应 + ASSERT_TRUE(g_client_counter.wait_for(1, + std::chrono::seconds(5))); + + auto end = std::chrono::high_resolution_clock::now(); + double latency = std::chrono::duration(end - start).count(); + latencies.push_back(latency); + + server_thread.join(); + client_thread.join(); + } + + // 计算平均延迟 + double total_latency = 0; + for (double lat : latencies) { + total_latency += lat; + } + double avg_latency = total_latency / num_iterations; + + std::cout << "平均往返延迟: " << avg_latency << " ms" << std::endl; + + EXPECT_GT(avg_latency, 0); + EXPECT_LT(avg_latency, 1000); // 应该小于1秒 } // ============================================================================ @@ -709,6 +710,6 @@ TEST_F(ZmqRpcTest, RoundTripLatency) { // ============================================================================ int main(int argc, char** argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); -} \ No newline at end of file + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/tests/shm/helpers/custom_assertions.h b/tests/shm/helpers/custom_assertions.h index d69327d..6364562 100644 --- a/tests/shm/helpers/custom_assertions.h +++ b/tests/shm/helpers/custom_assertions.h @@ -12,158 +12,157 @@ */ namespace shm_test { + // ============================================================================ + // 共享内存操作断言 + // ============================================================================ -// ============================================================================ -// 共享内存操作断言 -// ============================================================================ - -/** - * @brief 验证共享内存操作成功(非致命) - * - * 检查共享内存错误码是否为 SUCCESS - * - * @param expr 返回 shared_memory_error 的表达式 - * - * @example - * EXPECT_SHM_SUCCESS(manager.init(config)); - */ -#define EXPECT_SHM_SUCCESS(expr) \ + /** + * @brief 验证共享内存操作成功(非致命) + * + * 检查共享内存错误码是否为 SUCCESS + * + * @param expr 返回 shared_memory_error 的表达式 + * + * @example + * EXPECT_SHM_SUCCESS(manager.init(config)); + */ + #define EXPECT_SHM_SUCCESS(expr) \ EXPECT_EQ(expr, shared_memory_error::SUCCESS) \ << "共享内存操作失败: " #expr -/** - * @brief 断言共享内存操作成功(致命) - * - * 检查共享内存错误码是否为 SUCCESS,失败时终止测试 - * - * @param expr 返回 shared_memory_error 的表达式 - * - * @example - * ASSERT_SHM_SUCCESS(manager.init(config)); - */ -#define ASSERT_SHM_SUCCESS(expr) \ + /** + * @brief 断言共享内存操作成功(致命) + * + * 检查共享内存错误码是否为 SUCCESS,失败时终止测试 + * + * @param expr 返回 shared_memory_error 的表达式 + * + * @example + * ASSERT_SHM_SUCCESS(manager.init(config)); + */ + #define ASSERT_SHM_SUCCESS(expr) \ ASSERT_EQ(expr, shared_memory_error::SUCCESS) \ << "共享内存操作失败: " #expr -/** - * @brief 验证共享内存操作失败(非致命) - * - * @param expr 返回 shared_memory_error 的表达式 - * @param expected_error 期望的错误码 - */ -#define EXPECT_SHM_ERROR(expr, expected_error) \ + /** + * @brief 验证共享内存操作失败(非致命) + * + * @param expr 返回 shared_memory_error 的表达式 + * @param expected_error 期望的错误码 + */ + #define EXPECT_SHM_ERROR(expr, expected_error) \ EXPECT_EQ(expr, expected_error) \ << "共享内存错误码不匹配: " #expr -/** - * @brief 验证共享内存指针非空(非致命) - * - * @param ptr 共享内存指针 - */ -#define EXPECT_SHM_NOT_NULL(ptr) \ + /** + * @brief 验证共享内存指针非空(非致命) + * + * @param ptr 共享内存指针 + */ + #define EXPECT_SHM_NOT_NULL(ptr) \ EXPECT_NE(ptr, nullptr) \ << "共享内存指针为空: " #ptr -/** - * @brief 断言共享内存指针非空(致命) - * - * @param ptr 共享内存指针 - */ -#define ASSERT_SHM_NOT_NULL(ptr) \ + /** + * @brief 断言共享内存指针非空(致命) + * + * @param ptr 共享内存指针 + */ + #define ASSERT_SHM_NOT_NULL(ptr) \ ASSERT_NE(ptr, nullptr) \ << "共享内存指针为空: " #ptr -// ============================================================================ -// 缓冲区状态断言 -// ============================================================================ + // ============================================================================ + // 缓冲区状态断言 + // ============================================================================ -/** - * @brief 验证缓冲区为空(非致命) - * - * 要求缓冲区对象有 empty() 或 is_empty() 方法 - * - * @param buffer 缓冲区对象 - * - * @example - * EXPECT_BUFFER_EMPTY(ring_buffer); - */ -#define EXPECT_BUFFER_EMPTY(buffer) \ + /** + * @brief 验证缓冲区为空(非致命) + * + * 要求缓冲区对象有 empty() 或 is_empty() 方法 + * + * @param buffer 缓冲区对象 + * + * @example + * EXPECT_BUFFER_EMPTY(ring_buffer); + */ + #define EXPECT_BUFFER_EMPTY(buffer) \ EXPECT_TRUE((buffer).empty()) \ << "缓冲区非空: " #buffer \ << ", 大小: " << (buffer).size() -/** - * @brief 验证缓冲区已满(非致命) - * - * 要求缓冲区对象有 full() 或 is_full() 方法 - * - * @param buffer 缓冲区对象 - * - * @example - * EXPECT_BUFFER_FULL(ring_buffer); - */ -#define EXPECT_BUFFER_FULL(buffer) \ + /** + * @brief 验证缓冲区已满(非致命) + * + * 要求缓冲区对象有 full() 或 is_full() 方法 + * + * @param buffer 缓冲区对象 + * + * @example + * EXPECT_BUFFER_FULL(ring_buffer); + */ + #define EXPECT_BUFFER_FULL(buffer) \ EXPECT_TRUE((buffer).full()) \ << "缓冲区未满: " #buffer \ << ", 大小: " << (buffer).size() \ << ", 容量: " << (buffer).capacity() -/** - * @brief 验证缓冲区大小(非致命) - * - * @param buffer 缓冲区对象 - * @param expected_size 期望的大小 - * - * @example - * EXPECT_BUFFER_SIZE(ring_buffer, 10); - */ -#define EXPECT_BUFFER_SIZE(buffer, expected_size) \ + /** + * @brief 验证缓冲区大小(非致命) + * + * @param buffer 缓冲区对象 + * @param expected_size 期望的大小 + * + * @example + * EXPECT_BUFFER_SIZE(ring_buffer, 10); + */ + #define EXPECT_BUFFER_SIZE(buffer, expected_size) \ EXPECT_EQ(buffer.size(), static_cast(expected_size)) \ << "缓冲区大小不匹配: " #buffer -/** - * @brief 断言缓冲区大小(致命) - * - * @param buffer 缓冲区对象 - * @param expected_size 期望的大小 - */ -#define ASSERT_BUFFER_SIZE(buffer, expected_size) \ + /** + * @brief 断言缓冲区大小(致命) + * + * @param buffer 缓冲区对象 + * @param expected_size 期望的大小 + */ + #define ASSERT_BUFFER_SIZE(buffer, expected_size) \ ASSERT_EQ(buffer.size(), static_cast(expected_size)) \ << "缓冲区大小不匹配: " #buffer -/** - * @brief 验证缓冲区非空(非致命) - * - * @param buffer 缓冲区对象 - */ -#define EXPECT_BUFFER_NOT_EMPTY(buffer) \ + /** + * @brief 验证缓冲区非空(非致命) + * + * @param buffer 缓冲区对象 + */ + #define EXPECT_BUFFER_NOT_EMPTY(buffer) \ EXPECT_FALSE((buffer).empty()) \ << "缓冲区为空: " #buffer -/** - * @brief 验证缓冲区容量 - * - * @param buffer 缓冲区对象 - * @param expected_capacity 期望的容量 - */ -#define EXPECT_BUFFER_CAPACITY(buffer, expected_capacity) \ + /** + * @brief 验证缓冲区容量 + * + * @param buffer 缓冲区对象 + * @param expected_capacity 期望的容量 + */ + #define EXPECT_BUFFER_CAPACITY(buffer, expected_capacity) \ EXPECT_EQ(buffer.capacity(), static_cast(expected_capacity)) \ << "缓冲区容量不匹配: " #buffer -// ============================================================================ -// 性能断言 -// ============================================================================ + // ============================================================================ + // 性能断言 + // ============================================================================ -/** - * @brief 验证吞吐量大于阈值(非致命) - * - * @param meter ThroughputMeter 对象 - * @param threshold 吞吐量阈值(ops/sec) - * - * @example - * EXPECT_THROUGHPUT_GT(meter, 1000.0); // 期望 > 1000 ops/sec - */ -#define EXPECT_THROUGHPUT_GT(meter, threshold) \ + /** + * @brief 验证吞吐量大于阈值(非致命) + * + * @param meter ThroughputMeter 对象 + * @param threshold 吞吐量阈值(ops/sec) + * + * @example + * EXPECT_THROUGHPUT_GT(meter, 1000.0); // 期望 > 1000 ops/sec + */ + #define EXPECT_THROUGHPUT_GT(meter, threshold) \ do { \ double actual_throughput = meter.get_throughput(); \ EXPECT_GT(actual_throughput, threshold) \ @@ -172,13 +171,13 @@ namespace shm_test { << ", 阈值: " << threshold << " ops/sec"; \ } while(0) -/** - * @brief 验证吞吐量大于等于阈值(非致命) - * - * @param meter ThroughputMeter 对象 - * @param threshold 吞吐量阈值(ops/sec) - */ -#define EXPECT_THROUGHPUT_GE(meter, threshold) \ + /** + * @brief 验证吞吐量大于等于阈值(非致命) + * + * @param meter ThroughputMeter 对象 + * @param threshold 吞吐量阈值(ops/sec) + */ + #define EXPECT_THROUGHPUT_GE(meter, threshold) \ do { \ double actual_throughput = meter.get_throughput(); \ EXPECT_GE(actual_throughput, threshold) \ @@ -187,18 +186,18 @@ namespace shm_test { << ", 阈值: " << threshold << " ops/sec"; \ } while(0) -/** - * @brief 验证延迟小于阈值(非致命) - * - * 检查平均延迟是否小于阈值 - * - * @param stats Statistics 对象 - * @param threshold_ns 延迟阈值(纳秒) - * - * @example - * EXPECT_LATENCY_LT(stats, 1000000.0); // 期望 < 1ms - */ -#define EXPECT_LATENCY_LT(stats, threshold_ns) \ + /** + * @brief 验证延迟小于阈值(非致命) + * + * 检查平均延迟是否小于阈值 + * + * @param stats Statistics 对象 + * @param threshold_ns 延迟阈值(纳秒) + * + * @example + * EXPECT_LATENCY_LT(stats, 1000000.0); // 期望 < 1ms + */ + #define EXPECT_LATENCY_LT(stats, threshold_ns) \ do { \ double avg_latency = stats.avg(); \ EXPECT_LT(avg_latency, threshold_ns) \ @@ -207,13 +206,13 @@ namespace shm_test { << ", 阈值: " << (threshold_ns / 1000000.0) << " ms"; \ } while(0) -/** - * @brief 验证延迟小于等于阈值(非致命) - * - * @param stats Statistics 对象 - * @param threshold_ns 延迟阈值(纳秒) - */ -#define EXPECT_LATENCY_LE(stats, threshold_ns) \ + /** + * @brief 验证延迟小于等于阈值(非致命) + * + * @param stats Statistics 对象 + * @param threshold_ns 延迟阈值(纳秒) + */ + #define EXPECT_LATENCY_LE(stats, threshold_ns) \ do { \ double avg_latency = stats.avg(); \ EXPECT_LE(avg_latency, threshold_ns) \ @@ -222,13 +221,13 @@ namespace shm_test { << ", 阈值: " << (threshold_ns / 1000000.0) << " ms"; \ } while(0) -/** - * @brief 验证 P95 延迟小于阈值(非致命) - * - * @param stats Statistics 对象 - * @param threshold_ns 延迟阈值(纳秒) - */ -#define EXPECT_P95_LATENCY_LT(stats, threshold_ns) \ + /** + * @brief 验证 P95 延迟小于阈值(非致命) + * + * @param stats Statistics 对象 + * @param threshold_ns 延迟阈值(纳秒) + */ + #define EXPECT_P95_LATENCY_LT(stats, threshold_ns) \ do { \ double p95_latency = stats.p95(); \ EXPECT_LT(p95_latency, threshold_ns) \ @@ -237,13 +236,13 @@ namespace shm_test { << ", 阈值: " << (threshold_ns / 1000000.0) << " ms"; \ } while(0) -/** - * @brief 验证 P99 延迟小于阈值(非致命) - * - * @param stats Statistics 对象 - * @param threshold_ns 延迟阈值(纳秒) - */ -#define EXPECT_P99_LATENCY_LT(stats, threshold_ns) \ + /** + * @brief 验证 P99 延迟小于阈值(非致命) + * + * @param stats Statistics 对象 + * @param threshold_ns 延迟阈值(纳秒) + */ + #define EXPECT_P99_LATENCY_LT(stats, threshold_ns) \ do { \ double p99_latency = stats.p99(); \ EXPECT_LT(p99_latency, threshold_ns) \ @@ -252,29 +251,29 @@ namespace shm_test { << ", 阈值: " << (threshold_ns / 1000000.0) << " ms"; \ } while(0) -// ============================================================================ -// 数据验证断言 -// ============================================================================ + // ============================================================================ + // 数据验证断言 + // ============================================================================ -/** - * @brief 验证数据序列正确(非致命) - * - * @param generator DataGenerator 对象 - * @param data 数据向量 - * @param start 起始值 - * @param step 步长 - */ -#define EXPECT_SEQUENCE_VALID(generator, data, start, step) \ + /** + * @brief 验证数据序列正确(非致命) + * + * @param generator DataGenerator 对象 + * @param data 数据向量 + * @param start 起始值 + * @param step 步长 + */ + #define EXPECT_SEQUENCE_VALID(generator, data, start, step) \ EXPECT_TRUE(generator.verify_sequence(data, start, step)) \ << "数据序列验证失败: " #data -/** - * @brief 验证校验和正确(非致命) - * - * @param generator DataGenerator 对象 - * @param data 带校验和的数据向量 - */ -#define EXPECT_CHECKSUMS_VALID(generator, data) \ + /** + * @brief 验证校验和正确(非致命) + * + * @param generator DataGenerator 对象 + * @param data 带校验和的数据向量 + */ + #define EXPECT_CHECKSUMS_VALID(generator, data) \ do { \ auto result = generator.verify_checksums(data); \ EXPECT_EQ(result.first, result.second) \ @@ -283,43 +282,43 @@ namespace shm_test { << ", 总数: " << result.second; \ } while(0) -/** - * @brief 验证无重复数据(非致命) - * - * @param generator DataGenerator 对象 - * @param data 数据向量 - */ -#define EXPECT_NO_DUPLICATES(generator, data) \ + /** + * @brief 验证无重复数据(非致命) + * + * @param generator DataGenerator 对象 + * @param data 数据向量 + */ + #define EXPECT_NO_DUPLICATES(generator, data) \ EXPECT_TRUE(generator.verify_no_duplicates(data)) \ << "检测到重复数据: " #data -/** - * @brief 验证两个数据集相同(非致命) - * - * @param generator DataGenerator 对象 - * @param data1 第一个数据集 - * @param data2 第二个数据集 - */ -#define EXPECT_DATA_EQUAL(generator, data1, data2) \ + /** + * @brief 验证两个数据集相同(非致命) + * + * @param generator DataGenerator 对象 + * @param data1 第一个数据集 + * @param data2 第二个数据集 + */ + #define EXPECT_DATA_EQUAL(generator, data1, data2) \ EXPECT_TRUE(generator.compare_data(data1, data2)) \ << "数据集不相等: " #data1 " vs " #data2 -// ============================================================================ -// 时间断言 -// ============================================================================ + // ============================================================================ + // 时间断言 + // ============================================================================ -/** - * @brief 验证操作在指定时间内完成(非致命) - * - * @param operation 要执行的操作(lambda 或函数) - * @param timeout_ms 超时时间(毫秒) - * - * @example - * EXPECT_COMPLETES_WITHIN([&]() { - * buffer.push(data); - * }, 100); // 期望在 100ms 内完成 - */ -#define EXPECT_COMPLETES_WITHIN(operation, timeout_ms) \ + /** + * @brief 验证操作在指定时间内完成(非致命) + * + * @param operation 要执行的操作(lambda 或函数) + * @param timeout_ms 超时时间(毫秒) + * + * @example + * EXPECT_COMPLETES_WITHIN([&]() { + * buffer.push(data); + * }, 100); // 期望在 100ms 内完成 + */ + #define EXPECT_COMPLETES_WITHIN(operation, timeout_ms) \ do { \ shm_test::ScopedTimer timer("", false); \ operation; \ @@ -331,15 +330,15 @@ namespace shm_test { << ", 限制: " << timeout_ms << " ms"; \ } while(0) -/** - * @brief 验证操作至少花费指定时间(非致命) - * - * 用于验证阻塞操作确实发生了阻塞 - * - * @param operation 要执行的操作 - * @param min_time_ms 最小时间(毫秒) - */ -#define EXPECT_TAKES_AT_LEAST(operation, min_time_ms) \ + /** + * @brief 验证操作至少花费指定时间(非致命) + * + * 用于验证阻塞操作确实发生了阻塞 + * + * @param operation 要执行的操作 + * @param min_time_ms 最小时间(毫秒) + */ + #define EXPECT_TAKES_AT_LEAST(operation, min_time_ms) \ do { \ shm_test::ScopedTimer timer("", false); \ operation; \ @@ -351,18 +350,18 @@ namespace shm_test { << ", 最小: " << min_time_ms << " ms"; \ } while(0) -// ============================================================================ -// 条件等待断言 -// ============================================================================ + // ============================================================================ + // 条件等待断言 + // ============================================================================ -/** - * @brief 等待条件满足(带超时) - * - * @param condition 条件表达式(返回 bool) - * @param timeout_ms 超时时间(毫秒) - * @param check_interval_ms 检查间隔(毫秒) - */ -#define EXPECT_CONDITION_EVENTUALLY(condition, timeout_ms, check_interval_ms) \ + /** + * @brief 等待条件满足(带超时) + * + * @param condition 条件表达式(返回 bool) + * @param timeout_ms 超时时间(毫秒) + * @param check_interval_ms 检查间隔(毫秒) + */ + #define EXPECT_CONDITION_EVENTUALLY(condition, timeout_ms, check_interval_ms) \ do { \ auto start = std::chrono::steady_clock::now(); \ bool condition_met = false; \ @@ -378,7 +377,6 @@ namespace shm_test { << "条件超时未满足: " #condition \ << ", 超时: " << timeout_ms << " ms"; \ } while(0) - } // namespace shm_test // ============================================================================ @@ -388,4 +386,4 @@ namespace shm_test { // 将常用的断言宏导出到全局命名空间(可选) // 如果不想污染全局命名空间,可以注释掉这部分 -using shm_test::TimeUnit; \ No newline at end of file +using shm_test::TimeUnit; diff --git a/tests/shm/helpers/data_generator.h b/tests/shm/helpers/data_generator.h index 3854ff4..6cdd7bd 100644 --- a/tests/shm/helpers/data_generator.h +++ b/tests/shm/helpers/data_generator.h @@ -11,386 +11,385 @@ #include "logger.h" namespace shm_test { + /** + * @brief 带校验和的数据结构 + * @tparam T 数据类型 + */ + template + struct ChecksumData { + T value; ///< 实际数据值 + uint64_t checksum; ///< 校验和(简单求和或 hash) + uint64_t sequence; ///< 序列号(用于验证顺序) -/** - * @brief 带校验和的数据结构 - * @tparam T 数据类型 - */ -template -struct ChecksumData { - T value; ///< 实际数据值 - uint64_t checksum; ///< 校验和(简单求和或 hash) - uint64_t sequence; ///< 序列号(用于验证顺序) + /** + * @brief 计算校验和 + * @note 使用简单的 FNV-1a hash 算法 + */ + void calculate_checksum() { + // FNV-1a hash + constexpr uint64_t FNV_OFFSET = 14695981039346656037ULL; + constexpr uint64_t FNV_PRIME = 1099511628211ULL; - /** - * @brief 计算校验和 - * @note 使用简单的 FNV-1a hash 算法 - */ - void calculate_checksum() { - // FNV-1a hash - constexpr uint64_t FNV_OFFSET = 14695981039346656037ULL; - constexpr uint64_t FNV_PRIME = 1099511628211ULL; - - checksum = FNV_OFFSET; - const uint8_t* bytes = reinterpret_cast(&value); - for (size_t i = 0; i < sizeof(T); ++i) { - checksum ^= bytes[i]; - checksum *= FNV_PRIME; - } - - // 混入序列号 - const uint8_t* seq_bytes = reinterpret_cast(&sequence); - for (size_t i = 0; i < sizeof(sequence); ++i) { - checksum ^= seq_bytes[i]; - checksum *= FNV_PRIME; - } - } + checksum = FNV_OFFSET; + const uint8_t* bytes = reinterpret_cast(&value); + for (size_t i = 0; i < sizeof(T); ++i) { + checksum ^= bytes[i]; + checksum *= FNV_PRIME; + } - /** - * @brief 验证校验和是否正确 - */ - bool verify_checksum() const { - ChecksumData temp = *this; - temp.calculate_checksum(); - return temp.checksum == checksum; - } -}; + // 混入序列号 + const uint8_t* seq_bytes = reinterpret_cast(&sequence); + for (size_t i = 0; i < sizeof(sequence); ++i) { + checksum ^= seq_bytes[i]; + checksum *= FNV_PRIME; + } + } -/** - * @brief 数据生成器模板类 - * - * 提供各种数据生成和验证功能: - * - 生成顺序数据(用于验证FIFO顺序) - * - 生成随机数据 - * - 生成带校验和的数据 - * - 验证数据顺序和完整性 - * - * @tparam T 数据类型 - * - * @example - * DataGenerator gen; - * auto seq_data = gen.generate_sequence(100); // 生成 0-99 - * auto random_data = gen.generate_random(100); // 生成 100 个随机数 - * auto checksum_data = gen.generate_with_checksum(100); // 带校验和的数据 - */ -template -class DataGenerator { -public: - /** - * @brief 构造函数 - * @param seed 随机数种子(默认使用随机设备) - */ - explicit DataGenerator(uint32_t seed = std::random_device{}()) - : gen_(seed) { - log_module_debug("DataGenerator", "创建数据生成器,种子: {}", seed); - } + /** + * @brief 验证校验和是否正确 + */ + bool verify_checksum() const { + ChecksumData temp = *this; + temp.calculate_checksum(); + return temp.checksum == checksum; + } + }; - /** - * @brief 生成顺序数据 - * @param count 数据数量 - * @param start_value 起始值 - * @param step 步长 - * @return 顺序数据向量 - */ - std::vector generate_sequence( - size_t count, - T start_value = T{0}, - T step = T{1} - ) { - std::vector data(count); - T current = start_value; - - for (size_t i = 0; i < count; ++i) { - data[i] = current; - current = static_cast(current + step); - } - - log_module_debug("DataGenerator", "生成顺序数据: {} 个元素", count); - return data; - } + /** + * @brief 数据生成器模板类 + * + * 提供各种数据生成和验证功能: + * - 生成顺序数据(用于验证FIFO顺序) + * - 生成随机数据 + * - 生成带校验和的数据 + * - 验证数据顺序和完整性 + * + * @tparam T 数据类型 + * + * @example + * DataGenerator gen; + * auto seq_data = gen.generate_sequence(100); // 生成 0-99 + * auto random_data = gen.generate_random(100); // 生成 100 个随机数 + * auto checksum_data = gen.generate_with_checksum(100); // 带校验和的数据 + */ + template + class DataGenerator { + public: + /** + * @brief 构造函数 + * @param seed 随机数种子(默认使用随机设备) + */ + explicit DataGenerator(uint32_t seed = std::random_device{}()) : gen_(seed) { + log_module_debug("DataGenerator", "创建数据生成器,种子: {}", seed); + } - /** - * @brief 生成随机数据 - * @param count 数据数量 - * @param min_value 最小值 - * @param max_value 最大值 - * @return 随机数据向量 - */ - std::vector generate_random( - size_t count, - T min_value = std::numeric_limits::min(), - T max_value = std::numeric_limits::max() - ) { - std::vector data(count); - - if constexpr (std::is_integral_v) { - std::uniform_int_distribution dist(min_value, max_value); - for (size_t i = 0; i < count; ++i) { - data[i] = dist(gen_); - } - } else if constexpr (std::is_floating_point_v) { - std::uniform_real_distribution dist(min_value, max_value); - for (size_t i = 0; i < count; ++i) { - data[i] = dist(gen_); - } - } else { - // 对于其他类型,使用默认构造 - for (size_t i = 0; i < count; ++i) { - data[i] = T{}; - } - } - - log_module_debug("DataGenerator", "生成随机数据: {} 个元素", count); - return data; - } + /** + * @brief 生成顺序数据 + * @param count 数据数量 + * @param start_value 起始值 + * @param step 步长 + * @return 顺序数据向量 + */ + std::vector generate_sequence( + size_t count, + T start_value = T{0}, + T step = T{1} + ) { + std::vector data(count); + T current = start_value; - /** - * @brief 生成带校验和的数据 - * @param count 数据数量 - * @param start_sequence 起始序列号 - * @return 带校验和的数据向量 - */ - std::vector> generate_with_checksum( - size_t count, - uint64_t start_sequence = 0 - ) { - std::vector> data(count); - auto values = generate_sequence(count); - - for (size_t i = 0; i < count; ++i) { - data[i].value = values[i]; - data[i].sequence = start_sequence + i; - data[i].calculate_checksum(); - } - - log_module_debug("DataGenerator", "生成带校验和数据: {} 个元素", count); - return data; - } + for (size_t i = 0; i < count; ++i) { + data[i] = current; + current = static_cast(current + step); + } - /** - * @brief 生成带校验和的随机数据 - * @param count 数据数量 - * @param start_sequence 起始序列号 - * @param min_value 最小值 - * @param max_value 最大值 - * @return 带校验和的随机数据向量 - */ - std::vector> generate_random_with_checksum( - size_t count, - uint64_t start_sequence = 0, - T min_value = std::numeric_limits::min(), - T max_value = std::numeric_limits::max() - ) { - std::vector> data(count); - auto values = generate_random(count, min_value, max_value); - - for (size_t i = 0; i < count; ++i) { - data[i].value = values[i]; - data[i].sequence = start_sequence + i; - data[i].calculate_checksum(); - } - - log_module_debug("DataGenerator", "生成带校验和随机数据: {} 个元素", count); - return data; - } + log_module_debug("DataGenerator", "生成顺序数据: {} 个元素", count); + return data; + } - /** - * @brief 验证数据顺序完整性 - * @param data 待验证的数据 - * @param expected_start 期望的起始值 - * @param expected_step 期望的步长 - * @return true 如果数据顺序正确 - */ - bool verify_sequence( - const std::vector& data, - T expected_start = T{0}, - T expected_step = T{1} - ) { - if (data.empty()) { - log_module_warn("DataGenerator", "验证空数据序列"); - return true; - } + /** + * @brief 生成随机数据 + * @param count 数据数量 + * @param min_value 最小值 + * @param max_value 最大值 + * @return 随机数据向量 + */ + std::vector generate_random( + size_t count, + T min_value = std::numeric_limits::min(), + T max_value = std::numeric_limits::max() + ) { + std::vector data(count); - T expected = expected_start; - for (size_t i = 0; i < data.size(); ++i) { - if (data[i] != expected) { - log_module_error("DataGenerator", - "序列验证失败: 索引 {} 期望 {} 实际 {}", - i, expected, data[i]); - return false; - } - expected = static_cast(expected + expected_step); - } - - log_module_debug("DataGenerator", "序列验证成功: {} 个元素", data.size()); - return true; - } + if constexpr (std::is_integral_v) { + std::uniform_int_distribution dist(min_value, max_value); + for (size_t i = 0; i < count; ++i) { + data[i] = dist(gen_); + } + } + else if constexpr (std::is_floating_point_v) { + std::uniform_real_distribution dist(min_value, max_value); + for (size_t i = 0; i < count; ++i) { + data[i] = dist(gen_); + } + } + else { + // 对于其他类型,使用默认构造 + for (size_t i = 0; i < count; ++i) { + data[i] = T{}; + } + } - /** - * @brief 验证带校验和数据的完整性 - * @param data 待验证的数据 - * @param verify_sequence_order 是否验证序列顺序 - * @param expected_start_seq 期望的起始序列号 - * @return 验证结果(成功数量,总数量) - */ - std::pair verify_checksums( - const std::vector>& data, - bool verify_sequence_order = true, - uint64_t expected_start_seq = 0 - ) { - if (data.empty()) { - log_module_warn("DataGenerator", "验证空校验和数据"); - return {0, 0}; - } + log_module_debug("DataGenerator", "生成随机数据: {} 个元素", count); + return data; + } - size_t success_count = 0; - uint64_t expected_seq = expected_start_seq; + /** + * @brief 生成带校验和的数据 + * @param count 数据数量 + * @param start_sequence 起始序列号 + * @return 带校验和的数据向量 + */ + std::vector> generate_with_checksum( + size_t count, + uint64_t start_sequence = 0 + ) { + std::vector> data(count); + auto values = generate_sequence(count); - for (size_t i = 0; i < data.size(); ++i) { - bool valid = true; + for (size_t i = 0; i < count; ++i) { + data[i].value = values[i]; + data[i].sequence = start_sequence + i; + data[i].calculate_checksum(); + } - // 验证校验和 - if (!data[i].verify_checksum()) { - log_module_error("DataGenerator", - "校验和验证失败: 索引 {} 序列号 {}", - i, data[i].sequence); - valid = false; - } + log_module_debug("DataGenerator", "生成带校验和数据: {} 个元素", count); + return data; + } - // 验证序列顺序 - if (verify_sequence_order && data[i].sequence != expected_seq) { - log_module_error("DataGenerator", - "序列号验证失败: 索引 {} 期望 {} 实际 {}", - i, expected_seq, data[i].sequence); - valid = false; - } + /** + * @brief 生成带校验和的随机数据 + * @param count 数据数量 + * @param start_sequence 起始序列号 + * @param min_value 最小值 + * @param max_value 最大值 + * @return 带校验和的随机数据向量 + */ + std::vector> generate_random_with_checksum( + size_t count, + uint64_t start_sequence = 0, + T min_value = std::numeric_limits::min(), + T max_value = std::numeric_limits::max() + ) { + std::vector> data(count); + auto values = generate_random(count, min_value, max_value); - if (valid) { - ++success_count; - } + for (size_t i = 0; i < count; ++i) { + data[i].value = values[i]; + data[i].sequence = start_sequence + i; + data[i].calculate_checksum(); + } - ++expected_seq; - } + log_module_debug("DataGenerator", "生成带校验和随机数据: {} 个元素", count); + return data; + } - log_module_debug("DataGenerator", - "校验和验证完成: {}/{} 成功", - success_count, data.size()); + /** + * @brief 验证数据顺序完整性 + * @param data 待验证的数据 + * @param expected_start 期望的起始值 + * @param expected_step 期望的步长 + * @return true 如果数据顺序正确 + */ + bool verify_sequence( + const std::vector& data, + T expected_start = T{0}, + T expected_step = T{1} + ) { + if (data.empty()) { + log_module_warn("DataGenerator", "验证空数据序列"); + return true; + } - return {success_count, data.size()}; - } + T expected = expected_start; + for (size_t i = 0; i < data.size(); ++i) { + if (data[i] != expected) { + log_module_error("DataGenerator", + "序列验证失败: 索引 {} 期望 {} 实际 {}", + i, expected, data[i]); + return false; + } + expected = static_cast(expected + expected_step); + } - /** - * @brief 验证数据是否包含重复项 - * @param data 待验证的数据 - * @return true 如果没有重复项 - */ - bool verify_no_duplicates(const std::vector& data) { - if (data.size() <= 1) { - return true; - } + log_module_debug("DataGenerator", "序列验证成功: {} 个元素", data.size()); + return true; + } - std::vector sorted = data; - std::sort(sorted.begin(), sorted.end()); - - auto it = std::adjacent_find(sorted.begin(), sorted.end()); - if (it != sorted.end()) { - log_module_error("DataGenerator", "发现重复数据值"); - return false; - } + /** + * @brief 验证带校验和数据的完整性 + * @param data 待验证的数据 + * @param verify_sequence_order 是否验证序列顺序 + * @param expected_start_seq 期望的起始序列号 + * @return 验证结果(成功数量,总数量) + */ + std::pair verify_checksums( + const std::vector>& data, + bool verify_sequence_order = true, + uint64_t expected_start_seq = 0 + ) { + if (data.empty()) { + log_module_warn("DataGenerator", "验证空校验和数据"); + return {0, 0}; + } - log_module_debug("DataGenerator", "无重复数据验证成功"); - return true; - } + size_t success_count = 0; + uint64_t expected_seq = expected_start_seq; - /** - * @brief 创建具有特定模式的数据 - * @param count 数据数量 - * @param pattern_func 模式函数(索引 -> 值) - * @return 按模式生成的数据向量 - */ - std::vector generate_pattern( - size_t count, - std::function pattern_func - ) { - std::vector data(count); - for (size_t i = 0; i < count; ++i) { - data[i] = pattern_func(i); - } - - log_module_debug("DataGenerator", "生成模式数据: {} 个元素", count); - return data; - } + for (size_t i = 0; i < data.size(); ++i) { + bool valid = true; - /** - * @brief 填充数据到缓冲区 - * @param buffer 目标缓冲区 - * @param data 源数据 - * @param offset 起始偏移 - */ - template - void fill_buffer(Buffer& buffer, const std::vector& data, size_t offset = 0) { - for (size_t i = 0; i < data.size(); ++i) { - buffer[offset + i] = data[i]; - } - } + // 验证校验和 + if (!data[i].verify_checksum()) { + log_module_error("DataGenerator", + "校验和验证失败: 索引 {} 序列号 {}", + i, data[i].sequence); + valid = false; + } - /** - * @brief 比较两个数据集 - * @param data1 第一个数据集 - * @param data2 第二个数据集 - * @return true 如果数据相同 - */ - bool compare_data( - const std::vector& data1, - const std::vector& data2 - ) { - if (data1.size() != data2.size()) { - log_module_error("DataGenerator", - "数据大小不匹配: {} vs {}", - data1.size(), data2.size()); - return false; - } + // 验证序列顺序 + if (verify_sequence_order && data[i].sequence != expected_seq) { + log_module_error("DataGenerator", + "序列号验证失败: 索引 {} 期望 {} 实际 {}", + i, expected_seq, data[i].sequence); + valid = false; + } - for (size_t i = 0; i < data1.size(); ++i) { - if (data1[i] != data2[i]) { - log_module_error("DataGenerator", - "数据不匹配: 索引 {} 值 {} vs {}", - i, data1[i], data2[i]); - return false; - } - } + if (valid) { + ++success_count; + } - log_module_debug("DataGenerator", "数据比较成功: {} 个元素", data1.size()); - return true; - } + ++expected_seq; + } -private: - std::mt19937 gen_; ///< 随机数生成器 -}; + log_module_debug("DataGenerator", + "校验和验证完成: {}/{} 成功", + success_count, data.size()); -/** - * @brief 特化的字节数据生成器 - */ -using ByteDataGenerator = DataGenerator; + return {success_count, data.size()}; + } -/** - * @brief 特化的整数数据生成器 - */ -using IntDataGenerator = DataGenerator; + /** + * @brief 验证数据是否包含重复项 + * @param data 待验证的数据 + * @return true 如果没有重复项 + */ + bool verify_no_duplicates(const std::vector& data) { + if (data.size() <= 1) { + return true; + } -/** - * @brief 特化的长整数数据生成器 - */ -using LongDataGenerator = DataGenerator; + std::vector sorted = data; + std::sort(sorted.begin(), sorted.end()); -/** - * @brief 特化的浮点数据生成器 - */ -using FloatDataGenerator = DataGenerator; + auto it = std::adjacent_find(sorted.begin(), sorted.end()); + if (it != sorted.end()) { + log_module_error("DataGenerator", "发现重复数据值"); + return false; + } -/** - * @brief 特化的双精度浮点数据生成器 - */ -using DoubleDataGenerator = DataGenerator; + log_module_debug("DataGenerator", "无重复数据验证成功"); + return true; + } -} // namespace shm_test \ No newline at end of file + /** + * @brief 创建具有特定模式的数据 + * @param count 数据数量 + * @param pattern_func 模式函数(索引 -> 值) + * @return 按模式生成的数据向量 + */ + std::vector generate_pattern( + size_t count, + std::function pattern_func + ) { + std::vector data(count); + for (size_t i = 0; i < count; ++i) { + data[i] = pattern_func(i); + } + + log_module_debug("DataGenerator", "生成模式数据: {} 个元素", count); + return data; + } + + /** + * @brief 填充数据到缓冲区 + * @param buffer 目标缓冲区 + * @param data 源数据 + * @param offset 起始偏移 + */ + template + void fill_buffer(Buffer& buffer, const std::vector& data, size_t offset = 0) { + for (size_t i = 0; i < data.size(); ++i) { + buffer[offset + i] = data[i]; + } + } + + /** + * @brief 比较两个数据集 + * @param data1 第一个数据集 + * @param data2 第二个数据集 + * @return true 如果数据相同 + */ + bool compare_data( + const std::vector& data1, + const std::vector& data2 + ) { + if (data1.size() != data2.size()) { + log_module_error("DataGenerator", + "数据大小不匹配: {} vs {}", + data1.size(), data2.size()); + return false; + } + + for (size_t i = 0; i < data1.size(); ++i) { + if (data1[i] != data2[i]) { + log_module_error("DataGenerator", + "数据不匹配: 索引 {} 值 {} vs {}", + i, data1[i], data2[i]); + return false; + } + } + + log_module_debug("DataGenerator", "数据比较成功: {} 个元素", data1.size()); + return true; + } + + private: + std::mt19937 gen_; ///< 随机数生成器 + }; + + /** + * @brief 特化的字节数据生成器 + */ + using ByteDataGenerator = DataGenerator; + + /** + * @brief 特化的整数数据生成器 + */ + using IntDataGenerator = DataGenerator; + + /** + * @brief 特化的长整数数据生成器 + */ + using LongDataGenerator = DataGenerator; + + /** + * @brief 特化的浮点数据生成器 + */ + using FloatDataGenerator = DataGenerator; + + /** + * @brief 特化的双精度浮点数据生成器 + */ + using DoubleDataGenerator = DataGenerator; +} // namespace shm_test diff --git a/tests/shm/helpers/leak_detector.h b/tests/shm/helpers/leak_detector.h index 2c33a93..6ee2643 100644 --- a/tests/shm/helpers/leak_detector.h +++ b/tests/shm/helpers/leak_detector.h @@ -10,397 +10,398 @@ #include "logger.h" namespace shm_test { + /** + * @brief 共享内存段信息 + */ + struct SharedMemorySegmentInfo { + std::string name; ///< 段名称 + size_t size{0}; ///< 段大小(如果可获取) + bool is_test_segment{false}; ///< 是否是测试相关的段 -/** - * @brief 共享内存段信息 - */ -struct SharedMemorySegmentInfo { - std::string name; ///< 段名称 - size_t size{0}; ///< 段大小(如果可获取) - bool is_test_segment{false}; ///< 是否是测试相关的段 - - /** - * @brief 判断是否是测试段(基于名称前缀) - */ - static bool is_test_segment_name(const std::string& name) { - // 测试段通常以 "test_" 开头 - return name.find("test_") == 0; - } -}; + /** + * @brief 判断是否是测试段(基于名称前缀) + */ + static bool is_test_segment_name(const std::string& name) { + // 测试段通常以 "test_" 开头 + return name.find("test_") == 0; + } + }; -/** - * @brief 共享内存泄漏检测器 - * - * 功能: - * - 记录测试前的共享内存段快照 - * - 测试后检查新增的共享内存段(可能的泄漏) - * - 提供泄漏报告 - * - 强制清理测试相关的共享内存段 - * - * 使用方法: - * @code - * TEST(ShmTest, NoLeak) { - * ShmLeakDetector detector; - * detector.snapshot_before(); - * - * // 执行测试... - * - * EXPECT_FALSE(detector.check_leaks()); - * - * // 或在测试结束时: - * detector.cleanup_test_segments(); - * } - * @endcode - */ -class ShmLeakDetector { -public: - /** - * @brief 构造函数 - * @param auto_cleanup 析构时是否自动清理测试段 - */ - explicit ShmLeakDetector(bool auto_cleanup = false) - : auto_cleanup_(auto_cleanup) { - log_module_debug("ShmLeakDetector", "创建泄漏检测器"); - } + /** + * @brief 共享内存泄漏检测器 + * + * 功能: + * - 记录测试前的共享内存段快照 + * - 测试后检查新增的共享内存段(可能的泄漏) + * - 提供泄漏报告 + * - 强制清理测试相关的共享内存段 + * + * 使用方法: + * @code + * TEST(ShmTest, NoLeak) { + * ShmLeakDetector detector; + * detector.snapshot_before(); + * + * // 执行测试... + * + * EXPECT_FALSE(detector.check_leaks()); + * + * // 或在测试结束时: + * detector.cleanup_test_segments(); + * } + * @endcode + */ + class ShmLeakDetector { + public: + /** + * @brief 构造函数 + * @param auto_cleanup 析构时是否自动清理测试段 + */ + explicit ShmLeakDetector(bool auto_cleanup = false) : auto_cleanup_(auto_cleanup) { + log_module_debug("ShmLeakDetector", "创建泄漏检测器"); + } - /** - * @brief 析构函数 - */ - ~ShmLeakDetector() { - if (auto_cleanup_) { - cleanup_test_segments(); - } - } + /** + * @brief 析构函数 + */ + ~ShmLeakDetector() { + if (auto_cleanup_) { + cleanup_test_segments(); + } + } - // 禁止拷贝 - ShmLeakDetector(const ShmLeakDetector&) = delete; - ShmLeakDetector& operator=(const ShmLeakDetector&) = delete; + // 禁止拷贝 + ShmLeakDetector(const ShmLeakDetector&) = delete; + ShmLeakDetector& operator=(const ShmLeakDetector&) = delete; - // 允许移动 - ShmLeakDetector(ShmLeakDetector&&) = default; - ShmLeakDetector& operator=(ShmLeakDetector&&) = default; + // 允许移动 + ShmLeakDetector(ShmLeakDetector&&) = default; + ShmLeakDetector& operator=(ShmLeakDetector&&) = default; - /** - * @brief 记录当前共享内存段快照 - * - * 应在测试开始前调用,记录已存在的共享内存段 - */ - void snapshot_before() { - before_snapshot_ = get_current_segments(); - log_module_info("ShmLeakDetector", - "记录测试前快照: {} 个共享内存段", - before_snapshot_.size()); - - // 调试输出 - if (!before_snapshot_.empty()) { - log_module_debug("ShmLeakDetector", "测试前存在的段:"); - for (const auto& seg : before_snapshot_) { - log_module_debug("ShmLeakDetector", " - {}", seg.name); - } - } - } + /** + * @brief 记录当前共享内存段快照 + * + * 应在测试开始前调用,记录已存在的共享内存段 + */ + void snapshot_before() { + before_snapshot_ = get_current_segments(); + log_module_info("ShmLeakDetector", + "记录测试前快照: {} 个共享内存段", + before_snapshot_.size()); - /** - * @brief 检查是否有泄漏 - * - * 比较当前共享内存段与快照,查找新增的段 - * - * @return true 如果检测到泄漏 - */ - bool check_leaks() { - auto current_segments = get_current_segments(); - - // 找出新增的段(可能的泄漏) - leaked_segments_.clear(); - - for (const auto& current : current_segments) { - bool found_in_before = false; - for (const auto& before : before_snapshot_) { - if (current.name == before.name) { - found_in_before = true; - break; - } - } - - if (!found_in_before) { - leaked_segments_.push_back(current); - } - } - - bool has_leaks = !leaked_segments_.empty(); - - if (has_leaks) { - log_module_warn("ShmLeakDetector", - "检测到 {} 个可能的泄漏段", - leaked_segments_.size()); - } else { - log_module_info("ShmLeakDetector", "未检测到泄漏"); - } - - return has_leaks; - } + // 调试输出 + if (!before_snapshot_.empty()) { + log_module_debug("ShmLeakDetector", "测试前存在的段:"); + for (const auto& seg : before_snapshot_) { + log_module_debug("ShmLeakDetector", " - {}", seg.name); + } + } + } - /** - * @brief 打印泄漏报告 - * - * 输出详细的泄漏信息 - */ - void print_leak_report() const { - if (leaked_segments_.empty()) { - log_info("=== 共享内存泄漏报告 ==="); - log_info("未检测到泄漏"); - return; - } + /** + * @brief 检查是否有泄漏 + * + * 比较当前共享内存段与快照,查找新增的段 + * + * @return true 如果检测到泄漏 + */ + bool check_leaks() { + auto current_segments = get_current_segments(); - log_warn("=== 共享内存泄漏报告 ==="); - log_warn("检测到 {} 个可能的泄漏段:", leaked_segments_.size()); - - for (size_t i = 0; i < leaked_segments_.size(); ++i) { - const auto& seg = leaked_segments_[i]; - log_warn(" [{}] 名称: {}", i + 1, seg.name); - if (seg.size > 0) { - log_warn(" 大小: {} 字节", seg.size); - } - log_warn(" 测试段: {}", seg.is_test_segment ? "是" : "否"); - } - - log_warn("建议: 在测试结束时调用 cleanup_test_segments() 清理"); - } + // 找出新增的段(可能的泄漏) + leaked_segments_.clear(); - /** - * @brief 获取泄漏的段列表 - */ - const std::vector& get_leaked_segments() const { - return leaked_segments_; - } + for (const auto& current : current_segments) { + bool found_in_before = false; + for (const auto& before : before_snapshot_) { + if (current.name == before.name) { + found_in_before = true; + break; + } + } - /** - * @brief 强制清理所有测试相关的共享内存段 - * - * 删除所有以 "test_" 开头的共享内存段 - * - * @return 清理的段数量 - */ - size_t cleanup_test_segments() { - auto current_segments = get_current_segments(); - size_t cleaned_count = 0; + if (!found_in_before) { + leaked_segments_.push_back(current); + } + } - log_module_info("ShmLeakDetector", - "开始清理测试相关的共享内存段..."); + bool has_leaks = !leaked_segments_.empty(); - for (const auto& seg : current_segments) { - if (seg.is_test_segment) { - if (remove_segment(seg.name)) { - ++cleaned_count; - log_module_debug("ShmLeakDetector", - "已清理: {}", seg.name); - } else { - log_module_warn("ShmLeakDetector", - "清理失败: {}", seg.name); - } - } - } + if (has_leaks) { + log_module_warn("ShmLeakDetector", + "检测到 {} 个可能的泄漏段", + leaked_segments_.size()); + } + else { + log_module_info("ShmLeakDetector", "未检测到泄漏"); + } - log_module_info("ShmLeakDetector", - "清理完成: {} 个段", cleaned_count); + return has_leaks; + } - return cleaned_count; - } + /** + * @brief 打印泄漏报告 + * + * 输出详细的泄漏信息 + */ + void print_leak_report() const { + if (leaked_segments_.empty()) { + log_info("=== 共享内存泄漏报告 ==="); + log_info("未检测到泄漏"); + return; + } - /** - * @brief 清理指定名称的共享内存段 - * - * @param segment_name 段名称 - * @return true 如果成功清理 - */ - bool cleanup_segment(const std::string& segment_name) { - if (remove_segment(segment_name)) { - log_module_info("ShmLeakDetector", - "已清理共享内存段: {}", segment_name); - return true; - } else { - log_module_warn("ShmLeakDetector", - "清理共享内存段失败: {}", segment_name); - return false; - } - } + log_warn("=== 共享内存泄漏报告 ==="); + log_warn("检测到 {} 个可能的泄漏段:", leaked_segments_.size()); - /** - * @brief 获取当前所有共享内存段的数量 - */ - size_t get_current_segment_count() const { - return get_current_segments().size(); - } + for (size_t i = 0; i < leaked_segments_.size(); ++i) { + const auto& seg = leaked_segments_[i]; + log_warn(" [{}] 名称: {}", i + 1, seg.name); + if (seg.size > 0) { + log_warn(" 大小: {} 字节", seg.size); + } + log_warn(" 测试段: {}", seg.is_test_segment ? "是" : "否"); + } - /** - * @brief 重置检测器 - */ - void reset() { - before_snapshot_.clear(); - leaked_segments_.clear(); - log_module_debug("ShmLeakDetector", "检测器已重置"); - } + log_warn("建议: 在测试结束时调用 cleanup_test_segments() 清理"); + } -private: - /** - * @brief 获取当前系统中的所有共享内存段 - * - * 注意:Boost.Interprocess 没有直接提供枚举所有共享内存段的 API - * 这个实现可能需要平台特定的方法 - * - * @return 共享内存段信息列表 - */ - std::vector get_current_segments() const { - std::vector segments; + /** + * @brief 获取泄漏的段列表 + */ + const std::vector& get_leaked_segments() const { + return leaked_segments_; + } - // TODO: 平台特定的实现 - // - // Windows: - // - 使用 Windows API 枚举全局命名对象 - // - 或解析系统对象目录 - // - // Linux: - // - 读取 /dev/shm/ 目录(POSIX 共享内存) - // - 或使用 ipcs 命令输出 - // - // macOS: - // - 类似 Linux,但路径可能不同 - // - // 当前简化实现: - // 由于 Boost.Interprocess 的限制,我们无法直接枚举所有段 - // 在实际使用中,需要: - // 1. 使用平台特定 API - // 2. 或者在测试框架中手动跟踪所有创建的段 + /** + * @brief 强制清理所有测试相关的共享内存段 + * + * 删除所有以 "test_" 开头的共享内存段 + * + * @return 清理的段数量 + */ + size_t cleanup_test_segments() { + auto current_segments = get_current_segments(); + size_t cleaned_count = 0; -#ifdef _WIN32 - // Windows: 需要使用 WinAPI - log_module_warn("ShmLeakDetector", - "Windows 平台的共享内存枚举尚未实现"); -#elif defined(__linux__) - // Linux: 读取 /dev/shm - segments = enumerate_linux_shm(); -#elif defined(__APPLE__) - // macOS: 类似 Linux - log_module_warn("ShmLeakDetector", - "macOS 平台的共享内存枚举尚未实现"); -#else - log_module_warn("ShmLeakDetector", - "当前平台的共享内存枚举尚未实现"); -#endif + log_module_info("ShmLeakDetector", + "开始清理测试相关的共享内存段..."); - return segments; - } + for (const auto& seg : current_segments) { + if (seg.is_test_segment) { + if (remove_segment(seg.name)) { + ++cleaned_count; + log_module_debug("ShmLeakDetector", + "已清理: {}", seg.name); + } + else { + log_module_warn("ShmLeakDetector", + "清理失败: {}", seg.name); + } + } + } -#ifdef __linux__ - /** - * @brief Linux 平台:枚举 /dev/shm 中的共享内存段 - */ - std::vector enumerate_linux_shm() const { - std::vector segments; - - try { - // /dev/shm 是 POSIX 共享内存的挂载点 - const std::string shm_dir = "/dev/shm/"; - - // 注意:这需要文件系统支持 - // 可以使用 std::filesystem 或 boost::filesystem - - // 简化实现:由于我们在测试环境中,假设主要通过 - // ShmTestEnvironment 来跟踪段,这里返回空列表 - // 实际应用中可以使用 filesystem API - - log_module_debug("ShmLeakDetector", - "Linux /dev/shm 枚举功能需要 filesystem 支持"); - - } catch (const std::exception& e) { - log_module_error("ShmLeakDetector", - "枚举 Linux 共享内存段失败: {}", e.what()); - } - - return segments; - } -#endif + log_module_info("ShmLeakDetector", + "清理完成: {} 个段", cleaned_count); - /** - * @brief 删除指定的共享内存段 - * - * @param name 段名称 - * @return true 如果成功删除 - */ - bool remove_segment(const std::string& name) const { - try { - using namespace boost::interprocess; - return shared_memory_object::remove(name.c_str()); - } catch (const std::exception& e) { - log_module_error("ShmLeakDetector", - "删除共享内存段失败 {}: {}", name, e.what()); - return false; - } - } + return cleaned_count; + } - bool auto_cleanup_; ///< 是否自动清理 - std::vector before_snapshot_; ///< 测试前快照 - std::vector leaked_segments_; ///< 泄漏的段 -}; + /** + * @brief 清理指定名称的共享内存段 + * + * @param segment_name 段名称 + * @return true 如果成功清理 + */ + bool cleanup_segment(const std::string& segment_name) { + if (remove_segment(segment_name)) { + log_module_info("ShmLeakDetector", + "已清理共享内存段: {}", segment_name); + return true; + } + else { + log_module_warn("ShmLeakDetector", + "清理共享内存段失败: {}", segment_name); + return false; + } + } -/** - * @brief RAII 风格的泄漏检测器包装 - * - * 自动在构造时记录快照,析构时检查泄漏并清理 - * - * @example - * TEST(ShmTest, AutoLeak) { - * ScopedLeakDetector detector; - * // 执行测试... - * // 析构时自动检查并清理 - * } - */ -class ScopedLeakDetector { -public: - /** - * @brief 构造函数 - 记录快照 - * @param auto_cleanup 是否自动清理 - * @param fail_on_leak 如果检测到泄漏是否断言失败 - */ - explicit ScopedLeakDetector( - bool auto_cleanup = true, - bool fail_on_leak = true - ) : detector_(auto_cleanup), fail_on_leak_(fail_on_leak) { - detector_.snapshot_before(); - } + /** + * @brief 获取当前所有共享内存段的数量 + */ + size_t get_current_segment_count() const { + return get_current_segments().size(); + } - /** - * @brief 析构函数 - 检查泄漏 - */ - ~ScopedLeakDetector() { - bool has_leaks = detector_.check_leaks(); - - if (has_leaks) { - detector_.print_leak_report(); - - if (fail_on_leak_) { - // 使用 Google Test 的断言 - ADD_FAILURE() << "检测到共享内存泄漏!"; - } - } - } + /** + * @brief 重置检测器 + */ + void reset() { + before_snapshot_.clear(); + leaked_segments_.clear(); + log_module_debug("ShmLeakDetector", "检测器已重置"); + } - // 禁止拷贝和移动 - ScopedLeakDetector(const ScopedLeakDetector&) = delete; - ScopedLeakDetector& operator=(const ScopedLeakDetector&) = delete; - ScopedLeakDetector(ScopedLeakDetector&&) = delete; - ScopedLeakDetector& operator=(ScopedLeakDetector&&) = delete; + private: + /** + * @brief 获取当前系统中的所有共享内存段 + * + * 注意:Boost.Interprocess 没有直接提供枚举所有共享内存段的 API + * 这个实现可能需要平台特定的方法 + * + * @return 共享内存段信息列表 + */ + std::vector get_current_segments() const { + std::vector segments; - /** - * @brief 获取内部检测器 - */ - ShmLeakDetector& get_detector() { - return detector_; - } + // TODO: 平台特定的实现 + // + // Windows: + // - 使用 Windows API 枚举全局命名对象 + // - 或解析系统对象目录 + // + // Linux: + // - 读取 /dev/shm/ 目录(POSIX 共享内存) + // - 或使用 ipcs 命令输出 + // + // macOS: + // - 类似 Linux,但路径可能不同 + // + // 当前简化实现: + // 由于 Boost.Interprocess 的限制,我们无法直接枚举所有段 + // 在实际使用中,需要: + // 1. 使用平台特定 API + // 2. 或者在测试框架中手动跟踪所有创建的段 -private: - ShmLeakDetector detector_; - bool fail_on_leak_; -}; + #ifdef _WIN32 + // Windows: 需要使用 WinAPI + log_module_warn("ShmLeakDetector", + "Windows 平台的共享内存枚举尚未实现"); + #elif defined(__linux__) + // Linux: 读取 /dev/shm + segments = enumerate_linux_shm(); + #elif defined(__APPLE__) + // macOS: 类似 Linux + log_module_warn("ShmLeakDetector", + "macOS 平台的共享内存枚举尚未实现"); + #else + log_module_warn("ShmLeakDetector", + "当前平台的共享内存枚举尚未实现"); + #endif -} // namespace shm_test \ No newline at end of file + return segments; + } + + #ifdef __linux__ + /** + * @brief Linux 平台:枚举 /dev/shm 中的共享内存段 + */ + std::vector enumerate_linux_shm() const { + std::vector segments; + + try { + // /dev/shm 是 POSIX 共享内存的挂载点 + const std::string shm_dir = "/dev/shm/"; + + // 注意:这需要文件系统支持 + // 可以使用 std::filesystem 或 boost::filesystem + + // 简化实现:由于我们在测试环境中,假设主要通过 + // ShmTestEnvironment 来跟踪段,这里返回空列表 + // 实际应用中可以使用 filesystem API + + log_module_debug("ShmLeakDetector", + "Linux /dev/shm 枚举功能需要 filesystem 支持"); + } + catch (const std::exception& e) { + log_module_error("ShmLeakDetector", + "枚举 Linux 共享内存段失败: {}", e.what()); + } + + return segments; + } + #endif + + /** + * @brief 删除指定的共享内存段 + * + * @param name 段名称 + * @return true 如果成功删除 + */ + bool remove_segment(const std::string& name) const { + try { + using namespace boost::interprocess; + return shared_memory_object::remove(name.c_str()); + } + catch (const std::exception& e) { + log_module_error("ShmLeakDetector", + "删除共享内存段失败 {}: {}", name, e.what()); + return false; + } + } + + bool auto_cleanup_; ///< 是否自动清理 + std::vector before_snapshot_; ///< 测试前快照 + std::vector leaked_segments_; ///< 泄漏的段 + }; + + /** + * @brief RAII 风格的泄漏检测器包装 + * + * 自动在构造时记录快照,析构时检查泄漏并清理 + * + * @example + * TEST(ShmTest, AutoLeak) { + * ScopedLeakDetector detector; + * // 执行测试... + * // 析构时自动检查并清理 + * } + */ + class ScopedLeakDetector { + public: + /** + * @brief 构造函数 - 记录快照 + * @param auto_cleanup 是否自动清理 + * @param fail_on_leak 如果检测到泄漏是否断言失败 + */ + explicit ScopedLeakDetector( + bool auto_cleanup = true, + bool fail_on_leak = true + ) : detector_(auto_cleanup), fail_on_leak_(fail_on_leak) { + detector_.snapshot_before(); + } + + /** + * @brief 析构函数 - 检查泄漏 + */ + ~ScopedLeakDetector() { + bool has_leaks = detector_.check_leaks(); + + if (has_leaks) { + detector_.print_leak_report(); + + if (fail_on_leak_) { + // 使用 Google Test 的断言 + ADD_FAILURE() << "检测到共享内存泄漏!"; + } + } + } + + // 禁止拷贝和移动 + ScopedLeakDetector(const ScopedLeakDetector&) = delete; + ScopedLeakDetector& operator=(const ScopedLeakDetector&) = delete; + ScopedLeakDetector(ScopedLeakDetector&&) = delete; + ScopedLeakDetector& operator=(ScopedLeakDetector&&) = delete; + + /** + * @brief 获取内部检测器 + */ + ShmLeakDetector& get_detector() { + return detector_; + } + + private: + ShmLeakDetector detector_; + bool fail_on_leak_; + }; +} // namespace shm_test diff --git a/tests/shm/helpers/multiprocess_harness.h b/tests/shm/helpers/multiprocess_harness.h index 1fbc953..527a193 100644 --- a/tests/shm/helpers/multiprocess_harness.h +++ b/tests/shm/helpers/multiprocess_harness.h @@ -11,345 +11,342 @@ // 平台检测 #ifdef _WIN32 - #define SHM_TEST_WINDOWS +#define SHM_TEST_WINDOWS #elif defined(__unix__) || defined(__APPLE__) - #define SHM_TEST_UNIX +#define SHM_TEST_UNIX #endif namespace shm_test { + /** + * @brief 进程配置结构 + * + * 定义子进程的启动参数和运行配置 + */ + struct ProcessConfig { + std::string process_name; ///< 进程名称(用于日志) + std::function entry_point; ///< 进程入口函数(返回退出码) + std::chrono::seconds timeout{30}; ///< 进程超时时间 + bool capture_output{true}; ///< 是否捕获进程输出 -/** - * @brief 进程配置结构 - * - * 定义子进程的启动参数和运行配置 - */ -struct ProcessConfig { - std::string process_name; ///< 进程名称(用于日志) - std::function entry_point; ///< 进程入口函数(返回退出码) - std::chrono::seconds timeout{30}; ///< 进程超时时间 - bool capture_output{true}; ///< 是否捕获进程输出 - - /** - * @brief 构造函数 - * @param name 进程名称 - * @param func 入口函数 - */ - ProcessConfig(std::string name, std::function func) - : process_name(std::move(name)), entry_point(std::move(func)) {} -}; + /** + * @brief 构造函数 + * @param name 进程名称 + * @param func 入口函数 + */ + ProcessConfig(std::string name, std::function func) : process_name(std::move(name)), + entry_point(std::move(func)) { + } + }; -/** - * @brief 进程执行结果 - */ -struct ProcessResult { - int exit_code{-1}; ///< 退出码 - bool success{false}; ///< 是否成功(exit_code == 0) - bool timeout{false}; ///< 是否超时 - std::string output; ///< 进程输出(如果启用捕获) - std::string error_message; ///< 错误消息 - double elapsed_seconds{0.0}; ///< 运行时间(秒) -}; + /** + * @brief 进程执行结果 + */ + struct ProcessResult { + int exit_code{-1}; ///< 退出码 + bool success{false}; ///< 是否成功(exit_code == 0) + bool timeout{false}; ///< 是否超时 + std::string output; ///< 进程输出(如果启用捕获) + std::string error_message; ///< 错误消息 + double elapsed_seconds{0.0}; ///< 运行时间(秒) + }; -/** - * @brief 多进程测试框架类 - * - * 提供跨平台的多进程测试支持。 - * - * TODO: 完整实现待后续完成 - * - * 平台差异: - * - Windows: 使用 CreateProcess API 或 std::system - * - Unix/Linux: 使用 fork() + exec() 或 posix_spawn - * - macOS: 类似 Unix,但需要注意全限制 - * - * 当前限制: - * - 接口已定义,具体实现留待后续 - * - 跨平台实现较复杂,需要处理进程间通信、同步等问题 - * - 考虑使用 Boost.Process 库简化实现 - * - * @example - * // 示例用法(实现后): - * MultiProcessTestHarness harness; - * - * ProcessConfig writer("writer", []() { - * // 写入进程逻辑 - * return 0; - * }); - * - * ProcessConfig reader("reader", []() { - * // 读取进程逻辑 - * return 0; - * }); - * - * auto results = harness.spawn_processes({writer, reader}); - * harness.wait_all(results); - * - * for (const auto& result : results) { - * EXPECT_TRUE(result.success); - * } - */ -class MultiProcessTestHarness { -public: - /** - * @brief 构造函数 - */ - MultiProcessTestHarness() { - log_module_warn("MultiProcessTestHarness", - "多进测试框架尚未完全实现,当前仅为接口声明"); - } + /** + * @brief 多进程测试框架类 + * + * 提供跨平台的多进程测试支持。 + * + * TODO: 完整实现待后续完成 + * + * 平台差异: + * - Windows: 使用 CreateProcess API 或 std::system + * - Unix/Linux: 使用 fork() + exec() 或 posix_spawn + * - macOS: 类似 Unix,但需要注意全限制 + * + * 当前限制: + * - 接口已定义,具体实现留待后续 + * - 跨平台实现较复杂,需要处理进程间通信、同步等问题 + * - 考虑使用 Boost.Process 库简化实现 + * + * @example + * // 示例用法(实现后): + * MultiProcessTestHarness harness; + * + * ProcessConfig writer("writer", []() { + * // 写入进程逻辑 + * return 0; + * }); + * + * ProcessConfig reader("reader", []() { + * // 读取进程逻辑 + * return 0; + * }); + * + * auto results = harness.spawn_processes({writer, reader}); + * harness.wait_all(results); + * + * for (const auto& result : results) { + * EXPECT_TRUE(result.success); + * } + */ + class MultiProcessTestHarness { + public: + /** + * @brief 构造函数 + */ + MultiProcessTestHarness() { + log_module_warn("MultiProcessTestHarness", + "多进测试框架尚未完全实现,当前仅为接口声明"); + } - /** - * @brief 析构函数 - 确保所有子进程被清理 - */ - ~MultiProcessTestHarness() { - cleanup(); - } + /** + * @brief 析构函数 - 确保所有子进程被清理 + */ + ~MultiProcessTestHarness() { + cleanup(); + } - // 禁止拷贝和移动 - MultiProcessTestHarness(const MultiProcessTestHarness&) = delete; - MultiProcessTestHarness& operator=(const MultiProcessTestHarness&) = delete; - MultiProcessTestHarness(MultiProcessTestHarness&&) = delete; - MultiProcessTestHarness& operator=(MultiProcessTestHarness&&) = delete; + // 禁止拷贝和移动 + MultiProcessTestHarness(const MultiProcessTestHarness&) = delete; + MultiProcessTestHarness& operator=(const MultiProcessTestHarness&) = delete; + MultiProcessTestHarness(MultiProcessTestHarness&&) = delete; + MultiProcessTestHarness& operator=(MultiProcessTestHarness&&) = delete; - /** - * @brief 启动多个子进程 - * - * TODO: 实现细节 - * - Windows: 使用 CreateProcess 或序列化函数对象 - * - Unix: 使用 fork() + 函数调用 - * - 需要考虑进程间如何传递 std::function - * - * @param configs 进程配置列表 - * @return 进程结果列表(异步) - */ - std::vector spawn_processes( - const std::vector& configs - ) { - log_module_error("MultiProcessTestHarness", - "spawn_processes 尚未实现"); - - // TODO: 实现 - // 1. 遍历 configs - // 2. 为每个配置启动子进程 - // 3. 返回进程句柄/PID 的包装 - - std::vector results; - for (const auto& config : configs) { - ProcessResult result; - result.exit_code = -1; - result.success = false; - result.error_message = "未实现"; - results.push_back(result); - } - - return results; - } + /** + * @brief 启动多个子进程 + * + * TODO: 实现细节 + * - Windows: 使用 CreateProcess 或序列化函数对象 + * - Unix: 使用 fork() + 函数调用 + * - 需要考虑进程间如何传递 std::function + * + * @param configs 进程配置列表 + * @return 进程结果列表(异步) + */ + std::vector spawn_processes( + const std::vector& configs + ) { + log_module_error("MultiProcessTestHarness", + "spawn_processes 尚未实现"); - /** - * @brief 等待所有进程完成 - * - * TODO: 实现细节 - * - 等待所有子进程退出 - * - 收集退出码和输出 - * - 处理超时情况 - * - * @param results 进程结果引用(会被更新) - * @return true 如果所有进程成功完成 - */ - bool wait_all(std::vector& results) { - log_module_error("MultiProcessTestHarness", - "wait_all 尚未实现"); - - // TODO: 实现 - // 1. 等待所有进程 - // 2. 更新 results - // 3. 检查超时 - - return false; - } + // TODO: 实现 + // 1. 遍历 configs + // 2. 为每个配置启动子进程 + // 3. 返回进程句柄/PID 的包装 - /** - * @brief 终止所有子进程 - * - * TODO: 实现细节 - * - Windows: TerminateProcess - * - Unix: kill(pid, SIGTERM) 或 SIGKILL - */ - void terminate_all() { - log_module_error("MultiProcessTestHarness", - "terminate_all 尚未实现"); - - // TODO: 实现 - } + std::vector results; + for (const auto& config : configs) { + ProcessResult result; + result.exit_code = -1; + result.success = false; + result.error_message = "未实现"; + results.push_back(result); + } - /** - * @brief 清理资源 - */ - void cleanup() { - // TODO: 实现 - // 确保所有子进程被终止和清理 - } + return results; + } -private: - // TODO: 添加成员变量存储进程句柄/PID -#ifdef SHM_TEST_WINDOWS - // Windows 特定成员 - // std::vector process_handles_; -#elif defined(SHM_TEST_UNIX) - // Unix 特定成员 - // std::vector process_pids_; -#endif -}; + /** + * @brief 等待所有进程完成 + * + * TODO: 实现细节 + * - 等待所有子进程退出 + * - 收集退出码和输出 + * - 处理超时情况 + * + * @param results 进程结果引用(会被更新) + * @return true 如果所有进程成功完成 + */ + bool wait_all(std::vector& results) { + log_module_error("MultiProcessTestHarness", + "wait_all 尚未实现"); -/** - * @brief 进程间同步屏障 - * - * 用于多进程测试中的同步点,确保所有进程到达某个点后再继续。 - * - * TODO: 完整实现待后续完成 - * - * 实现方案: - * - 使用 Boost.Interprocess 的 named_semaphore 或 named_condition - * - 或使用共享内存 + 原子计数器 - * - * @example - * // 示例用法(实现后): - * Barrier barrier("test_barrier", 3); // 3 个进程 - * - * // 在每个进程中: - * barrier.wait(); // 阻塞直到所有 3 个进程都到达 - */ -class Barrier { -public: - /** - * @brief 构造函数 - * @param name 屏障名称(用于跨进程识别) - * @param count 需要等待的进程数量 - */ - Barrier(std::string name, size_t count) - : name_(std::move(name)), expected_count_(count) { - - log_module_warn("Barrier", - "进程屏障尚未完实现:{}, 预期进程数: {}", - name_, expected_count_); - - // TODO: 实现 - // 1. 创建命名的同步对象(semaphore/condition) - // 2. 初始化共享计数器 - } + // TODO: 实现 + // 1. 等待所有进程 + // 2. 更新 results + // 3. 检查超时 - /** - * @brief 析构函数 - 清理同步对象 - */ - ~Barrier() { - cleanup(); - } + return false; + } - // 禁止拷贝和移动 - Barrier(const Barrier&) = delete; - Barrier& operator=(const Barrier&) = delete; - Barrier(Barrier&&) = delete; - Barrier& operator=(Barrier&&) = delete; + /** + * @brief 终止所有子进程 + * + * TODO: 实现细节 + * - Windows: TerminateProcess + * - Unix: kill(pid, SIGTERM) 或 SIGKILL + */ + void terminate_all() { + log_module_error("MultiProcessTestHarness", + "terminate_all 尚未实现"); - /** - * @brief 等待所有进程到达屏障 - * - * TODO: 实现细节 - * - 原子递增到达计数 - * - 如果未达到预期数量,阻塞等待 - * - 最后一个到达的进程唤醒所有等待者 - * - * @param timeout_ms 超时时间(毫秒),0 表示无限等待 - * @return true 如果所有进程都到达,false 如果超时 - */ - bool wait(uint32_t timeout_ms = 0) { - log_module_error("Barrier", - "wait 尚未实现:{}", name_); - - // TODO: 实现 - // 1. 原子递增计数器 - // 2. 检查是否达到 expected_count_ - // 3. 如果未达到,等待信号 - // 4. 如果达到,发送信号给所有等待者 - - return false; - } + // TODO: 实现 + } - /** - * @brief 重置屏障(用于重复使用) - * - * TODO: 实现 - */ - void reset() { - log_module_error("Barrier", - "reset 尚未实现:{}", name_); - - // TODO: 实现 - // 重置计数器为 0 - } + /** + * @brief 清理资源 + */ + void cleanup() { + // TODO: 实现 + // 确保所有子进程被终止和清理 + } - /** - * @brief 获取当到达的进程数 - * - * TODO: 实现 - */ - size_t get_current_count() const { - // TODO: 实现 - return 0; - } + private: + // TODO: 添加成员变量存储进程句柄/PID + #ifdef SHM_TEST_WINDOWS + // Windows 特定成员 + // std::vector process_handles_; + #elif defined(SHM_TEST_UNIX) + // Unix 特定成员 + // std::vector process_pids_; + #endif + }; -private: - /** - * @brief 清理同步对象 - */ - void cleanup() { - // TODO: 实现 - // 删除命名的同步对象 - } + /** + * @brief 进程间同步屏障 + * + * 用于多进程测试中的同步点,确保所有进程到达某个点后再继续。 + * + * TODO: 完整实现待后续完成 + * + * 实现方案: + * - 使用 Boost.Interprocess 的 named_semaphore 或 named_condition + * - 或使用共享内存 + 原子计数器 + * + * @example + * // 示例用法(实现后): + * Barrier barrier("test_barrier", 3); // 3 个进程 + * + * // 在每个进程中: + * barrier.wait(); // 阻塞直到所有 3 个进程都到达 + */ + class Barrier { + public: + /** + * @brief 构造函数 + * @param name 屏障名称(用于跨进程识别) + * @param count 需要等待的进程数量 + */ + Barrier(std::string name, size_t count) : name_(std::move(name)), expected_count_(count) { + log_module_warn("Barrier", + "进程屏障尚未完实现:{}, 预期进程数: {}", + name_, expected_count_); - std::string name_; ///< 屏障名称 - size_t expected_count_; ///< 预期进程数量 - - // TODO: 添加同步对象成员 - // - Boost::interprocess::named_semaphore - // - 或共享内存 + 原子计数器 -}; + // TODO: 实现 + // 1. 创建命名的同步对象(semaphore/condition) + // 2. 初始化共享计数器 + } -/** - * @brief 辅助函数:简化单进程测试 - * - * TODO: 实现 - * - * @param name 进程名称 - * @param func 进程函数 - * @param timeout 超时时间 - * @return 进程结果 - */ -inline ProcessResult run_in_process( - const std::string& name, - std::function func, - std::chrono::seconds timeout = std::chrono::seconds(30) -) { - log_module_error("run_in_process", "尚未实现"); - - ProcessResult result; - result.exit_code = -1; - result.success = false; - result.error_message = "未实现"; - - // TODO: 实现 - // MultiProcessTestHarness harness; - // ProcessConfig config(name, func); - // config.timeout = timeout; - // auto results = harness.spawn_processes({config}); - // harness.wait_all(results); - // return results[0]; - - return result; -} + /** + * @brief 析构函数 - 清理同步对象 + */ + ~Barrier() { + cleanup(); + } + // 禁止拷贝和移动 + Barrier(const Barrier&) = delete; + Barrier& operator=(const Barrier&) = delete; + Barrier(Barrier&&) = delete; + Barrier& operator=(Barrier&&) = delete; + + /** + * @brief 等待所有进程到达屏障 + * + * TODO: 实现细节 + * - 原子递增到达计数 + * - 如果未达到预期数量,阻塞等待 + * - 最后一个到达的进程唤醒所有等待者 + * + * @param timeout_ms 超时时间(毫秒),0 表示无限等待 + * @return true 如果所有进程都到达,false 如果超时 + */ + bool wait(uint32_t timeout_ms = 0) { + log_module_error("Barrier", + "wait 尚未实现:{}", name_); + + // TODO: 实现 + // 1. 原子递增计数器 + // 2. 检查是否达到 expected_count_ + // 3. 如果未达到,等待信号 + // 4. 如果达到,发送信号给所有等待者 + + return false; + } + + /** + * @brief 重置屏障(用于重复使用) + * + * TODO: 实现 + */ + void reset() { + log_module_error("Barrier", + "reset 尚未实现:{}", name_); + + // TODO: 实现 + // 重置计数器为 0 + } + + /** + * @brief 获取当到达的进程数 + * + * TODO: 实现 + */ + size_t get_current_count() const { + // TODO: 实现 + return 0; + } + + private: + /** + * @brief 清理同步对象 + */ + void cleanup() { + // TODO: 实现 + // 删除命名的同步对象 + } + + std::string name_; ///< 屏障名称 + size_t expected_count_; ///< 预期进程数量 + + // TODO: 添加同步对象成员 + // - Boost::interprocess::named_semaphore + // - 或共享内存 + 原子计数器 + }; + + /** + * @brief 辅助函数:简化单进程测试 + * + * TODO: 实现 + * + * @param name 进程名称 + * @param func 进程函数 + * @param timeout 超时时间 + * @return 进程结果 + */ + inline ProcessResult run_in_process( + const std::string& name, + std::function func, + std::chrono::seconds timeout = std::chrono::seconds(30) + ) { + log_module_error("run_in_process", "尚未实现"); + + ProcessResult result; + result.exit_code = -1; + result.success = false; + result.error_message = "未实现"; + + // TODO: 实现 + // MultiProcessTestHarness harness; + // ProcessConfig config(name, func); + // config.timeout = timeout; + // auto results = harness.spawn_processes({config}); + // harness.wait_all(results); + // return results[0]; + + return result; + } } // namespace shm_test // 清理平台宏 #undef SHM_TEST_WINDOWS -#undef SHM_TEST_UNIX \ No newline at end of file +#undef SHM_TEST_UNIX diff --git a/tests/shm/helpers/performance_timer.h b/tests/shm/helpers/performance_timer.h index 67d62ae..c43b5de 100644 --- a/tests/shm/helpers/performance_timer.h +++ b/tests/shm/helpers/performance_timer.h @@ -12,509 +12,513 @@ #include "logger.h" namespace shm_test { + /** + * @brief 时间单位枚举 + */ + enum class TimeUnit { + NANOSECONDS, + MICROSECONDS, + MILLISECONDS, + SECONDS + }; -/** - * @brief 时间单位枚举 - */ -enum class TimeUnit { - NANOSECONDS, - MICROSECONDS, - MILLISECONDS, - SECONDS -}; + /** + * @brief 性能统计数据 + * + * 包含延迟分布的各种统计指标 + */ + class Statistics { + public: + Statistics() = default; -/** - * @brief 性能统计数据 - * - * 包含延迟分布的各种统计指标 - */ -class Statistics { -public: - Statistics() = default; + /** + * @brief 从延迟样本构造统计数据 + * @param latencies 延迟样本(纳秒) + */ + explicit Statistics(std::vector latencies) : latencies_(std::move(latencies)) { + calculate(); + } - /** - * @brief 从延迟样本构造统计数据 - * @param latencies 延迟样本(纳秒) - */ - explicit Statistics(std::vector latencies) - : latencies_(std::move(latencies)) { - calculate(); - } + /** + * @brief 添加延迟样本 + * @param latency_ns 延迟(纳秒) + */ + void add_sample(double latency_ns) { + latencies_.push_back(latency_ns); + } - /** - * @brief 添加延迟样本 - * @param latency_ns 延迟(纳秒) - */ - void add_sample(double latency_ns) { - latencies_.push_back(latency_ns); - } + /** + * @brief 计算所有统计指标 + */ + void calculate() { + if (latencies_.empty()) { + min_ = max_ = avg_ = p50_ = p95_ = p99_ = 0.0; + stddev_ = 0.0; + return; + } - /** - * @brief 计算所有统计指标 - */ - void calculate() { - if (latencies_.empty()) { - min_ = max_ = avg_ = p50_ = p95_ = p99_ = 0.0; - stddev_ = 0.0; - return; - } + // 排序以计算百分位数 + std::sort(latencies_.begin(), latencies_.end()); - // 排序以计算百分位数 - std::sort(latencies_.begin(), latencies_.end()); + // 最小值和最大值 + min_ = latencies_.front(); + max_ = latencies_.back(); - // 最小值和最大值 - min_ = latencies_.front(); - max_ = latencies_.back(); + // 平均值 + avg_ = std::accumulate(latencies_.begin(), latencies_.end(), 0.0) + / latencies_.size(); - // 平均值 - avg_ = std::accumulate(latencies_.begin(), latencies_.end(), 0.0) - / latencies_.size(); + // 标准差 + double sq_sum = std::accumulate(latencies_.begin(), latencies_.end(), 0.0, + [this](double sum, double val) { + double diff = val - avg_; + return sum + diff * diff; + }); + stddev_ = std::sqrt(sq_sum / latencies_.size()); - // 标准差 - double sq_sum = std::accumulate(latencies_.begin(), latencies_.end(), 0.0, - [this](double sum, double val) { - double diff = val - avg_; - return sum + diff * diff; - }); - stddev_ = std::sqrt(sq_sum / latencies_.size()); + // 百分位数 + p50_ = percentile(50.0); + p95_ = percentile(95.0); + p99_ = percentile(99.0); + } - // 百分位数 - p50_ = percentile(50.0); - p95_ = percentile(95.0); - p99_ = percentile(99.0); - } + /** + * @brief 获取指定百分位数的值 + * @param p 百分比 (0-100) + * @return 百分位数值(纳秒) + */ + double percentile(double p) const { + if (latencies_.empty()) { + return 0.0; + } - /** - * @brief 获取指定百分位数的值 - * @param p 百分比 (0-100) - * @return 百分位数值(纳秒) - */ - double percentile(double p) const { - if (latencies_.empty()) { - return 0.0; - } + if (p <= 0.0) + return latencies_.front(); + if (p >= 100.0) + return latencies_.back(); - if (p <= 0.0) return latencies_.front(); - if (p >= 100.0) return latencies_.back(); + double index = (p / 100.0) * (latencies_.size() - 1); + size_t lower = static_cast(std::floor(index)); + size_t upper = static_cast(std::ceil(index)); - double index = (p / 100.0) * (latencies_.size() - 1); - size_t lower = static_cast(std::floor(index)); - size_t upper = static_cast(std::ceil(index)); + if (lower == upper) { + return latencies_[lower]; + } - if (lower == upper) { - return latencies_[lower]; - } + double weight = index - lower; + return latencies_[lower] * (1.0 - weight) + latencies_[upper] * weight; + } - double weight = index - lower; - return latencies_[lower] * (1.0 - weight) + latencies_[upper] * weight; - } + /** + * @brief 转换时间单位 + * @param value_ns 纳秒值 + * @param unit 目标单位 + * @return 转换后的值 + */ + static double convert_time(double value_ns, TimeUnit unit) { + switch (unit) { + case TimeUnit::NANOSECONDS: + return value_ns; + case TimeUnit::MICROSECONDS: + return value_ns / 1000.0; + case TimeUnit::MILLISECONDS: + return value_ns / 1000000.0; + case TimeUnit::SECONDS: + return value_ns / 1000000000.0; + default: + return value_ns; + } + } - /** - * @brief 转换时间单位 - * @param value_ns 纳秒值 - * @param unit 目标单位 - * @return 转换后的值 - */ - static double convert_time(double value_ns, TimeUnit unit) { - switch (unit) { - case TimeUnit::NANOSECONDS: - return value_ns; - case TimeUnit::MICROSECONDS: - return value_ns / 1000.0; - case TimeUnit::MILLISECONDS: - return value_ns / 1000000.0; - case TimeUnit::SECONDS: - return value_ns / 1000000000.0; - default: - return value_ns; - } - } + /** + * @brief 获取时间单位的字符串表示 + */ + static const char* time_unit_string(TimeUnit unit) { + switch (unit) { + case TimeUnit::NANOSECONDS: + return "ns"; + case TimeUnit::MICROSECONDS: + return "μs"; + case TimeUnit::MILLISECONDS: + return "ms"; + case TimeUnit::SECONDS: + return "s"; + default: + return "ns"; + } + } - /** - * @brief 获取时间单位的字符串表示 - */ - static const char* time_unit_string(TimeUnit unit) { - switch (unit) { - case TimeUnit::NANOSECONDS: return "ns"; - case TimeUnit::MICROSECONDS: return "μs"; - case TimeUnit::MILLISECONDS: return "ms"; - case TimeUnit::SECONDS: return "s"; - default: return "ns"; - } - } + /** + * @brief 打印统计报告 + * @param unit 时间单位 + * @param label 标签(可选) + */ + void print_report(TimeUnit unit = TimeUnit::MICROSECONDS, + const std::string& label = "") const { + if (!label.empty()) { + log_info("=== {} 性能统计 ===", label); + } + else { + log_info("=== 性能统计 ==="); + } - /** - * @brief 打印统计报告 - * @param unit 时间单位 - * @param label 标签(可选) - */ - void print_report(TimeUnit unit = TimeUnit::MICROSECONDS, - const std::string& label = "") const { - if (!label.empty()) { - log_info("=== {} 性能统计 ===", label); - } else { - log_info("=== 性能统计 ==="); - } + const char* unit_str = time_unit_string(unit); - const char* unit_str = time_unit_string(unit); - - log_info("样本数量: {}", latencies_.size()); - log_info("最小值: {:.2f} {}", convert_time(min_, unit), unit_str); - log_info("最大值: {:.2f} {}", convert_time(max_, unit), unit_str); - log_info("平均值: {:.2f} {}", convert_time(avg_, unit), unit_str); - log_info("标准差: {:.2f} {}", convert_time(stddev_, unit), unit_str); - log_info("P50: {:.2f} {}", convert_time(p50_, unit), unit_str); - log_info("P95: {:.2f} {}", convert_time(p95_, unit), unit_str); - log_info("P99: {:.2f} {}", convert_time(p99_, unit), unit_str); - } + log_info("样本数量: {}", latencies_.size()); + log_info("最小值: {:.2f} {}", convert_time(min_, unit), unit_str); + log_info("最大值: {:.2f} {}", convert_time(max_, unit), unit_str); + log_info("平均值: {:.2f} {}", convert_time(avg_, unit), unit_str); + log_info("标准差: {:.2f} {}", convert_time(stddev_, unit), unit_str); + log_info("P50: {:.2f} {}", convert_time(p50_, unit), unit_str); + log_info("P95: {:.2f} {}", convert_time(p95_, unit), unit_str); + log_info("P99: {:.2f} {}", convert_time(p99_, unit), unit_str); + } - // Getters(纳秒) - double min() const { return min_; } - double max() const { return max_; } - double avg() const { return avg_; } - double stddev() const { return stddev_; } - double p50() const { return p50_; } - double p95() const { return p95_; } - double p99() const { return p99_; } - size_t sample_count() const { return latencies_.size(); } + // Getters(纳秒) + double min() const { return min_; } + double max() const { return max_; } + double avg() const { return avg_; } + double stddev() const { return stddev_; } + double p50() const { return p50_; } + double p95() const { return p95_; } + double p99() const { return p99_; } + size_t sample_count() const { return latencies_.size(); } - /** - * @brief 重置统计数据 - */ - void reset() { - latencies_.clear(); - min_ = max_ = avg_ = p50_ = p95_ = p99_ = 0.0; - stddev_ = 0.0; - } + /** + * @brief 重置统计数据 + */ + void reset() { + latencies_.clear(); + min_ = max_ = avg_ = p50_ = p95_ = p99_ = 0.0; + stddev_ = 0.0; + } -private: - std::vector latencies_; ///< 延迟本(纳秒) - double min_{0.0}; ///< 最小延迟 - double max_{0.0}; ///< 最大延迟 - double avg_{0.0}; ///< 平均延迟 - double stddev_{0.0}; ///< 标准差 - double p50_{0.0}; ///< 50% 百分位数(中位数) - double p95_{0.0}; ///< 95% 百分位数 - double p99_{0.0}; ///< 99% 百分位数 -}; + private: + std::vector latencies_; ///< 延迟本(纳秒) + double min_{0.0}; ///< 最小延迟 + double max_{0.0}; ///< 最大延迟 + double avg_{0.0}; ///< 平均延迟 + double stddev_{0.0}; ///< 标准差 + double p50_{0.0}; ///< 50% 百分位数(中位数) + double p95_{0.0}; ///< 95% 百分位数 + double p99_{0.0}; ///< 99% 百分位数 + }; -/** - * @brief RAII 作用域计时器 - * - * 在构造时开始计时,析构时自动记录耗时 - * - * @example - * { - * ScopedTimer timer("操作名称"); - * // 执行需要计时的操作... - * } // 析构时自动输出耗时 - */ -class ScopedTimer { -public: - using Clock = std::chrono::high_resolution_clock; - using TimePoint = Clock::time_point; - using Duration = std::chrono::nanoseconds; + /** + * @brief RAII 作用域计时器 + * + * 在构造时开始计时,析构时自动记录耗时 + * + * @example + * { + * ScopedTimer timer("操作名称"); + * // 执行需要计时的操作... + * } // 析构时自动输出耗时 + */ + class ScopedTimer { + public: + using Clock = std::chrono::high_resolution_clock; + using TimePoint = Clock::time_point; + using Duration = std::chrono::nanoseconds; - /** - * @brief 构造函数 - 开始计时 - * @param name 计时器名称 - * @param auto_print 析构时是否自动打印结果 - * @param unit 时间单位 - */ - explicit ScopedTimer( - std::string name = "ScopedTimer", - bool auto_print = true, - TimeUnit unit = TimeUnit::MILLISECONDS - ) : name_(std::move(name)), - auto_print_(auto_print), - unit_(unit), - start_(Clock::now()) { - } + /** + * @brief 构造函数 - 开始计时 + * @param name 计时器名称 + * @param auto_print 析构时是否自动打印结果 + * @param unit 时间单位 + */ + explicit ScopedTimer( + std::string name = "ScopedTimer", + bool auto_print = true, + TimeUnit unit = TimeUnit::MILLISECONDS + ) : name_(std::move(name)), + auto_print_(auto_print), + unit_(unit), + start_(Clock::now()) { + } - /** - * @brief 析构函数 - 停止计时并可选打印结果 - */ - ~ScopedTimer() { - if (!stopped_) { - stop(); - if (auto_print_) { - print(); - } - } - } + /** + * @brief 析构函数 - 停止计时并可选打印结果 + */ + ~ScopedTimer() { + if (!stopped_) { + stop(); + if (auto_print_) { + print(); + } + } + } - // 禁止拷贝和移动 - ScopedTimer(const ScopedTimer&) = delete; - ScopedTimer& operator=(const ScopedTimer&) = delete; - ScopedTimer(ScopedTimer&&) = delete; - ScopedTimer& operator=(ScopedTimer&&) = delete; + // 禁止拷贝和移动 + ScopedTimer(const ScopedTimer&) = delete; + ScopedTimer& operator=(const ScopedTimer&) = delete; + ScopedTimer(ScopedTimer&&) = delete; + ScopedTimer& operator=(ScopedTimer&&) = delete; - /** - * @brief 手动停止计时 - */ - void stop() { - if (!stopped_) { - end_ = Clock::now(); - stopped_ = true; - elapsed_ns_ = std::chrono::duration_cast(end_ - start_).count(); - } - } + /** + * @brief 手动停止计时 + */ + void stop() { + if (!stopped_) { + end_ = Clock::now(); + stopped_ = true; + elapsed_ns_ = std::chrono::duration_cast(end_ - start_).count(); + } + } - /** - * @brief 获取已经过的时间(纳秒) - */ - double elapsed_ns() const { - if (stopped_) { - return static_cast(elapsed_ns_); - } else { - auto now = Clock::now(); - return static_cast( - std::chrono::duration_cast(now - start_).count() - ); - } - } + /** + * @brief 获取已经过的时间(纳秒) + */ + double elapsed_ns() const { + if (stopped_) { + return static_cast(elapsed_ns_); + } + else { + auto now = Clock::now(); + return static_cast( + std::chrono::duration_cast(now - start_).count() + ); + } + } - /** - * @brief 获取已经过的时间(指定单位) - */ - double elapsed(TimeUnit unit = TimeUnit::MILLISECONDS) const { - return Statistics::convert_time(elapsed_ns(), unit); - } + /** + * @brief 获取已经过的时间(指定单位) + */ + double elapsed(TimeUnit unit = TimeUnit::MILLISECONDS) const { + return Statistics::convert_time(elapsed_ns(), unit); + } - /** - * @brief 打印耗时 - */ - void print() const { - const char* unit_str = Statistics::time_unit_string(unit_); - log_info("{}: {:.3f} {}", name_, elapsed(unit_), unit_str); - } + /** + * @brief 打印耗时 + */ + void print() const { + const char* unit_str = Statistics::time_unit_string(unit_); + log_info("{}: {:.3f} {}", name_, elapsed(unit_), unit_str); + } - /** - * @brief 重置计时器 - */ - void reset() { - start_ = Clock::now(); - stopped_ = false; - elapsed_ns_ = 0; - } + /** + * @brief 重置计时器 + */ + void reset() { + start_ = Clock::now(); + stopped_ = false; + elapsed_ns_ = 0; + } -private: - std::string name_; ///< 计时器名称 - bool auto_print_; ///< 是否自动打印 - TimeUnit unit_; ///< 时间单位 - TimePoint start_; ///< 开始时间 - TimePoint end_; ///< 结束时间 - bool stopped_{false}; ///< 是否已停止 - int64_t elapsed_ns_{0}; ///< 已过时间(纳秒) -}; + private: + std::string name_; ///< 计时器名称 + bool auto_print_; ///< 是否自动打印 + TimeUnit unit_; ///< 时间单位 + TimePoint start_; ///< 开始时间 + TimePoint end_; ///< 结束时间 + bool stopped_{false}; ///< 是否已停止 + int64_t elapsed_ns_{0}; ///< 已过时间(纳秒) + }; -/** - * @brief 吞吐量测量器 - * - * 用于测量操作的吞吐量(ops/sec) - * - * @example - * ThroughputMeter meter("写入操作"); - * meter.start(); - * for (int i = 0; i < 1000; ++i) { - * // 执行操作... - * meter.record_operation(); - * } - * meter.stop(); - * meter.print_report(); - */ -class ThroughputMeter { -public: - using Clock = std::chrono::high_resolution_clock; - using TimePoint = Clock::time_point; + /** + * @brief 吞吐量测量器 + * + * 用于测量操作的吞吐量(ops/sec) + * + * @example + * ThroughputMeter meter("写入操作"); + * meter.start(); + * for (int i = 0; i < 1000; ++i) { + * // 执行操作... + * meter.record_operation(); + * } + * meter.stop(); + * meter.print_report(); + */ + class ThroughputMeter { + public: + using Clock = std::chrono::high_resolution_clock; + using TimePoint = Clock::time_point; - /** - * @brief 构造函 - * @param name 测量器名 - */ - explicit ThroughputMeter(std::string name = "ThroughputMeter") - : name_(std::move(name)) { - } + /** + * @brief 构造函 + * @param name 测量器名 + */ + explicit ThroughputMeter(std::string name = "ThroughputMeter") : name_(std::move(name)) { + } - /** - * @brief 开始测量 - */ - void start() { - start_time_ = Clock::now(); - operation_count_ = 0; - running_ = true; - log_module_debug("ThroughputMeter", "开始测量: {}", name_); - } + /** + * @brief 开始测量 + */ + void start() { + start_time_ = Clock::now(); + operation_count_ = 0; + running_ = true; + log_module_debug("ThroughputMeter", "开始测量: {}", name_); + } - /** - * @brief 停止测量 - */ - void stop() { - if (!running_) { - return; - } + /** + * @brief 停止测量 + */ + void stop() { + if (!running_) { + return; + } - end_time_ = Clock::now(); - running_ = false; - - auto duration = std::chrono::duration_cast( - end_time_ - start_time_); - elapsed_seconds_ = duration.count() / 1000000000.0; - - if (elapsed_seconds_ > 0.0) { - throughput_ = operation_count_ / elapsed_seconds_; - } + end_time_ = Clock::now(); + running_ = false; - log_module_debug("ThroughputMeter", "停止测量: {}", name_); - } + auto duration = std::chrono::duration_cast( + end_time_ - start_time_); + elapsed_seconds_ = duration.count() / 1000000000.0; - /** - * @brief 记录一次操作 - * @param count 操作数量(默认1) - */ - void record_operation(size_t count = 1) { - operation_count_ += count; - } + if (elapsed_seconds_ > 0.0) { + throughput_ = operation_count_ / elapsed_seconds_; + } - /** - * @brief 获取当前吞吐量(ops/sec) - * @note 如果仍在运行,返回当前时刻的吞吐量 - */ - double get_throughput() const { - if (running_) { - auto now = Clock::now(); - auto duration = std::chrono::duration_cast( - now - start_time_); - double elapsed = duration.count() / 1000000000.0; - return (elapsed > 0.0) ? (operation_count_ / elapsed) : 0.0; - } - return throughput_; - } + log_module_debug("ThroughputMeter", "停止测量: {}", name_); + } - /** - * @brief 获取操作总数 - */ - size_t get_operation_count() const { - return operation_count_; - } + /** + * @brief 记录一次操作 + * @param count 操作数量(默认1) + */ + void record_operation(size_t count = 1) { + operation_count_ += count; + } - /** - * @brief 获取耗时(秒) - */ - double get_elapsed_seconds() const { - if (running_) { - auto now = Clock::now(); - auto duration = std::chrono::duration_cast( - now - start_time_); - return duration.count() / 1000000000.0; - } - return elapsed_seconds_; - } + /** + * @brief 获取当前吞吐量(ops/sec) + * @note 如果仍在运行,返回当前时刻的吞吐量 + */ + double get_throughput() const { + if (running_) { + auto now = Clock::now(); + auto duration = std::chrono::duration_cast( + now - start_time_); + double elapsed = duration.count() / 1000000000.0; + return (elapsed > 0.0) ? (operation_count_ / elapsed) : 0.0; + } + return throughput_; + } - /** - * @brief 打印吞吐量报告 - */ - void print_report() const { - log_info("=== {} 吞吐量报告 ===", name_); - log_info("总操作数: {}", operation_count_); - log_info("耗时: {:.3f} 秒", get_elapsed_seconds()); - log_info("吞吐量: {:.2f} ops/sec", get_throughput()); - - // 计算每个操作的平均延迟 - if (operation_count_ > 0) { - double avg_latency_ms = (get_elapsed_seconds() * 1000.0) / operation_count_; - log_info("平均延迟: {:.3f} ms/op", avg_latency_ms); - } - } + /** + * @brief 获取操作总数 + */ + size_t get_operation_count() const { + return operation_count_; + } - /** - * @brief 重置测量器 - */ - void reset() { - operation_count_ = 0; - throughput_ = 0.0; - elapsed_seconds_ = 0.0; - running_ = false; - } + /** + * @brief 获取耗时(秒) + */ + double get_elapsed_seconds() const { + if (running_) { + auto now = Clock::now(); + auto duration = std::chrono::duration_cast( + now - start_time_); + return duration.count() / 1000000000.0; + } + return elapsed_seconds_; + } -private: - std::string name_; ///< 测量器名称 - TimePoint start_time_; ///< 开始时间 - TimePoint end_time_; ///< 结束时间 - size_t operation_count_{0}; ///< 操作计数 - double throughput_{0.0}; ///< 吞吐量(ops/sec) - double elapsed_seconds_{0.0}; ///< 耗时(秒) - bool running_{false}; ///< 是否正在运行 -}; + /** + * @brief 打印吞吐量报告 + */ + void print_report() const { + log_info("=== {} 吞吐量报告 ===", name_); + log_info("总操作数: {}", operation_count_); + log_info("耗时: {:.3f} 秒", get_elapsed_seconds()); + log_info("吞吐量: {:.2f} ops/sec", get_throughput()); -/** - * @brief 延迟记录器 - * - * 用于批量记录操作延迟并生成统计报告 - * - * @example - * LatencyRecorder recorder("读取操作"); - * for (int i = 0; i < 1000; ++i) { - * ScopedTimer timer("", false); - * // 执行操作... - * timer.stop(); - * recorder.record(timer.elapsed_ns()); - * } - * auto stats = recorder.get_statistics(); - * stats.print_report(); - */ -class LatencyRecorder { -public: - /** - * @brief 构造函数 - * @param name 记录器名称 - */ - explicit LatencyRecorder(std::string name = "LatencyRecorder") - : name_(std::move(name)) { - } + // 计算每个操作的平均延迟 + if (operation_count_ > 0) { + double avg_latency_ms = (get_elapsed_seconds() * 1000.0) / operation_count_; + log_info("平均延迟: {:.3f} ms/op", avg_latency_ms); + } + } - /** - * @brief 记录一次延迟 - * @param latency_ns 延迟(纳秒) - */ - void record(double latency_ns) { - latencies_.push_back(latency_ns); - } + /** + * @brief 重置测量器 + */ + void reset() { + operation_count_ = 0; + throughput_ = 0.0; + elapsed_seconds_ = 0.0; + running_ = false; + } - /** - * @brief 获取统计数据 - */ - Statistics get_statistics() { - Statistics stats(latencies_); - return stats; - } + private: + std::string name_; ///< 测量器名称 + TimePoint start_time_; ///< 开始时间 + TimePoint end_time_; ///< 结束时间 + size_t operation_count_{0}; ///< 操作计数 + double throughput_{0.0}; ///< 吞吐量(ops/sec) + double elapsed_seconds_{0.0}; ///< 耗时(秒) + bool running_{false}; ///< 是否正在运行 + }; - /** - * @brief 打印统计报告 - * @param unit 时间单位 - */ - void print_report(TimeUnit unit = TimeUnit::MICROSECONDS) { - auto stats = get_statistics(); - stats.print_report(unit, name_); - } + /** + * @brief 延迟记录器 + * + * 用于批量记录操作延迟并生成统计报告 + * + * @example + * LatencyRecorder recorder("读取操作"); + * for (int i = 0; i < 1000; ++i) { + * ScopedTimer timer("", false); + * // 执行操作... + * timer.stop(); + * recorder.record(timer.elapsed_ns()); + * } + * auto stats = recorder.get_statistics(); + * stats.print_report(); + */ + class LatencyRecorder { + public: + /** + * @brief 构造函数 + * @param name 记录器名称 + */ + explicit LatencyRecorder(std::string name = "LatencyRecorder") : name_(std::move(name)) { + } - /** - * @brief 重置记录器 - */ - void reset() { - latencies_.clear(); - } + /** + * @brief 记录一次延迟 + * @param latency_ns 延迟(纳秒) + */ + void record(double latency_ns) { + latencies_.push_back(latency_ns); + } - /** - * @brief 获取记录的延迟样本数量 - */ - size_t sample_count() const { - return latencies_.size(); - } + /** + * @brief 获取统计数据 + */ + Statistics get_statistics() { + Statistics stats(latencies_); + return stats; + } -private: - std::string name_; ///< 记录器名称 - std::vector latencies_; ///< 延迟样本(纳秒) -}; + /** + * @brief 打印统计报告 + * @param unit 时间单位 + */ + void print_report(TimeUnit unit = TimeUnit::MICROSECONDS) { + auto stats = get_statistics(); + stats.print_report(unit, name_); + } -} // namespace shm_test \ No newline at end of file + /** + * @brief 重置记录器 + */ + void reset() { + latencies_.clear(); + } + + /** + * @brief 获取记录的延迟样本数量 + */ + size_t sample_count() const { + return latencies_.size(); + } + + private: + std::string name_; ///< 记录器名称 + std::vector latencies_; ///< 延迟样本(纳秒) + }; +} // namespace shm_test diff --git a/tests/shm/helpers/test_environment.h b/tests/shm/helpers/test_environment.h index 9c9fb7c..a24f3e7 100644 --- a/tests/shm/helpers/test_environment.h +++ b/tests/shm/helpers/test_environment.h @@ -13,250 +13,250 @@ #include "logger.h" namespace shm_test { + /** + * @brief 共享内存测试环境管理类 + * + * 提供以下功能: + * - 自动生成唯一的共享内存段名称 + * - 自动初始化和清理 shared_memory_manager + * - 跟踪和清理所有创建的共享内存段 + * - 提供临时目录路径管理 + * - RAII 模式确保异常安全 + * + * @example + * TEST(ShmTest, BasicUsage) { + * ShmTestEnvironment env("BasicUsage"); + * auto segment_name = env.generate_segment_name("buffer"); + * // 使用 segment_name 创建共享内存... + * // env 析构时自动清理所有资源 + * } + */ + class ShmTestEnvironment { + public: + /** + * @brief 构造函数 + * @param test_name 测试名称,用于生成唯一的共享内存段名称 + * @param segment_size 默认共享内存段大小(字节) + * @param auto_init 是否自动初始化 shared_memory_manager + */ + explicit ShmTestEnvironment( + const std::string& test_name, + size_t segment_size = 1024 * 1024, // 默认 1MB + bool auto_init = true + ) : test_name_(test_name), + default_segment_size_(segment_size), + initialized_(false) { + // 生成测试会话ID(时间戳 + 随机数) + auto now = std::chrono::system_clock::now(); + auto timestamp = std::chrono::duration_cast( + now.time_since_epoch()).count(); -/** - * @brief 共享内存测试环境管理类 - * - * 提供以下功能: - * - 自动生成唯一的共享内存段名称 - * - 自动初始化和清理 shared_memory_manager - * - 跟踪和清理所有创建的共享内存段 - * - 提供临时目录路径管理 - * - RAII 模式确保异常安全 - * - * @example - * TEST(ShmTest, BasicUsage) { - * ShmTestEnvironment env("BasicUsage"); - * auto segment_name = env.generate_segment_name("buffer"); - * // 使用 segment_name 创建共享内存... - * // env 析构时自动清理所有资源 - * } - */ -class ShmTestEnvironment { -public: - /** - * @brief 构造函数 - * @param test_name 测试名称,用于生成唯一的共享内存段名称 - * @param segment_size 默认共享内存段大小(字节) - * @param auto_init 是否自动初始化 shared_memory_manager - */ - explicit ShmTestEnvironment( - const std::string& test_name, - size_t segment_size = 1024 * 1024, // 默认 1MB - bool auto_init = true - ) : test_name_(test_name), - default_segment_size_(segment_size), - initialized_(false) { - - // 生成测试会话ID(时间戳 + 随机数) - auto now = std::chrono::system_clock::now(); - auto timestamp = std::chrono::duration_cast( - now.time_since_epoch()).count(); - - std::random_device rd; - std::mt19937 gen(rd()); - std::uniform_int_distribution<> dis(1000, 9999); - - session_id_ = std::to_string(timestamp) + "_" + std::to_string(dis(gen)); - - log_module_info("ShmTestEnv", "创建测试环境: {} (会话ID: {})", - test_name_, session_id_); - - if (auto_init) { - init(); - } - } + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis(1000, 9999); - /** - * @brief 析构函数 - 清理所有资源 - */ - ~ShmTestEnvironment() { - cleanup(); - } + session_id_ = std::to_string(timestamp) + "_" + std::to_string(dis(gen)); - // 禁止拷贝和移动 - ShmTestEnvironment(const ShmTestEnvironment&) = delete; - ShmTestEnvironment& operator=(const ShmTestEnvironment&) = delete; - ShmTestEnvironment(ShmTestEnvironment&&) = delete; - ShmTestEnvironment& operator=(ShmTestEnvironment&&) = delete; + log_module_info("ShmTestEnv", "创建测试环境: {} (会话ID: {})", + test_name_, session_id_); - /** - * @brief 初始化 shared_memory_manager - * @param segment_size 共享内存段大小,默认使用构造函数中的值 - * @return true 如果初始化成功 - */ - bool init(size_t segment_size = 0) { - if (initialized_) { - log_module_warn("ShmTestEnv", "测试环境已经初始化"); - return true; - } + if (auto_init) { + init(); + } + } - if (segment_size == 0) { - segment_size = default_segment_size_; - } + /** + * @brief 析构函数 - 清理所有资源 + */ + ~ShmTestEnvironment() { + cleanup(); + } - // 生成主共享内存段名称 - main_segment_name_ = generate_segment_name("main"); + // 禁止拷贝和移动 + ShmTestEnvironment(const ShmTestEnvironment&) = delete; + ShmTestEnvironment& operator=(const ShmTestEnvironment&) = delete; + ShmTestEnvironment(ShmTestEnvironment&&) = delete; + ShmTestEnvironment& operator=(ShmTestEnvironment&&) = delete; - shared_memory_config config{}; - config.segment_name = main_segment_name_; - config.segment_size = segment_size; - config.create_if_not_exists = true; - config.remove_on_destroy = true; - config.mutex_name = generate_segment_name("mutex"); - config.condition_name = generate_segment_name("cond"); - config.semaphore_name = generate_segment_name("sem"); + /** + * @brief 初始化 shared_memory_manager + * @param segment_size 共享内存段大小,默认使用构造函数中的值 + * @return true 如果初始化成功 + */ + bool init(size_t segment_size = 0) { + if (initialized_) { + log_module_warn("ShmTestEnv", "测试环境已经初始化"); + return true; + } - auto& manager = shared_memory_manager::instance(); - auto result = manager.init(config); + if (segment_size == 0) { + segment_size = default_segment_size_; + } - if (result == shared_memory_error::SUCCESS) { - initialized_ = true; - log_module_info("ShmTestEnv", "成功初始化共享内存管理器: {}", - main_segment_name_); - return true; - } else { - log_module_error("ShmTestEnv", "初始化共享内存管理器失败"); - return false; - } - } + // 生成主共享内存段名称 + main_segment_name_ = generate_segment_name("main"); - /** - * @brief 生成唯一的共享内存段名称 - * @param prefix 名称前缀 - * @return 格式为 "test____" 的唯一名称 - */ - std::string generate_segment_name(const std::string& prefix = "") { - std::string name = "test_" + test_name_; - if (!prefix.empty()) { - name += "_" + prefix; - } - name += "_" + session_id_ + "_" + std::to_string(segment_counter_++); - - // 清理Windows不支持的字符(路径分隔符等) - // Google Test 参数化测试会在测试名称中使用斜杠,如 "TestName/0" - // 这在 Windows 共享内存对象名称中是非法的 - std::replace(name.begin(), name.end(), '/', '_'); - std::replace(name.begin(), name.end(), '\\', '_'); - - // 记录生成的段名称以便清理 - tracked_segments_.push_back(name); - - log_module_debug("ShmTestEnv", "生成共享内存段名称: {}", name); - return name; - } + shared_memory_config config{}; + config.segment_name = main_segment_name_; + config.segment_size = segment_size; + config.create_if_not_exists = true; + config.remove_on_destroy = true; + config.mutex_name = generate_segment_name("mutex"); + config.condition_name = generate_segment_name("cond"); + config.semaphore_name = generate_segment_name("sem"); - /** - * @brief 获取临时目录路径 - * @return 临时目录的文件系统路径 - */ - std::filesystem::path get_temp_dir() const { - auto temp_dir = std::filesystem::temp_directory_path() / - ("shm_test_" + test_name_ + "_" + session_id_); - - // 确保目录存在 - std::filesystem::create_directories(temp_dir); - - return temp_dir; - } + auto& manager = shared_memory_manager::instance(); + auto result = manager.init(config); - /** - * @brief 检查 shared_memory_manager 是否已初始化 - */ - bool is_initialized() const { - return initialized_; - } + if (result == shared_memory_error::SUCCESS) { + initialized_ = true; + log_module_info("ShmTestEnv", "成功初始化共享内存管理器: {}", + main_segment_name_); + return true; + } + else { + log_module_error("ShmTestEnv", "初始化共享内存管理器失败"); + return false; + } + } - /** - * @brief 获取主共享内存段名称 - */ - const std::string& get_main_segment_name() const { - return main_segment_name_; - } + /** + * @brief 生成唯一的共享内存段名称 + * @param prefix 名称前缀 + * @return 格式为 "test____" 的唯一名称 + */ + std::string generate_segment_name(const std::string& prefix = "") { + std::string name = "test_" + test_name_; + if (!prefix.empty()) { + name += "_" + prefix; + } + name += "_" + session_id_ + "_" + std::to_string(segment_counter_++); - /** - * @brief 获取测试名称 - */ - const std::string& get_test_name() const { - return test_name_; - } + // 清理Windows不支持的字符(路径分隔符等) + // Google Test 参数化测试会在测试名称中使用斜杠,如 "TestName/0" + // 这在 Windows 共享内存对象名称中是非法的 + std::replace(name.begin(), name.end(), '/', '_'); + std::replace(name.begin(), name.end(), '\\', '_'); - /** - * @brief 获取会话ID - */ - const std::string& get_session_id() const { - return session_id_; - } + // 记录生成的段名称以便清理 + tracked_segments_.push_back(name); - /** - * @brief 获取所有跟踪的共享内存段名称 - */ - const std::vector& get_tracked_segments() const { - return tracked_segments_; - } + log_module_debug("ShmTestEnv", "生成共享内存段名称: {}", name); + return name; + } - /** - * @brief 手动添加需要跟踪的共享内存段 - * @param segment_name 共享内存段名称 - */ - void track_segment(const std::string& segment_name) { - tracked_segments_.push_back(segment_name); - log_module_debug("ShmTestEnv", "添加跟踪段: {}", segment_name); - } + /** + * @brief 获取临时目录路径 + * @return 临时目录的文件系统路径 + */ + std::filesystem::path get_temp_dir() const { + auto temp_dir = std::filesystem::temp_directory_path() / + ("shm_test_" + test_name_ + "_" + session_id_); - /** - * @brief 清理所有资源 - */ - void cleanup() { - if (!initialized_) { - return; - } + // 确保目录存在 + std::filesystem::create_directories(temp_dir); - log_module_info("ShmTestEnv", "开始清理测试环境: {}", test_name_); + return temp_dir; + } - // 关闭 shared_memory_manager - auto& manager = shared_memory_manager::instance(); - manager.shutdown(); + /** + * @brief 检查 shared_memory_manager 是否已初始化 + */ + bool is_initialized() const { + return initialized_; + } - // 清理所有跟踪的共享内存段 - for (const auto& segment_name : tracked_segments_) { - try { - using namespace boost::interprocess; - shared_memory_object::remove(segment_name.c_str()); - log_module_debug("ShmTestEnv", "已删除共享内存段: {}", segment_name); - } catch (const std::exception& e) { - log_module_warn("ShmTestEnv", "删除共享内存段失败 {}: {}", - segment_name, e.what()); - } - } + /** + * @brief 获取主共享内存段名称 + */ + const std::string& get_main_segment_name() const { + return main_segment_name_; + } - // 清理临时目录 - try { - auto temp_dir = get_temp_dir(); - if (std::filesystem::exists(temp_dir)) { - std::filesystem::remove_all(temp_dir); - log_module_debug("ShmTestEnv", "已删除临时目录: {}", - temp_dir.string()); - } - } catch (const std::exception& e) { - log_module_warn("ShmTestEnv", "删除临时目录失败: {}", e.what()); - } + /** + * @brief 获取测试名称 + */ + const std::string& get_test_name() const { + return test_name_; + } - initialized_ = false; - tracked_segments_.clear(); - - log_module_info("ShmTestEnv", "测试环境清理完成: {}", test_name_); - } + /** + * @brief 获取会话ID + */ + const std::string& get_session_id() const { + return session_id_; + } -private: - std::string test_name_; ///< 测试名称 - std::string session_id_; ///< 测试会话ID(时间戳+随机数) - std::string main_segment_name_; ///< 主共享内存段名称 - size_t default_segment_size_; ///< 默认共享内存段大小 - size_t segment_counter_{0}; ///< 共享内存段计数器 - bool initialized_; ///< 是否已初始化 - - std::vector tracked_segments_; ///< 跟踪的所有共享内存段 -}; + /** + * @brief 获取所有跟踪的共享内存段名称 + */ + const std::vector& get_tracked_segments() const { + return tracked_segments_; + } -} // namespace shm_test \ No newline at end of file + /** + * @brief 手动添加需要跟踪的共享内存段 + * @param segment_name 共享内存段名称 + */ + void track_segment(const std::string& segment_name) { + tracked_segments_.push_back(segment_name); + log_module_debug("ShmTestEnv", "添加跟踪段: {}", segment_name); + } + + /** + * @brief 清理所有资源 + */ + void cleanup() { + if (!initialized_) { + return; + } + + log_module_info("ShmTestEnv", "开始清理测试环境: {}", test_name_); + + // 关闭 shared_memory_manager + auto& manager = shared_memory_manager::instance(); + manager.shutdown(); + + // 清理所有跟踪的共享内存段 + for (const auto& segment_name : tracked_segments_) { + try { + using namespace boost::interprocess; + shared_memory_object::remove(segment_name.c_str()); + log_module_debug("ShmTestEnv", "已删除共享内存段: {}", segment_name); + } + catch (const std::exception& e) { + log_module_warn("ShmTestEnv", "删除共享内存段失败 {}: {}", + segment_name, e.what()); + } + } + + // 清理临时目录 + try { + auto temp_dir = get_temp_dir(); + if (std::filesystem::exists(temp_dir)) { + std::filesystem::remove_all(temp_dir); + log_module_debug("ShmTestEnv", "已删除临时目录: {}", + temp_dir.string()); + } + } + catch (const std::exception& e) { + log_module_warn("ShmTestEnv", "删除临时目录失败: {}", e.what()); + } + + initialized_ = false; + tracked_segments_.clear(); + + log_module_info("ShmTestEnv", "测试环境清理完成: {}", test_name_); + } + + private: + std::string test_name_; ///< 测试名称 + std::string session_id_; ///< 测试会话ID(时间戳+随机数) + std::string main_segment_name_; ///< 主共享内存段名称 + size_t default_segment_size_; ///< 默认共享内存段大小 + size_t segment_counter_{0}; ///< 共享内存段计数器 + bool initialized_; ///< 是否已初始化 + + std::vector tracked_segments_; ///< 跟踪的所有共享内存段 + }; +} // namespace shm_test diff --git a/tests/shm/test_interprocess_sync.cpp b/tests/shm/test_interprocess_sync.cpp index 49af068..a59195d 100644 --- a/tests/shm/test_interprocess_sync.cpp +++ b/tests/shm/test_interprocess_sync.cpp @@ -16,138 +16,138 @@ using namespace shm_test; */ class InterprocessMutexTest : public ::testing::Test { protected: - void SetUp() override { - // 创建测试环境 - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - // 创建配置 - shared_memory_config config{}; - config.segment_name = env_->generate_segment_name("main"); - config.segment_size = 1024 * 1024; - config.create_if_not_exists = true; - config.remove_on_destroy = true; - config.mutex_name = env_->generate_segment_name("mutex"); - config.condition_name = env_->generate_segment_name("cond"); - config.semaphore_name = env_->generate_segment_name("sem"); - - // 创建 interprocess_synchronization 实例 - sync_ = std::make_unique(config); - - // 生成唯一的互斥量名称 - mutex_name_ = env_->generate_segment_name("test_mutex"); - } - - void TearDown() override { - sync_.reset(); - env_.reset(); - } - - std::unique_ptr env_; - std::unique_ptr sync_; - std::string mutex_name_; + void SetUp() override { + // 创建测试环境 + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + // 创建配置 + shared_memory_config config{}; + config.segment_name = env_->generate_segment_name("main"); + config.segment_size = 1024 * 1024; + config.create_if_not_exists = true; + config.remove_on_destroy = true; + config.mutex_name = env_->generate_segment_name("mutex"); + config.condition_name = env_->generate_segment_name("cond"); + config.semaphore_name = env_->generate_segment_name("sem"); + + // 创建 interprocess_synchronization 实例 + sync_ = std::make_unique(config); + + // 生成唯一的互斥量名称 + mutex_name_ = env_->generate_segment_name("test_mutex"); + } + + void TearDown() override { + sync_.reset(); + env_.reset(); + } + + std::unique_ptr env_; + std::unique_ptr sync_; + std::string mutex_name_; }; /** * @brief 测试创建互斥量成功 */ TEST_F(InterprocessMutexTest, CreateMutex) { - auto result = sync_->create_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); + auto result = sync_->create_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试打开已存在的互斥量 */ TEST_F(InterprocessMutexTest, OpenMutex) { - // 先创建互斥量 - auto result = sync_->create_mutex(mutex_name_); - ASSERT_SHM_SUCCESS(result); - - // 打开已存在的互斥量 - result = sync_->open_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); + // 先创建互斥量 + auto result = sync_->create_mutex(mutex_name_); + ASSERT_SHM_SUCCESS(result); + + // 打开已存在的互斥量 + result = sync_->open_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试锁定和解锁互斥量 */ TEST_F(InterprocessMutexTest, LockUnlock) { - // 创建互斥量 - auto result = sync_->create_mutex(mutex_name_); - ASSERT_SHM_SUCCESS(result); - - // 锁定互斥量 - result = sync_->lock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); - - // 解锁互斥量 - result = sync_->unlock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); - - // 再次锁定和解锁 - result = sync_->lock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); - - result = sync_->unlock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); + // 创建互斥量 + auto result = sync_->create_mutex(mutex_name_); + ASSERT_SHM_SUCCESS(result); + + // 锁定互斥量 + result = sync_->lock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); + + // 解锁互斥量 + result = sync_->unlock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); + + // 再次锁定和解锁 + result = sync_->lock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); + + result = sync_->unlock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试带超时的锁定 */ TEST_F(InterprocessMutexTest, LockWithTimeout) { - // 创建互斥量 - auto result = sync_->create_mutex(mutex_name_); - ASSERT_SHM_SUCCESS(result); - - // 带超时的锁定(应该立即成功,因为互斥量未被占用) - result = sync_->lock_mutex(mutex_name_, std::chrono::milliseconds(100)); - EXPECT_SHM_SUCCESS(result); - - // 解锁 - result = sync_->unlock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); + // 创建互斥量 + auto result = sync_->create_mutex(mutex_name_); + ASSERT_SHM_SUCCESS(result); + + // 带超时的锁定(应该立即成功,因为互斥量未被占用) + result = sync_->lock_mutex(mutex_name_, std::chrono::milliseconds(100)); + EXPECT_SHM_SUCCESS(result); + + // 解锁 + result = sync_->unlock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试移除互斥量 */ TEST_F(InterprocessMutexTest, RemoveMutex) { - // 创建互斥量 - auto result = sync_->create_mutex(mutex_name_); - ASSERT_SHM_SUCCESS(result); - - // 移除互斥量 - result = sync_->remove_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); + // 创建互斥量 + auto result = sync_->create_mutex(mutex_name_); + ASSERT_SHM_SUCCESS(result); + + // 移除互斥量 + result = sync_->remove_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试 scoped_mutex_lock RAII 行为 */ TEST_F(InterprocessMutexTest, ScopedLock) { - // 创建互斥量 - auto result = sync_->create_mutex(mutex_name_); - ASSERT_SHM_SUCCESS(result); - - // 使用 scoped_mutex_lock - { - interprocess_synchronization::scoped_mutex_lock lock(*sync_, mutex_name_); - EXPECT_TRUE(lock.is_locked()); - - // 在这个作用域内,互斥量被锁定 - // lock 析构时会自动解锁 - } - - // 离开作用域后,互斥量应该已解锁 - // 我们应该能再次锁定 - result = sync_->lock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); - - result = sync_->unlock_mutex(mutex_name_); - EXPECT_SHM_SUCCESS(result); + // 创建互斥量 + auto result = sync_->create_mutex(mutex_name_); + ASSERT_SHM_SUCCESS(result); + + // 使用 scoped_mutex_lock + { + interprocess_synchronization::scoped_mutex_lock lock(*sync_, mutex_name_); + EXPECT_TRUE(lock.is_locked()); + + // 在这个作用域内,互斥量被锁定 + // lock 析构时会自动解锁 + } + + // 离开作用域后,互斥量应该已解锁 + // 我们应该能再次锁定 + result = sync_->lock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); + + result = sync_->unlock_mutex(mutex_name_); + EXPECT_SHM_SUCCESS(result); } // ============================================================================ @@ -159,118 +159,118 @@ TEST_F(InterprocessMutexTest, ScopedLock) { */ class InterprocessConditionTest : public ::testing::Test { protected: - void SetUp() override { - // 创建测试环境 - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - // 创建配置 - shared_memory_config config{}; - config.segment_name = env_->generate_segment_name("main"); - config.segment_size = 1024 * 1024; - config.create_if_not_exists = true; - config.remove_on_destroy = true; - config.mutex_name = env_->generate_segment_name("mutex"); - config.condition_name = env_->generate_segment_name("cond"); - config.semaphore_name = env_->generate_segment_name("sem"); - - // 创建 interprocess_synchronization 实例 - sync_ = std::make_unique(config); - - // 生成唯一的条件变量和互斥量名称 - condition_name_ = env_->generate_segment_name("test_cond"); - mutex_name_ = env_->generate_segment_name("test_mutex"); - } - - void TearDown() override { - sync_.reset(); - env_.reset(); - } - - std::unique_ptr env_; - std::unique_ptr sync_; - std::string condition_name_; - std::string mutex_name_; + void SetUp() override { + // 创建测试环境 + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + // 创建配置 + shared_memory_config config{}; + config.segment_name = env_->generate_segment_name("main"); + config.segment_size = 1024 * 1024; + config.create_if_not_exists = true; + config.remove_on_destroy = true; + config.mutex_name = env_->generate_segment_name("mutex"); + config.condition_name = env_->generate_segment_name("cond"); + config.semaphore_name = env_->generate_segment_name("sem"); + + // 创建 interprocess_synchronization 实例 + sync_ = std::make_unique(config); + + // 生成唯一的条件变量和互斥量名称 + condition_name_ = env_->generate_segment_name("test_cond"); + mutex_name_ = env_->generate_segment_name("test_mutex"); + } + + void TearDown() override { + sync_.reset(); + env_.reset(); + } + + std::unique_ptr env_; + std::unique_ptr sync_; + std::string condition_name_; + std::string mutex_name_; }; /** * @brief 测试创建条件变量成功 */ TEST_F(InterprocessConditionTest, CreateCondition) { - auto result = sync_->create_condition(condition_name_); - EXPECT_SHM_SUCCESS(result); + auto result = sync_->create_condition(condition_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试打开已存在的条件变量 */ TEST_F(InterprocessConditionTest, OpenCondition) { - // 先创建条件变量 - auto result = sync_->create_condition(condition_name_); - ASSERT_SHM_SUCCESS(result); - - // 打开已存在的条件变量 - result = sync_->open_condition(condition_name_); - EXPECT_SHM_SUCCESS(result); + // 先创建条件变量 + auto result = sync_->create_condition(condition_name_); + ASSERT_SHM_SUCCESS(result); + + // 打开已存在的条件变量 + result = sync_->open_condition(condition_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试等待和通知(使用超时避免阻塞) */ TEST_F(InterprocessConditionTest, WaitAndNotify) { - // 创建条件变量和互斥量 - auto result = sync_->create_condition(condition_name_); - ASSERT_SHM_SUCCESS(result); - - result = sync_->create_mutex(mutex_name_); - ASSERT_SHM_SUCCESS(result); - - // 等待条件变量(使用很短的超时,应该超时) - result = sync_->wait_condition(condition_name_, mutex_name_, - std::chrono::milliseconds(50)); - // 应该超时(返回错误或特定状态) - // 根据实现,可能返回 TIMEOUT 或其他错误码 - EXPECT_NE(result, shared_memory_error::SUCCESS); + // 创建条件变量和互斥量 + auto result = sync_->create_condition(condition_name_); + ASSERT_SHM_SUCCESS(result); + + result = sync_->create_mutex(mutex_name_); + ASSERT_SHM_SUCCESS(result); + + // 等待条件变量(使用很短的超时,应该超时) + result = sync_->wait_condition(condition_name_, mutex_name_, + std::chrono::milliseconds(50)); + // 应该超时(返回错误或特定状态) + // 根据实现,可能返回 TIMEOUT 或其他错误码 + EXPECT_NE(result, shared_memory_error::SUCCESS); } /** * @brief 测试 notify_one 行为 */ TEST_F(InterprocessConditionTest, NotifyOne) { - // 创建条件变量 - auto result = sync_->create_condition(condition_name_); - ASSERT_SHM_SUCCESS(result); - - // 通知一个等待者(即使没有等待者也应该成功) - result = sync_->notify_condition(condition_name_, false); - EXPECT_SHM_SUCCESS(result); + // 创建条件变量 + auto result = sync_->create_condition(condition_name_); + ASSERT_SHM_SUCCESS(result); + + // 通知一个等待者(即使没有等待者也应该成功) + result = sync_->notify_condition(condition_name_, false); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试 notify_all 行为 */ TEST_F(InterprocessConditionTest, NotifyAll) { - // 创建条件变量 - auto result = sync_->create_condition(condition_name_); - ASSERT_SHM_SUCCESS(result); - - // 通知所有等待者(即使没有等待者也应该成功) - result = sync_->notify_condition(condition_name_, true); - EXPECT_SHM_SUCCESS(result); + // 创建条件变量 + auto result = sync_->create_condition(condition_name_); + ASSERT_SHM_SUCCESS(result); + + // 通知所有等待者(即使没有等待者也应该成功) + result = sync_->notify_condition(condition_name_, true); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试移除条件变量 */ TEST_F(InterprocessConditionTest, RemoveCondition) { - // 创建条件变量 - auto result = sync_->create_condition(condition_name_); - ASSERT_SHM_SUCCESS(result); - - // 移除条件变量 - result = sync_->remove_condition(condition_name_); - EXPECT_SHM_SUCCESS(result); + // 创建条件变量 + auto result = sync_->create_condition(condition_name_); + ASSERT_SHM_SUCCESS(result); + + // 移除条件变量 + result = sync_->remove_condition(condition_name_); + EXPECT_SHM_SUCCESS(result); } // ============================================================================ @@ -282,169 +282,169 @@ TEST_F(InterprocessConditionTest, RemoveCondition) { */ class InterprocessSemaphoreTest : public ::testing::Test { protected: - void SetUp() override { - // 创建测试环境 - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - // 创建配置 - shared_memory_config config{}; - config.segment_name = env_->generate_segment_name("main"); - config.segment_size = 1024 * 1024; - config.create_if_not_exists = true; - config.remove_on_destroy = true; - config.mutex_name = env_->generate_segment_name("mutex"); - config.condition_name = env_->generate_segment_name("cond"); - config.semaphore_name = env_->generate_segment_name("sem"); - - // 创建 interprocess_synchronization 实例 - sync_ = std::make_unique(config); - - // 生成唯一的信号量名称 - semaphore_name_ = env_->generate_segment_name("test_sem"); - } - - void TearDown() override { - sync_.reset(); - env_.reset(); - } - - std::unique_ptr env_; - std::unique_ptr sync_; - std::string semaphore_name_; + void SetUp() override { + // 创建测试环境 + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + // 创建配置 + shared_memory_config config{}; + config.segment_name = env_->generate_segment_name("main"); + config.segment_size = 1024 * 1024; + config.create_if_not_exists = true; + config.remove_on_destroy = true; + config.mutex_name = env_->generate_segment_name("mutex"); + config.condition_name = env_->generate_segment_name("cond"); + config.semaphore_name = env_->generate_segment_name("sem"); + + // 创建 interprocess_synchronization 实例 + sync_ = std::make_unique(config); + + // 生成唯一的信号量名称 + semaphore_name_ = env_->generate_segment_name("test_sem"); + } + + void TearDown() override { + sync_.reset(); + env_.reset(); + } + + std::unique_ptr env_; + std::unique_ptr sync_; + std::string semaphore_name_; }; /** * @brief 测试创建信号量,验证初始计数 */ TEST_F(InterprocessSemaphoreTest, CreateSemaphore) { - // 创建初始计数为 3 的信号量 - auto result = sync_->create_semaphore(semaphore_name_, 3); - EXPECT_SHM_SUCCESS(result); + // 创建初始计数为 3 的信号量 + auto result = sync_->create_semaphore(semaphore_name_, 3); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试打开已存在的信号量 */ TEST_F(InterprocessSemaphoreTest, OpenSemaphore) { - // 先创建信号量 - auto result = sync_->create_semaphore(semaphore_name_, 1); - ASSERT_SHM_SUCCESS(result); - - // 打开已存在的信号量 - result = sync_->open_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); + // 先创建信号量 + auto result = sync_->create_semaphore(semaphore_name_, 1); + ASSERT_SHM_SUCCESS(result); + + // 打开已存在的信号量 + result = sync_->open_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试等待信号量(减少计数) */ TEST_F(InterprocessSemaphoreTest, WaitSemaphore) { - // 创建初始计数为 2 的信号量 - auto result = sync_->create_semaphore(semaphore_name_, 2); - ASSERT_SHM_SUCCESS(result); - - // 第一次 wait(计数变为 1) - result = sync_->wait_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); - - // 第二次 wait(计数变为 0) - result = sync_->wait_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); - - // 第三次 wait 应该阻塞,使用超时测试 - result = sync_->wait_semaphore(semaphore_name_, - std::chrono::milliseconds(50)); - // 应该超时 - EXPECT_NE(result, shared_memory_error::SUCCESS); + // 创建初始计数为 2 的信号量 + auto result = sync_->create_semaphore(semaphore_name_, 2); + ASSERT_SHM_SUCCESS(result); + + // 第一次 wait(计数变为 1) + result = sync_->wait_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); + + // 第二次 wait(计数变为 0) + result = sync_->wait_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); + + // 第三次 wait 应该阻塞,使用超时测试 + result = sync_->wait_semaphore(semaphore_name_, + std::chrono::milliseconds(50)); + // 应该超时 + EXPECT_NE(result, shared_memory_error::SUCCESS); } /** * @brief 测试释放信号量(增加计数) */ TEST_F(InterprocessSemaphoreTest, PostSemaphore) { - // 创建初始计数为 0 的信号量 - auto result = sync_->create_semaphore(semaphore_name_, 0); - ASSERT_SHM_SUCCESS(result); - - // wait 应该阻塞(超时) - result = sync_->wait_semaphore(semaphore_name_, - std::chrono::milliseconds(50)); - EXPECT_NE(result, shared_memory_error::SUCCESS); - - // post 增加计数 - result = sync_->post_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); - - // 现在 wait 应该成功 - result = sync_->wait_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); + // 创建初始计数为 0 的信号量 + auto result = sync_->create_semaphore(semaphore_name_, 0); + ASSERT_SHM_SUCCESS(result); + + // wait 应该阻塞(超时) + result = sync_->wait_semaphore(semaphore_name_, + std::chrono::milliseconds(50)); + EXPECT_NE(result, shared_memory_error::SUCCESS); + + // post 增加计数 + result = sync_->post_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); + + // 现在 wait 应该成功 + result = sync_->wait_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); } /** * @brief 测试带超时的等待 */ TEST_F(InterprocessSemaphoreTest, WaitWithTimeout) { - // 创建初始计数为 0 的信号量 - auto result = sync_->create_semaphore(semaphore_name_, 0); - ASSERT_SHM_SUCCESS(result); - - // 带超时的 wait(应该超时) - auto start = std::chrono::steady_clock::now(); - result = sync_->wait_semaphore(semaphore_name_, - std::chrono::milliseconds(100)); - auto end = std::chrono::steady_clock::now(); - - // 应该超时 - EXPECT_NE(result, shared_memory_error::SUCCESS); - - // 验证确实等待了一段时间 - auto elapsed = std::chrono::duration_cast(end - start); - EXPECT_GE(elapsed.count(), 50); // 至少等待了 50ms + // 创建初始计数为 0 的信号量 + auto result = sync_->create_semaphore(semaphore_name_, 0); + ASSERT_SHM_SUCCESS(result); + + // 带超时的 wait(应该超时) + auto start = std::chrono::steady_clock::now(); + result = sync_->wait_semaphore(semaphore_name_, + std::chrono::milliseconds(100)); + auto end = std::chrono::steady_clock::now(); + + // 应该超时 + EXPECT_NE(result, shared_memory_error::SUCCESS); + + // 验证确实等待了一段时间 + auto elapsed = std::chrono::duration_cast(end - start); + EXPECT_GE(elapsed.count(), 50); // 至少等待了 50ms } /** * @brief 测试多次 wait/post 验证计数正确 */ TEST_F(InterprocessSemaphoreTest, MultipleWaitPost) { - // 创建初始计数为 5 的信号量 - auto result = sync_->create_semaphore(semaphore_name_, 5); - ASSERT_SHM_SUCCESS(result); - - // wait 3 次(计数变为 2) - for (int i = 0; i < 3; ++i) { - result = sync_->wait_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); - } - - // post 2 次(计数变为 4) - for (int i = 0; i < 2; ++i) { - result = sync_->post_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); - } - - // 应该能再 wait 4 次 - for (int i = 0; i < 4; ++i) { - result = sync_->wait_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); - } - - // 第 5 次 wait 应该超时 - result = sync_->wait_semaphore(semaphore_name_, - std::chrono::milliseconds(50)); - EXPECT_NE(result, shared_memory_error::SUCCESS); + // 创建初始计数为 5 的信号量 + auto result = sync_->create_semaphore(semaphore_name_, 5); + ASSERT_SHM_SUCCESS(result); + + // wait 3 次(计数变为 2) + for (int i = 0; i < 3; ++i) { + result = sync_->wait_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); + } + + // post 2 次(计数变为 4) + for (int i = 0; i < 2; ++i) { + result = sync_->post_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); + } + + // 应该能再 wait 4 次 + for (int i = 0; i < 4; ++i) { + result = sync_->wait_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); + } + + // 第 5 次 wait 应该超时 + result = sync_->wait_semaphore(semaphore_name_, + std::chrono::milliseconds(50)); + EXPECT_NE(result, shared_memory_error::SUCCESS); } /** * @brief 测试移除信号量 */ TEST_F(InterprocessSemaphoreTest, RemoveSemaphore) { - // 创建信号量 - auto result = sync_->create_semaphore(semaphore_name_, 1); - ASSERT_SHM_SUCCESS(result); - - // 移除信号量 - result = sync_->remove_semaphore(semaphore_name_); - EXPECT_SHM_SUCCESS(result); -} \ No newline at end of file + // 创建信号量 + auto result = sync_->create_semaphore(semaphore_name_, 1); + ASSERT_SHM_SUCCESS(result); + + // 移除信号量 + result = sync_->remove_semaphore(semaphore_name_); + EXPECT_SHM_SUCCESS(result); +} diff --git a/tests/shm/test_lock_free_ring_buffer.cpp b/tests/shm/test_lock_free_ring_buffer.cpp index 0db7dcb..eb9000c 100644 --- a/tests/shm/test_lock_free_ring_buffer.cpp +++ b/tests/shm/test_lock_free_ring_buffer.cpp @@ -15,273 +15,273 @@ using namespace shm_test; */ class LockFreeRingBufferTest : public ::testing::Test { protected: - void SetUp() override { - // 创建测试环境 - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - // 生成 buffer 名称 - buffer_name_ = env_->generate_segment_name("ring_buffer"); - - // 创建容量为 100 的缓冲区 - buffer_ = std::make_unique>(buffer_name_, 100); - } - - void TearDown() override { - // 清理资源 - buffer_.reset(); - env_.reset(); - } - - std::unique_ptr env_; - std::string buffer_name_; - std::unique_ptr> buffer_; + void SetUp() override { + // 创建测试环境 + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + // 生成 buffer 名称 + buffer_name_ = env_->generate_segment_name("ring_buffer"); + + // 创建容量为 100 的缓冲区 + buffer_ = std::make_unique>(buffer_name_, 100); + } + + void TearDown() override { + // 清理资源 + buffer_.reset(); + env_.reset(); + } + + std::unique_ptr env_; + std::string buffer_name_; + std::unique_ptr> buffer_; }; /** * @brief 测试创建缓冲区 */ TEST_F(LockFreeRingBufferTest, CreateBuffer) { - EXPECT_EQ(buffer_->capacity(), 100u); - EXPECT_BUFFER_EMPTY(*buffer_); - EXPECT_EQ(buffer_->size(), 0u); - EXPECT_EQ(buffer_->available_space(), 100u); - EXPECT_FALSE(buffer_->full()); + EXPECT_EQ(buffer_->capacity(), 100u); + EXPECT_BUFFER_EMPTY(*buffer_); + EXPECT_EQ(buffer_->size(), 0u); + EXPECT_EQ(buffer_->available_space(), 100u); + EXPECT_FALSE(buffer_->full()); } /** * @brief 测试单次 push/pop,验证 FIFO */ TEST_F(LockFreeRingBufferTest, SinglePushPop) { - // Push 一个元素 - EXPECT_TRUE(buffer_->try_push(42)); - EXPECT_BUFFER_NOT_EMPTY(*buffer_); - EXPECT_EQ(buffer_->size(), 1u); - - // Pop 元素并验证 - int value = 0; - EXPECT_TRUE(buffer_->try_pop(value)); - EXPECT_EQ(value, 42); - EXPECT_BUFFER_EMPTY(*buffer_); + // Push 一个元素 + EXPECT_TRUE(buffer_->try_push(42)); + EXPECT_BUFFER_NOT_EMPTY(*buffer_); + EXPECT_EQ(buffer_->size(), 1u); + + // Pop 元素并验证 + int value = 0; + EXPECT_TRUE(buffer_->try_pop(value)); + EXPECT_EQ(value, 42); + EXPECT_BUFFER_EMPTY(*buffer_); } /** * @brief 测试多次 push/pop,验证顺序 */ TEST_F(LockFreeRingBufferTest, MultiplePushPop) { - DataGenerator gen; - auto test_data = gen.generate_sequence(10, 0, 1); // 0-9 - - // Push 多个元素 - for (int val : test_data) { - EXPECT_TRUE(buffer_->try_push(val)); - } - - EXPECT_EQ(buffer_->size(), 10u); - - // Pop 并验证顺序 - std::vector popped_data; - int value; - while (buffer_->try_pop(value)) { - popped_data.push_back(value); - } - - EXPECT_EQ(popped_data.size(), 10u); - EXPECT_SEQUENCE_VALID(gen, popped_data, 0, 1); - EXPECT_BUFFER_EMPTY(*buffer_); + DataGenerator gen; + auto test_data = gen.generate_sequence(10, 0, 1); // 0-9 + + // Push 多个元素 + for (int val : test_data) { + EXPECT_TRUE(buffer_->try_push(val)); + } + + EXPECT_EQ(buffer_->size(), 10u); + + // Pop 并验证顺序 + std::vector popped_data; + int value; + while (buffer_->try_pop(value)) { + popped_data.push_back(value); + } + + EXPECT_EQ(popped_data.size(), 10u); + EXPECT_SEQUENCE_VALID(gen, popped_data, 0, 1); + EXPECT_BUFFER_EMPTY(*buffer_); } /** * @brief 测试 peek 不移除元素 */ TEST_F(LockFreeRingBufferTest, TryPeek) { - // Push 一些元素 - EXPECT_TRUE(buffer_->try_push(10)); - EXPECT_TRUE(buffer_->try_push(20)); - EXPECT_TRUE(buffer_->try_push(30)); - - size_t size_before = buffer_->size(); - - // Peek 第一个元素 - int value = 0; - EXPECT_TRUE(buffer_->try_peek(value)); - EXPECT_EQ(value, 10); // 应该是第一个元素 - - // 验证 size 没有改变 - EXPECT_EQ(buffer_->size(), size_before); - - // 再次 peek,应该还是同一个元素 - EXPECT_TRUE(buffer_->try_peek(value)); - EXPECT_EQ(value, 10); - - // Pop 后 peek 应该返回下一个元素 - EXPECT_TRUE(buffer_->try_pop(value)); - EXPECT_EQ(value, 10); - - EXPECT_TRUE(buffer_->try_peek(value)); - EXPECT_EQ(value, 20); // 现在应该是 20 + // Push 一些元素 + EXPECT_TRUE(buffer_->try_push(10)); + EXPECT_TRUE(buffer_->try_push(20)); + EXPECT_TRUE(buffer_->try_push(30)); + + size_t size_before = buffer_->size(); + + // Peek 第一个元素 + int value = 0; + EXPECT_TRUE(buffer_->try_peek(value)); + EXPECT_EQ(value, 10); // 应该是第一个元素 + + // 验证 size 没有改变 + EXPECT_EQ(buffer_->size(), size_before); + + // 再次 peek,应该还是同一个元素 + EXPECT_TRUE(buffer_->try_peek(value)); + EXPECT_EQ(value, 10); + + // Pop 后 peek 应该返回下一个元素 + EXPECT_TRUE(buffer_->try_pop(value)); + EXPECT_EQ(value, 10); + + EXPECT_TRUE(buffer_->try_peek(value)); + EXPECT_EQ(value, 20); // 现在应该是 20 } /** * @brief 测试空缓冲区 pop 返回 false */ TEST_F(LockFreeRingBufferTest, EmptyBufferPop) { - EXPECT_BUFFER_EMPTY(*buffer_); - - int value = 0; - EXPECT_FALSE(buffer_->try_pop(value)); - - // 多次尝试 pop 都应该失败 - EXPECT_FALSE(buffer_->try_pop(value)); - EXPECT_FALSE(buffer_->try_pop(value)); - - // Peek 也应该失败 - EXPECT_FALSE(buffer_->try_peek(value)); + EXPECT_BUFFER_EMPTY(*buffer_); + + int value = 0; + EXPECT_FALSE(buffer_->try_pop(value)); + + // 多次尝试 pop 都应该失败 + EXPECT_FALSE(buffer_->try_pop(value)); + EXPECT_FALSE(buffer_->try_pop(value)); + + // Peek 也应该失败 + EXPECT_FALSE(buffer_->try_peek(value)); } /** * @brief 测试满缓冲区 push 返回 false */ TEST_F(LockFreeRingBufferTest, FullBufferPush) { - // 填满缓冲区 - size_t capacity = buffer_->capacity(); - for (size_t i = 0; i < capacity; ++i) { - EXPECT_TRUE(buffer_->try_push(static_cast(i))); - } - - EXPECT_BUFFER_FULL(*buffer_); - EXPECT_EQ(buffer_->size(), capacity); - EXPECT_EQ(buffer_->available_space(), 0u); - - // 尝试再 push 应该失败 - EXPECT_FALSE(buffer_->try_push(999)); - EXPECT_FALSE(buffer_->try_push(888)); - - // Pop 一个元素后应该能再 push - int value; - EXPECT_TRUE(buffer_->try_pop(value)); - EXPECT_TRUE(buffer_->try_push(777)); + // 填满缓冲区 + size_t capacity = buffer_->capacity(); + for (size_t i = 0; i < capacity; ++i) { + EXPECT_TRUE(buffer_->try_push(static_cast(i))); + } + + EXPECT_BUFFER_FULL(*buffer_); + EXPECT_EQ(buffer_->size(), capacity); + EXPECT_EQ(buffer_->available_space(), 0u); + + // 尝试再 push 应该失败 + EXPECT_FALSE(buffer_->try_push(999)); + EXPECT_FALSE(buffer_->try_push(888)); + + // Pop 一个元素后应该能再 push + int value; + EXPECT_TRUE(buffer_->try_pop(value)); + EXPECT_TRUE(buffer_->try_push(777)); } /** * @brief 测试批量 push */ TEST_F(LockFreeRingBufferTest, BatchPush) { - DataGenerator gen; - auto test_data = gen.generate_sequence(50, 100, 2); // 100, 102, 104, ... - - // 批量 push - size_t pushed = buffer_->try_push_batch(test_data.data(), test_data.size()); - EXPECT_EQ(pushed, 50u); - EXPECT_EQ(buffer_->size(), 50u); - - // 验证数据 - std::vector popped_data; - int value; - while (buffer_->try_pop(value)) { - popped_data.push_back(value); - } - - EXPECT_DATA_EQUAL(gen, test_data, popped_data); + DataGenerator gen; + auto test_data = gen.generate_sequence(50, 100, 2); // 100, 102, 104, ... + + // 批量 push + size_t pushed = buffer_->try_push_batch(test_data.data(), test_data.size()); + EXPECT_EQ(pushed, 50u); + EXPECT_EQ(buffer_->size(), 50u); + + // 验证数据 + std::vector popped_data; + int value; + while (buffer_->try_pop(value)) { + popped_data.push_back(value); + } + + EXPECT_DATA_EQUAL(gen, test_data, popped_data); } /** * @brief 测试批量 pop */ TEST_F(LockFreeRingBufferTest, BatchPop) { - DataGenerator gen; - auto test_data = gen.generate_sequence(30, 0, 5); // 0, 5, 10, 15, ... - - // Push 数据 - for (int val : test_data) { - EXPECT_TRUE(buffer_->try_push(val)); - } - - // 批量 pop - std::vector popped_data(30); - size_t popped = buffer_->try_pop_batch(popped_data.data(), 30); - - EXPECT_EQ(popped, 30u); - EXPECT_BUFFER_EMPTY(*buffer_); - - // 验证数据 - popped_data.resize(popped); - EXPECT_DATA_EQUAL(gen, test_data, popped_data); + DataGenerator gen; + auto test_data = gen.generate_sequence(30, 0, 5); // 0, 5, 10, 15, ... + + // Push 数据 + for (int val : test_data) { + EXPECT_TRUE(buffer_->try_push(val)); + } + + // 批量 pop + std::vector popped_data(30); + size_t popped = buffer_->try_pop_batch(popped_data.data(), 30); + + EXPECT_EQ(popped, 30u); + EXPECT_BUFFER_EMPTY(*buffer_); + + // 验证数据 + popped_data.resize(popped); + EXPECT_DATA_EQUAL(gen, test_data, popped_data); } /** * @brief 测试原地构造 */ TEST_F(LockFreeRingBufferTest, TryEmplace) { - // 使用 try_emplace 直接构造元素 - EXPECT_TRUE(buffer_->try_emplace(42)); - EXPECT_TRUE(buffer_->try_emplace(100)); - EXPECT_TRUE(buffer_->try_emplace(200)); - - EXPECT_EQ(buffer_->size(), 3u); - - // 验证值 - int value; - EXPECT_TRUE(buffer_->try_pop(value)); - EXPECT_EQ(value, 42); - - EXPECT_TRUE(buffer_->try_pop(value)); - EXPECT_EQ(value, 100); - - EXPECT_TRUE(buffer_->try_pop(value)); - EXPECT_EQ(value, 200); + // 使用 try_emplace 直接构造元素 + EXPECT_TRUE(buffer_->try_emplace(42)); + EXPECT_TRUE(buffer_->try_emplace(100)); + EXPECT_TRUE(buffer_->try_emplace(200)); + + EXPECT_EQ(buffer_->size(), 3u); + + // 验证值 + int value; + EXPECT_TRUE(buffer_->try_pop(value)); + EXPECT_EQ(value, 42); + + EXPECT_TRUE(buffer_->try_pop(value)); + EXPECT_EQ(value, 100); + + EXPECT_TRUE(buffer_->try_pop(value)); + EXPECT_EQ(value, 200); } /** * @brief 测试 size() 和 capacity() */ TEST_F(LockFreeRingBufferTest, SizeAndCapacity) { - EXPECT_EQ(buffer_->capacity(), 100u); - EXPECT_EQ(buffer_->size(), 0u); - - // 添加一些元素 - for (int i = 0; i < 25; ++i) { - EXPECT_TRUE(buffer_->try_push(i)); - } - - EXPECT_EQ(buffer_->size(), 25u); - EXPECT_EQ(buffer_->capacity(), 100u); // 容量不变 - - // 移除一些元素 - int value; - for (int i = 0; i < 10; ++i) { - EXPECT_TRUE(buffer_->try_pop(value)); - } - - EXPECT_EQ(buffer_->size(), 15u); - EXPECT_EQ(buffer_->capacity(), 100u); // 容量不变 + EXPECT_EQ(buffer_->capacity(), 100u); + EXPECT_EQ(buffer_->size(), 0u); + + // 添加一些元素 + for (int i = 0; i < 25; ++i) { + EXPECT_TRUE(buffer_->try_push(i)); + } + + EXPECT_EQ(buffer_->size(), 25u); + EXPECT_EQ(buffer_->capacity(), 100u); // 容量不变 + + // 移除一些元素 + int value; + for (int i = 0; i < 10; ++i) { + EXPECT_TRUE(buffer_->try_pop(value)); + } + + EXPECT_EQ(buffer_->size(), 15u); + EXPECT_EQ(buffer_->capacity(), 100u); // 容量不变 } /** * @brief 测试 available_space() */ TEST_F(LockFreeRingBufferTest, AvailableSpace) { - size_t capacity = buffer_->capacity(); - - EXPECT_EQ(buffer_->available_space(), capacity); - - // 添加元素 - for (size_t i = 0; i < 40; ++i) { - EXPECT_TRUE(buffer_->try_push(static_cast(i))); - } - - EXPECT_EQ(buffer_->available_space(), capacity - 40); - EXPECT_EQ(buffer_->size() + buffer_->available_space(), capacity); - - // 移除一些元素 - int value; - for (int i = 0; i < 15; ++i) { - EXPECT_TRUE(buffer_->try_pop(value)); - } - - EXPECT_EQ(buffer_->available_space(), capacity - 25); - EXPECT_EQ(buffer_->size() + buffer_->available_space(), capacity); + size_t capacity = buffer_->capacity(); + + EXPECT_EQ(buffer_->available_space(), capacity); + + // 添加元素 + for (size_t i = 0; i < 40; ++i) { + EXPECT_TRUE(buffer_->try_push(static_cast(i))); + } + + EXPECT_EQ(buffer_->available_space(), capacity - 40); + EXPECT_EQ(buffer_->size() + buffer_->available_space(), capacity); + + // 移除一些元素 + int value; + for (int i = 0; i < 15; ++i) { + EXPECT_TRUE(buffer_->try_pop(value)); + } + + EXPECT_EQ(buffer_->available_space(), capacity - 25); + EXPECT_EQ(buffer_->size() + buffer_->available_space(), capacity); } // ============================================================================ @@ -293,82 +293,82 @@ TEST_F(LockFreeRingBufferTest, AvailableSpace) { */ class LockFreeRingBufferParamTest : public ::testing::TestWithParam { protected: - void SetUp() override { - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - capacity_ = GetParam(); - buffer_name_ = env_->generate_segment_name("param_buffer"); - buffer_ = std::make_unique>(buffer_name_, capacity_); - } - - void TearDown() override { - buffer_.reset(); - env_.reset(); - } - - std::unique_ptr env_; - std::string buffer_name_; - std::unique_ptr> buffer_; - size_t capacity_; + void SetUp() override { + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + capacity_ = GetParam(); + buffer_name_ = env_->generate_segment_name("param_buffer"); + buffer_ = std::make_unique>(buffer_name_, capacity_); + } + + void TearDown() override { + buffer_.reset(); + env_.reset(); + } + + std::unique_ptr env_; + std::string buffer_name_; + std::unique_ptr> buffer_; + size_t capacity_; }; /** * @brief 测试不同容量的缓冲区基本操作 */ TEST_P(LockFreeRingBufferParamTest, DifferentCapacities) { - size_t capacity = GetParam(); - - // 验证容量 - EXPECT_EQ(buffer_->capacity(), capacity); - EXPECT_BUFFER_EMPTY(*buffer_); - - // 生成测试数据(不超过容量) - DataGenerator gen; - size_t test_count = std::min(capacity, size_t(50)); - auto test_data = gen.generate_sequence(test_count, 1, 1); - - // Push 数据 - for (int val : test_data) { - EXPECT_TRUE(buffer_->try_push(val)); - } - - EXPECT_EQ(buffer_->size(), test_count); - - // Pop 数据并验证 - std::vector popped_data; - int value; - while (buffer_->try_pop(value)) { - popped_data.push_back(value); - } - - EXPECT_EQ(popped_data.size(), test_count); - EXPECT_SEQUENCE_VALID(gen, popped_data, 1, 1); - EXPECT_BUFFER_EMPTY(*buffer_); - - // 测试填满缓冲区 - for (size_t i = 0; i < capacity; ++i) { - EXPECT_TRUE(buffer_->try_push(static_cast(i))); - } - - EXPECT_BUFFER_FULL(*buffer_); - EXPECT_EQ(buffer_->size(), capacity); - - // 满了之后无法 push - EXPECT_FALSE(buffer_->try_push(999)); - - // 清空 - while (buffer_->try_pop(value)) { - // 不断 pop - } - - EXPECT_BUFFER_EMPTY(*buffer_); + size_t capacity = GetParam(); + + // 验证容量 + EXPECT_EQ(buffer_->capacity(), capacity); + EXPECT_BUFFER_EMPTY(*buffer_); + + // 生成测试数据(不超过容量) + DataGenerator gen; + size_t test_count = std::min(capacity, size_t(50)); + auto test_data = gen.generate_sequence(test_count, 1, 1); + + // Push 数据 + for (int val : test_data) { + EXPECT_TRUE(buffer_->try_push(val)); + } + + EXPECT_EQ(buffer_->size(), test_count); + + // Pop 数据并验证 + std::vector popped_data; + int value; + while (buffer_->try_pop(value)) { + popped_data.push_back(value); + } + + EXPECT_EQ(popped_data.size(), test_count); + EXPECT_SEQUENCE_VALID(gen, popped_data, 1, 1); + EXPECT_BUFFER_EMPTY(*buffer_); + + // 测试填满缓冲区 + for (size_t i = 0; i < capacity; ++i) { + EXPECT_TRUE(buffer_->try_push(static_cast(i))); + } + + EXPECT_BUFFER_FULL(*buffer_); + EXPECT_EQ(buffer_->size(), capacity); + + // 满了之后无法 push + EXPECT_FALSE(buffer_->try_push(999)); + + // 清空 + while (buffer_->try_pop(value)) { + // 不断 pop + } + + EXPECT_BUFFER_EMPTY(*buffer_); } // 实例化参数化测试,测试容量 1, 10, 100, 1000 INSTANTIATE_TEST_SUITE_P( - CapacityTests, - LockFreeRingBufferParamTest, - ::testing::Values(1, 10, 100, 1000) -); \ No newline at end of file + CapacityTests, + LockFreeRingBufferParamTest, + ::testing::Values(1, 10, 100, 1000) +); diff --git a/tests/shm/test_shared_memory_manager.cpp b/tests/shm/test_shared_memory_manager.cpp index 57026a8..f658685 100644 --- a/tests/shm/test_shared_memory_manager.cpp +++ b/tests/shm/test_shared_memory_manager.cpp @@ -14,214 +14,214 @@ using namespace shm_test; * @brief 测试共享内存管理器初始化成功 */ TEST(SharedMemoryManagerTest, InitializeSuccess) { - ShmTestEnvironment env("InitializeSuccess"); - - auto& manager = shared_memory_manager::instance(); - EXPECT_TRUE(manager.is_initialized()); + ShmTestEnvironment env("InitializeSuccess"); + + auto& manager = shared_memory_manager::instance(); + EXPECT_TRUE(manager.is_initialized()); } /** * @brief 测试使用自定义配置初始化 */ TEST(SharedMemoryManagerTest, InitializeWithConfig) { - ShmTestEnvironment env("InitializeWithConfig", 0, false); // 不自动初始化 - - auto& manager = shared_memory_manager::instance(); - EXPECT_FALSE(manager.is_initialized()); - - // 使用自定义配置初始化 - shared_memory_config config{}; - config.segment_name = env.generate_segment_name("custom"); - config.segment_size = 2 * 1024 * 1024; // 2MB - config.create_if_not_exists = true; - config.remove_on_destroy = true; - config.mutex_name = env.generate_segment_name("mutex"); - config.condition_name = env.generate_segment_name("cond"); - config.semaphore_name = env.generate_segment_name("sem"); - - EXPECT_SHM_SUCCESS(manager.init(config)); - EXPECT_TRUE(manager.is_initialized()); + ShmTestEnvironment env("InitializeWithConfig", 0, false); // 不自动初始化 + + auto& manager = shared_memory_manager::instance(); + EXPECT_FALSE(manager.is_initialized()); + + // 使用自定义配置初始化 + shared_memory_config config{}; + config.segment_name = env.generate_segment_name("custom"); + config.segment_size = 2 * 1024 * 1024; // 2MB + config.create_if_not_exists = true; + config.remove_on_destroy = true; + config.mutex_name = env.generate_segment_name("mutex"); + config.condition_name = env.generate_segment_name("cond"); + config.semaphore_name = env.generate_segment_name("sem"); + + EXPECT_SHM_SUCCESS(manager.init(config)); + EXPECT_TRUE(manager.is_initialized()); } /** * @brief 测试分配对象后能成功查找 */ TEST(SharedMemoryManagerTest, AllocateAndFind) { - ShmTestEnvironment env("AllocateAndFind"); - - auto& manager = shared_memory_manager::instance(); - - // 分配 int 对象 - int* ptr = manager.allocate("test_int"); - ASSERT_SHM_NOT_NULL(ptr); - *ptr = 42; - - // 查找已分配的对象 - int* found = manager.find("test_int"); - ASSERT_SHM_NOT_NULL(found); - EXPECT_EQ(*found, 42); - EXPECT_EQ(ptr, found); // 应该是同一个对象 + ShmTestEnvironment env("AllocateAndFind"); + + auto& manager = shared_memory_manager::instance(); + + // 分配 int 对象 + int* ptr = manager.allocate("test_int"); + ASSERT_SHM_NOT_NULL(ptr); + *ptr = 42; + + // 查找已分配的对象 + int* found = manager.find("test_int"); + ASSERT_SHM_NOT_NULL(found); + EXPECT_EQ(*found, 42); + EXPECT_EQ(ptr, found); // 应该是同一个对象 } /** * @brief 测试分配原始内存 */ TEST(SharedMemoryManagerTest, AllocateRawMemory) { - ShmTestEnvironment env("AllocateRawMemory"); - - auto& manager = shared_memory_manager::instance(); - - // 分配 1024 字节原始内存 - size_t size = 1024; - void* raw_ptr = manager.allocate_raw(size, "raw_memory"); - ASSERT_SHM_NOT_NULL(raw_ptr); - - // 写入数据 - uint8_t* byte_ptr = static_cast(raw_ptr); - for (size_t i = 0; i < size; ++i) { - byte_ptr[i] = static_cast(i % 256); - } - - // 查找并验证 - void* found = manager.find_raw("raw_memory"); - ASSERT_SHM_NOT_NULL(found); - EXPECT_EQ(raw_ptr, found); - - uint8_t* found_bytes = static_cast(found); - for (size_t i = 0; i < size; ++i) { - EXPECT_EQ(found_bytes[i], static_cast(i % 256)); - } + ShmTestEnvironment env("AllocateRawMemory"); + + auto& manager = shared_memory_manager::instance(); + + // 分配 1024 字节原始内存 + size_t size = 1024; + void* raw_ptr = manager.allocate_raw(size, "raw_memory"); + ASSERT_SHM_NOT_NULL(raw_ptr); + + // 写入数据 + uint8_t* byte_ptr = static_cast(raw_ptr); + for (size_t i = 0; i < size; ++i) { + byte_ptr[i] = static_cast(i % 256); + } + + // 查找并验证 + void* found = manager.find_raw("raw_memory"); + ASSERT_SHM_NOT_NULL(found); + EXPECT_EQ(raw_ptr, found); + + uint8_t* found_bytes = static_cast(found); + for (size_t i = 0; i < size; ++i) { + EXPECT_EQ(found_bytes[i], static_cast(i % 256)); + } } /** * @brief 测试释放对象后无法查找 */ TEST(SharedMemoryManagerTest, DeallocateObject) { - ShmTestEnvironment env("DeallocateObject"); - - auto& manager = shared_memory_manager::instance(); - - // 分配对象 - double* ptr = manager.allocate("test_double"); - ASSERT_SHM_NOT_NULL(ptr); - *ptr = 3.14159; - - // 确认能找到 - double* found = manager.find("test_double"); - ASSERT_SHM_NOT_NULL(found); - EXPECT_DOUBLE_EQ(*found, 3.14159); - - // 释放对象 - EXPECT_TRUE(manager.deallocate("test_double")); - - // 释放后应该找不到 - double* not_found = manager.find("test_double"); - EXPECT_EQ(not_found, nullptr); + ShmTestEnvironment env("DeallocateObject"); + + auto& manager = shared_memory_manager::instance(); + + // 分配对象 + double* ptr = manager.allocate("test_double"); + ASSERT_SHM_NOT_NULL(ptr); + *ptr = 3.14159; + + // 确认能找到 + double* found = manager.find("test_double"); + ASSERT_SHM_NOT_NULL(found); + EXPECT_DOUBLE_EQ(*found, 3.14159); + + // 释放对象 + EXPECT_TRUE(manager.deallocate("test_double")); + + // 释放后应该找不到 + double* not_found = manager.find("test_double"); + EXPECT_EQ(not_found, nullptr); } /** * @brief 测试统计信息准确反映使用情况 */ TEST(SharedMemoryManagerTest, GetStatistics) { - ShmTestEnvironment env("GetStatistics"); - - auto& manager = shared_memory_manager::instance(); - - // 获取初始统计信息 - auto stats_before = manager.get_statistics(); - EXPECT_GT(stats_before.total_size, 0u); - EXPECT_GT(stats_before.free_size, 0u); - - // 分配一些对象 - int* ptr1 = manager.allocate("stats_int"); - double* ptr2 = manager.allocate("stats_double"); - ASSERT_SHM_NOT_NULL(ptr1); - ASSERT_SHM_NOT_NULL(ptr2); - - // 检查统计信息变化 - auto stats_after = manager.get_statistics(); - EXPECT_LT(stats_after.free_size, stats_before.free_size); // 空闲空间减少 - EXPECT_GT(stats_after.used_size, stats_before.used_size); // 已用空间增加 - EXPECT_GT(stats_after.num_allocations, stats_before.num_allocations); // 分配次数增加 + ShmTestEnvironment env("GetStatistics"); + + auto& manager = shared_memory_manager::instance(); + + // 获取初始统计信息 + auto stats_before = manager.get_statistics(); + EXPECT_GT(stats_before.total_size, 0u); + EXPECT_GT(stats_before.free_size, 0u); + + // 分配一些对象 + int* ptr1 = manager.allocate("stats_int"); + double* ptr2 = manager.allocate("stats_double"); + ASSERT_SHM_NOT_NULL(ptr1); + ASSERT_SHM_NOT_NULL(ptr2); + + // 检查统计信息变化 + auto stats_after = manager.get_statistics(); + EXPECT_LT(stats_after.free_size, stats_before.free_size); // 空闲空间减少 + EXPECT_GT(stats_after.used_size, stats_before.used_size); // 已用空间增加 + EXPECT_GT(stats_after.num_allocations, stats_before.num_allocations); // 分配次数增加 } /** * @brief 测试多次分配不同对象 */ TEST(SharedMemoryManagerTest, MultipleAllocations) { - ShmTestEnvironment env("MultipleAllocations"); - - auto& manager = shared_memory_manager::instance(); - - // 分配多个不同类型的对象 - int* int_ptr = manager.allocate("multi_int"); - float* float_ptr = manager.allocate("multi_float"); - double* double_ptr = manager.allocate("multi_double"); - - ASSERT_SHM_NOT_NULL(int_ptr); - ASSERT_SHM_NOT_NULL(float_ptr); - ASSERT_SHM_NOT_NULL(double_ptr); - - // 赋值 - *int_ptr = 100; - *float_ptr = 1.5f; - *double_ptr = 2.71828; - - // 验证所有对象都能找到且值正确 - int* found_int = manager.find("multi_int"); - float* found_float = manager.find("multi_float"); - double* found_double = manager.find("multi_double"); - - ASSERT_SHM_NOT_NULL(found_int); - ASSERT_SHM_NOT_NULL(found_float); - ASSERT_SHM_NOT_NULL(found_double); - - EXPECT_EQ(*found_int, 100); - EXPECT_FLOAT_EQ(*found_float, 1.5f); - EXPECT_DOUBLE_EQ(*found_double, 2.71828); + ShmTestEnvironment env("MultipleAllocations"); + + auto& manager = shared_memory_manager::instance(); + + // 分配多个不同类型的对象 + int* int_ptr = manager.allocate("multi_int"); + float* float_ptr = manager.allocate("multi_float"); + double* double_ptr = manager.allocate("multi_double"); + + ASSERT_SHM_NOT_NULL(int_ptr); + ASSERT_SHM_NOT_NULL(float_ptr); + ASSERT_SHM_NOT_NULL(double_ptr); + + // 赋值 + *int_ptr = 100; + *float_ptr = 1.5f; + *double_ptr = 2.71828; + + // 验证所有对象都能找到且值正确 + int* found_int = manager.find("multi_int"); + float* found_float = manager.find("multi_float"); + double* found_double = manager.find("multi_double"); + + ASSERT_SHM_NOT_NULL(found_int); + ASSERT_SHM_NOT_NULL(found_float); + ASSERT_SHM_NOT_NULL(found_double); + + EXPECT_EQ(*found_int, 100); + EXPECT_FLOAT_EQ(*found_float, 1.5f); + EXPECT_DOUBLE_EQ(*found_double, 2.71828); } /** * @brief 测试关闭后可重新初始化 */ TEST(SharedMemoryManagerTest, ShutdownAndReinit) { - ShmTestEnvironment env("ShutdownAndReinit", 0, false); // 不自动初始化 - - auto& manager = shared_memory_manager::instance(); - - // 第一次初始化 - shared_memory_config config1{}; - config1.segment_name = env.generate_segment_name("first"); - config1.segment_size = 1024 * 1024; - config1.create_if_not_exists = true; - config1.remove_on_destroy = true; - config1.mutex_name = env.generate_segment_name("mutex1"); - config1.condition_name = env.generate_segment_name("cond1"); - config1.semaphore_name = env.generate_segment_name("sem1"); - - EXPECT_SHM_SUCCESS(manager.init(config1)); - EXPECT_TRUE(manager.is_initialized()); - - // 分配对象 - int* ptr = manager.allocate("test_obj"); - ASSERT_SHM_NOT_NULL(ptr); - - // 关闭 - EXPECT_SHM_SUCCESS(manager.shutdown()); - EXPECT_FALSE(manager.is_initialized()); - - // 重新初始化 - shared_memory_config config2{}; - config2.segment_name = env.generate_segment_name("second"); - config2.segment_size = 2 * 1024 * 1024; - config2.create_if_not_exists = true; - config2.remove_on_destroy = true; - config2.mutex_name = env.generate_segment_name("mutex2"); - config2.condition_name = env.generate_segment_name("cond2"); - config2.semaphore_name = env.generate_segment_name("sem2"); - - EXPECT_SHM_SUCCESS(manager.init(config2)); - EXPECT_TRUE(manager.is_initialized()); + ShmTestEnvironment env("ShutdownAndReinit", 0, false); // 不自动初始化 + + auto& manager = shared_memory_manager::instance(); + + // 第一次初始化 + shared_memory_config config1{}; + config1.segment_name = env.generate_segment_name("first"); + config1.segment_size = 1024 * 1024; + config1.create_if_not_exists = true; + config1.remove_on_destroy = true; + config1.mutex_name = env.generate_segment_name("mutex1"); + config1.condition_name = env.generate_segment_name("cond1"); + config1.semaphore_name = env.generate_segment_name("sem1"); + + EXPECT_SHM_SUCCESS(manager.init(config1)); + EXPECT_TRUE(manager.is_initialized()); + + // 分配对象 + int* ptr = manager.allocate("test_obj"); + ASSERT_SHM_NOT_NULL(ptr); + + // 关闭 + EXPECT_SHM_SUCCESS(manager.shutdown()); + EXPECT_FALSE(manager.is_initialized()); + + // 重新初始化 + shared_memory_config config2{}; + config2.segment_name = env.generate_segment_name("second"); + config2.segment_size = 2 * 1024 * 1024; + config2.create_if_not_exists = true; + config2.remove_on_destroy = true; + config2.mutex_name = env.generate_segment_name("mutex2"); + config2.condition_name = env.generate_segment_name("cond2"); + config2.semaphore_name = env.generate_segment_name("sem2"); + + EXPECT_SHM_SUCCESS(manager.init(config2)); + EXPECT_TRUE(manager.is_initialized()); } // ============================================================================ @@ -233,144 +233,144 @@ TEST(SharedMemoryManagerTest, ShutdownAndReinit) { */ class SharedMemoryManagerFixture : public ::testing::Test { protected: - void SetUp() override { - // 创建测试环境 - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - // 生成测试用的共享内存段名称 - segment_name_ = env_->generate_segment_name("fixture"); - } - - void TearDown() override { - // 清理资源 - env_.reset(); - } - - std::unique_ptr env_; - std::string segment_name_; + void SetUp() override { + // 创建测试环境 + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + // 生成测试用的共享内存段名称 + segment_name_ = env_->generate_segment_name("fixture"); + } + + void TearDown() override { + // 清理资源 + env_.reset(); + } + + std::unique_ptr env_; + std::string segment_name_; }; /** * @brief 测试分配不同类型的对象 */ TEST_F(SharedMemoryManagerFixture, AllocateDifferentTypes) { - auto& manager = shared_memory_manager::instance(); - - // 定义测试结构体 - struct TestStruct { - int a; - double b; - char c; - }; - - // 分配基本类型 - int* int_ptr = manager.allocate("diff_int"); - double* double_ptr = manager.allocate("diff_double"); - TestStruct* struct_ptr = manager.allocate("diff_struct"); - - ASSERT_SHM_NOT_NULL(int_ptr); - ASSERT_SHM_NOT_NULL(double_ptr); - ASSERT_SHM_NOT_NULL(struct_ptr); - - // 赋值 - *int_ptr = 123; - *double_ptr = 456.789; - struct_ptr->a = 10; - struct_ptr->b = 20.5; - struct_ptr->c = 'X'; - - // 查找并验证 - int* found_int = manager.find("diff_int"); - double* found_double = manager.find("diff_double"); - TestStruct* found_struct = manager.find("diff_struct"); - - ASSERT_SHM_NOT_NULL(found_int); - ASSERT_SHM_NOT_NULL(found_double); - ASSERT_SHM_NOT_NULL(found_struct); - - EXPECT_EQ(*found_int, 123); - EXPECT_DOUBLE_EQ(*found_double, 456.789); - EXPECT_EQ(found_struct->a, 10); - EXPECT_DOUBLE_EQ(found_struct->b, 20.5); - EXPECT_EQ(found_struct->c, 'X'); + auto& manager = shared_memory_manager::instance(); + + // 定义测试结构体 + struct TestStruct { + int a; + double b; + char c; + }; + + // 分配基本类型 + int* int_ptr = manager.allocate("diff_int"); + double* double_ptr = manager.allocate("diff_double"); + TestStruct* struct_ptr = manager.allocate("diff_struct"); + + ASSERT_SHM_NOT_NULL(int_ptr); + ASSERT_SHM_NOT_NULL(double_ptr); + ASSERT_SHM_NOT_NULL(struct_ptr); + + // 赋值 + *int_ptr = 123; + *double_ptr = 456.789; + struct_ptr->a = 10; + struct_ptr->b = 20.5; + struct_ptr->c = 'X'; + + // 查找并验证 + int* found_int = manager.find("diff_int"); + double* found_double = manager.find("diff_double"); + TestStruct* found_struct = manager.find("diff_struct"); + + ASSERT_SHM_NOT_NULL(found_int); + ASSERT_SHM_NOT_NULL(found_double); + ASSERT_SHM_NOT_NULL(found_struct); + + EXPECT_EQ(*found_int, 123); + EXPECT_DOUBLE_EQ(*found_double, 456.789); + EXPECT_EQ(found_struct->a, 10); + EXPECT_DOUBLE_EQ(found_struct->b, 20.5); + EXPECT_EQ(found_struct->c, 'X'); } /** * @brief 测试查找不存在的对象返回 nullptr */ TEST_F(SharedMemoryManagerFixture, FindNonExistentObject) { - auto& manager = shared_memory_manager::instance(); - - // 查找不存在的对象 - int* not_found1 = manager.find("non_existent_1"); - double* not_found2 = manager.find("non_existent_2"); - void* not_found3 = manager.find_raw("non_existent_3"); - - EXPECT_EQ(not_found1, nullptr); - EXPECT_EQ(not_found2, nullptr); - EXPECT_EQ(not_found3, nullptr); + auto& manager = shared_memory_manager::instance(); + + // 查找不存在的对象 + int* not_found1 = manager.find("non_existent_1"); + double* not_found2 = manager.find("non_existent_2"); + void* not_found3 = manager.find_raw("non_existent_3"); + + EXPECT_EQ(not_found1, nullptr); + EXPECT_EQ(not_found2, nullptr); + EXPECT_EQ(not_found3, nullptr); } /** * @brief 测试重复分配同名对象的行为 */ TEST_F(SharedMemoryManagerFixture, AllocateSameName) { - auto& manager = shared_memory_manager::instance(); - - // 第一次分配 - int* ptr1 = manager.allocate("same_name"); - ASSERT_SHM_NOT_NULL(ptr1); - *ptr1 = 100; - - // 尝试再次分配同名对象(应该失败,因为名称已存在) - int* ptr2 = manager.allocate("same_name"); - EXPECT_EQ(ptr2, nullptr); // Boost.Interprocess 不允许重复名称 - - // 原对象仍然存在且值不变 - int* found = manager.find("same_name"); - ASSERT_SHM_NOT_NULL(found); - EXPECT_EQ(*found, 100); + auto& manager = shared_memory_manager::instance(); + + // 第一次分配 + int* ptr1 = manager.allocate("same_name"); + ASSERT_SHM_NOT_NULL(ptr1); + *ptr1 = 100; + + // 尝试再次分配同名对象(应该失败,因为名称已存在) + int* ptr2 = manager.allocate("same_name"); + EXPECT_EQ(ptr2, nullptr); // Boost.Interprocess 不允许重复名称 + + // 原对象仍然存在且值不变 + int* found = manager.find("same_name"); + ASSERT_SHM_NOT_NULL(found); + EXPECT_EQ(*found, 100); } /** * @brief 测试统计信息的准确性 */ TEST_F(SharedMemoryManagerFixture, StatisticsAccuracy) { - auto& manager = shared_memory_manager::instance(); - - // 获取初始统计信息 - auto initial_stats = manager.get_statistics(); - size_t initial_allocations = initial_stats.num_allocations; - size_t initial_used = initial_stats.used_size; - - // 分配几个对象 - int* ptr1 = manager.allocate("stats_1"); - int* ptr2 = manager.allocate("stats_2"); - int* ptr3 = manager.allocate("stats_3"); - - ASSERT_SHM_NOT_NULL(ptr1); - ASSERT_SHM_NOT_NULL(ptr2); - ASSERT_SHM_NOT_NULL(ptr3); - - // 检查分配后的统计信息 - auto after_alloc_stats = manager.get_statistics(); - EXPECT_GT(after_alloc_stats.num_allocations, initial_allocations); - EXPECT_GT(after_alloc_stats.used_size, initial_used); - - // 验证总大小不变 - EXPECT_EQ(after_alloc_stats.total_size, initial_stats.total_size); - - // 验证 used + free = total(考虑管理开销) - EXPECT_LE(after_alloc_stats.used_size + after_alloc_stats.free_size, - after_alloc_stats.total_size); - - // 释放一个对象 - EXPECT_TRUE(manager.deallocate("stats_2")); - - // 检查释放后的统计信息 - auto after_dealloc_stats = manager.get_statistics(); - EXPECT_LT(after_dealloc_stats.used_size, after_alloc_stats.used_size); - EXPECT_GT(after_dealloc_stats.free_size, after_alloc_stats.free_size); -} \ No newline at end of file + auto& manager = shared_memory_manager::instance(); + + // 获取初始统计信息 + auto initial_stats = manager.get_statistics(); + size_t initial_allocations = initial_stats.num_allocations; + size_t initial_used = initial_stats.used_size; + + // 分配几个对象 + int* ptr1 = manager.allocate("stats_1"); + int* ptr2 = manager.allocate("stats_2"); + int* ptr3 = manager.allocate("stats_3"); + + ASSERT_SHM_NOT_NULL(ptr1); + ASSERT_SHM_NOT_NULL(ptr2); + ASSERT_SHM_NOT_NULL(ptr3); + + // 检查分配后的统计信息 + auto after_alloc_stats = manager.get_statistics(); + EXPECT_GT(after_alloc_stats.num_allocations, initial_allocations); + EXPECT_GT(after_alloc_stats.used_size, initial_used); + + // 验证总大小不变 + EXPECT_EQ(after_alloc_stats.total_size, initial_stats.total_size); + + // 验证 used + free = total(考虑管理开销) + EXPECT_LE(after_alloc_stats.used_size + after_alloc_stats.free_size, + after_alloc_stats.total_size); + + // 释放一个对象 + EXPECT_TRUE(manager.deallocate("stats_2")); + + // 检查释放后的统计信息 + auto after_dealloc_stats = manager.get_statistics(); + EXPECT_LT(after_dealloc_stats.used_size, after_alloc_stats.used_size); + EXPECT_GT(after_dealloc_stats.free_size, after_alloc_stats.free_size); +} diff --git a/tests/shm/test_triple_buffer.cpp b/tests/shm/test_triple_buffer.cpp index fd11a53..1e4b56a 100644 --- a/tests/shm/test_triple_buffer.cpp +++ b/tests/shm/test_triple_buffer.cpp @@ -15,17 +15,19 @@ using namespace shm_test; * @brief 用于 triple_buffer 测试的数据结构 */ struct TestData { - int value; - uint64_t timestamp; - char padding[56]; // 确保是平凡可拷贝类型,总大小 72 字节(对齐友好) - - TestData() : value(0), timestamp(0), padding{} {} - - TestData(int v, uint64_t ts) : value(v), timestamp(ts), padding{} {} - - bool operator==(const TestData& other) const { - return value == other.value && timestamp == other.timestamp; - } + int value; + uint64_t timestamp; + char padding[56]; // 确保是平凡可拷贝类型,总大小 72 字节(对齐友好) + + TestData() : value(0), timestamp(0), padding{} { + } + + TestData(int v, uint64_t ts) : value(v), timestamp(ts), padding{} { + } + + bool operator==(const TestData& other) const { + return value == other.value && timestamp == other.timestamp; + } }; static_assert(std::is_trivially_copyable_v, "TestData 必须是平凡可拷贝类型"); @@ -34,22 +36,24 @@ static_assert(std::is_trivially_copyable_v, "TestData 必须是平凡 * @brief 中等大小的测试数据结构 */ struct MediumData { - int values[64]; // 256 字节 - - MediumData() : values{} {} - - explicit MediumData(int start_val) : values{} { - for (int i = 0; i < 64; ++i) { - values[i] = start_val + i; - } - } - - bool operator==(const MediumData& other) const { - for (int i = 0; i < 64; ++i) { - if (values[i] != other.values[i]) return false; - } - return true; - } + int values[64]; // 256 字节 + + MediumData() : values{} { + } + + explicit MediumData(int start_val) : values{} { + for (int i = 0; i < 64; ++i) { + values[i] = start_val + i; + } + } + + bool operator==(const MediumData& other) const { + for (int i = 0; i < 64; ++i) { + if (values[i] != other.values[i]) + return false; + } + return true; + } }; static_assert(std::is_trivially_copyable_v); @@ -58,22 +62,24 @@ static_assert(std::is_trivially_copyable_v); * @brief 大型测试数据结构 */ struct LargeData { - int values[256]; // 1024 字节 - - LargeData() : values{} {} - - explicit LargeData(int start_val) : values{} { - for (int i = 0; i < 256; ++i) { - values[i] = start_val + i; - } - } - - bool operator==(const LargeData& other) const { - for (int i = 0; i < 256; ++i) { - if (values[i] != other.values[i]) return false; - } - return true; - } + int values[256]; // 1024 字节 + + LargeData() : values{} { + } + + explicit LargeData(int start_val) : values{} { + for (int i = 0; i < 256; ++i) { + values[i] = start_val + i; + } + } + + bool operator==(const LargeData& other) const { + for (int i = 0; i < 256; ++i) { + if (values[i] != other.values[i]) + return false; + } + return true; + } }; static_assert(std::is_trivially_copyable_v); @@ -87,311 +93,311 @@ static_assert(std::is_trivially_copyable_v); */ class TripleBufferTest : public ::testing::Test { protected: - void SetUp() override { - // 创建测试环境 - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - - // 生成缓冲区名称 - buffer_name_ = env_->generate_segment_name("triple_buf"); - - // 创建 triple_buffer - buffer_ = std::make_unique>(buffer_name_); - } - - void TearDown() override { - buffer_.reset(); - env_.reset(); - } - - std::unique_ptr env_; - std::string buffer_name_; - std::unique_ptr> buffer_; + void SetUp() override { + // 创建测试环境 + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + + // 生成缓冲区名称 + buffer_name_ = env_->generate_segment_name("triple_buf"); + + // 创建 triple_buffer + buffer_ = std::make_unique>(buffer_name_); + } + + void TearDown() override { + buffer_.reset(); + env_.reset(); + } + + std::unique_ptr env_; + std::string buffer_name_; + std::unique_ptr> buffer_; }; /** * @brief 测试创建缓冲区,验证初始状态 */ TEST_F(TripleBufferTest, CreateBuffer) { - // 初始状态:没有新数据 - EXPECT_FALSE(buffer_->has_new_data()); - - // 初始状态:没有待提交的写操作 - EXPECT_EQ(buffer_->pending_writes(), 0u); + // 初始状态:没有新数据 + EXPECT_FALSE(buffer_->has_new_data()); + + // 初始状态:没有待提交的写操作 + EXPECT_EQ(buffer_->pending_writes(), 0u); } /** * @brief 测试写入并读取数据 */ TEST_F(TripleBufferTest, WriteAndRead) { - // 写入数据 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - - write_buf->value = 42; - write_buf->timestamp = 12345; - - buffer_->commit_write(); - - // 验证有新数据 - EXPECT_TRUE(buffer_->has_new_data()); - - // 读取数据 - const TestData* read_buf = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - - EXPECT_EQ(read_buf->value, 42); - EXPECT_EQ(read_buf->timestamp, 12345u); - - buffer_->commit_read(); - - // 提交读取后,新数据标志应清除 - EXPECT_FALSE(buffer_->has_new_data()); + // 写入数据 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + + write_buf->value = 42; + write_buf->timestamp = 12345; + + buffer_->commit_write(); + + // 验证有新数据 + EXPECT_TRUE(buffer_->has_new_data()); + + // 读取数据 + const TestData* read_buf = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + + EXPECT_EQ(read_buf->value, 42); + EXPECT_EQ(read_buf->timestamp, 12345u); + + buffer_->commit_read(); + + // 提交读取后,新数据标志应清除 + EXPECT_FALSE(buffer_->has_new_data()); } /** * @brief 测试获取写缓冲区 */ TEST_F(TripleBufferTest, GetWriteBuffer) { - // 第一次获取写缓冲区应该成功 - TestData* write_buf1 = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf1); - - // 在提交或丢弃之前,再次获取应该返回 nullptr - TestData* write_buf2 = buffer_->get_write_buffer(); - EXPECT_EQ(write_buf2, nullptr); - - // 提交后可以再次获取 - buffer_->commit_write(); - - TestData* write_buf3 = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf3); + // 第一次获取写缓冲区应该成功 + TestData* write_buf1 = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf1); + + // 在提交或丢弃之前,再次获取应该返回 nullptr + TestData* write_buf2 = buffer_->get_write_buffer(); + EXPECT_EQ(write_buf2, nullptr); + + // 提交后可以再次获取 + buffer_->commit_write(); + + TestData* write_buf3 = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf3); } /** * @brief 测试提交写入后 has_new_data() 返回 true */ TEST_F(TripleBufferTest, CommitWrite) { - EXPECT_FALSE(buffer_->has_new_data()); - - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 100; - - buffer_->commit_write(); - - // 提交后应该有新数据 - EXPECT_TRUE(buffer_->has_new_data()); + EXPECT_FALSE(buffer_->has_new_data()); + + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 100; + + buffer_->commit_write(); + + // 提交后应该有新数据 + EXPECT_TRUE(buffer_->has_new_data()); } /** * @brief 测试丢弃写入不影响读端 */ TEST_F(TripleBufferTest, DiscardWrite) { - // 先写入一个值 - TestData* write_buf1 = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf1); - write_buf1->value = 100; - buffer_->commit_write(); - - // 读取并提交 - const TestData* read_buf1 = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf1); - EXPECT_EQ(read_buf1->value, 100); - buffer_->commit_read(); - - // 获取写缓冲区但丢弃 - TestData* write_buf2 = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf2); - write_buf2->value = 999; // 写入新值 - buffer_->discard_write(); // 丢弃 - - // 不应该有新数据 - EXPECT_FALSE(buffer_->has_new_data()); - - // 读取应该返回 nullptr(因为没有数据) - const TestData* read_buf2 = buffer_->get_read_buffer(); - EXPECT_EQ(read_buf2, nullptr); + // 先写入一个值 + TestData* write_buf1 = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf1); + write_buf1->value = 100; + buffer_->commit_write(); + + // 读取并提交 + const TestData* read_buf1 = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf1); + EXPECT_EQ(read_buf1->value, 100); + buffer_->commit_read(); + + // 获取写缓冲区但丢弃 + TestData* write_buf2 = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf2); + write_buf2->value = 999; // 写入新值 + buffer_->discard_write(); // 丢弃 + + // 不应该有新数据 + EXPECT_FALSE(buffer_->has_new_data()); + + // 读取应该返回 nullptr(因为没有数据) + const TestData* read_buf2 = buffer_->get_read_buffer(); + EXPECT_EQ(read_buf2, nullptr); } /** * @brief 测试获取读缓冲区 */ TEST_F(TripleBufferTest, GetReadBuffer) { - // 没有数据时,获取读缓冲区应该返回 nullptr - const TestData* read_buf1 = buffer_->get_read_buffer(); - EXPECT_EQ(read_buf1, nullptr); - - // 写入数据 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 42; - buffer_->commit_write(); - - // 现在应该能获取读缓冲区 - const TestData* read_buf2 = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf2); - EXPECT_EQ(read_buf2->value, 42); + // 没有数据时,获取读缓冲区应该返回 nullptr + const TestData* read_buf1 = buffer_->get_read_buffer(); + EXPECT_EQ(read_buf1, nullptr); + + // 写入数据 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 42; + buffer_->commit_write(); + + // 现在应该能获取读缓冲区 + const TestData* read_buf2 = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf2); + EXPECT_EQ(read_buf2->value, 42); } /** * @brief 测试提交读取后清除新数据标志 */ TEST_F(TripleBufferTest, CommitRead) { - // 写入数据 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 42; - buffer_->commit_write(); - - EXPECT_TRUE(buffer_->has_new_data()); - - // 读取数据 - const TestData* read_buf = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - - buffer_->commit_read(); - - // 提交读取后,新数据标志应清除 - EXPECT_FALSE(buffer_->has_new_data()); + // 写入数据 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 42; + buffer_->commit_write(); + + EXPECT_TRUE(buffer_->has_new_data()); + + // 读取数据 + const TestData* read_buf = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + + buffer_->commit_read(); + + // 提交读取后,新数据标志应清除 + EXPECT_FALSE(buffer_->has_new_data()); } /** * @brief 测试 has_new_data() 状态 */ TEST_F(TripleBufferTest, HasNewData) { - // 初始状态:无新数据 - EXPECT_FALSE(buffer_->has_new_data()); - - // 写入并提交:有新数据 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 1; - buffer_->commit_write(); - EXPECT_TRUE(buffer_->has_new_data()); - - // 读取并提交:无新数据 - const TestData* read_buf = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - buffer_->commit_read(); - EXPECT_FALSE(buffer_->has_new_data()); + // 初始状态:无新数据 + EXPECT_FALSE(buffer_->has_new_data()); + + // 写入并提交:有新数据 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 1; + buffer_->commit_write(); + EXPECT_TRUE(buffer_->has_new_data()); + + // 读取并提交:无新数据 + const TestData* read_buf = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + buffer_->commit_read(); + EXPECT_FALSE(buffer_->has_new_data()); } /** * @brief 测试 pending_writes() 计数 */ TEST_F(TripleBufferTest, PendingWrites) { - // 初始状态:无待提交的写操作 - EXPECT_EQ(buffer_->pending_writes(), 0u); - - // 获取写缓冲区后:有1个待提交的写操作 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - EXPECT_EQ(buffer_->pending_writes(), 1u); - - // 提交后:无待提交的写操作 - buffer_->commit_write(); - EXPECT_EQ(buffer_->pending_writes(), 0u); - - // 再次获取 - write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - EXPECT_EQ(buffer_->pending_writes(), 1u); - - // 丢弃后:无待提交的写操作 - buffer_->discard_write(); - EXPECT_EQ(buffer_->pending_writes(), 0u); + // 初始状态:无待提交的写操作 + EXPECT_EQ(buffer_->pending_writes(), 0u); + + // 获取写缓冲区后:有1个待提交的写操作 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + EXPECT_EQ(buffer_->pending_writes(), 1u); + + // 提交后:无待提交的写操作 + buffer_->commit_write(); + EXPECT_EQ(buffer_->pending_writes(), 0u); + + // 再次获取 + write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + EXPECT_EQ(buffer_->pending_writes(), 1u); + + // 丢弃后:无待提交的写操作 + buffer_->discard_write(); + EXPECT_EQ(buffer_->pending_writes(), 0u); } /** * @brief 测试连续多次写入,读取最新数据 */ TEST_F(TripleBufferTest, MultipleWrites) { - // 第一次写入 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 1; - buffer_->commit_write(); - - // 第二次写入(覆盖) - write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 2; - buffer_->commit_write(); - - // 第三次写入(覆盖) - write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 3; - buffer_->commit_write(); - - // 读取应该得到最新的值 - const TestData* read_buf = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - EXPECT_EQ(read_buf->value, 3); + // 第一次写入 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 1; + buffer_->commit_write(); + + // 第二次写入(覆盖) + write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 2; + buffer_->commit_write(); + + // 第三次写入(覆盖) + write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 3; + buffer_->commit_write(); + + // 读取应该得到最新的值 + const TestData* read_buf = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + EXPECT_EQ(read_buf->value, 3); } /** * @brief 测试重复读取返回相同数据直到新写入 */ TEST_F(TripleBufferTest, RepeatedReads) { - // 写入数据 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 42; - buffer_->commit_write(); - - // 第一次读取 - const TestData* read_buf1 = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf1); - EXPECT_EQ(read_buf1->value, 42); - - // 第二次读取(不提交第一次读取 - const TestData* read_buf2 = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf2); - EXPECT_EQ(read_buf2->value, 42); - EXPECT_EQ(read_buf1, read_buf2); // 应该是同一个缓冲区 - - // 提交读取 - buffer_->commit_read(); - - // 没有新数据时,读取返回 nullptr - const TestData* read_buf3 = buffer_->get_read_buffer(); - EXPECT_EQ(read_buf3, nullptr); - - // 写入新数据 - write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 100; - buffer_->commit_write(); - - // 现在应该能读取新数据 - const TestData* read_buf4 = buffer_->get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf4); - EXPECT_EQ(read_buf4->value, 100); + // 写入数据 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 42; + buffer_->commit_write(); + + // 第一次读取 + const TestData* read_buf1 = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf1); + EXPECT_EQ(read_buf1->value, 42); + + // 第二次读取(不提交第一次读取 + const TestData* read_buf2 = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf2); + EXPECT_EQ(read_buf2->value, 42); + EXPECT_EQ(read_buf1, read_buf2); // 应该是同一个缓冲区 + + // 提交读取 + buffer_->commit_read(); + + // 没有新数据时,读取返回 nullptr + const TestData* read_buf3 = buffer_->get_read_buffer(); + EXPECT_EQ(read_buf3, nullptr); + + // 写入新数据 + write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 100; + buffer_->commit_write(); + + // 现在应该能读取新数据 + const TestData* read_buf4 = buffer_->get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf4); + EXPECT_EQ(read_buf4->value, 100); } /** * @brief 测试获取写缓冲区但不提交 */ TEST_F(TripleBufferTest, WriteWithoutCommit) { - // 获取写缓冲区 - TestData* write_buf = buffer_->get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - write_buf->value = 999; - - // 不提交,直接检查状态 - EXPECT_FALSE(buffer_->has_new_data()); - EXPECT_EQ(buffer_->pending_writes(), 1u); - - // 读取应该返回 nullptr - const TestData* read_buf = buffer_->get_read_buffer(); - EXPECT_EQ(read_buf, nullptr); - - // 丢弃写入 - buffer_->discard_write(); - EXPECT_EQ(buffer_->pending_writes(), 0u); + // 获取写缓冲区 + TestData* write_buf = buffer_->get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + write_buf->value = 999; + + // 不提交,直接检查状态 + EXPECT_FALSE(buffer_->has_new_data()); + EXPECT_EQ(buffer_->pending_writes(), 1u); + + // 读取应该返回 nullptr + const TestData* read_buf = buffer_->get_read_buffer(); + EXPECT_EQ(read_buf, nullptr); + + // 丢弃写入 + buffer_->discard_write(); + EXPECT_EQ(buffer_->pending_writes(), 0u); } // ============================================================================ @@ -403,87 +409,87 @@ TEST_F(TripleBufferTest, WriteWithoutCommit) { */ class TripleBufferParamTest : public ::testing::TestWithParam { protected: - void SetUp() override { - env_ = std::make_unique( - ::testing::UnitTest::GetInstance()->current_test_info()->name() - ); - } - - void TearDown() override { - env_.reset(); - } - - std::unique_ptr env_; + void SetUp() override { + env_ = std::make_unique( + ::testing::UnitTest::GetInstance()->current_test_info()->name() + ); + } + + void TearDown() override { + env_.reset(); + } + + std::unique_ptr env_; }; /** * @brief 测试不同大小数据结构的三缓冲区操作 */ TEST_P(TripleBufferParamTest, DifferentDataSizes) { - size_t data_size = GetParam(); - - if (data_size == sizeof(int)) { - // 小数据 - auto buffer_name = env_->generate_segment_name("small"); - triple_buffer buffer(buffer_name); - - // 写入 - int* write_buf = buffer.get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - *write_buf = 12345; - buffer.commit_write(); - - // 读取 - EXPECT_TRUE(buffer.has_new_data()); - const int* read_buf = buffer.get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - EXPECT_EQ(*read_buf, 12345); - buffer.commit_read(); - - } else if (data_size == sizeof(MediumData)) { - // 中等数据 - auto buffer_name = env_->generate_segment_name("medium"); - triple_buffer buffer(buffer_name); - - // 写入 - MediumData* write_buf = buffer.get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - *write_buf = MediumData(100); - buffer.commit_write(); - - // 读取 - EXPECT_TRUE(buffer.has_new_data()); - const MediumData* read_buf = buffer.get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - EXPECT_EQ(*read_buf, MediumData(100)); - buffer.commit_read(); - - } else if (data_size == sizeof(LargeData)) { - // 大数据 - auto buffer_name = env_->generate_segment_name("large"); - triple_buffer buffer(buffer_name); - - // 写入 - LargeData* write_buf = buffer.get_write_buffer(); - ASSERT_SHM_NOT_NULL(write_buf); - *write_buf = LargeData(1000); - buffer.commit_write(); - - // 读取 - EXPECT_TRUE(buffer.has_new_data()); - const LargeData* read_buf = buffer.get_read_buffer(); - ASSERT_SHM_NOT_NULL(read_buf); - EXPECT_EQ(*read_buf, LargeData(1000)); - buffer.commit_read(); - } + size_t data_size = GetParam(); + + if (data_size == sizeof(int)) { + // 小数据 + auto buffer_name = env_->generate_segment_name("small"); + triple_buffer buffer(buffer_name); + + // 写入 + int* write_buf = buffer.get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + *write_buf = 12345; + buffer.commit_write(); + + // 读取 + EXPECT_TRUE(buffer.has_new_data()); + const int* read_buf = buffer.get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + EXPECT_EQ(*read_buf, 12345); + buffer.commit_read(); + } + else if (data_size == sizeof(MediumData)) { + // 中等数据 + auto buffer_name = env_->generate_segment_name("medium"); + triple_buffer buffer(buffer_name); + + // 写入 + MediumData* write_buf = buffer.get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + *write_buf = MediumData(100); + buffer.commit_write(); + + // 读取 + EXPECT_TRUE(buffer.has_new_data()); + const MediumData* read_buf = buffer.get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + EXPECT_EQ(*read_buf, MediumData(100)); + buffer.commit_read(); + } + else if (data_size == sizeof(LargeData)) { + // 大数据 + auto buffer_name = env_->generate_segment_name("large"); + triple_buffer buffer(buffer_name); + + // 写入 + LargeData* write_buf = buffer.get_write_buffer(); + ASSERT_SHM_NOT_NULL(write_buf); + *write_buf = LargeData(1000); + buffer.commit_write(); + + // 读取 + EXPECT_TRUE(buffer.has_new_data()); + const LargeData* read_buf = buffer.get_read_buffer(); + ASSERT_SHM_NOT_NULL(read_buf); + EXPECT_EQ(*read_buf, LargeData(1000)); + buffer.commit_read(); + } } INSTANTIATE_TEST_SUITE_P( - DataSizes, - TripleBufferParamTest, - ::testing::Values( - sizeof(int), // 小数据:4 字节 - sizeof(MediumData), // 中等数据:256 字节 - sizeof(LargeData) // 大数据:1024 字节 - ) -); \ No newline at end of file + DataSizes, + TripleBufferParamTest, + ::testing::Values( + sizeof(int), // 小数据:4 字节 + sizeof(MediumData), // 中等数据:256 字节 + sizeof(LargeData) // 大数据:1024 字节 + ) +); diff --git a/tests/test_audio_processing_comprehensive.cpp b/tests/test_audio_processing_comprehensive.cpp index 7ca9bb1..15b5c5a 100644 --- a/tests/test_audio_processing_comprehensive.cpp +++ b/tests/test_audio_processing_comprehensive.cpp @@ -41,34 +41,36 @@ // 测试容差设置 constexpr float FLOAT_TOLERANCE = 1e-6f; -constexpr float RMS_TOLERANCE = 1e-5f; -constexpr float PEAK_TOLERANCE = 1e-6f; +constexpr float RMS_TOLERANCE = 1e-5f; +constexpr float PEAK_TOLERANCE = 1e-6f; // 性能测试设置 -constexpr size_t PERF_TEST_SIZE = 1024 * 1024; // 1M samples -constexpr int PERF_TEST_ITERATIONS = 100; +constexpr size_t PERF_TEST_SIZE = 1024 * 1024; // 1M samples +constexpr int PERF_TEST_ITERATIONS = 100; /** * 浮点数比较函数 */ bool float_equal(float a, float b, float tolerance = FLOAT_TOLERANCE) { - if (std::isnan(a) && std::isnan(b)) return true; - if (std::isinf(a) && std::isinf(b)) return (a > 0) == (b > 0); - return std::abs(a - b) <= tolerance; + if (std::isnan(a) && std::isnan(b)) + return true; + if (std::isinf(a) && std::isinf(b)) + return (a > 0) == (b > 0); + return std::abs(a - b) <= tolerance; } /** * 数组比较函数 */ bool arrays_equal(const float* arr1, const float* arr2, size_t size, float tolerance = FLOAT_TOLERANCE) { - for (size_t i = 0; i < size; ++i) { - if (!float_equal(arr1[i], arr2[i], tolerance)) { - std::cout << " 差异在位置 " << i << ": " << arr1[i] << " vs " << arr2[i] - << " (差值: " << std::abs(arr1[i] - arr2[i]) << ")" << std::endl; - return false; - } - } - return true; + for (size_t i = 0; i < size; ++i) { + if (!float_equal(arr1[i], arr2[i], tolerance)) { + std::cout << " 差异在位置 " << i << ": " << arr1[i] << " vs " << arr2[i] + << " (差值: " << std::abs(arr1[i] - arr2[i]) << ")" << std::endl; + return false; + } + } + return true; } /** @@ -76,75 +78,82 @@ bool arrays_equal(const float* arr1, const float* arr2, size_t size, float toler */ class AudioDataGenerator { private: - mutable std::mt19937 rng_{std::random_device{}()}; - + mutable std::mt19937 rng_{std::random_device{}()}; + public: - // 生成正弦波 - auto generate_sine_wave(size_t num_samples, float frequency = 440.0f, - float sample_rate = 44100.0f, float amplitude = 1.0f) const { - std::vector> data(num_samples); - for (size_t i = 0; i < num_samples; ++i) { - data[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sample_rate); - } - return data; - } - - // 生成白噪声 - auto generate_white_noise(size_t num_samples, float amplitude = 1.0f) const { - std::vector> data(num_samples); - std::uniform_real_distribution dist(-amplitude, amplitude); - for (size_t i = 0; i < num_samples; ++i) { - data[i] = dist(rng_); - } - return data; - } - - // 生成脉冲信号 - auto generate_impulse(size_t num_samples, size_t impulse_pos = 0, float amplitude = 1.0f) const { - std::vector> data(num_samples, 0.0f); - if (impulse_pos < num_samples) { - data[impulse_pos] = amplitude; - } - return data; - } - - // 生成直流信号 - std::vector generate_dc(size_t num_samples, float value = 1.0f) const { - return std::vector(num_samples, value); - } - - // 生成立体声测试数据 - std::vector generate_stereo_test_data(size_t num_stereo_samples) const { - std::vector data(num_stereo_samples * 2); - for (size_t i = 0; i < num_stereo_samples; ++i) { - data[i * 2] = std::sin(2.0f * M_PI * 440.0f * i / 44100.0f); // 左声道: 440Hz - data[i * 2 + 1] = std::sin(2.0f * M_PI * 880.0f * i / 44100.0f); // 右声道: 880Hz - } - return data; - } - - // 生成边界测试数据 - std::vector generate_boundary_data(size_t num_samples) const { - std::vector data; - data.reserve(num_samples); - - // 添加各种边界值 - if (num_samples > 0) data.push_back(0.0f); - if (num_samples > 1) data.push_back(1.0f); - if (num_samples > 2) data.push_back(-1.0f); - if (num_samples > 3) data.push_back(std::numeric_limits::min()); - if (num_samples > 4) data.push_back(std::numeric_limits::max()); - if (num_samples > 5) data.push_back(std::numeric_limits::epsilon()); - if (num_samples > 6) data.push_back(-std::numeric_limits::epsilon()); - - // 填充剩余位置 - std::uniform_real_distribution dist(-1.0f, 1.0f); - while (data.size() < num_samples) { - data.push_back(dist(rng_)); - } - - return data; - } + // 生成正弦波 + auto generate_sine_wave(size_t num_samples, float frequency = 440.0f, + float sample_rate = 44100.0f, float amplitude = 1.0f) const { + std::vector> data(num_samples); + for (size_t i = 0; i < num_samples; ++i) { + data[i] = amplitude * std::sin(2.0f * M_PI * frequency * i / sample_rate); + } + return data; + } + + // 生成白噪声 + auto generate_white_noise(size_t num_samples, float amplitude = 1.0f) const { + std::vector> data(num_samples); + std::uniform_real_distribution dist(-amplitude, amplitude); + for (size_t i = 0; i < num_samples; ++i) { + data[i] = dist(rng_); + } + return data; + } + + // 生成脉冲信号 + auto generate_impulse(size_t num_samples, size_t impulse_pos = 0, float amplitude = 1.0f) const { + std::vector> data(num_samples, 0.0f); + if (impulse_pos < num_samples) { + data[impulse_pos] = amplitude; + } + return data; + } + + // 生成直流信号 + std::vector generate_dc(size_t num_samples, float value = 1.0f) const { + return std::vector(num_samples, value); + } + + // 生成立体声测试数据 + std::vector generate_stereo_test_data(size_t num_stereo_samples) const { + std::vector data(num_stereo_samples * 2); + for (size_t i = 0; i < num_stereo_samples; ++i) { + data[i * 2] = std::sin(2.0f * M_PI * 440.0f * i / 44100.0f); // 左声道: 440Hz + data[i * 2 + 1] = std::sin(2.0f * M_PI * 880.0f * i / 44100.0f); // 右声道: 880Hz + } + return data; + } + + // 生成边界测试数据 + std::vector generate_boundary_data(size_t num_samples) const { + std::vector data; + data.reserve(num_samples); + + // 添加各种边界值 + if (num_samples > 0) + data.push_back(0.0f); + if (num_samples > 1) + data.push_back(1.0f); + if (num_samples > 2) + data.push_back(-1.0f); + if (num_samples > 3) + data.push_back(std::numeric_limits::min()); + if (num_samples > 4) + data.push_back(std::numeric_limits::max()); + if (num_samples > 5) + data.push_back(std::numeric_limits::epsilon()); + if (num_samples > 6) + data.push_back(-std::numeric_limits::epsilon()); + + // 填充剩余位置 + std::uniform_real_distribution dist(-1.0f, 1.0f); + while (data.size() < num_samples) { + data.push_back(dist(rng_)); + } + + return data; + } }; /** @@ -152,46 +161,46 @@ public: */ class PerformanceTester { public: - template - double measure_execution_time(Func&& func, int iterations = PERF_TEST_ITERATIONS) { - auto start = std::chrono::high_resolution_clock::now(); - - for (int i = 0; i < iterations; ++i) { - func(); - } - - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start); - - return duration.count() / 1e6 / iterations; // 返回平均毫秒数 - } - - void print_performance_comparison(const std::string& test_name, - double scalar_time, - double simd_time) { - double speedup = scalar_time / simd_time; - std::cout << "[PERF] " << test_name << std::endl; - std::cout << " 标量版本: " << std::fixed << std::setprecision(3) << scalar_time << "ms" << std::endl; - std::cout << " SIMD版本: " << std::fixed << std::setprecision(3) << simd_time << "ms" << std::endl; - std::cout << " 加速比: " << std::fixed << std::setprecision(2) << speedup << "x" << std::endl; - } + template + double measure_execution_time(Func&& func, int iterations = PERF_TEST_ITERATIONS) { + auto start = std::chrono::high_resolution_clock::now(); + + for (int i = 0; i < iterations; ++i) { + func(); + } + + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start); + + return duration.count() / 1e6 / iterations; // 返回平均毫秒数 + } + + void print_performance_comparison(const std::string& test_name, + double scalar_time, + double simd_time) { + double speedup = scalar_time / simd_time; + std::cout << "[PERF] " << test_name << std::endl; + std::cout << " 标量版本: " << std::fixed << std::setprecision(3) << scalar_time << "ms" << std::endl; + std::cout << " SIMD版本: " << std::fixed << std::setprecision(3) << simd_time << "ms" << std::endl; + std::cout << " 加速比: " << std::fixed << std::setprecision(2) << speedup << "x" << std::endl; + } }; // 音频处理函数测试类 class AudioProcessingTest : public ::testing::Test { protected: - void SetUp() override { - // 初始化测试环境 - audio_processing_registry::register_all_functions(); - } + void SetUp() override { + // 初始化测试环境 + audio_processing_registry::register_all_functions(); + } - void TearDown() override { - // 清理测试环境 - } + void TearDown() override { + // 清理测试环境 + } - // 全局实例 - AudioDataGenerator data_gen; - PerformanceTester perf_tester; + // 全局实例 + AudioDataGenerator data_gen; + PerformanceTester perf_tester; }; /** @@ -202,111 +211,111 @@ protected: // 测试 simd_audio_processing_registry 注册功能 TEST_F(AudioProcessingTest, RegistryRegistration) { - // 打印已注册的函数以供调试 - audio_processing_registry::print_available_functions(); + // 打印已注册的函数以供调试 + audio_processing_registry::print_available_functions(); - // 验证关键函数已注册 - auto& dispatcher = simd_func_dispatcher::instance(); + // 验证关键函数已注册 + auto& dispatcher = simd_func_dispatcher::instance(); - EXPECT_NO_THROW({ - dispatcher.get_function("mix_audio"); - }) << "函数 mix_audio 未正确注册"; + EXPECT_NO_THROW({ + dispatcher.get_function("mix_audio"); + }) << "函数 mix_audio 未正确注册"; - EXPECT_NO_THROW({ - dispatcher.get_function("apply_gain"); - }) << "函数 apply_gain 未正确注册"; + EXPECT_NO_THROW({ + dispatcher.get_function("apply_gain"); + }) << "函数 apply_gain 未正确注册"; - EXPECT_NO_THROW({ - dispatcher.get_function("calculate_rms"); - }) << "函数 calculate_rms 未正确注册"; + EXPECT_NO_THROW({ + dispatcher.get_function("calculate_rms"); + }) << "函数 calculate_rms 未正确注册"; - EXPECT_NO_THROW({ - dispatcher.get_function("calculate_peak"); - }) << "函数 calculate_peak 未正确注册"; + EXPECT_NO_THROW({ + dispatcher.get_function("calculate_peak"); + }) << "函数 calculate_peak 未正确注册"; - EXPECT_NO_THROW({ - dispatcher.get_function("normalize_audio"); - }) << "函数 normalize_audio 未正确注册"; + EXPECT_NO_THROW({ + dispatcher.get_function("normalize_audio"); + }) << "函数 normalize_audio 未正确注册"; - EXPECT_NO_THROW({ - dispatcher.get_function("stereo_to_mono"); - }) << "函数 stereo_to_mono 未正确注册"; + EXPECT_NO_THROW({ + dispatcher.get_function("stereo_to_mono"); + }) << "函数 stereo_to_mono 未正确注册"; } // 测试 mix_audio 函数 TEST_F(AudioProcessingTest, MixAudioBasic) { - const size_t num_samples = 16; - auto src1 = data_gen.generate_sine_wave(num_samples, 440.0f); - auto src2 = data_gen.generate_sine_wave(num_samples, 880.0f); - std::vector result(num_samples); - std::vector expected(num_samples); - - // 计算期望结果 - for (size_t i = 0; i < num_samples; ++i) { - expected[i] = src1[i] + src2[i]; - } - - // 测试标量版本 - scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), result.data(), num_samples); - - EXPECT_TRUE(arrays_equal(result.data(), expected.data(), num_samples)) + const size_t num_samples = 16; + auto src1 = data_gen.generate_sine_wave(num_samples, 440.0f); + auto src2 = data_gen.generate_sine_wave(num_samples, 880.0f); + std::vector result(num_samples); + std::vector expected(num_samples); + + // 计算期望结果 + for (size_t i = 0; i < num_samples; ++i) { + expected[i] = src1[i] + src2[i]; + } + + // 测试标量版本 + scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), result.data(), num_samples); + + EXPECT_TRUE(arrays_equal(result.data(), expected.data(), num_samples)) << "混合音频结果与期望不符"; } // 测试 apply_gain 函数 TEST_F(AudioProcessingTest, ApplyGainBasic) { - const size_t num_samples = 16; - const float gain = 0.5f; - auto src = data_gen.generate_sine_wave(num_samples); - std::vector result(num_samples); - std::vector expected(num_samples); - - // 计算期望结果 - for (size_t i = 0; i < num_samples; ++i) { - expected[i] = src[i] * gain; - } - - // 测试标量版本 - scalar_audio_processing_func::apply_gain(src.data(), result.data(), gain, num_samples); - - EXPECT_TRUE(arrays_equal(result.data(), expected.data(), num_samples)) + const size_t num_samples = 16; + const float gain = 0.5f; + auto src = data_gen.generate_sine_wave(num_samples); + std::vector result(num_samples); + std::vector expected(num_samples); + + // 计算期望结果 + for (size_t i = 0; i < num_samples; ++i) { + expected[i] = src[i] * gain; + } + + // 测试标量版本 + scalar_audio_processing_func::apply_gain(src.data(), result.data(), gain, num_samples); + + EXPECT_TRUE(arrays_equal(result.data(), expected.data(), num_samples)) << "增益应用结果与期望不符"; } // 测试 calculate_rms 函数 TEST_F(AudioProcessingTest, CalculateRmsBasic) { - const size_t num_samples = 1024; - auto src = data_gen.generate_sine_wave(num_samples); - - // 计算期望的RMS值 - double sum_squares = 0.0; - for (size_t i = 0; i < num_samples; ++i) { - sum_squares += src[i] * src[i]; - } - float expected_rms = std::sqrt(sum_squares / num_samples); - - // 测试标量版本 - float result_rms = scalar_audio_processing_func::calculate_rms(src.data(), num_samples); - - EXPECT_TRUE(float_equal(result_rms, expected_rms, RMS_TOLERANCE)) + const size_t num_samples = 1024; + auto src = data_gen.generate_sine_wave(num_samples); + + // 计算期望的RMS值 + double sum_squares = 0.0; + for (size_t i = 0; i < num_samples; ++i) { + sum_squares += src[i] * src[i]; + } + float expected_rms = std::sqrt(sum_squares / num_samples); + + // 测试标量版本 + float result_rms = scalar_audio_processing_func::calculate_rms(src.data(), num_samples); + + EXPECT_TRUE(float_equal(result_rms, expected_rms, RMS_TOLERANCE)) << "期望 RMS: " << expected_rms << ", 得到: " << result_rms; } // 测试 calculate_peak 函数 TEST_F(AudioProcessingTest, CalculatePeakBasic) { - const size_t num_samples = 1024; - auto src = data_gen.generate_boundary_data(num_samples); - - // 计算期望的峰值 - float expected_peak = 0.0f; - for (size_t i = 0; i < num_samples; ++i) { - expected_peak = std::max(expected_peak, std::abs(src[i])); - } - - // 测试标量版本 - float result_peak = scalar_audio_processing_func::calculate_peak(src.data(), num_samples); - - EXPECT_TRUE(float_equal(result_peak, expected_peak, PEAK_TOLERANCE)) + const size_t num_samples = 1024; + auto src = data_gen.generate_boundary_data(num_samples); + + // 计算期望的峰值 + float expected_peak = 0.0f; + for (size_t i = 0; i < num_samples; ++i) { + expected_peak = std::max(expected_peak, std::abs(src[i])); + } + + // 测试标量版本 + float result_peak = scalar_audio_processing_func::calculate_peak(src.data(), num_samples); + + EXPECT_TRUE(float_equal(result_peak, expected_peak, PEAK_TOLERANCE)) << "期望峰值: " << expected_peak << ", 得到: " << result_peak; } @@ -318,118 +327,120 @@ TEST_F(AudioProcessingTest, CalculatePeakBasic) { // 测试 normalize_audio 函数 TEST_F(AudioProcessingTest, NormalizeAudioBasic) { - const size_t num_samples = 1024; - const float target_peak = 0.8f; - auto src = data_gen.generate_sine_wave(num_samples, 440.0f, 44100.0f, 2.0f); // 超过1.0的幅度 - std::vector result(num_samples); - - // 测试标量版本 - scalar_audio_processing_func::normalize_audio(src.data(), result.data(), target_peak, num_samples); - - // 验证归一化后的峰值 - float actual_peak = scalar_audio_processing_func::calculate_peak(result.data(), num_samples); - - EXPECT_TRUE(float_equal(actual_peak, target_peak, PEAK_TOLERANCE)) + const size_t num_samples = 1024; + const float target_peak = 0.8f; + auto src = data_gen.generate_sine_wave(num_samples, 440.0f, 44100.0f, 2.0f); // 超过1.0的幅度 + std::vector result(num_samples); + + // 测试标量版本 + scalar_audio_processing_func::normalize_audio(src.data(), result.data(), target_peak, num_samples); + + // 验证归一化后的峰值 + float actual_peak = scalar_audio_processing_func::calculate_peak(result.data(), num_samples); + + EXPECT_TRUE(float_equal(actual_peak, target_peak, PEAK_TOLERANCE)) << "期望峰值: " << target_peak << ", 实际峰值: " << actual_peak; } // 测试 stereo_to_mono 函数 TEST_F(AudioProcessingTest, StereoToMonoBasic) { - const size_t num_stereo_samples = 512; - auto stereo_src = data_gen.generate_stereo_test_data(num_stereo_samples); - std::vector mono_result(num_stereo_samples); - std::vector expected_mono(num_stereo_samples); - - // 计算期望结果 - for (size_t i = 0; i < num_stereo_samples; ++i) { - expected_mono[i] = (stereo_src[i * 2] + stereo_src[i * 2 + 1]) * 0.5f; - } - - // 测试标量版本 - scalar_audio_processing_func::stereo_to_mono(stereo_src.data(), mono_result.data(), num_stereo_samples); - - EXPECT_TRUE(arrays_equal(mono_result.data(), expected_mono.data(), num_stereo_samples)) + const size_t num_stereo_samples = 512; + auto stereo_src = data_gen.generate_stereo_test_data(num_stereo_samples); + std::vector mono_result(num_stereo_samples); + std::vector expected_mono(num_stereo_samples); + + // 计算期望结果 + for (size_t i = 0; i < num_stereo_samples; ++i) { + expected_mono[i] = (stereo_src[i * 2] + stereo_src[i * 2 + 1]) * 0.5f; + } + + // 测试标量版本 + scalar_audio_processing_func::stereo_to_mono(stereo_src.data(), mono_result.data(), num_stereo_samples); + + EXPECT_TRUE(arrays_equal(mono_result.data(), expected_mono.data(), num_stereo_samples)) << "立体声转单声道果与期望不符"; } // 测试 limit_audio 函数 TEST_F(AudioProcessingTest, LimitAudioBasic) { - const size_t num_samples = 1024; - const float threshold = 0.5f; - auto src = data_gen.generate_sine_wave(num_samples, 440.0f, 44100.0f, 1.0f); - std::vector result(num_samples); - float limiter_state = 1.0f; - - // 测试标量版本 - scalar_audio_processing_func::limit_audio(src.data(), result.data(), threshold, &limiter_state, 44100.f, num_samples); - - // 验证没有样本超过阈值 - bool all_samples_limited = true; - size_t violation_index = 0; - float violation_value = 0.0f; - - for (size_t i = 0; i < num_samples; ++i) { - if (std::abs(result[i]) > threshold + FLOAT_TOLERANCE) { - all_samples_limited = false; - violation_index = i; - violation_value = result[i]; - break; - } - } - - EXPECT_TRUE(all_samples_limited) + const size_t num_samples = 1024; + const float threshold = 0.5f; + auto src = data_gen.generate_sine_wave(num_samples, 440.0f, 44100.0f, 1.0f); + std::vector result(num_samples); + float limiter_state = 1.0f; + + // 测试标量版本 + scalar_audio_processing_func::limit_audio(src.data(), result.data(), threshold, &limiter_state, 44100.f, + num_samples); + + // 验证没有样本超过阈值 + bool all_samples_limited = true; + size_t violation_index = 0; + float violation_value = 0.0f; + + for (size_t i = 0; i < num_samples; ++i) { + if (std::abs(result[i]) > threshold + FLOAT_TOLERANCE) { + all_samples_limited = false; + violation_index = i; + violation_value = result[i]; + break; + } + } + + EXPECT_TRUE(all_samples_limited) << "样本 " << violation_index << " 超过阈值: " << violation_value; } // 测试 fade_audio 函数 TEST_F(AudioProcessingTest, FadeAudioBasic) { - const size_t num_samples = 1024; - const size_t fade_in_samples = 100; - const size_t fade_out_samples = 100; - auto src = data_gen.generate_dc(num_samples, 1.0f); - std::vector result(num_samples); - - // 测试标量版本 - scalar_audio_processing_func::fade_audio(src.data(), result.data(), fade_in_samples, fade_out_samples, num_samples); - - // 验证淡入淡出效果 - - // 检查淡入部分 - EXPECT_FLOAT_EQ(result[0], 0.0f) + const size_t num_samples = 1024; + const size_t fade_in_samples = 100; + const size_t fade_out_samples = 100; + auto src = data_gen.generate_dc(num_samples, 1.0f); + std::vector result(num_samples); + + // 测试标量版本 + scalar_audio_processing_func::fade_audio(src.data(), result.data(), fade_in_samples, fade_out_samples, num_samples); + + // 验证淡入淡出效果 + + // 检查淡入部分 + EXPECT_FLOAT_EQ(result[0], 0.0f) << "淡入开始应为0,实际为: " << result[0]; - - // 检查中间部分(应该是1.0) - EXPECT_FLOAT_EQ(result[num_samples / 2], 1.0f) + + // 检查中间部分(应该是1.0) + EXPECT_FLOAT_EQ(result[num_samples / 2], 1.0f) << "中间部分应保持原始值1.0,实际为: " << result[num_samples / 2]; - - // 检查淡出部分 - EXPECT_TRUE(float_equal(result[num_samples - 1], 0.0f, FLOAT_TOLERANCE)) + + // 检查淡出部分 + EXPECT_TRUE(float_equal(result[num_samples - 1], 0.0f, FLOAT_TOLERANCE)) << "淡出结束应为0,实际为: " << result[num_samples - 1]; } // 测试 simple_eq 函数 TEST_F(AudioProcessingTest, SimpleEqBasic) { - const size_t num_samples = 1024; - const float low_gain = 1.2f; - const float mid_gain = 1.0f; - const float high_gain = 0.8f; - auto src = data_gen.generate_white_noise(num_samples, 0.5f); - std::vector result(num_samples); - std::vector eq_state(2, 0.0f); // 低通和高通滤波器状态 - - // 测试标量版本 - scalar_audio_processing_func::simple_eq(src.data(), result.data(), low_gain, mid_gain, high_gain, eq_state.data(), num_samples); - - // 基本验证:结果不应该全为零(除非输入全为零) - bool has_nonzero = false; - for (size_t i = 0; i < num_samples; ++i) { - if (result[i] != 0.0f) { - has_nonzero = true; - break; - } - } - - EXPECT_TRUE(has_nonzero) + const size_t num_samples = 1024; + const float low_gain = 1.2f; + const float mid_gain = 1.0f; + const float high_gain = 0.8f; + auto src = data_gen.generate_white_noise(num_samples, 0.5f); + std::vector result(num_samples); + std::vector eq_state(2, 0.0f); // 低通和高通滤波器状态 + + // 测试标量版本 + scalar_audio_processing_func::simple_eq(src.data(), result.data(), low_gain, mid_gain, high_gain, eq_state.data(), + num_samples); + + // 基本验证:结果不应该全为零(除非输入全为零) + bool has_nonzero = false; + for (size_t i = 0; i < num_samples; ++i) { + if (result[i] != 0.0f) { + has_nonzero = true; + break; + } + } + + EXPECT_TRUE(has_nonzero) << "EQ处理后的结果全为零,可能存在处理错误"; } @@ -441,87 +452,87 @@ TEST_F(AudioProcessingTest, SimpleEqBasic) { // 测试零长度输入 TEST_F(AudioProcessingTest, ZeroLengthInput) { - const size_t num_samples = 0; - std::vector dummy(1, 0.0f); - std::vector result(1, 0.0f); - float state = 1.0f; - - // 这些函数应该能安全处理零长度输入 - EXPECT_NO_THROW({ - scalar_audio_processing_func::mix_audio(dummy.data(), dummy.data(), result.data(), num_samples); - scalar_audio_processing_func::apply_gain(dummy.data(), result.data(), 1.0f, num_samples); - scalar_audio_processing_func::normalize_audio(dummy.data(), result.data(), 1.0f, num_samples); - scalar_audio_processing_func::stereo_to_mono(dummy.data(), result.data(), num_samples); - scalar_audio_processing_func::limit_audio(dummy.data(), result.data(), 1.0f, &state, 44100.f, num_samples); - scalar_audio_processing_func::fade_audio(dummy.data(), result.data(), 0, 0, num_samples); - scalar_audio_processing_func::simple_eq(dummy.data(), result.data(), 1.0f, 1.0f, 1.0f, &state, num_samples); - }); + const size_t num_samples = 0; + std::vector dummy(1, 0.0f); + std::vector result(1, 0.0f); + float state = 1.0f; + + // 这些函数应该能安全处理零长度输入 + EXPECT_NO_THROW({ + scalar_audio_processing_func::mix_audio(dummy.data(), dummy.data(), result.data(), num_samples); + scalar_audio_processing_func::apply_gain(dummy.data(), result.data(), 1.0f, num_samples); + scalar_audio_processing_func::normalize_audio(dummy.data(), result.data(), 1.0f, num_samples); + scalar_audio_processing_func::stereo_to_mono(dummy.data(), result.data(), num_samples); + scalar_audio_processing_func::limit_audio(dummy.data(), result.data(), 1.0f, &state, 44100.f, num_samples); + scalar_audio_processing_func::fade_audio(dummy.data(), result.data(), 0, 0, num_samples); + scalar_audio_processing_func::simple_eq(dummy.data(), result.data(), 1.0f, 1.0f, 1.0f, &state, num_samples); + }); } // 测试单样本输入 TEST_F(AudioProcessingTest, SingleSampleInput) { - const size_t num_samples = 1; - std::vector src1{0.5f}; - std::vector src2{0.3f}; - std::vector result(1); - float state = 1.0f; - - // 测试混合 - scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), result.data(), num_samples); - EXPECT_TRUE(float_equal(result[0], 0.8f)) + const size_t num_samples = 1; + std::vector src1{0.5f}; + std::vector src2{0.3f}; + std::vector result(1); + float state = 1.0f; + + // 测试混合 + scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), result.data(), num_samples); + EXPECT_TRUE(float_equal(result[0], 0.8f)) << "混合单样本: 期望0.8,实际" << result[0]; - - // 测试增益 - scalar_audio_processing_func::apply_gain(src1.data(), result.data(), 2.0f, num_samples); - EXPECT_TRUE(float_equal(result[0], 1.0f)) + + // 测试增益 + scalar_audio_processing_func::apply_gain(src1.data(), result.data(), 2.0f, num_samples); + EXPECT_TRUE(float_equal(result[0], 1.0f)) << "应用增益: 期望1.0,实际" << result[0]; - - // 测试RMS - float rms = scalar_audio_processing_func::calculate_rms(src1.data(), num_samples); - EXPECT_TRUE(float_equal(rms, 0.5f)) + + // 测试RMS + float rms = scalar_audio_processing_func::calculate_rms(src1.data(), num_samples); + EXPECT_TRUE(float_equal(rms, 0.5f)) << "单样本RMS: 期望0.5,实际" << rms; - - // 测试峰值 - float peak = scalar_audio_processing_func::calculate_peak(src1.data(), num_samples); - EXPECT_TRUE(float_equal(peak, 0.5f)) + + // 测试峰值 + float peak = scalar_audio_processing_func::calculate_peak(src1.data(), num_samples); + EXPECT_TRUE(float_equal(peak, 0.5f)) << "单样本峰值: 期望0.5,实际" << peak; } // 测试极值处理 TEST_F(AudioProcessingTest, ExtremeValues) { - const size_t num_samples = 8; - std::vector extreme_values = { - 0.0f, - 1.0f, - -1.0f, - std::numeric_limits::max(), - -std::numeric_limits::max(), - std::numeric_limits::min(), - std::numeric_limits::epsilon(), - -std::numeric_limits::epsilon() - }; - - std::vector result(num_samples); - - // 测试峰值计算对极值的处理 - float peak = scalar_audio_processing_func::calculate_peak(extreme_values.data(), num_samples); - EXPECT_EQ(peak, std::numeric_limits::max()) + const size_t num_samples = 8; + std::vector extreme_values = { + 0.0f, + 1.0f, + -1.0f, + std::numeric_limits::max(), + -std::numeric_limits::max(), + std::numeric_limits::min(), + std::numeric_limits::epsilon(), + -std::numeric_limits::epsilon() + }; + + std::vector result(num_samples); + + // 测试峰值计算对极值的处理 + float peak = scalar_audio_processing_func::calculate_peak(extreme_values.data(), num_samples); + EXPECT_EQ(peak, std::numeric_limits::max()) << "极值峰值检测失败,期望" << std::numeric_limits::max() << ",实际" << peak; - - // 测试增益对极值的处理 - scalar_audio_processing_func::apply_gain(extreme_values.data(), result.data(), 0.5f, num_samples); - bool all_finite = true; - size_t nan_inf_index = 0; - - for (size_t i = 0; i < num_samples; ++i) { - if (std::isnan(result[i]) || std::isinf(result[i])) { - all_finite = false; - nan_inf_index = i; - break; - } - } - - EXPECT_TRUE(all_finite) + + // 测试增益对极值的处理 + scalar_audio_processing_func::apply_gain(extreme_values.data(), result.data(), 0.5f, num_samples); + bool all_finite = true; + size_t nan_inf_index = 0; + + for (size_t i = 0; i < num_samples; ++i) { + if (std::isnan(result[i]) || std::isinf(result[i])) { + all_finite = false; + nan_inf_index = i; + break; + } + } + + EXPECT_TRUE(all_finite) << "增益处理后在位置" << nan_inf_index << "存在NaN或Inf"; } @@ -535,67 +546,67 @@ TEST_F(AudioProcessingTest, ExtremeValues) { // 测试x86 SIMD版本与标量版的一致性 TEST_F(AudioProcessingTest, X86SimdConsistency) { - const size_t num_samples = 1024; - auto src1 = data_gen.generate_sine_wave(num_samples, 440.0f); - auto src2 = data_gen.generate_sine_wave(num_samples, 880.0f); - auto stereo_src = data_gen.generate_stereo_test_data(num_samples); - - std::vector scalar_result(num_samples); - std::vector sse_result(num_samples); - std::vector avx_result(num_samples); - std::vector> avx512_result(num_samples); - - // 测试 mix_audio 一致性 - scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), scalar_result.data(), num_samples); - x86_simd_audio_processing_func::mix_audio_sse(src1.data(), src2.data(), sse_result.data(), num_samples); - x86_simd_audio_processing_func::mix_audio_avx(src1.data(), src2.data(), avx_result.data(), num_samples); - x86_simd_audio_processing_func::mix_audio_avx512(src1.data(), src2.data(), avx512_result.data(), num_samples); - - EXPECT_TRUE(arrays_equal(scalar_result.data(), sse_result.data(), num_samples)) + const size_t num_samples = 1024; + auto src1 = data_gen.generate_sine_wave(num_samples, 440.0f); + auto src2 = data_gen.generate_sine_wave(num_samples, 880.0f); + auto stereo_src = data_gen.generate_stereo_test_data(num_samples); + + std::vector scalar_result(num_samples); + std::vector sse_result(num_samples); + std::vector avx_result(num_samples); + std::vector> avx512_result(num_samples); + + // 测试 mix_audio 一致性 + scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), scalar_result.data(), num_samples); + x86_simd_audio_processing_func::mix_audio_sse(src1.data(), src2.data(), sse_result.data(), num_samples); + x86_simd_audio_processing_func::mix_audio_avx(src1.data(), src2.data(), avx_result.data(), num_samples); + x86_simd_audio_processing_func::mix_audio_avx512(src1.data(), src2.data(), avx512_result.data(), num_samples); + + EXPECT_TRUE(arrays_equal(scalar_result.data(), sse_result.data(), num_samples)) << "mix_audio SSE版本与标量版本不一致"; - EXPECT_TRUE(arrays_equal(scalar_result.data(), avx_result.data(), num_samples)) + EXPECT_TRUE(arrays_equal(scalar_result.data(), avx_result.data(), num_samples)) << "mix_audio AVX版本与标量版本不一致"; - EXPECT_TRUE(arrays_equal(scalar_result.data(), avx512_result.data(), num_samples)) + EXPECT_TRUE(arrays_equal(scalar_result.data(), avx512_result.data(), num_samples)) << "mix_audio AVX512版本与标量版本不一致"; - - // 测试 apply_gain 一致性 - const float gain = 0.75f; - scalar_audio_processing_func::apply_gain(src1.data(), scalar_result.data(), gain, num_samples); - x86_simd_audio_processing_func::apply_gain_sse(src1.data(), sse_result.data(), gain, num_samples); - x86_simd_audio_processing_func::apply_gain_avx(src1.data(), avx_result.data(), gain, num_samples); - x86_simd_audio_processing_func::apply_gain_avx512(src1.data(), avx512_result.data(), gain, num_samples); - - EXPECT_TRUE(arrays_equal(scalar_result.data(), sse_result.data(), num_samples)) + + // 测试 apply_gain 一致性 + const float gain = 0.75f; + scalar_audio_processing_func::apply_gain(src1.data(), scalar_result.data(), gain, num_samples); + x86_simd_audio_processing_func::apply_gain_sse(src1.data(), sse_result.data(), gain, num_samples); + x86_simd_audio_processing_func::apply_gain_avx(src1.data(), avx_result.data(), gain, num_samples); + x86_simd_audio_processing_func::apply_gain_avx512(src1.data(), avx512_result.data(), gain, num_samples); + + EXPECT_TRUE(arrays_equal(scalar_result.data(), sse_result.data(), num_samples)) << "apply_gain SSE版本与标量版本不一致"; - EXPECT_TRUE(arrays_equal(scalar_result.data(), avx_result.data(), num_samples)) + EXPECT_TRUE(arrays_equal(scalar_result.data(), avx_result.data(), num_samples)) << "apply_gain AVX版本与标量版本不一致"; - EXPECT_TRUE(arrays_equal(scalar_result.data(), avx512_result.data(), num_samples)) + EXPECT_TRUE(arrays_equal(scalar_result.data(), avx512_result.data(), num_samples)) << "apply_gain AVX512版本与标量版本不一致"; - - // 测试 calculate_rms 一致性 - float scalar_rms = scalar_audio_processing_func::calculate_rms(src1.data(), num_samples); - float sse_rms = x86_simd_audio_processing_func::calculate_rms_sse(src1.data(), num_samples); - float avx_rms = x86_simd_audio_processing_func::calculate_rms_avx(src1.data(), num_samples); - float avx512_rms = x86_simd_audio_processing_func::calculate_rms_avx512(src1.data(), num_samples); - - EXPECT_TRUE(float_equal(scalar_rms, sse_rms, RMS_TOLERANCE)) + + // 测试 calculate_rms 一致性 + float scalar_rms = scalar_audio_processing_func::calculate_rms(src1.data(), num_samples); + float sse_rms = x86_simd_audio_processing_func::calculate_rms_sse(src1.data(), num_samples); + float avx_rms = x86_simd_audio_processing_func::calculate_rms_avx(src1.data(), num_samples); + float avx512_rms = x86_simd_audio_processing_func::calculate_rms_avx512(src1.data(), num_samples); + + EXPECT_TRUE(float_equal(scalar_rms, sse_rms, RMS_TOLERANCE)) << "calculate_rms SSE版本与标量版本不一致: " << scalar_rms << " vs " << sse_rms; - EXPECT_TRUE(float_equal(scalar_rms, avx_rms, RMS_TOLERANCE)) + EXPECT_TRUE(float_equal(scalar_rms, avx_rms, RMS_TOLERANCE)) << "calculate_rms AVX版本与标量版本不一致: " << scalar_rms << " vs " << avx_rms; - EXPECT_TRUE(float_equal(scalar_rms, avx512_rms, RMS_TOLERANCE)) + EXPECT_TRUE(float_equal(scalar_rms, avx512_rms, RMS_TOLERANCE)) << "calculate_rms AVX512版本与标量版本不一致: " << scalar_rms << " vs " << avx512_rms; - - // 测试 calculate_peak 一致性 - float scalar_peak = scalar_audio_processing_func::calculate_peak(src1.data(), num_samples); - float sse_peak = x86_simd_audio_processing_func::calculate_peak_sse(src1.data(), num_samples); - float avx_peak = x86_simd_audio_processing_func::calculate_peak_avx(src1.data(), num_samples); - float avx512_peak = x86_simd_audio_processing_func::calculate_peak_avx512(src1.data(), num_samples); - - EXPECT_TRUE(float_equal(scalar_peak, sse_peak, PEAK_TOLERANCE)) + + // 测试 calculate_peak 一致性 + float scalar_peak = scalar_audio_processing_func::calculate_peak(src1.data(), num_samples); + float sse_peak = x86_simd_audio_processing_func::calculate_peak_sse(src1.data(), num_samples); + float avx_peak = x86_simd_audio_processing_func::calculate_peak_avx(src1.data(), num_samples); + float avx512_peak = x86_simd_audio_processing_func::calculate_peak_avx512(src1.data(), num_samples); + + EXPECT_TRUE(float_equal(scalar_peak, sse_peak, PEAK_TOLERANCE)) << "calculate_peak SSE版本与标量版本不一致: " << scalar_peak << " vs " << sse_peak; - EXPECT_TRUE(float_equal(scalar_peak, avx_peak, PEAK_TOLERANCE)) + EXPECT_TRUE(float_equal(scalar_peak, avx_peak, PEAK_TOLERANCE)) << "calculate_peak AVX版本与标量版本不一致: " << scalar_peak << " vs " << avx_peak; - EXPECT_TRUE(float_equal(scalar_peak, avx512_peak, PEAK_TOLERANCE)) + EXPECT_TRUE(float_equal(scalar_peak, avx512_peak, PEAK_TOLERANCE)) << "calculate_peak AVX512版本与标量版本不一致: " << scalar_peak << " vs " << avx512_peak; } @@ -605,41 +616,41 @@ TEST_F(AudioProcessingTest, X86SimdConsistency) { // 测试ARM NEON版本与标量版的一致性 TEST_F(AudioProcessingTest, ArmSimdConsistency) { - const size_t num_samples = 1024; - auto src1 = data_gen.generate_sine_wave(num_samples, 440.0f); - auto src2 = data_gen.generate_sine_wave(num_samples, 880.0f); - - std::vector scalar_result(num_samples); - std::vector neon_result(num_samples); - - // 测试 mix_audio 一致性 - scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), scalar_result.data(), num_samples); - arm_simd_audio_processing_func::mix_audio_neon(src1.data(), src2.data(), neon_result.data(), num_samples); - - EXPECT_TRUE(arrays_equal(scalar_result.data(), neon_result.data(), num_samples)) - << "mix_audio NEON版本与标量版不一致"; - - // 测试 apply_gain 一致性 - const float gain = 0.75f; - scalar_audio_processing_func::apply_gain(src1.data(), scalar_result.data(), gain, num_samples); - arm_simd_audio_processing_func::apply_gain_neon(src1.data(), neon_result.data(), gain, num_samples); - - EXPECT_TRUE(arrays_equal(scalar_result.data(), neon_result.data(), num_samples)) - << "apply_gain NEON版本与标量版不一致"; - - // 测试 calculate_rms 一致性 - float scalar_rms = scalar_audio_processing_func::calculate_rms(src1.data(), num_samples); - float neon_rms = arm_simd_audio_processing_func::calculate_rms_neon(src1.data(), num_samples); - - EXPECT_TRUE(float_equal(scalar_rms, neon_rms, RMS_TOLERANCE)) - << "calculate_rms NEON版本与标量版不一致: " << scalar_rms << " vs " << neon_rms; - - // 测试 calculate_peak 一致性 - float scalar_peak = scalar_audio_processing_func::calculate_peak(src1.data(), num_samples); - float neon_peak = arm_simd_audio_processing_func::calculate_peak_neon(src1.data(), num_samples); - - EXPECT_TRUE(float_equal(scalar_peak, neon_peak, PEAK_TOLERANCE)) - << "calculate_peak NEON版本与标量版不一致: " << scalar_peak << " vs " << neon_peak; + const size_t num_samples = 1024; + auto src1 = data_gen.generate_sine_wave(num_samples, 440.0f); + auto src2 = data_gen.generate_sine_wave(num_samples, 880.0f); + + std::vector scalar_result(num_samples); + std::vector neon_result(num_samples); + + // 测试 mix_audio 一致性 + scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), scalar_result.data(), num_samples); + arm_simd_audio_processing_func::mix_audio_neon(src1.data(), src2.data(), neon_result.data(), num_samples); + + EXPECT_TRUE(arrays_equal(scalar_result.data(), neon_result.data(), num_samples)) + << "mix_audio NEON版本与标量版不一致"; + + // 测试 apply_gain 一致性 + const float gain = 0.75f; + scalar_audio_processing_func::apply_gain(src1.data(), scalar_result.data(), gain, num_samples); + arm_simd_audio_processing_func::apply_gain_neon(src1.data(), neon_result.data(), gain, num_samples); + + EXPECT_TRUE(arrays_equal(scalar_result.data(), neon_result.data(), num_samples)) + << "apply_gain NEON版本与标量版不一致"; + + // 测试 calculate_rms 一致性 + float scalar_rms = scalar_audio_processing_func::calculate_rms(src1.data(), num_samples); + float neon_rms = arm_simd_audio_processing_func::calculate_rms_neon(src1.data(), num_samples); + + EXPECT_TRUE(float_equal(scalar_rms, neon_rms, RMS_TOLERANCE)) + << "calculate_rms NEON版本与标量版不一致: " << scalar_rms << " vs " << neon_rms; + + // 测试 calculate_peak 一致性 + float scalar_peak = scalar_audio_processing_func::calculate_peak(src1.data(), num_samples); + float neon_peak = arm_simd_audio_processing_func::calculate_peak_neon(src1.data(), num_samples); + + EXPECT_TRUE(float_equal(scalar_peak, neon_peak, PEAK_TOLERANCE)) + << "calculate_peak NEON版本与标量版不一致: " << scalar_peak << " vs " << neon_peak; } #endif // ALICHO_PLATFORM_ARM @@ -652,135 +663,135 @@ TEST_F(AudioProcessingTest, ArmSimdConsistency) { // 性能测试运行为 TEST_F 测试,但不使用 EXPECT/ASSERT TEST_F(AudioProcessingTest, PerformanceTests) { - std::cout << "\n=== 性能测试 ===" << std::endl; - - // 生成大量测试数据 - auto src1 = data_gen.generate_sine_wave(PERF_TEST_SIZE, 440.0f); - auto src2 = data_gen.generate_sine_wave(PERF_TEST_SIZE, 880.0f); - std::vector result(PERF_TEST_SIZE); - - // 性能测试:mix_audio - { - double scalar_time = perf_tester.measure_execution_time([&]() { - scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); - }); - -#if ALICHO_PLATFORM_X86 - double sse_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::mix_audio_sse(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("mix_audio (SSE vs Scalar)", scalar_time, sse_time); - - double avx_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::mix_audio_avx(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("mix_audio (AVX vs Scalar)", scalar_time, avx_time); - - double avx512_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::mix_audio_avx512(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("mix_audio (AVX512 vs Scalar)", scalar_time, avx512_time); -#endif + std::cout << "\n=== 性能测试 ===" << std::endl; -#if ALICHO_PLATFORM_ARM - double neon_time = perf_tester.measure_execution_time([&]() { - arm_simd_audio_processing_func::mix_audio_neon(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("mix_audio (NEON vs Scalar)", scalar_time, neon_time); -#endif - } - - // 性能测试:apply_gain - { - const float gain = 0.5f; - double scalar_time = perf_tester.measure_execution_time([&]() { - scalar_audio_processing_func::apply_gain(src1.data(), result.data(), gain, PERF_TEST_SIZE); - }); - -#if ALICHO_PLATFORM_X86 - double sse_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::apply_gain_sse(src1.data(), result.data(), gain, PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("apply_gain (SSE vs Scalar)", scalar_time, sse_time); + // 生成大量测试数据 + auto src1 = data_gen.generate_sine_wave(PERF_TEST_SIZE, 440.0f); + auto src2 = data_gen.generate_sine_wave(PERF_TEST_SIZE, 880.0f); + std::vector result(PERF_TEST_SIZE); - double avx_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::apply_gain_avx(src1.data(), result.data(), gain, PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("apply_gain (AVX vs Scalar)", scalar_time, avx_time); + // 性能测试:mix_audio + { + double scalar_time = perf_tester.measure_execution_time([&]() { + scalar_audio_processing_func::mix_audio(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); + }); - double avx512_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::apply_gain_avx512(src1.data(), result.data(), gain, PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("apply_gain (AVX512 vs Scalar)", scalar_time, avx512_time); -#endif + #if ALICHO_PLATFORM_X86 + double sse_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::mix_audio_sse(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("mix_audio (SSE vs Scalar)", scalar_time, sse_time); -#if ALICHO_PLATFORM_ARM - double neon_time = perf_tester.measure_execution_time([&]() { - arm_simd_audio_processing_func::apply_gain_neon(src1.data(), result.data(), gain, PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("apply_gain (NEON vs Scalar)", scalar_time, neon_time); -#endif - } - - // 性能测试:calculate_rms - { - double scalar_time = perf_tester.measure_execution_time([&]() { - scalar_audio_processing_func::calculate_rms(src1.data(), PERF_TEST_SIZE); - }); - -#if ALICHO_PLATFORM_X86 - double sse_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::calculate_rms_sse(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_rms (SSE vs Scalar)", scalar_time, sse_time); + double avx_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::mix_audio_avx(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("mix_audio (AVX vs Scalar)", scalar_time, avx_time); - double avx_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::calculate_rms_avx(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_rms (AVX vs Scalar)", scalar_time, avx_time); + double avx512_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::mix_audio_avx512(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("mix_audio (AVX512 vs Scalar)", scalar_time, avx512_time); + #endif - double avx512_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::calculate_rms_avx512(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_rms (AVX512 vs Scalar)", scalar_time, avx512_time); -#endif + #if ALICHO_PLATFORM_ARM + double neon_time = perf_tester.measure_execution_time([&]() { + arm_simd_audio_processing_func::mix_audio_neon(src1.data(), src2.data(), result.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("mix_audio (NEON vs Scalar)", scalar_time, neon_time); + #endif + } -#if ALICHO_PLATFORM_ARM - double neon_time = perf_tester.measure_execution_time([&]() { - arm_simd_audio_processing_func::calculate_rms_neon(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_rms (NEON vs Scalar)", scalar_time, neon_time); -#endif - } + // 性能测试:apply_gain + { + const float gain = 0.5f; + double scalar_time = perf_tester.measure_execution_time([&]() { + scalar_audio_processing_func::apply_gain(src1.data(), result.data(), gain, PERF_TEST_SIZE); + }); - // 性能测试:calculate_peak - { - double scalar_time = perf_tester.measure_execution_time([&]() { - scalar_audio_processing_func::calculate_peak(src1.data(), PERF_TEST_SIZE); - }); + #if ALICHO_PLATFORM_X86 + double sse_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::apply_gain_sse(src1.data(), result.data(), gain, PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("apply_gain (SSE vs Scalar)", scalar_time, sse_time); -#if ALICHO_PLATFORM_X86 - double sse_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::calculate_peak_sse(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_peak (SSE vs Scalar)", scalar_time, sse_time); + double avx_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::apply_gain_avx(src1.data(), result.data(), gain, PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("apply_gain (AVX vs Scalar)", scalar_time, avx_time); - double avx_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::calculate_peak_avx(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_peak (AVX vs Scalar)", scalar_time, avx_time); + double avx512_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::apply_gain_avx512(src1.data(), result.data(), gain, PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("apply_gain (AVX512 vs Scalar)", scalar_time, avx512_time); + #endif - double avx512_time = perf_tester.measure_execution_time([&]() { - x86_simd_audio_processing_func::calculate_peak_avx512(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_peak (AVX512 vs Scalar)", scalar_time, avx512_time); -#endif + #if ALICHO_PLATFORM_ARM + double neon_time = perf_tester.measure_execution_time([&]() { + arm_simd_audio_processing_func::apply_gain_neon(src1.data(), result.data(), gain, PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("apply_gain (NEON vs Scalar)", scalar_time, neon_time); + #endif + } -#if ALICHO_PLATFORM_ARM - double neon_time = perf_tester.measure_execution_time([&]() { - arm_simd_audio_processing_func::calculate_peak_neon(src1.data(), PERF_TEST_SIZE); - }); - perf_tester.print_performance_comparison("calculate_peak (NEON vs Scalar)", scalar_time, neon_time); -#endif - } -} \ No newline at end of file + // 性能测试:calculate_rms + { + double scalar_time = perf_tester.measure_execution_time([&]() { + scalar_audio_processing_func::calculate_rms(src1.data(), PERF_TEST_SIZE); + }); + + #if ALICHO_PLATFORM_X86 + double sse_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::calculate_rms_sse(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_rms (SSE vs Scalar)", scalar_time, sse_time); + + double avx_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::calculate_rms_avx(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_rms (AVX vs Scalar)", scalar_time, avx_time); + + double avx512_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::calculate_rms_avx512(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_rms (AVX512 vs Scalar)", scalar_time, avx512_time); + #endif + + #if ALICHO_PLATFORM_ARM + double neon_time = perf_tester.measure_execution_time([&]() { + arm_simd_audio_processing_func::calculate_rms_neon(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_rms (NEON vs Scalar)", scalar_time, neon_time); + #endif + } + + // 性能测试:calculate_peak + { + double scalar_time = perf_tester.measure_execution_time([&]() { + scalar_audio_processing_func::calculate_peak(src1.data(), PERF_TEST_SIZE); + }); + + #if ALICHO_PLATFORM_X86 + double sse_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::calculate_peak_sse(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_peak (SSE vs Scalar)", scalar_time, sse_time); + + double avx_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::calculate_peak_avx(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_peak (AVX vs Scalar)", scalar_time, avx_time); + + double avx512_time = perf_tester.measure_execution_time([&]() { + x86_simd_audio_processing_func::calculate_peak_avx512(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_peak (AVX512 vs Scalar)", scalar_time, avx512_time); + #endif + + #if ALICHO_PLATFORM_ARM + double neon_time = perf_tester.measure_execution_time([&]() { + arm_simd_audio_processing_func::calculate_peak_neon(src1.data(), PERF_TEST_SIZE); + }); + perf_tester.print_performance_comparison("calculate_peak (NEON vs Scalar)", scalar_time, neon_time); + #endif + } +} diff --git a/tests/test_simd.cpp b/tests/test_simd.cpp index 11af894..a71af34 100644 --- a/tests/test_simd.cpp +++ b/tests/test_simd.cpp @@ -39,60 +39,61 @@ // 测试辅助函数 namespace simd_test_helpers { - // 简单的性能计时器 - class timer { - public: - timer() : start_(std::chrono::high_resolution_clock::now()) {} - - auto elapsed_ms() const -> double { - auto end = std::chrono::high_resolution_clock::now(); - auto duration = std::chrono::duration_cast(end - start_); - return duration.count() / 1000.0; - } - - private: - std::chrono::high_resolution_clock::time_point start_; - }; + // 简单的性能计时器 + class timer { + public: + timer() : start_(std::chrono::high_resolution_clock::now()) { + } - // 测试用的简单数学函数 - auto add_scalar(float a, float b) -> float { return a + b; } - auto add_sse(float a, float b) -> float { return a + b + 0.1f; } // 模拟SSE版本 - auto add_avx(float a, float b) -> float { return a + b + 0.2f; } // 模拟AVX版本 - - // 测试用的数组求和函数 - auto sum_array_scalar(const std::vector& arr) -> float { - float sum = 0.0f; - for (const auto& val : arr) { - sum += val; - } - return sum; - } - - auto sum_array_sse(const std::vector& arr) -> float { - // 模拟SSE实现 - return sum_array_scalar(arr) * 1.01f; - } - - auto sum_array_avx(const std::vector& arr) -> float { - // 模拟AVX实现 - return sum_array_scalar(arr) * 1.02f; - } + auto elapsed_ms() const -> double { + auto end = std::chrono::high_resolution_clock::now(); + auto duration = std::chrono::duration_cast(end - start_); + return duration.count() / 1000.0; + } - // 检查指针是否正确对齐 - template - auto is_properly_aligned(void* ptr) -> bool { - return (reinterpret_cast(ptr) % alignment) == 0; - } + private: + std::chrono::high_resolution_clock::time_point start_; + }; - // 生成测试数据 - auto generate_test_data(size_t size) -> std::vector { - std::vector data; - data.reserve(size); - for (size_t i = 0; i < size; ++i) { - data.push_back(static_cast(i) * 0.1f); - } - return data; - } + // 测试用的简单数学函数 + auto add_scalar(float a, float b) -> float { return a + b; } + auto add_sse(float a, float b) -> float { return a + b + 0.1f; } // 模拟SSE版本 + auto add_avx(float a, float b) -> float { return a + b + 0.2f; } // 模拟AVX版本 + + // 测试用的数组求和函数 + auto sum_array_scalar(const std::vector& arr) -> float { + float sum = 0.0f; + for (const auto& val : arr) { + sum += val; + } + return sum; + } + + auto sum_array_sse(const std::vector& arr) -> float { + // 模拟SSE实现 + return sum_array_scalar(arr) * 1.01f; + } + + auto sum_array_avx(const std::vector& arr) -> float { + // 模拟AVX实现 + return sum_array_scalar(arr) * 1.02f; + } + + // 检查指针是否正确对齐 + template + auto is_properly_aligned(void* ptr) -> bool { + return (reinterpret_cast(ptr) % alignment) == 0; + } + + // 生成测试数据 + auto generate_test_data(size_t size) -> std::vector { + std::vector data; + data.reserve(size); + for (size_t i = 0; i < size; ++i) { + data.push_back(static_cast(i) * 0.1f); + } + return data; + } } // ============================================================================= @@ -101,16 +102,16 @@ namespace simd_test_helpers { class simd_test : public ::testing::Test { protected: - void SetUp() override { - // 获取CPU信息用于后续测试 - cpu_info_ = &get_cpu_info(); - } + void SetUp() override { + // 获取CPU信息用于后续测试 + cpu_info_ = &get_cpu_info(); + } - void TearDown() override { - // 清理测试环境 - } + void TearDown() override { + // 清理测试环境 + } - const cpu_info* cpu_info_ = nullptr; + const cpu_info* cpu_info_ = nullptr; }; // ============================================================================= @@ -119,277 +120,277 @@ protected: // 基础功能测试 TEST_F(simd_test, CpuFeaturesTest_BasicDetection) { - ASSERT_NE(cpu_info_, nullptr); - - // 基本信息应该已填充 - EXPECT_FALSE(cpu_info_->vendor.empty()); - EXPECT_FALSE(cpu_info_->brand.empty()); - EXPECT_GT(cpu_info_->logical_cores, 0); - EXPECT_GT(cpu_info_->physical_cores, 0); - - // 特性字符串应该可以生成 - auto features_str = cpu_info_->features_string(); - EXPECT_TRUE(features_str.empty() || !features_str.empty()); // 总是为真,但测试调用成功 - - std::cout << "CPU Vendor: " << cpu_info_->vendor << std::endl; - std::cout << "CPU Brand: " << cpu_info_->brand << std::endl; - std::cout << "Logical Cores: " << cpu_info_->logical_cores << std::endl; - std::cout << "Physical Cores: " << cpu_info_->physical_cores << std::endl; - std::cout << "Features: " << features_str << std::endl; + ASSERT_NE(cpu_info_, nullptr); + + // 基本信息应该已填充 + EXPECT_FALSE(cpu_info_->vendor.empty()); + EXPECT_FALSE(cpu_info_->brand.empty()); + EXPECT_GT(cpu_info_->logical_cores, 0); + EXPECT_GT(cpu_info_->physical_cores, 0); + + // 特性字符串应该可以生成 + auto features_str = cpu_info_->features_string(); + EXPECT_TRUE(features_str.empty() || !features_str.empty()); // 总是为真,但测试调用成功 + + std::cout << "CPU Vendor: " << cpu_info_->vendor << std::endl; + std::cout << "CPU Brand: " << cpu_info_->brand << std::endl; + std::cout << "Logical Cores: " << cpu_info_->logical_cores << std::endl; + std::cout << "Physical Cores: " << cpu_info_->physical_cores << std::endl; + std::cout << "Features: " << features_str << std::endl; } TEST_F(simd_test, CpuFeaturesTest_SimdLevelDetection) { - auto max_level = get_max_simd_level(); - auto recommended_level = get_recommended_simd_level(); - - // SIMD级别应该在有效范围内 - EXPECT_GE(static_cast(max_level), static_cast(simd_level::NONE)); - EXPECT_LE(static_cast(max_level), static_cast(simd_level::NEON_FP16)); - - EXPECT_GE(static_cast(recommended_level), static_cast(simd_level::NONE)); - EXPECT_LE(static_cast(recommended_level), static_cast(simd_level::NEON_FP16)); - - // 推荐级别不应该超过最大级别 - EXPECT_LE(static_cast(recommended_level), static_cast(max_level)); - - std::cout << "Max SIMD Level: " << static_cast(max_level) << std::endl; - std::cout << "Recommended SIMD Level: " << static_cast(recommended_level) << std::endl; + auto max_level = get_max_simd_level(); + auto recommended_level = get_recommended_simd_level(); + + // SIMD级别应该在有效范围内 + EXPECT_GE(static_cast(max_level), static_cast(simd_level::NONE)); + EXPECT_LE(static_cast(max_level), static_cast(simd_level::NEON_FP16)); + + EXPECT_GE(static_cast(recommended_level), static_cast(simd_level::NONE)); + EXPECT_LE(static_cast(recommended_level), static_cast(simd_level::NEON_FP16)); + + // 推荐级别不应该超过最大级别 + EXPECT_LE(static_cast(recommended_level), static_cast(max_level)); + + std::cout << "Max SIMD Level: " << static_cast(max_level) << std::endl; + std::cout << "Recommended SIMD Level: " << static_cast(recommended_level) << std::endl; } TEST_F(simd_test, CpuFeaturesTest_GlobalFunctions) { - // 测试全局便利函数 - const auto& info = get_cpu_info(); - EXPECT_EQ(&info, cpu_info_); - - // 测试特性检查函数 - auto sse_supported = cpu_supports(cpu_feature::SSE); - auto sse2_supported = cpu_supports(cpu_feature::SSE2); - - // 如果支持SSE2,应该也支持SSE - if (sse2_supported) { - EXPECT_TRUE(sse_supported); - } - - // 测试级别检查 - auto detector = &cpu_feature_detector::instance(); - EXPECT_EQ(detector->max_simd_level(), info.max_simd_level); - - // 验证支持级别检查逻辑 - EXPECT_TRUE(detector->supports_at_least(simd_level::NONE)); - - if (info.max_simd_level >= simd_level::SSE) { - EXPECT_TRUE(detector->supports_at_least(simd_level::SSE)); - } + // 测试全局便利函数 + const auto& info = get_cpu_info(); + EXPECT_EQ(&info, cpu_info_); + + // 测试特性检查函数 + auto sse_supported = cpu_supports(cpu_feature::SSE); + auto sse2_supported = cpu_supports(cpu_feature::SSE2); + + // 如果支持SSE2,应该也支持SSE + if (sse2_supported) { + EXPECT_TRUE(sse_supported); + } + + // 测试级别检查 + auto detector = &cpu_feature_detector::instance(); + EXPECT_EQ(detector->max_simd_level(), info.max_simd_level); + + // 验证支持级别检查逻辑 + EXPECT_TRUE(detector->supports_at_least(simd_level::NONE)); + + if (info.max_simd_level >= simd_level::SSE) { + EXPECT_TRUE(detector->supports_at_least(simd_level::SSE)); + } } // 平台兼容性测试 TEST_F(simd_test, CpuFeaturesTest_X86PlatformSupport) { -#if ALICHO_PLATFORM_X86 - // 在x86平台上,至少应该支持SSE - EXPECT_TRUE(cpu_supports(cpu_feature::SSE) || cpu_supports(cpu_feature::SSE2)); - - // 检查常见的x86特性 - std::vector x86_features = { - cpu_feature::SSE, cpu_feature::SSE2, cpu_feature::SSE3, - cpu_feature::AVX, cpu_feature::AVX2, cpu_feature::FMA - }; - - bool has_any_x86_feature = false; - for (auto feature : x86_features) { - if (cpu_supports(feature)) { - has_any_x86_feature = true; - break; - } - } - EXPECT_TRUE(has_any_x86_feature); -#else - GTEST_SKIP() << "Not x86 platform"; -#endif + #if ALICHO_PLATFORM_X86 + // 在x86平台上,至少应该支持SSE + EXPECT_TRUE(cpu_supports(cpu_feature::SSE) || cpu_supports(cpu_feature::SSE2)); + + // 检查常见的x86特性 + std::vector x86_features = { + cpu_feature::SSE, cpu_feature::SSE2, cpu_feature::SSE3, + cpu_feature::AVX, cpu_feature::AVX2, cpu_feature::FMA + }; + + bool has_any_x86_feature = false; + for (auto feature : x86_features) { + if (cpu_supports(feature)) { + has_any_x86_feature = true; + break; + } + } + EXPECT_TRUE(has_any_x86_feature); + #else + GTEST_SKIP() << "Not x86 platform"; + #endif } TEST_F(simd_test, CpuFeaturesTest_ArmPlatformSupport) { -#if ALICHO_PLATFORM_ARM - // 在ARM平台上,可能支持NEON - bool has_neon = cpu_supports(cpu_feature::NEON); - bool has_neon_fp16 = cpu_supports(cpu_feature::NEON_FP16); - - // 如果支持FP16,应该也支持基础NEON - if (has_neon_fp16) { - EXPECT_TRUE(has_neon); - } - - // 检查SIMD级别 - auto max_level = get_max_simd_level(); - if (has_neon) { - EXPECT_GE(static_cast(max_level), static_cast(simd_level::NEON)); - } -#else - GTEST_SKIP() << "Not ARM platform"; -#endif + #if ALICHO_PLATFORM_ARM + // 在ARM平台上,可能支持NEON + bool has_neon = cpu_supports(cpu_feature::NEON); + bool has_neon_fp16 = cpu_supports(cpu_feature::NEON_FP16); + + // 如果支持FP16,应该也支持基础NEON + if (has_neon_fp16) { + EXPECT_TRUE(has_neon); + } + + // 检查SIMD级别 + auto max_level = get_max_simd_level(); + if (has_neon) { + EXPECT_GE(static_cast(max_level), static_cast(simd_level::NEON)); + } + #else + GTEST_SKIP() << "Not ARM platform"; + #endif } TEST_F(simd_test, CpuFeaturesTest_CrossPlatformConsistency) { - // 跨平台一致性检查 - auto detector = &cpu_feature_detector::instance(); - - // 单例应该总是返回相同的实例 - EXPECT_EQ(detector, &cpu_feature_detector::instance()); - - // 多次调用应该返回相同的结果 - auto level1 = get_max_simd_level(); - auto level2 = get_max_simd_level(); - EXPECT_EQ(level1, level2); - - auto recommended1 = get_recommended_simd_level(); - auto recommended2 = get_recommended_simd_level(); - EXPECT_EQ(recommended1, recommended2); - - // 特性检测应该一致 - auto sse_check1 = cpu_supports(cpu_feature::SSE); - auto sse_check2 = cpu_supports(cpu_feature::SSE); - EXPECT_EQ(sse_check1, sse_check2); + // 跨平台一致性检查 + auto detector = &cpu_feature_detector::instance(); + + // 单例应该总是返回相同的实例 + EXPECT_EQ(detector, &cpu_feature_detector::instance()); + + // 多次调用应该返回相同的结果 + auto level1 = get_max_simd_level(); + auto level2 = get_max_simd_level(); + EXPECT_EQ(level1, level2); + + auto recommended1 = get_recommended_simd_level(); + auto recommended2 = get_recommended_simd_level(); + EXPECT_EQ(recommended1, recommended2); + + // 特性检测应该一致 + auto sse_check1 = cpu_supports(cpu_feature::SSE); + auto sse_check2 = cpu_supports(cpu_feature::SSE); + EXPECT_EQ(sse_check1, sse_check2); } // SIMD级别推荐测试 TEST_F(simd_test, CpuFeaturesTest_SimdLevelRecommendation) { - auto max_level = get_max_simd_level(); - auto recommended_level = get_recommended_simd_level(); - - // 推荐算法的合理性检查 - switch (max_level) { - case simd_level::NONE: - EXPECT_EQ(recommended_level, simd_level::NONE); - break; - case simd_level::SSE: - case simd_level::SSE3: - case simd_level::SSE4: - case simd_level::AVX: - case simd_level::AVX2: - // 对于这些级别,推荐级别应该等于最大级别 - EXPECT_EQ(recommended_level, max_level); - break; - case simd_level::AVX512: - // AVX512可能会回退到AVX2以确保兼容性 - EXPECT_TRUE(recommended_level == simd_level::AVX512 || - recommended_level == simd_level::AVX2); - break; - case simd_level::NEON: - case simd_level::NEON_FP16: - EXPECT_EQ(recommended_level, max_level); - break; - } + auto max_level = get_max_simd_level(); + auto recommended_level = get_recommended_simd_level(); + + // 推荐算法的合理性检查 + switch (max_level) { + case simd_level::NONE: + EXPECT_EQ(recommended_level, simd_level::NONE); + break; + case simd_level::SSE: + case simd_level::SSE3: + case simd_level::SSE4: + case simd_level::AVX: + case simd_level::AVX2: + // 对于这些级别,推荐级别应该等于最大级别 + EXPECT_EQ(recommended_level, max_level); + break; + case simd_level::AVX512: + // AVX512可能会回退到AVX2以确保兼容性 + EXPECT_TRUE(recommended_level == simd_level::AVX512 || + recommended_level == simd_level::AVX2); + break; + case simd_level::NEON: + case simd_level::NEON_FP16: + EXPECT_EQ(recommended_level, max_level); + break; + } } TEST_F(simd_test, CpuFeaturesTest_PerformanceGuidedSelection) { - // 测试性能引导的SIMD级别选择 - auto recommended = get_recommended_simd_level(); - auto max_level = get_max_simd_level(); - - // 推荐级别应该考虑性能和兼容性 - EXPECT_LE(static_cast(recommended), static_cast(max_level)); - - // 在AVX512的情况下,验证特殊逻辑 - if (max_level == simd_level::AVX512) { - bool has_avx512f = cpu_supports(cpu_feature::AVX512F); - bool has_avx512vl = cpu_supports(cpu_feature::AVX512VL); - bool has_avx512bw = cpu_supports(cpu_feature::AVX512BW); - - if (has_avx512f && has_avx512vl && has_avx512bw) { - // 应该根据CPU供应商和型号决定 - if (cpu_info_->vendor.find("AMD") != std::string::npos) { - EXPECT_EQ(recommended, simd_level::AVX512); - } - // Intel的情况下可能会有特殊处理 - } - } + // 测试性能引导的SIMD级别选择 + auto recommended = get_recommended_simd_level(); + auto max_level = get_max_simd_level(); + + // 推荐级别应该考虑性能和兼容性 + EXPECT_LE(static_cast(recommended), static_cast(max_level)); + + // 在AVX512的情况下,验证特殊逻辑 + if (max_level == simd_level::AVX512) { + bool has_avx512f = cpu_supports(cpu_feature::AVX512F); + bool has_avx512vl = cpu_supports(cpu_feature::AVX512VL); + bool has_avx512bw = cpu_supports(cpu_feature::AVX512BW); + + if (has_avx512f && has_avx512vl && has_avx512bw) { + // 应该根据CPU供应商和型号决定 + if (cpu_info_->vendor.find("AMD") != std::string::npos) { + EXPECT_EQ(recommended, simd_level::AVX512); + } + // Intel的情况下可能会有特殊处理 + } + } } // 异常处理测试 TEST_F(simd_test, CpuFeaturesTest_InvalidFeatureHandling) { - // 测试无效特性值的处理 - // 由于cpu_feature是enum class,编译器会阻止大多数无效值 - - // 测试边界值 - 使用一个明确未定义的特性值 - auto invalid_feature = static_cast(0); // 0值通常不代表任何特性 - EXPECT_NO_THROW({ - bool result = cpu_supports(invalid_feature); - // 0值应该返回false - EXPECT_FALSE(result); - }); - - // 测试特性位掩码的正确性 - uint32_t all_features = cpu_info_->features; - for (int bit = 0; bit < 32; ++bit) { - auto feature = static_cast(1U << bit); - bool expected = (all_features & (1U << bit)) != 0; - bool actual = cpu_supports(feature); - EXPECT_EQ(expected, actual) << "Bit " << bit << " mismatch"; - } + // 测试无效特性值的处理 + // 由于cpu_feature是enum class,编译器会阻止大多数无效值 + + // 测试边界值 - 使用一个明确未定义的特性值 + auto invalid_feature = static_cast(0); // 0值通常不代表任何特性 + EXPECT_NO_THROW({ + bool result = cpu_supports(invalid_feature); + // 0值应该返回false + EXPECT_FALSE(result); + }); + + // 测试特性位掩码的正确性 + uint32_t all_features = cpu_info_->features; + for (int bit = 0; bit < 32; ++bit) { + auto feature = static_cast(1U << bit); + bool expected = (all_features & (1U << bit)) != 0; + bool actual = cpu_supports(feature); + EXPECT_EQ(expected, actual) << "Bit " << bit << " mismatch"; + } } TEST_F(simd_test, CpuFeaturesTest_ThreadSafety) { - // 测试多线程安全性 - const int num_threads = 4; - const int calls_per_thread = 100; - - std::vector threads; - std::vector results(num_threads * calls_per_thread); - - // 启动多个线程同时访问CPU特性检测 - for (int t = 0; t < num_threads; ++t) { - threads.emplace_back([&, t]() { - for (int i = 0; i < calls_per_thread; ++i) { - int idx = t * calls_per_thread + i; - - // 测试不同的API调用 - switch (i % 4) { - case 0: - results[idx] = cpu_supports(cpu_feature::SSE); - break; - case 1: - results[idx] = (get_max_simd_level() != simd_level::NONE); - break; - case 2: - results[idx] = (get_recommended_simd_level() != simd_level::NONE); - break; - case 3: - results[idx] = !get_cpu_info().vendor.empty(); - break; - } - } - }); - } - - // 等待所有线程完成 - for (auto& thread : threads) { - thread.join(); - } - - // 验证同一类型的调用返回相同结果 - bool sse_result = cpu_supports(cpu_feature::SSE); - auto max_level = get_max_simd_level(); - auto recommended_level = get_recommended_simd_level(); - bool has_vendor = !get_cpu_info().vendor.empty(); - - for (int i = 0; i < calls_per_thread; ++i) { - for (int t = 0; t < num_threads; ++t) { - int idx = t * calls_per_thread + i; - switch (i % 4) { - case 0: - EXPECT_EQ(results[idx], sse_result); - break; - case 1: - EXPECT_EQ(results[idx], (max_level != simd_level::NONE)); - break; - case 2: - EXPECT_EQ(results[idx], (recommended_level != simd_level::NONE)); - break; - case 3: - EXPECT_EQ(results[idx], has_vendor); - break; - } - } - } + // 测试多线程安全性 + const int num_threads = 4; + const int calls_per_thread = 100; + + std::vector threads; + std::vector results(num_threads * calls_per_thread); + + // 启动多个线程同时访问CPU特性检测 + for (int t = 0; t < num_threads; ++t) { + threads.emplace_back([&, t]() { + for (int i = 0; i < calls_per_thread; ++i) { + int idx = t * calls_per_thread + i; + + // 测试不同的API调用 + switch (i % 4) { + case 0: + results[idx] = cpu_supports(cpu_feature::SSE); + break; + case 1: + results[idx] = (get_max_simd_level() != simd_level::NONE); + break; + case 2: + results[idx] = (get_recommended_simd_level() != simd_level::NONE); + break; + case 3: + results[idx] = !get_cpu_info().vendor.empty(); + break; + } + } + }); + } + + // 等待所有线程完成 + for (auto& thread : threads) { + thread.join(); + } + + // 验证同一类型的调用返回相同结果 + bool sse_result = cpu_supports(cpu_feature::SSE); + auto max_level = get_max_simd_level(); + auto recommended_level = get_recommended_simd_level(); + bool has_vendor = !get_cpu_info().vendor.empty(); + + for (int i = 0; i < calls_per_thread; ++i) { + for (int t = 0; t < num_threads; ++t) { + int idx = t * calls_per_thread + i; + switch (i % 4) { + case 0: + EXPECT_EQ(results[idx], sse_result); + break; + case 1: + EXPECT_EQ(results[idx], (max_level != simd_level::NONE)); + break; + case 2: + EXPECT_EQ(results[idx], (recommended_level != simd_level::NONE)); + break; + case 3: + EXPECT_EQ(results[idx], has_vendor); + break; + } + } + } } // ============================================================================= @@ -398,259 +399,261 @@ TEST_F(simd_test, CpuFeaturesTest_ThreadSafety) { // 函数注册和查找 TEST_F(simd_test, SimdDispatcherTest_FunctionRegistration) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 注册测试函数 - std::function scalar_add = simd_test_helpers::add_scalar; - std::function sse_add = simd_test_helpers::add_sse; - std::function avx_add = simd_test_helpers::add_avx; - - EXPECT_NO_THROW({ - dispatcher.register_function("test_add", simd_func_version::SCALAR, scalar_add); - dispatcher.register_function("test_add", simd_func_version::SSE, sse_add); - dispatcher.register_function("test_add", simd_func_version::AVX, avx_add); - }); - - // 验证函数已注册 - auto func_list = dispatcher.list_functions(); - EXPECT_TRUE(std::find(func_list.begin(), func_list.end(), "test_add") != func_list.end()); + auto& dispatcher = simd_func_dispatcher::instance(); + + // 注册测试函数 + std::function scalar_add = simd_test_helpers::add_scalar; + std::function sse_add = simd_test_helpers::add_sse; + std::function avx_add = simd_test_helpers::add_avx; + + EXPECT_NO_THROW({ + dispatcher.register_function("test_add", simd_func_version::SCALAR, scalar_add); + dispatcher.register_function("test_add", simd_func_version::SSE, sse_add); + dispatcher.register_function("test_add", simd_func_version::AVX, avx_add); + }); + + // 验证函数已注册 + auto func_list = dispatcher.list_functions(); + EXPECT_TRUE(std::find(func_list.begin(), func_list.end(), "test_add") != func_list.end()); } TEST_F(simd_test, SimdDispatcherTest_FunctionLookup) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 查找已注册的函数 - EXPECT_NO_THROW({ - const auto& func = dispatcher.get_function("test_add"); - - // 函数应该可以调用 - float result = func(1.0f, 2.0f); - EXPECT_GT(result, 0.0f); // 结果应该是正数 - }); - - // 查找不存在的函数应该抛出异常 - EXPECT_THROW({ - const auto& nonexistent = dispatcher.get_function("nonexistent_func"); - }, std::runtime_error); + auto& dispatcher = simd_func_dispatcher::instance(); + + // 查找已注册的函数 + EXPECT_NO_THROW({ + const auto& func = dispatcher.get_function("test_add"); + + // 函数应该可以调用 + float result = func(1.0f, 2.0f); + EXPECT_GT(result, 0.0f); // 结果应该是正数 + }); + + // 查找不存在的函数应该抛出异常 + EXPECT_THROW({ + const auto& nonexistent = dispatcher.get_function("nonexistent_func"); + }, std::runtime_error); } TEST_F(simd_test, SimdDispatcherTest_MultiVersionManagement) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 创建一个新的测试函数 - const std::string func_name = "multi_version_test"; - - // 注册多个版本 - dispatcher.register_function&)>( - func_name, simd_func_version::SCALAR, simd_test_helpers::sum_array_scalar); - dispatcher.register_function&)>( - func_name, simd_func_version::SSE, simd_test_helpers::sum_array_sse); - dispatcher.register_function&)>( - func_name, simd_func_version::AVX, simd_test_helpers::sum_array_avx); - - // 获取函数并测试 - const auto& func = dispatcher.get_function&)>(func_name); - - auto test_data = simd_test_helpers::generate_test_data(100); - float result = func(test_data); - - // 结果应该大于纯标量计算的结果(因为模拟的SIMD版本会增加系数) - float scalar_result = simd_test_helpers::sum_array_scalar(test_data); - EXPECT_GE(result, scalar_result); - - std::cout << "Multi-version result: " << result << " (scalar: " << scalar_result << ")" << std::endl; + auto& dispatcher = simd_func_dispatcher::instance(); + + // 创建一个新的测试函数 + const std::string func_name = "multi_version_test"; + + // 注册多个版本 + dispatcher.register_function&)>( + func_name, simd_func_version::SCALAR, simd_test_helpers::sum_array_scalar); + dispatcher.register_function&)>( + func_name, simd_func_version::SSE, simd_test_helpers::sum_array_sse); + dispatcher.register_function&)>( + func_name, simd_func_version::AVX, simd_test_helpers::sum_array_avx); + + // 获取函数并测试 + const auto& func = dispatcher.get_function&)>(func_name); + + auto test_data = simd_test_helpers::generate_test_data(100); + float result = func(test_data); + + // 结果应该大于纯标量计算的结果(因为模拟的SIMD版本会增加系数) + float scalar_result = simd_test_helpers::sum_array_scalar(test_data); + EXPECT_GE(result, scalar_result); + + std::cout << "Multi-version result: " << result << " (scalar: " << scalar_result << ")" << std::endl; } // 自动分发机制 TEST_F(simd_test, SimdDispatcherTest_AutomaticDispatch) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 测试自动分发是否选择最佳版本 - const std::string func_name = "auto_dispatch_test"; - - // 只注册标量版本 - dispatcher.register_function( - func_name, simd_func_version::SCALAR, - [](int a, int b) { return a + b; }); - - // 根据当前系统支持,可能还会注册其他版本 - if (cpu_supports(cpu_feature::SSE)) { - dispatcher.register_function( - func_name, simd_func_version::SSE, - [](int a, int b) { return a + b + 1; }); // SSE版本加1标识 - } - - if (cpu_supports(cpu_feature::AVX)) { - dispatcher.register_function( - func_name, simd_func_version::AVX, - [](int a, int b) { return a + b + 2; }); // AVX版本加2标识 - } - - // 测试分发选择 - const auto& func = dispatcher.get_function(func_name); - int result = func(10, 20); - - // 验证选择了正确的版本 - if (cpu_supports(cpu_feature::AVX)) { - EXPECT_EQ(result, 32); // 10 + 20 + 2 - } else if (cpu_supports(cpu_feature::SSE)) { - EXPECT_EQ(result, 31); // 10 + 20 + 1 - } else { - EXPECT_EQ(result, 30); // 10 + 20 - } + auto& dispatcher = simd_func_dispatcher::instance(); + + // 测试自动分发是否选择最佳版本 + const std::string func_name = "auto_dispatch_test"; + + // 只注册标量版本 + dispatcher.register_function( + func_name, simd_func_version::SCALAR, + [](int a, int b) { return a + b; }); + + // 根据当前系统支持,可能还会注册其他版本 + if (cpu_supports(cpu_feature::SSE)) { + dispatcher.register_function( + func_name, simd_func_version::SSE, + [](int a, int b) { return a + b + 1; }); // SSE版本加1标识 + } + + if (cpu_supports(cpu_feature::AVX)) { + dispatcher.register_function( + func_name, simd_func_version::AVX, + [](int a, int b) { return a + b + 2; }); // AVX版本加2标识 + } + + // 测试分发选择 + const auto& func = dispatcher.get_function(func_name); + int result = func(10, 20); + + // 验证选择了正确的版本 + if (cpu_supports(cpu_feature::AVX)) { + EXPECT_EQ(result, 32); // 10 + 20 + 2 + } + else if (cpu_supports(cpu_feature::SSE)) { + EXPECT_EQ(result, 31); // 10 + 20 + 1 + } + else { + EXPECT_EQ(result, 30); // 10 + 20 + } } TEST_F(simd_test, SimdDispatcherTest_PriorityBasedSelection) { - // 测试基于优先级的版本选择 - auto recommended_level = get_recommended_simd_level(); - auto expected_version = simd_level_to_version(recommended_level); - - std::cout << "Recommended SIMD level: " << static_cast(recommended_level) << std::endl; - std::cout << "Expected version: " << static_cast(expected_version) << std::endl; - - // 验证级别转换函数 - EXPECT_GE(static_cast(expected_version), static_cast(simd_func_version::SCALAR)); - EXPECT_LE(static_cast(expected_version), static_cast(simd_func_version::VECTOR)); - - // 测试转换一致性 - switch (recommended_level) { - case simd_level::NONE: - EXPECT_EQ(expected_version, simd_func_version::SCALAR); - break; - case simd_level::SSE: - EXPECT_EQ(expected_version, simd_func_version::SSE); - break; - case simd_level::AVX: - EXPECT_EQ(expected_version, simd_func_version::AVX); - break; - case simd_level::AVX2: - EXPECT_EQ(expected_version, simd_func_version::AVX2); - break; - default: - // 其他情况也应该有对应的版本 - break; - } + // 测试基于优先级的版本选择 + auto recommended_level = get_recommended_simd_level(); + auto expected_version = simd_level_to_version(recommended_level); + + std::cout << "Recommended SIMD level: " << static_cast(recommended_level) << std::endl; + std::cout << "Expected version: " << static_cast(expected_version) << std::endl; + + // 验证级别转换函数 + EXPECT_GE(static_cast(expected_version), static_cast(simd_func_version::SCALAR)); + EXPECT_LE(static_cast(expected_version), static_cast(simd_func_version::VECTOR)); + + // 测试转换一致性 + switch (recommended_level) { + case simd_level::NONE: + EXPECT_EQ(expected_version, simd_func_version::SCALAR); + break; + case simd_level::SSE: + EXPECT_EQ(expected_version, simd_func_version::SSE); + break; + case simd_level::AVX: + EXPECT_EQ(expected_version, simd_func_version::AVX); + break; + case simd_level::AVX2: + EXPECT_EQ(expected_version, simd_func_version::AVX2); + break; + default: + // 其他情况也应该有对应的版本 + break; + } } TEST_F(simd_test, SimdDispatcherTest_VersionFallback) { - auto& dispatcher = simd_func_dispatcher::instance(); - const std::string func_name = "fallback_test"; - - // 只注册标量版本,测试回退机制 - dispatcher.register_function( - func_name, simd_func_version::SCALAR, - [](double x) { return x * 2.0; }); - - // 即使系统支持更高级的SIMD,也应该回退到标量版本 - const auto& func = dispatcher.get_function(func_name); - double result = func(3.14); - EXPECT_DOUBLE_EQ(result, 6.28); - - // 现在注册一个高级版本 - if (cpu_supports(cpu_feature::AVX)) { - dispatcher.register_function( - func_name, simd_func_version::AVX, - [](double x) { return x * 3.0; }); // 不同的计算以验证选择了正确版本 - - // 重新获取函数,应该选择AVX版本 - const auto& avx_func = dispatcher.get_function(func_name); - double avx_result = avx_func(3.14); - EXPECT_DOUBLE_EQ(avx_result, 9.42); - } + auto& dispatcher = simd_func_dispatcher::instance(); + const std::string func_name = "fallback_test"; + + // 只注册标量版本,测试回退机制 + dispatcher.register_function( + func_name, simd_func_version::SCALAR, + [](double x) { return x * 2.0; }); + + // 即使系统支持更高级的SIMD,也应该回退到标量版本 + const auto& func = dispatcher.get_function(func_name); + double result = func(3.14); + EXPECT_DOUBLE_EQ(result, 6.28); + + // 现在注册一个高级版本 + if (cpu_supports(cpu_feature::AVX)) { + dispatcher.register_function( + func_name, simd_func_version::AVX, + [](double x) { return x * 3.0; }); // 不同的计算以验证选择了正确版本 + + // 重新获取函数,应该选择AVX版本 + const auto& avx_func = dispatcher.get_function(func_name); + double avx_result = avx_func(3.14); + EXPECT_DOUBLE_EQ(avx_result, 9.42); + } } // 宏接口测试 TEST_F(simd_test, SimdDispatcherTest_MacroInterface) { - // 测试注册宏 - EXPECT_NO_THROW({ - std::function square_func = [](int x) { return x * x; }; - REGISTER_SIMD_FUNCTION("macro_test", simd_func_version::SCALAR, square_func); - }); - - // 测试获取宏 - EXPECT_NO_THROW({ - const auto& func = GET_SIMD_FUNCTION(int(int), "macro_test"); - int result = func(5); - EXPECT_EQ(result, 25); - }); - - // 测试调用宏 - EXPECT_NO_THROW({ - int result = CALL_SIMD_FUNCTION(int(int), "macro_test", 6); - EXPECT_EQ(result, 36); - }); - - // 测试字符串转换函数 - EXPECT_STREQ(simd_func_version_to_string(simd_func_version::SCALAR), "SCALAR"); - EXPECT_STREQ(simd_func_version_to_string(simd_func_version::SSE), "SSE"); - EXPECT_STREQ(simd_func_version_to_string(simd_func_version::AVX), "AVX"); - - EXPECT_EQ(string_to_simd_func_version("SCALAR"), simd_func_version::SCALAR); - EXPECT_EQ(string_to_simd_func_version("SSE"), simd_func_version::SSE); - EXPECT_EQ(string_to_simd_func_version("AVX"), simd_func_version::AVX); - EXPECT_EQ(string_to_simd_func_version("INVALID"), simd_func_version::SCALAR); // 默认回退 + // 测试注册宏 + EXPECT_NO_THROW({ + std::function square_func = [](int x) { return x * x; }; + REGISTER_SIMD_FUNCTION("macro_test", simd_func_version::SCALAR, square_func); + }); + + // 测试获取宏 + EXPECT_NO_THROW({ + const auto& func = GET_SIMD_FUNCTION(int(int), "macro_test"); + int result = func(5); + EXPECT_EQ(result, 25); + }); + + // 测试调用宏 + EXPECT_NO_THROW({ + int result = CALL_SIMD_FUNCTION(int(int), "macro_test", 6); + EXPECT_EQ(result, 36); + }); + + // 测试字符串转换函数 + EXPECT_STREQ(simd_func_version_to_string(simd_func_version::SCALAR), "SCALAR"); + EXPECT_STREQ(simd_func_version_to_string(simd_func_version::SSE), "SSE"); + EXPECT_STREQ(simd_func_version_to_string(simd_func_version::AVX), "AVX"); + + EXPECT_EQ(string_to_simd_func_version("SCALAR"), simd_func_version::SCALAR); + EXPECT_EQ(string_to_simd_func_version("SSE"), simd_func_version::SSE); + EXPECT_EQ(string_to_simd_func_version("AVX"), simd_func_version::AVX); + EXPECT_EQ(string_to_simd_func_version("INVALID"), simd_func_version::SCALAR); // 默认回退 } TEST_F(simd_test, SimdDispatcherTest_TypeSafety) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 注册不同类型的函数 - dispatcher.register_function("int_func", simd_func_version::SCALAR, - [](int x) { return x + 1; }); - dispatcher.register_function("float_func", simd_func_version::SCALAR, - [](float x) { return x + 1.0f; }); - - // 类型安全检查 - EXPECT_NO_THROW({ - const auto& int_func = dispatcher.get_function("int_func"); - int result = int_func(42); - EXPECT_EQ(result, 43); - }); - - EXPECT_NO_THROW({ - const auto& float_func = dispatcher.get_function("float_func"); - float result = float_func(3.14f); - EXPECT_FLOAT_EQ(result, 4.14f); - }); - - // 尝试用不同的类型获取同名函数会创建独立的函数持有者 - EXPECT_NO_THROW({ - // 这会创建一个新的double类型函数持有者,与int类型的是分离的 - const auto& double_func = dispatcher.get_function("int_func"); - // 这验证了类型安全性 - 不同类型的函数是分离的 - }); + auto& dispatcher = simd_func_dispatcher::instance(); + + // 注册不同类型的函数 + dispatcher.register_function("int_func", simd_func_version::SCALAR, + [](int x) { return x + 1; }); + dispatcher.register_function("float_func", simd_func_version::SCALAR, + [](float x) { return x + 1.0f; }); + + // 类型安全检查 + EXPECT_NO_THROW({ + const auto& int_func = dispatcher.get_function("int_func"); + int result = int_func(42); + EXPECT_EQ(result, 43); + }); + + EXPECT_NO_THROW({ + const auto& float_func = dispatcher.get_function("float_func"); + float result = float_func(3.14f); + EXPECT_FLOAT_EQ(result, 4.14f); + }); + + // 尝试用不同的类型获取同名函数会创建独立的函数持有者 + EXPECT_NO_THROW({ + // 这会创建一个新的double类型函数持有者,与int类型的是分离的 + const auto& double_func = dispatcher.get_function("int_func"); + // 这验证了类型安全性 - 不同类型的函数是分离的 + }); } // 错误处理 TEST_F(simd_test, SimdDispatcherTest_InvalidRegistration) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 测试重复注册相同版本 - EXPECT_NO_THROW({ - dispatcher.register_function("duplicate_test", simd_func_version::SCALAR, - []() { return 1; }); - dispatcher.register_function("duplicate_test", simd_func_version::SCALAR, - []() { return 2; }); // 覆盖前一个 - }); - - // 验证最后注册的版本生效 - const auto& func = dispatcher.get_function("duplicate_test"); - int result = func(); - EXPECT_EQ(result, 2); + auto& dispatcher = simd_func_dispatcher::instance(); + + // 测试重复注册相同版本 + EXPECT_NO_THROW({ + dispatcher.register_function("duplicate_test", simd_func_version::SCALAR, + []() { return 1; }); + dispatcher.register_function("duplicate_test", simd_func_version::SCALAR, + []() { return 2; }); // 覆盖前一个 + }); + + // 验证最后注册的版本生效 + const auto& func = dispatcher.get_function("duplicate_test"); + int result = func(); + EXPECT_EQ(result, 2); } TEST_F(simd_test, SimdDispatcherTest_MissingFunction) { - auto& dispatcher = simd_func_dispatcher::instance(); - - // 尝试获取未注册的函数应该抛出异常 - EXPECT_THROW({ - const auto& missing_func = dispatcher.get_function("nonexistent_function"); - }, std::runtime_error); - - // 尝试调用未注册的函数 - EXPECT_THROW({ - CALL_SIMD_FUNCTION(void(), "another_nonexistent_function"); - }, std::runtime_error); + auto& dispatcher = simd_func_dispatcher::instance(); + + // 尝试获取未注册的函数应该抛出异常 + EXPECT_THROW({ + const auto& missing_func = dispatcher.get_function("nonexistent_function"); + }, std::runtime_error); + + // 尝试调用未注册的函数 + EXPECT_THROW({ + CALL_SIMD_FUNCTION(void(), "another_nonexistent_function"); + }, std::runtime_error); } // ============================================================================= @@ -659,278 +662,280 @@ TEST_F(simd_test, SimdDispatcherTest_MissingFunction) { // 基础分配测试 TEST_F(simd_test, AlignedAllocatorTest_BasicAllocation) { - // 测试基本的对齐分配 - constexpr size_t alignment = ALIGNMENT_AVX; // 32字节对齐 - constexpr size_t size = 1024; - - void* ptr = aligned_malloc(size, alignment); - ASSERT_NE(ptr, nullptr); - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(ptr)); - - // 写入数据验证可用性 - auto* data = static_cast(ptr); - for (size_t i = 0; i < size; ++i) { - data[i] = static_cast(i % 256); - } - - // 验证数据 - for (size_t i = 0; i < size; ++i) { - EXPECT_EQ(data[i], static_cast(i % 256)); - } - - aligned_free(ptr); + // 测试基本的对齐分配 + constexpr size_t alignment = ALIGNMENT_AVX; // 32字节对齐 + constexpr size_t size = 1024; + + void* ptr = aligned_malloc(size, alignment); + ASSERT_NE(ptr, nullptr); + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(ptr)); + + // 写入数据验证可用性 + auto* data = static_cast(ptr); + for (size_t i = 0; i < size; ++i) { + data[i] = static_cast(i % 256); + } + + // 验证数据 + for (size_t i = 0; i < size; ++i) { + EXPECT_EQ(data[i], static_cast(i % 256)); + } + + aligned_free(ptr); } TEST_F(simd_test, AlignedAllocatorTest_VariousAlignments) { - // 测试不同的对齐要求 - std::vector alignments = { - ALIGNMENT_SSE, // 16字节 - ALIGNMENT_AVX, // 32字节 - ALIGNMENT_AVX512, // 64字节 - ALIGNMENT_CACHE // 64字节(缓存行) - }; - - constexpr size_t size = 256; - - for (auto alignment : alignments) { - void* ptr = aligned_malloc(size, alignment); - ASSERT_NE(ptr, nullptr) << "Failed to allocate with alignment " << alignment; - - EXPECT_TRUE(is_aligned(ptr, alignment)) + // 测试不同的对齐要求 + std::vector alignments = { + ALIGNMENT_SSE, // 16字节 + ALIGNMENT_AVX, // 32字节 + ALIGNMENT_AVX512, // 64字节 + ALIGNMENT_CACHE // 64字节(缓存行) + }; + + constexpr size_t size = 256; + + for (auto alignment : alignments) { + void* ptr = aligned_malloc(size, alignment); + ASSERT_NE(ptr, nullptr) << "Failed to allocate with alignment " << alignment; + + EXPECT_TRUE(is_aligned(ptr, alignment)) << "Pointer not properly aligned to " << alignment << " bytes"; - - // 验证可以写入数据 - std::memset(ptr, 0xAB, size); - - aligned_free(ptr); - } + + // 验证可以写入数据 + std::memset(ptr, 0xAB, size); + + aligned_free(ptr); + } } TEST_F(simd_test, AlignedAllocatorTest_LargeAllocations) { - // 测试大块内存分配 - std::vector sizes = { - 1024, // 1KB - 1024 * 64, // 64KB - 1024 * 1024 // 1MB - }; - - constexpr size_t alignment = ALIGNMENT_AVX; - - for (auto size : sizes) { - void* ptr = aligned_malloc(size, alignment); - ASSERT_NE(ptr, nullptr) << "Failed to allocate " << size << " bytes"; - - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(ptr)); - - // 简单的读写测试 - auto* data = static_cast(ptr); - data[0] = 0x12345678; - data[size/sizeof(int) - 1] = 0x87654321; - - EXPECT_EQ(data[0], 0x12345678); - EXPECT_EQ(data[size/sizeof(int) - 1], 0x87654321); - - aligned_free(ptr); - } + // 测试大块内存分配 + std::vector sizes = { + 1024, // 1KB + 1024 * 64, // 64KB + 1024 * 1024 // 1MB + }; + + constexpr size_t alignment = ALIGNMENT_AVX; + + for (auto size : sizes) { + void* ptr = aligned_malloc(size, alignment); + ASSERT_NE(ptr, nullptr) << "Failed to allocate " << size << " bytes"; + + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(ptr)); + + // 简单的读写测试 + auto* data = static_cast(ptr); + data[0] = 0x12345678; + data[size / sizeof(int) - 1] = 0x87654321; + + EXPECT_EQ(data[0], 0x12345678); + EXPECT_EQ(data[size/sizeof(int) - 1], 0x87654321); + + aligned_free(ptr); + } } // STL兼容性 TEST_F(simd_test, AlignedAllocatorTest_StlContainerCompat) { - // 测试STL容器兼容性(需要修复aligned_allocator中的错误) - using aligned_vector = std::vector>; - - EXPECT_NO_THROW({ - aligned_vector vec; - vec.reserve(100); - - for (int i = 0; i < 50; ++i) { - vec.push_back(static_cast(i)); - } - - EXPECT_EQ(vec.size(), 50); - EXPECT_GE(vec.capacity(), 50); - - // 验证对齐 - if (!vec.empty()) { - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(vec.data())); - } - }); + // 测试STL容器兼容性(需要修复aligned_allocator中的错误) + using aligned_vector = std::vector>; + + EXPECT_NO_THROW({ + aligned_vector vec; + vec.reserve(100); + + for (int i = 0; i < 50; ++i) { + vec.push_back(static_cast(i)); + } + + EXPECT_EQ(vec.size(), 50); + EXPECT_GE(vec.capacity(), 50); + + // 验证对齐 + if (!vec.empty()) { + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(vec.data())); + } + }); } TEST_F(simd_test, AlignedAllocatorTest_VectorOperations) { - using sse_vector = std::vector>; - using avx_vector = std::vector>; - - // SSE对齐的vector - sse_vector sse_vec(100, 3.14); - EXPECT_EQ(sse_vec.size(), 100); - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(sse_vec.data())); - - // AVX对齐的vector - avx_vector avx_vec(200, 2.71f); - EXPECT_EQ(avx_vec.size(), 200); - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(avx_vec.data())); - - // 测试resize操作 - sse_vec.resize(200); - EXPECT_EQ(sse_vec.size(), 200); - if (!sse_vec.empty()) { - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(sse_vec.data())); - } + using sse_vector = std::vector>; + using avx_vector = std::vector>; + + // SSE对齐的vector + sse_vector sse_vec(100, 3.14); + EXPECT_EQ(sse_vec.size(), 100); + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(sse_vec.data())); + + // AVX对齐的vector + avx_vector avx_vec(200, 2.71f); + EXPECT_EQ(avx_vec.size(), 200); + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(avx_vec.data())); + + // 测试resize操作 + sse_vec.resize(200); + EXPECT_EQ(sse_vec.size(), 200); + if (!sse_vec.empty()) { + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(sse_vec.data())); + } } TEST_F(simd_test, AlignedAllocatorTest_MemoryManagement) { - using cache_vector = std::vector>; - - // 测试内存管理 - { - cache_vector vec(1000); - std::iota(vec.begin(), vec.end(), 0); - - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(vec.data())); - - // 验证数据正确性 - for (size_t i = 0; i < vec.size(); ++i) { - EXPECT_EQ(vec[i], static_cast(i)); - } - } // vector销毁,测试析构函数 - - // 测试移动语义 - cache_vector vec1(100, 42); - auto vec1_data = vec1.data(); - - cache_vector vec2 = std::move(vec1); - EXPECT_EQ(vec2.size(), 100); - EXPECT_EQ(vec2.data(), vec1_data); // 移动后数据指针应该相同 - EXPECT_TRUE(vec1.empty() || vec1.data() != vec1_data); // vec1应该被清空或数据被移走 + using cache_vector = std::vector>; + + // 测试内存管理 + { + cache_vector vec(1000); + std::iota(vec.begin(), vec.end(), 0); + + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(vec.data())); + + // 验证数据正确性 + for (size_t i = 0; i < vec.size(); ++i) { + EXPECT_EQ(vec[i], static_cast(i)); + } + } // vector销毁,测试析构函数 + + // 测试移动语义 + cache_vector vec1(100, 42); + auto vec1_data = vec1.data(); + + cache_vector vec2 = std::move(vec1); + EXPECT_EQ(vec2.size(), 100); + EXPECT_EQ(vec2.data(), vec1_data); // 移动后数据指针应该相同 + EXPECT_TRUE(vec1.empty() || vec1.data() != vec1_data); // vec1应该被清空或数据被移走 } // 跨平台行为 TEST_F(simd_test, AlignedAllocatorTest_PlatformConsistency) { - // 测试跨平台的一致行为 - constexpr size_t alignment = 32; - constexpr size_t size = 1024; - - std::vector ptrs; - - // 分配多个内存块 - for (int i = 0; i < 10; ++i) { - void* ptr = aligned_malloc(size, alignment); - ASSERT_NE(ptr, nullptr); - EXPECT_TRUE(is_aligned(ptr, alignment)); - ptrs.push_back(ptr); - } - - // 验证所有指针都正确对齐 - for (auto ptr : ptrs) { - EXPECT_TRUE(is_aligned(ptr, alignment)); - - // 写入特定模式 - auto* data = static_cast(ptr); - for (size_t j = 0; j < size / sizeof(uint32_t); ++j) { - data[j] = static_cast(j * 0x12345678); - } - } - - // 验证数据完整性 - for (size_t i = 0; i < ptrs.size(); ++i) { - auto* data = static_cast(ptrs[i]); - for (size_t j = 0; j < size / sizeof(uint32_t); ++j) { - EXPECT_EQ(data[j], static_cast(j * 0x12345678)) + // 测试跨平台的一致行为 + constexpr size_t alignment = 32; + constexpr size_t size = 1024; + + std::vector ptrs; + + // 分配多个内存块 + for (int i = 0; i < 10; ++i) { + void* ptr = aligned_malloc(size, alignment); + ASSERT_NE(ptr, nullptr); + EXPECT_TRUE(is_aligned(ptr, alignment)); + ptrs.push_back(ptr); + } + + // 验证所有指针都正确对齐 + for (auto ptr : ptrs) { + EXPECT_TRUE(is_aligned(ptr, alignment)); + + // 写入特定模式 + auto* data = static_cast(ptr); + for (size_t j = 0; j < size / sizeof(uint32_t); ++j) { + data[j] = static_cast(j * 0x12345678); + } + } + + // 验证数据完整性 + for (size_t i = 0; i < ptrs.size(); ++i) { + auto* data = static_cast(ptrs[i]); + for (size_t j = 0; j < size / sizeof(uint32_t); ++j) { + EXPECT_EQ(data[j], static_cast(j * 0x12345678)) << "Data corruption at ptr " << i << ", index " << j; - } - } - - // 释放所有内存 - for (auto ptr : ptrs) { - aligned_free(ptr); - } + } + } + + // 释放所有内存 + for (auto ptr : ptrs) { + aligned_free(ptr); + } } TEST_F(simd_test, AlignedAllocatorTest_AlignmentVerification) { - // 测试对齐验证函数 - std::vector test_alignments = {1, 2, 4, 8, 16, 32, 64, 128}; - - for (auto alignment : test_alignments) { - // 测试2的幂次对齐 - if ((alignment & (alignment - 1)) == 0) { // 是2的幂 - void* ptr = aligned_malloc(256, alignment); - ASSERT_NE(ptr, nullptr); - EXPECT_TRUE(is_aligned(ptr, alignment)); - aligned_free(ptr); - } else { - // 非2的幂次应该返回nullptr - void* ptr = aligned_malloc(256, alignment); - EXPECT_EQ(ptr, nullptr); - } - } - - // 测试边界情况 - EXPECT_EQ(aligned_malloc(100, 0), nullptr); // 0对齐应该失败 - - // 测试align_size函数 - EXPECT_EQ(align_size(15, 16), 16); - EXPECT_EQ(align_size(16, 16), 16); - EXPECT_EQ(align_size(17, 16), 32); - EXPECT_EQ(align_size(31, 32), 32); - EXPECT_EQ(align_size(33, 32), 64); + // 测试对齐验证函数 + std::vector test_alignments = {1, 2, 4, 8, 16, 32, 64, 128}; + + for (auto alignment : test_alignments) { + // 测试2的幂次对齐 + if ((alignment & (alignment - 1)) == 0) { + // 是2的幂 + void* ptr = aligned_malloc(256, alignment); + ASSERT_NE(ptr, nullptr); + EXPECT_TRUE(is_aligned(ptr, alignment)); + aligned_free(ptr); + } + else { + // 非2的幂次应该返回nullptr + void* ptr = aligned_malloc(256, alignment); + EXPECT_EQ(ptr, nullptr); + } + } + + // 测试边界情况 + EXPECT_EQ(aligned_malloc(100, 0), nullptr); // 0对齐应该失败 + + // 测试align_size函数 + EXPECT_EQ(align_size(15, 16), 16); + EXPECT_EQ(align_size(16, 16), 16); + EXPECT_EQ(align_size(17, 16), 32); + EXPECT_EQ(align_size(31, 32), 32); + EXPECT_EQ(align_size(33, 32), 64); } TEST_F(simd_test, AlignedAllocatorTest_PerformanceCharacteristics) { - // 简单的性能特征测试 - constexpr size_t num_allocations = 1000; - constexpr size_t allocation_size = 1024; - - // 测试对齐分配的性能 - simd_test_helpers::timer timer; - - std::vector aligned_ptrs; - aligned_ptrs.reserve(num_allocations); - - // 分配阶段 - for (size_t i = 0; i < num_allocations; ++i) { - void* ptr = aligned_malloc(allocation_size, ALIGNMENT_AVX); - ASSERT_NE(ptr, nullptr); - aligned_ptrs.push_back(ptr); - } - - double allocation_time = timer.elapsed_ms(); - - // 访问测试 - simd_test_helpers::timer access_timer; - uint64_t checksum = 0; - - for (auto ptr : aligned_ptrs) { - auto* data = static_cast(ptr); - checksum += data[0]; // 简单访问测试 - } - - double access_time = access_timer.elapsed_ms(); - - // 释放阶段 - simd_test_helpers::timer free_timer; - - for (auto ptr : aligned_ptrs) { - aligned_free(ptr); - } - - double free_time = free_timer.elapsed_ms(); - - // 性能报告 - std::cout << "Aligned allocation performance:" << std::endl; - std::cout << " Allocations: " << num_allocations << " x " << allocation_size << " bytes" << std::endl; - std::cout << " Allocation time: " << allocation_time << " ms" << std::endl; - std::cout << " Access time: " << access_time << " ms" << std::endl; - std::cout << " Free time: " << free_time << " ms" << std::endl; - std::cout << " Avg allocation time: " << (allocation_time / num_allocations) << " ms" << std::endl; - - // 基本合理性检查 - EXPECT_GT(allocation_time, 0.0); - EXPECT_GT(access_time, 0.0); - EXPECT_GT(free_time, 0.0); - - // 避免编译器优化掉checksum计算 - EXPECT_GE(checksum, 0); // checksum可能为0,但应该不会是负数 + // 简单的性能特征测试 + constexpr size_t num_allocations = 1000; + constexpr size_t allocation_size = 1024; + + // 测试对齐分配的性能 + simd_test_helpers::timer timer; + + std::vector aligned_ptrs; + aligned_ptrs.reserve(num_allocations); + + // 分配阶段 + for (size_t i = 0; i < num_allocations; ++i) { + void* ptr = aligned_malloc(allocation_size, ALIGNMENT_AVX); + ASSERT_NE(ptr, nullptr); + aligned_ptrs.push_back(ptr); + } + + double allocation_time = timer.elapsed_ms(); + + // 访问测试 + simd_test_helpers::timer access_timer; + uint64_t checksum = 0; + + for (auto ptr : aligned_ptrs) { + auto* data = static_cast(ptr); + checksum += data[0]; // 简单访问测试 + } + + double access_time = access_timer.elapsed_ms(); + + // 释放阶段 + simd_test_helpers::timer free_timer; + + for (auto ptr : aligned_ptrs) { + aligned_free(ptr); + } + + double free_time = free_timer.elapsed_ms(); + + // 性能报告 + std::cout << "Aligned allocation performance:" << std::endl; + std::cout << " Allocations: " << num_allocations << " x " << allocation_size << " bytes" << std::endl; + std::cout << " Allocation time: " << allocation_time << " ms" << std::endl; + std::cout << " Access time: " << access_time << " ms" << std::endl; + std::cout << " Free time: " << free_time << " ms" << std::endl; + std::cout << " Avg allocation time: " << (allocation_time / num_allocations) << " ms" << std::endl; + + // 基本合理性检查 + EXPECT_GT(allocation_time, 0.0); + EXPECT_GT(access_time, 0.0); + EXPECT_GT(free_time, 0.0); + + // 避免编译器优化掉checksum计算 + EXPECT_GE(checksum, 0); // checksum可能为0,但应该不会是负数 } // ============================================================================= @@ -939,287 +944,287 @@ TEST_F(simd_test, AlignedAllocatorTest_PerformanceCharacteristics) { // 端到端集成测试 TEST_F(simd_test, SimdIntegrationTest_FullWorkflow) { - // 完整的SIMD工作流程测试:检测 -> 分发 -> 分配 -> 执行 - - // 1. CPU特性检测 - auto max_level = get_max_simd_level(); - auto recommended_level = get_recommended_simd_level(); - - std::cout << "Integration test - SIMD levels: max=" << static_cast(max_level) - << ", recommended=" << static_cast(recommended_level) << std::endl; - - // 2. 注册多版本函数 - auto& dispatcher = simd_func_dispatcher::instance(); - const std::string func_name = "integration_vector_sum"; - - // 使用对齐分配器的向量进行计算 - using aligned_float_vector = std::vector>; - - // 注册标量版本 - dispatcher.register_function( - func_name, simd_func_version::SCALAR, - [](const aligned_float_vector& vec) -> float { - float sum = 0.0f; - for (const auto& val : vec) { - sum += val; - } - return sum; - }); - - // 根据支持的特性注册优化版本 - if (cpu_supports(cpu_feature::SSE)) { - dispatcher.register_function( - func_name, simd_func_version::SSE, - [](const aligned_float_vector& vec) -> float { - // 模拟SSE优化(实际实现会使用SSE指令) - float sum = 0.0f; - for (const auto& val : vec) { - sum += val; - } - return sum * 1.001f; // 添加小的标识以区分版本 - }); - } - - if (cpu_supports(cpu_feature::AVX)) { - dispatcher.register_function( - func_name, simd_func_version::AVX, - [](const aligned_float_vector& vec) -> float { - // 模拟AVX优化 - float sum = 0.0f; - for (const auto& val : vec) { - sum += val; - } - return sum * 1.002f; // AVX版本标识 - }); - } - - // 3. 创建测试数据(使用对齐分配) - aligned_float_vector test_data(10000); - std::iota(test_data.begin(), test_data.end(), 1.0f); - - // 验证数据对齐 - EXPECT_TRUE(simd_test_helpers::is_properly_aligned(test_data.data())); - - // 4. 执行计算 - const auto& func = dispatcher.get_function(func_name); - float result = func(test_data); - - // 5. 验证结果 - float expected_base = 10000.0f * 10001.0f / 2.0f; // 等差数列求和 - EXPECT_GT(result, expected_base * 0.99f); // 允许一定的误差和版本差异 - EXPECT_LT(result, expected_base * 1.01f); - - std::cout << "Integration test result: " << result << " (expected ~" << expected_base << ")" << std::endl; + // 完整的SIMD工作流程测试:检测 -> 分发 -> 分配 -> 执行 + + // 1. CPU特性检测 + auto max_level = get_max_simd_level(); + auto recommended_level = get_recommended_simd_level(); + + std::cout << "Integration test - SIMD levels: max=" << static_cast(max_level) + << ", recommended=" << static_cast(recommended_level) << std::endl; + + // 2. 注册多版本函数 + auto& dispatcher = simd_func_dispatcher::instance(); + const std::string func_name = "integration_vector_sum"; + + // 使用对齐分配器的向量进行计算 + using aligned_float_vector = std::vector>; + + // 注册标量版本 + dispatcher.register_function( + func_name, simd_func_version::SCALAR, + [](const aligned_float_vector& vec) -> float { + float sum = 0.0f; + for (const auto& val : vec) { + sum += val; + } + return sum; + }); + + // 根据支持的特性注册优化版本 + if (cpu_supports(cpu_feature::SSE)) { + dispatcher.register_function( + func_name, simd_func_version::SSE, + [](const aligned_float_vector& vec) -> float { + // 模拟SSE优化(实际实现会使用SSE指令) + float sum = 0.0f; + for (const auto& val : vec) { + sum += val; + } + return sum * 1.001f; // 添加小的标识以区分版本 + }); + } + + if (cpu_supports(cpu_feature::AVX)) { + dispatcher.register_function( + func_name, simd_func_version::AVX, + [](const aligned_float_vector& vec) -> float { + // 模拟AVX优化 + float sum = 0.0f; + for (const auto& val : vec) { + sum += val; + } + return sum * 1.002f; // AVX版本标识 + }); + } + + // 3. 创建测试数据(使用对齐分配) + aligned_float_vector test_data(10000); + std::iota(test_data.begin(), test_data.end(), 1.0f); + + // 验证数据对齐 + EXPECT_TRUE(simd_test_helpers::is_properly_aligned(test_data.data())); + + // 4. 执行计算 + const auto& func = dispatcher.get_function(func_name); + float result = func(test_data); + + // 5. 验证结果 + float expected_base = 10000.0f * 10001.0f / 2.0f; // 等差数列求和 + EXPECT_GT(result, expected_base * 0.99f); // 允许一定的误差和版本差异 + EXPECT_LT(result, expected_base * 1.01f); + + std::cout << "Integration test result: " << result << " (expected ~" << expected_base << ")" << std::endl; } TEST_F(simd_test, SimdIntegrationTest_RealWorldScenarios) { - // 真实世界场景测试:图像处理、数值计算等 - - // 场景1:向量点积计算 - const size_t vector_size = 1024; - using aligned_vector = std::vector>; - - aligned_vector vec_a(vector_size), vec_b(vector_size); - - // 初始化向量 - for (size_t i = 0; i < vector_size; ++i) { - vec_a[i] = static_cast(i + 1); - vec_b[i] = static_cast((i + 1) * 2); - } - - // 注册点积函数 - auto& dispatcher = simd_func_dispatcher::instance(); - const std::string dot_product_name = "dot_product"; - - dispatcher.register_function( - dot_product_name, simd_func_version::SCALAR, - [](const aligned_vector& a, const aligned_vector& b) -> float { - float result = 0.0f; - for (size_t i = 0; i < a.size(); ++i) { - result += a[i] * b[i]; - } - return result; - }); - - // 执行点积计算 - float dot_result = CALL_SIMD_FUNCTION(float(const aligned_vector&, const aligned_vector&), - dot_product_name, vec_a, vec_b); - - // 验证结果(数学验证) - float expected = 0.0f; - for (size_t i = 0; i < vector_size; ++i) { - expected += vec_a[i] * vec_b[i]; - } - EXPECT_FLOAT_EQ(dot_result, expected); - - // 场景2:矩阵转置(简化版) - const size_t matrix_size = 64; // 64x64矩阵 - aligned_vector matrix(matrix_size * matrix_size); - aligned_vector transposed(matrix_size * matrix_size); - - // 初始化矩阵 - for (size_t i = 0; i < matrix_size; ++i) { - for (size_t j = 0; j < matrix_size; ++j) { - matrix[i * matrix_size + j] = static_cast(i * matrix_size + j); - } - } - - // 矩阵转置 - const std::string transpose_name = "matrix_transpose"; - dispatcher.register_function( - transpose_name, simd_func_version::SCALAR, - [](const aligned_vector& src, aligned_vector& dst, size_t size) { - for (size_t i = 0; i < size; ++i) { - for (size_t j = 0; j < size; ++j) { - dst[j * size + i] = src[i * size + j]; - } - } - }); - - CALL_SIMD_FUNCTION(void(const aligned_vector&, aligned_vector&, size_t), - transpose_name, matrix, transposed, matrix_size); - - // 验证转置结果 - for (size_t i = 0; i < matrix_size; ++i) { - for (size_t j = 0; j < matrix_size; ++j) { - EXPECT_FLOAT_EQ(transposed[j * matrix_size + i], matrix[i * matrix_size + j]); - } - } - - std::cout << "Real-world scenarios test completed successfully" << std::endl; + // 真实世界场景测试:图像处理、数值计算等 + + // 场景1:向量点积计算 + const size_t vector_size = 1024; + using aligned_vector = std::vector>; + + aligned_vector vec_a(vector_size), vec_b(vector_size); + + // 初始化向量 + for (size_t i = 0; i < vector_size; ++i) { + vec_a[i] = static_cast(i + 1); + vec_b[i] = static_cast((i + 1) * 2); + } + + // 注册点积函数 + auto& dispatcher = simd_func_dispatcher::instance(); + const std::string dot_product_name = "dot_product"; + + dispatcher.register_function( + dot_product_name, simd_func_version::SCALAR, + [](const aligned_vector& a, const aligned_vector& b) -> float { + float result = 0.0f; + for (size_t i = 0; i < a.size(); ++i) { + result += a[i] * b[i]; + } + return result; + }); + + // 执行点积计算 + float dot_result = CALL_SIMD_FUNCTION(float(const aligned_vector&, const aligned_vector&), + dot_product_name, vec_a, vec_b); + + // 验证结果(数学验证) + float expected = 0.0f; + for (size_t i = 0; i < vector_size; ++i) { + expected += vec_a[i] * vec_b[i]; + } + EXPECT_FLOAT_EQ(dot_result, expected); + + // 场景2:矩阵转置(简化版) + const size_t matrix_size = 64; // 64x64矩阵 + aligned_vector matrix(matrix_size * matrix_size); + aligned_vector transposed(matrix_size * matrix_size); + + // 初始化矩阵 + for (size_t i = 0; i < matrix_size; ++i) { + for (size_t j = 0; j < matrix_size; ++j) { + matrix[i * matrix_size + j] = static_cast(i * matrix_size + j); + } + } + + // 矩阵转置 + const std::string transpose_name = "matrix_transpose"; + dispatcher.register_function( + transpose_name, simd_func_version::SCALAR, + [](const aligned_vector& src, aligned_vector& dst, size_t size) { + for (size_t i = 0; i < size; ++i) { + for (size_t j = 0; j < size; ++j) { + dst[j * size + i] = src[i * size + j]; + } + } + }); + + CALL_SIMD_FUNCTION(void(const aligned_vector&, aligned_vector&, size_t), + transpose_name, matrix, transposed, matrix_size); + + // 验证转置结果 + for (size_t i = 0; i < matrix_size; ++i) { + for (size_t j = 0; j < matrix_size; ++j) { + EXPECT_FLOAT_EQ(transposed[j * matrix_size + i], matrix[i * matrix_size + j]); + } + } + + std::cout << "Real-world scenarios test completed successfully" << std::endl; } // 性能基准测试 TEST_F(simd_test, SimdPerformanceTest_AllocationSpeed) { - // 对齐分配性能基准测试 - - struct BenchmarkConfig { - size_t allocation_size; - size_t alignment; - size_t num_iterations; - std::string name; - }; - - std::vector configs = { - {1024, ALIGNMENT_SSE, 10000, "SSE-1KB"}, - {1024, ALIGNMENT_AVX, 10000, "AVX-1KB"}, - {1024, ALIGNMENT_AVX512, 10000, "AVX512-1KB"}, - {4096, ALIGNMENT_AVX, 5000, "AVX-4KB"}, - {16384, ALIGNMENT_AVX, 2000, "AVX-16KB"}, - {65536, ALIGNMENT_AVX, 1000, "AVX-64KB"} - }; - - std::cout << "\nAllocation Speed Benchmark:" << std::endl; - std::cout << "Config\t\tAlloc(ms)\tFree(ms)\tTotal(ms)" << std::endl; - - for (const auto& config : configs) { - std::vector ptrs; - ptrs.reserve(config.num_iterations); - - // 分配基准 - simd_test_helpers::timer alloc_timer; - for (size_t i = 0; i < config.num_iterations; ++i) { - void* ptr = aligned_malloc(config.allocation_size, config.alignment); - ASSERT_NE(ptr, nullptr); - ptrs.push_back(ptr); - } - double alloc_time = alloc_timer.elapsed_ms(); - - // 释放基准 - simd_test_helpers::timer free_timer; - for (auto ptr : ptrs) { - aligned_free(ptr); - } - double free_time = free_timer.elapsed_ms(); - - double total_time = alloc_time + free_time; - - std::cout << config.name << "\t\t" - << std::fixed << std::setprecision(2) - << alloc_time << "\t\t" - << free_time << "\t\t" - << total_time << std::endl; - - // 基本性能断言 - EXPECT_GT(alloc_time, 0.0); - EXPECT_GT(free_time, 0.0); - EXPECT_LT(alloc_time / config.num_iterations, 1.0); // 平均每次分配应该小于1ms - } + // 对齐分配性能基准测试 + + struct BenchmarkConfig { + size_t allocation_size; + size_t alignment; + size_t num_iterations; + std::string name; + }; + + std::vector configs = { + {1024, ALIGNMENT_SSE, 10000, "SSE-1KB"}, + {1024, ALIGNMENT_AVX, 10000, "AVX-1KB"}, + {1024, ALIGNMENT_AVX512, 10000, "AVX512-1KB"}, + {4096, ALIGNMENT_AVX, 5000, "AVX-4KB"}, + {16384, ALIGNMENT_AVX, 2000, "AVX-16KB"}, + {65536, ALIGNMENT_AVX, 1000, "AVX-64KB"} + }; + + std::cout << "\nAllocation Speed Benchmark:" << std::endl; + std::cout << "Config\t\tAlloc(ms)\tFree(ms)\tTotal(ms)" << std::endl; + + for (const auto& config : configs) { + std::vector ptrs; + ptrs.reserve(config.num_iterations); + + // 分配基准 + simd_test_helpers::timer alloc_timer; + for (size_t i = 0; i < config.num_iterations; ++i) { + void* ptr = aligned_malloc(config.allocation_size, config.alignment); + ASSERT_NE(ptr, nullptr); + ptrs.push_back(ptr); + } + double alloc_time = alloc_timer.elapsed_ms(); + + // 释放基准 + simd_test_helpers::timer free_timer; + for (auto ptr : ptrs) { + aligned_free(ptr); + } + double free_time = free_timer.elapsed_ms(); + + double total_time = alloc_time + free_time; + + std::cout << config.name << "\t\t" + << std::fixed << std::setprecision(2) + << alloc_time << "\t\t" + << free_time << "\t\t" + << total_time << std::endl; + + // 基本性能断言 + EXPECT_GT(alloc_time, 0.0); + EXPECT_GT(free_time, 0.0); + EXPECT_LT(alloc_time / config.num_iterations, 1.0); // 平均每次分配应该小于1ms + } } TEST_F(simd_test, SimdPerformanceTest_DispatchOverhead) { - // 函数分发开销基准测试 - - auto& dispatcher = simd_func_dispatcher::instance(); - const std::string bench_func_name = "dispatch_overhead_test"; - - // 注册一个简单的测试函数 - dispatcher.register_function( - bench_func_name, simd_func_version::SCALAR, - [](int x) { return x + 1; }); - - if (cpu_supports(cpu_feature::SSE)) { - dispatcher.register_function( - bench_func_name, simd_func_version::SSE, - [](int x) { return x + 2; }); - } - - const size_t num_calls = 1000000; // 100万次调用 - - // 基准1:直接函数调用 - auto direct_func = [](int x) { return x + 1; }; - - simd_test_helpers::timer direct_timer; - volatile int direct_result = 0; // volatile防止优化 - for (size_t i = 0; i < num_calls; ++i) { - direct_result += direct_func(static_cast(i)); - } - double direct_time = direct_timer.elapsed_ms(); - - // 基准2:通过分发器调用 - const auto& dispatched_func = dispatcher.get_function(bench_func_name); - - simd_test_helpers::timer dispatch_timer; - volatile int dispatch_result = 0; - for (size_t i = 0; i < num_calls; ++i) { - dispatch_result += dispatched_func(static_cast(i)); - } - double dispatch_time = dispatch_timer.elapsed_ms(); - - // 基准3:通过宏调用 - simd_test_helpers::timer macro_timer; - volatile int macro_result = 0; - for (size_t i = 0; i < num_calls; ++i) { - macro_result += CALL_SIMD_FUNCTION(int(int), bench_func_name, static_cast(i)); - } - double macro_time = macro_timer.elapsed_ms(); - - // 结果报告 - std::cout << "\nDispatch Overhead Benchmark (" << num_calls << " calls):" << std::endl; - std::cout << "Direct function: " << direct_time << " ms" << std::endl; - std::cout << "Dispatched function: " << dispatch_time << " ms" << std::endl; - std::cout << "Macro call: " << macro_time << " ms" << std::endl; - - double dispatch_overhead = (dispatch_time - direct_time) / direct_time * 100.0; - double macro_overhead = (macro_time - direct_time) / direct_time * 100.0; - - std::cout << "Dispatch overhead: " << std::fixed << std::setprecision(2) - << dispatch_overhead << "%" << std::endl; - std::cout << "Macro overhead: " << macro_overhead << "%" << std::endl; - - // 性能断言 - EXPECT_GT(direct_time, 0.0); - EXPECT_GT(dispatch_time, 0.0); - EXPECT_GT(macro_time, 0.0); - - // 分发开销应该在合理范围内(调整为更现实的阈值) - EXPECT_LT(dispatch_overhead, 1000.0); // 允许10倍开销 - EXPECT_LT(macro_overhead, 10000.0); // 宏调用开销更大 - - // 验证结果正确性(防止编译器优化掉计算) - EXPECT_GT(direct_result, 0); - EXPECT_GT(dispatch_result, 0); - EXPECT_GT(macro_result, 0); + // 函数分发开销基准测试 + + auto& dispatcher = simd_func_dispatcher::instance(); + const std::string bench_func_name = "dispatch_overhead_test"; + + // 注册一个简单的测试函数 + dispatcher.register_function( + bench_func_name, simd_func_version::SCALAR, + [](int x) { return x + 1; }); + + if (cpu_supports(cpu_feature::SSE)) { + dispatcher.register_function( + bench_func_name, simd_func_version::SSE, + [](int x) { return x + 2; }); + } + + const size_t num_calls = 1000000; // 100万次调用 + + // 基准1:直接函数调用 + auto direct_func = [](int x) { return x + 1; }; + + simd_test_helpers::timer direct_timer; + volatile int direct_result = 0; // volatile防止优化 + for (size_t i = 0; i < num_calls; ++i) { + direct_result += direct_func(static_cast(i)); + } + double direct_time = direct_timer.elapsed_ms(); + + // 基准2:通过分发器调用 + const auto& dispatched_func = dispatcher.get_function(bench_func_name); + + simd_test_helpers::timer dispatch_timer; + volatile int dispatch_result = 0; + for (size_t i = 0; i < num_calls; ++i) { + dispatch_result += dispatched_func(static_cast(i)); + } + double dispatch_time = dispatch_timer.elapsed_ms(); + + // 基准3:通过宏调用 + simd_test_helpers::timer macro_timer; + volatile int macro_result = 0; + for (size_t i = 0; i < num_calls; ++i) { + macro_result += CALL_SIMD_FUNCTION(int(int), bench_func_name, static_cast(i)); + } + double macro_time = macro_timer.elapsed_ms(); + + // 结果报告 + std::cout << "\nDispatch Overhead Benchmark (" << num_calls << " calls):" << std::endl; + std::cout << "Direct function: " << direct_time << " ms" << std::endl; + std::cout << "Dispatched function: " << dispatch_time << " ms" << std::endl; + std::cout << "Macro call: " << macro_time << " ms" << std::endl; + + double dispatch_overhead = (dispatch_time - direct_time) / direct_time * 100.0; + double macro_overhead = (macro_time - direct_time) / direct_time * 100.0; + + std::cout << "Dispatch overhead: " << std::fixed << std::setprecision(2) + << dispatch_overhead << "%" << std::endl; + std::cout << "Macro overhead: " << macro_overhead << "%" << std::endl; + + // 性能断言 + EXPECT_GT(direct_time, 0.0); + EXPECT_GT(dispatch_time, 0.0); + EXPECT_GT(macro_time, 0.0); + + // 分发开销应该在合理范围内(调整为更现实的阈值) + EXPECT_LT(dispatch_overhead, 1000.0); // 允许10倍开销 + EXPECT_LT(macro_overhead, 10000.0); // 宏调用开销更大 + + // 验证结果正确性(防止编译器优化掉计算) + EXPECT_GT(direct_result, 0); + EXPECT_GT(dispatch_result, 0); + EXPECT_GT(macro_result, 0); } // ============================================================================= @@ -1229,25 +1234,25 @@ TEST_F(simd_test, SimdPerformanceTest_DispatchOverhead) { // 在测试开始前打印系统信息 class SimdTestEnvironment : public ::testing::Environment { public: - void SetUp() override { - std::cout << "\n" << std::string(60, '=') << std::endl; - std::cout << "SIMD Test Suite - System Information" << std::endl; - std::cout << std::string(60, '=') << std::endl; - - cpu_feature_detector::instance().print_info(); - - std::cout << std::string(60, '=') << std::endl; - std::cout << "Starting SIMD tests..." << std::endl; - std::cout << std::string(60, '=') << std::endl; - } - - void TearDown() override { - std::cout << std::string(60, '=') << std::endl; - std::cout << "SIMD Test Suite completed." << std::endl; - std::cout << std::string(60, '=') << std::endl; - } + void SetUp() override { + std::cout << "\n" << std::string(60, '=') << std::endl; + std::cout << "SIMD Test Suite - System Information" << std::endl; + std::cout << std::string(60, '=') << std::endl; + + cpu_feature_detector::instance().print_info(); + + std::cout << std::string(60, '=') << std::endl; + std::cout << "Starting SIMD tests..." << std::endl; + std::cout << std::string(60, '=') << std::endl; + } + + void TearDown() override { + std::cout << std::string(60, '=') << std::endl; + std::cout << "SIMD Test Suite completed." << std::endl; + std::cout << std::string(60, '=') << std::endl; + } }; // 注册测试环境 -static ::testing::Environment* const simd_test_env = - ::testing::AddGlobalTestEnvironment(new SimdTestEnvironment); +static ::testing::Environment* const simd_test_env = + ::testing::AddGlobalTestEnvironment(new SimdTestEnvironment);