#include #include "simd_func_dispatcher.h" #include "simd_interface.h" #include "lib_handle.h" #include "cpu_features.h" #include "aligned_allocator.h" #include "simd_api.h" #include #include // 定义测试用的函数指针类型 using simd_func_ptr = void (*)(const float*, const float*, float*, size_t); /** * @brief 测试 SimdFuncDispatcher 是否能根据CPU特性自动加载动态库 * * 这个测试通过模拟 SimdFuncDispatcher 的核心加载逻辑来验证其功能。 * * 测试步骤: * 1. 检测当前CPU支持的最佳SIMD级别。 * 2. 根据SIMD级别和操作系统平台,构造出预期的动态库文件名。 * (这模拟了 SimdFuncDispatcher 构造函数中的逻辑) * 3. 使用 lib_handle 手动加载这个动态库。 * 4. 从加载的库中获取 "add" 和 "subtract" 函数的指针。 * 5. 验证获取到的函数指针是否有效(非空)。 * 6. 准备测试数据并调用函数指针,验证其功能正确性。 * * 这个测试间接验证了: * - SimdFuncDispatcher 的库选择逻辑是正确的。 * - 对应于当前CPU的SIMD动态库是存在的并且可以被加载。 * - 库中导出了正确的 "add" 和 "subtract" 函数。 */ TEST(SimdFuncDispatcherTest, ShouldLoadCorrectLibraryBasedOnCpuFeatures) { const auto& detector = cpu_feature_detector::instance(); std::string lib_name; // 模拟 SimdFuncDispatcher 的库选择逻辑 if (detector.supports(cpu_feature::AVX512F)) { #if ALICHO_PLATFORM_WINDOWS lib_name = "alicho_simd_avx512.dll"; #elif ALICHO_PLATFORM_LINUX lib_name = "./libalicho_simd_avx512.so"; #elif ALICHO_PLATFORM_APPLE lib_name = "./libalicho_simd_avx512.dylib"; #endif } else if (detector.supports(cpu_feature::AVX)) { #if ALICHO_PLATFORM_WINDOWS lib_name = "alicho_simd_avx.dll"; #elif ALICHO_PLATFORM_LINUX lib_name = "./libalicho_simd_avx.so"; #elif ALICHO_PLATFORM_APPLE lib_name = "./libalicho_simd_avx.dylib"; #endif } else if (detector.supports(cpu_feature::SSE)) { #if ALICHO_PLATFORM_WINDOWS lib_name = "alicho_simd_sse.dll"; #elif ALICHO_PLATFORM_LINUX lib_name = "./libalicho_simd_sse.so"; #elif ALICHO_PLATFORM_APPLE lib_name = "./libalicho_simd_sse.dylib"; #endif } else { #if ALICHO_PLATFORM_WINDOWS lib_name = "alicho_simd_scaler.dll"; #elif ALICHO_PLATFORM_LINUX lib_name = "./libalicho_simd_scaler.so"; #elif ALICHO_PLATFORM_APPLE lib_name = "./libalicho_simd_scaler.dylib"; #endif } ASSERT_FALSE(lib_name.empty()) << "Could not determine the SIMD library name for the current CPU."; lib_handle handle; ASSERT_TRUE(handle.open(lib_name)) << "Failed to open SIMD library: " << lib_name; auto fill_func = get_function_by_func_signature(handle, fill_buffer); auto mix_func = get_function_by_func_signature(handle, mix_audio); ASSERT_NE(fill_func, nullptr) << "Failed to load 'fill_buffer' function from " << lib_name; ASSERT_NE(mix_func, nullptr) << "Failed to load 'mix_audio' function from " << lib_name; // 准备测试数据 constexpr size_t num_samples = 1024; std::vector src1(num_samples); std::vector src2(num_samples); std::vector dst_fill(num_samples, 0.0f); std::vector dst_mix(num_samples, 0.0f); std::iota(src1.begin(), src1.end(), 0.0f); std::iota(src2.begin(), src2.end(), static_cast(num_samples)); // 调用加载的函数 fill_func(dst_fill.data(), 1.0f, num_samples); // 初始化为0 mix_func(src1.data(), src2.data(), dst_mix.data(), num_samples); // 执行加法 // 验证 fill_buffer 函数 for (size_t i = 0; i < num_samples; ++i) { ASSERT_EQ(dst_fill[i], 1.0f) << "fill_buffer function did not work correctly at index " << i; } // 验证 mix_audio 函数 for (size_t i = 0; i < num_samples; ++i) { ASSERT_EQ(dst_mix[i], src1[i] + src2[i]) << "mix_audio function did not work correctly at index " << i; } } // ==================================================================== // 测试 simd_func_dispatcher 和 simd_api 的功能 // ==================================================================== TEST(SimdDispatcher, AutoVersionSelection) { // 测试自动版本选择 auto& dispatcher = simd_func_dispatcher::instance(); auto version = dispatcher.get_active_version(); // 版本应该不是 COUNT(无效值) EXPECT_NE(version, simd_func_version::COUNT); // 所有函数指针都应该非空 EXPECT_NE(dispatcher.get_fill_buffer(), nullptr); EXPECT_NE(dispatcher.get_mix_audio(), nullptr); EXPECT_NE(dispatcher.get_apply_gain(), nullptr); EXPECT_NE(dispatcher.get_calculate_rms(), nullptr); EXPECT_NE(dispatcher.get_calculate_peak(), nullptr); // 打印当前使用的版本 auto version_str = simd::get_active_simd_version_string(); std::cout << "当前SIMD版本: " << version_str << std::endl; } TEST(SimdAPI, FillBuffer) { // 测试 fill_buffer API - 使用对齐的分配器 std::vector buffer(1024, 0.0f); simd::fill_buffer(buffer.data(), 1.0f, buffer.size()); // 验证所有元素都被填充为 1.0 for (size_t i = 0; i < buffer.size(); ++i) { EXPECT_FLOAT_EQ(buffer[i], 1.0f) << "索引 " << i << " 的值不正确"; } } TEST(SimdAPI, CalculateRMS) { // 测试 calculate_rms API - 使用对齐的分配器 std::vector buffer(1024, 1.0f); float rms = simd::calculate_rms(buffer.data(), buffer.size()); // 所有值为1.0的RMS应该是1.0 EXPECT_NEAR(rms, 1.0f, 0.001f); } TEST(SimdAPI, MixAudio) { // 测试 mix_audio API - 使用对齐的分配器 std::vector src1(1024, 1.0f); std::vector src2(1024, 2.0f); std::vector dst(1024, 0.0f); simd::mix_audio(src1.data(), src2.data(), dst.data(), dst.size()); // 验证混音结果 (1.0 + 2.0 = 3.0) for (size_t i = 0; i < dst.size(); ++i) { EXPECT_FLOAT_EQ(dst[i], 3.0f) << "索引 " << i << " 的混音结果不正确"; } } TEST(SimdAPI, ApplyGain) { // 测试 apply_gain API - 使用对齐的分配器 std::vector src(1024, 2.0f); std::vector dst(1024, 0.0f); simd::apply_gain(src.data(), dst.data(), 0.5f, dst.size()); // 验证增益应用 (2.0 * 0.5 = 1.0) for (size_t i = 0; i < dst.size(); ++i) { EXPECT_FLOAT_EQ(dst[i], 1.0f) << "索引 " << i << " 的增益结果不正确"; } }