Files
Alicho/tests/unit/communication/thread_safety_test.cpp
2025-10-28 10:27:49 +08:00

540 lines
18 KiB
C++

// ================================================================================================
// Audio Backend - 通信组件线程安全性测试
// ================================================================================================
#include <gtest/gtest.h>
#include <gmock/gmock.h>
#include "communication/communication.h"
#include "tests/common/test_fixtures.h"
#include <thread>
#include <mutex>
#include <condition_variable>
#include <atomic>
#include <vector>
#include <random>
#include <chrono>
using namespace audio_backend;
using namespace audio_backend::communication;
using namespace std::chrono_literals;
// 线程安全测试消息类
class ThreadSafetyMessage : public Message {
public:
ThreadSafetyMessage(int id = 0)
: Message("ThreadSafetyMessage"), id_(id) {}
int id() const { return id_; }
void set_id(int id) { id_ = id; }
// 实现虚函数
size_t estimated_size() const override { return sizeof(*this); }
Priority priority() const override { return Priority::Normal; }
TransportChannel preferred_channel() const override { return TransportChannel::ZeroMQ; }
private:
int id_ = 0;
};
// 通信线程安全测试固定装置
class ThreadSafetyTest : public test::CommunicationTest {
protected:
void SetUp() override {
test::CommunicationTest::SetUp();
// 创建通信管理器配置
server_config_.process_name = "thread_safety_server";
server_config_.routing_strategy = RoutingStrategy::Auto;
server_config_.enable_zmq = true;
server_config_.enable_shm = true;
// ZeroMQ配置
ZmqConfig zmq_config;
zmq_config.endpoint = "tcp://127.0.0.1:5558";
zmq_config.socket_type = ZMQ_REP;
zmq_config.bind_instead_of_connect = true;
server_config_.zmq_configs.push_back(zmq_config);
// 共享内存配置
server_config_.shm_config.segment_name = "thread_safety_test";
server_config_.shm_config.segment_size = 1024 * 1024;
server_config_.shm_config.create_if_not_exists = true;
server_config_.shm_config.remove_on_destroy = true;
// 创建客户端配置
client_config_ = server_config_;
client_config_.process_name = "thread_safety_client";
// 修改客户端ZeroMQ配置
client_config_.zmq_configs[0].socket_type = ZMQ_REQ;
client_config_.zmq_configs[0].bind_instead_of_connect = false;
// 客户端不应该创建或销毁共享内存
client_config_.shm_config.create_if_not_exists = false;
client_config_.shm_config.remove_on_destroy = false;
// 注册消息类型
message_factory_.register_message<ThreadSafetyMessage>("ThreadSafetyMessage");
}
void TearDown() override {
// 清理资源
server_.reset();
clients_.clear();
test::CommunicationTest::TearDown();
}
// 创建服务器
void create_server() {
server_ = std::make_unique<CommunicationManager>(server_config_, message_factory_);
ASSERT_EQ(server_->initialize(), CommError::Success);
}
// 创建客户端
std::shared_ptr<CommunicationManager> create_client() {
auto client = std::make_shared<CommunicationManager>(client_config_, message_factory_);
ASSERT_EQ(client->initialize(), CommError::Success);
clients_.push_back(client);
return client;
}
protected:
CommunicationConfig server_config_;
CommunicationConfig client_config_;
MessageFactory message_factory_;
std::unique_ptr<CommunicationManager> server_;
std::vector<std::shared_ptr<CommunicationManager>> clients_;
};
// 测试多线程消息接收
TEST_F(ThreadSafetyTest, ConcurrentMessageHandling) {
// 创建服务器
create_server();
// 设置接收计数器
std::atomic<int> messages_received{0};
std::mutex mtx;
std::condition_variable cv;
// 设置服务器消息处理器
server_->register_message_handler("ThreadSafetyMessage", [&](std::unique_ptr<IMessage> message) {
auto* typed_msg = dynamic_cast<ThreadSafetyMessage*>(message.get());
if (typed_msg) {
// 模拟处理时间(增加并发可能性)
std::this_thread::sleep_for(std::chrono::milliseconds(typed_msg->id() % 10));
// 创建并发送响应
auto response = std::make_unique<ThreadSafetyMessage>(typed_msg->id());
server_->send_message(*response);
// 更新计数器
messages_received++;
// 通知等待线程
cv.notify_all();
}
});
// 创建多个客户端和线程
const int num_clients = 10;
const int messages_per_client = 100;
std::atomic<int> responses_received{0};
// 启动多个客户端线程
std::vector<std::thread> client_threads;
for (int i = 0; i < num_clients; ++i) {
client_threads.emplace_back([this, i, messages_per_client, &responses_received]() {
// 创建客户端
auto client = create_client();
// 注册响应处理器
client->register_message_handler("ThreadSafetyMessage", [&](std::unique_ptr<IMessage> message) {
auto* typed_msg = dynamic_cast<ThreadSafetyMessage*>(message.get());
if (typed_msg) {
responses_received++;
}
});
// 随机数生成器,使消息发送更随机
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> delay_dist(1, 20);
// 发送多个消息
for (int j = 0; j < messages_per_client; ++j) {
// 创建消息
auto msg = std::make_unique<ThreadSafetyMessage>(i * 1000 + j);
// 发送消息
EXPECT_EQ(client->send_message(*msg), CommError::Success);
// 随机延迟,增加并发性
std::this_thread::sleep_for(std::chrono::milliseconds(delay_dist(gen)));
}
});
}
// 等待所有线程完成
for (auto& t : client_threads) {
t.join();
}
// 等待所有消息处理完成
{
std::unique_lock<std::mutex> lock(mtx);
EXPECT_TRUE(cv.wait_for(lock, 5s, [&] {
return messages_received.load() >= num_clients * messages_per_client;
}));
}
// 验证所有消息都被正确处理
EXPECT_EQ(messages_received.load(), num_clients * messages_per_client);
EXPECT_EQ(responses_received.load(), num_clients * messages_per_client);
// 检查服务器统计信息
const auto& stats = server_->get_statistics();
EXPECT_EQ(stats.total_messages_received.load(), num_clients * messages_per_client);
EXPECT_EQ(stats.total_messages_sent.load(), num_clients * messages_per_client);
}
// 测试多线程共享内存访问
TEST_F(ThreadSafetyTest, ConcurrentSharedMemoryAccess) {
// 初始化共享内存配置
ShmConfig config;
config.segment_name = "concurrent_shm_test";
config.segment_size = 1024 * 1024; // 1MB
config.create_if_not_exists = true;
config.remove_on_destroy = true;
// 创建共享内存管理器
auto shm_manager = std::make_unique<SharedMemoryManager>(config);
ASSERT_EQ(shm_manager->initialize(), ShmError::Success);
// 创建并发访问的环形缓冲区
const size_t buffer_capacity = 1024;
RingBuffer<int> ring_buffer(*shm_manager, "concurrent_buffer", buffer_capacity);
// 创建并发度
const int num_producers = 5;
const int num_consumers = 5;
const int items_per_producer = 1000;
// 生产者线程函数
auto producer_func = [&](int producer_id) {
for (int i = 0; i < items_per_producer; ++i) {
int value = producer_id * items_per_producer + i;
// 尝试写入直到成功
while (!ring_buffer.write(&value, 1)) {
std::this_thread::yield();
}
}
};
// 消费者线程函数
std::atomic<int> total_consumed{0};
std::vector<std::set<int>> consumed_values(num_consumers);
std::mutex consumed_mutex;
auto consumer_func = [&](int consumer_id) {
while (total_consumed.load() < num_producers * items_per_producer) {
int value;
if (ring_buffer.read(&value, 1)) {
{
std::lock_guard<std::mutex> lock(consumed_mutex);
consumed_values[consumer_id].insert(value);
}
total_consumed++;
} else {
std::this_thread::yield();
}
}
};
// 启动生产者线程
std::vector<std::thread> producer_threads;
for (int i = 0; i < num_producers; ++i) {
producer_threads.emplace_back(producer_func, i);
}
// 启动消费者线程
std::vector<std::thread> consumer_threads;
for (int i = 0; i < num_consumers; ++i) {
consumer_threads.emplace_back(consumer_func, i);
}
// 等待所有生产者完成
for (auto& t : producer_threads) {
t.join();
}
// 等待所有消费者完成
for (auto& t : consumer_threads) {
t.join();
}
// 验证所有项目都被消费
EXPECT_EQ(total_consumed.load(), num_producers * items_per_producer);
// 合并所有消费者的集合
std::set<int> all_consumed;
for (const auto& set : consumed_values) {
all_consumed.insert(set.begin(), set.end());
}
// 确认每个项目只被消费一次
EXPECT_EQ(all_consumed.size(), num_producers * items_per_producer);
// 验证生产的每个项目都被消费了
for (int p = 0; p < num_producers; ++p) {
for (int i = 0; i < items_per_producer; ++i) {
int value = p * items_per_producer + i;
EXPECT_TRUE(all_consumed.find(value) != all_consumed.end());
}
}
// 清理
shm_manager->shutdown();
}
// 测试三缓冲区线程安全性
TEST_F(ThreadSafetyTest, TripleBufferThreadSafety) {
// 初始化共享内存配置
ShmConfig config;
config.segment_name = "triple_buffer_test";
config.segment_size = 1024 * 1024; // 1MB
config.create_if_not_exists = true;
config.remove_on_destroy = true;
// 创建共享内存管理器
auto shm_manager = std::make_unique<SharedMemoryManager>(config);
ASSERT_EQ(shm_manager->initialize(), ShmError::Success);
// 创建三缓冲区
using AudioFrame = std::array<float, 512>;
TripleBuffer<AudioFrame> triple_buffer(*shm_manager, "audio_frames");
// 设置运行时间
const auto test_duration = 500ms;
// 设置线程同步变量
std::atomic<bool> running{true};
std::atomic<int> frames_written{0};
std::atomic<int> frames_read{0};
std::atomic<bool> data_corrupted{false};
// 生产者线程:持续写入递增的样本值
std::thread producer([&]() {
int frame_count = 0;
while (running.load()) {
// 获取写入缓冲区
auto* write_buffer = triple_buffer.get_write_buffer();
if (!write_buffer) continue;
// 写入递增的样本值
float base_value = static_cast<float>(frame_count) / 1000.0f;
for (size_t i = 0; i < write_buffer->size(); ++i) {
(*write_buffer)[i] = base_value + static_cast<float>(i) / 10000.0f;
}
// 提交写入
triple_buffer.commit_write();
frames_written++;
frame_count++;
// 模拟实际音频处理速率
std::this_thread::sleep_for(1ms);
}
});
// 消费者线程:读取并验证数据
std::thread consumer([&]() {
while (running.load()) {
// 如果有新数据
if (triple_buffer.has_new_data()) {
// 获取读取缓冲区
const auto* read_buffer = triple_buffer.get_read_buffer();
if (!read_buffer) continue;
// 验证数据连续性
float expected_base = -1.0f;
bool first_sample = true;
for (size_t i = 0; i < read_buffer->size(); ++i) {
float current = (*read_buffer)[i];
// 设置首个样本的基准值
if (first_sample) {
expected_base = current - static_cast<float>(i) / 10000.0f;
first_sample = false;
} else {
// 验证当前样本是否符合预期
float expected = expected_base + static_cast<float>(i) / 10000.0f;
if (std::abs(current - expected) > 0.0001f) {
// 数据不一致,可能是多线程访问冲突导致
data_corrupted = true;
}
}
}
// 提交读取
triple_buffer.commit_read();
frames_read++;
} else {
// 无新数据时让出CPU
std::this_thread::yield();
}
}
});
// 运行一段时间
std::this_thread::sleep_for(test_duration);
running = false;
// 等待线程结束
producer.join();
consumer.join();
// 验证结果
EXPECT_GT(frames_written.load(), 0);
EXPECT_GT(frames_read.load(), 0);
EXPECT_FALSE(data_corrupted.load()) << "三缓冲区数据访问冲突,可能存在线程安全问题";
// 清理
shm_manager->shutdown();
}
// 测试通信管理器统计信息的线程安全性
TEST_F(ThreadSafetyTest, CommunicationStatisticsThreadSafety) {
// 创建服务器
create_server();
// 创建统计信息压力测试线程
const int num_threads = 10;
const int updates_per_thread = 1000;
std::atomic<bool> running{true};
// 启动多个线程同时更新统计信息
std::vector<std::thread> threads;
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back([&, i]() {
for (int j = 0; j < updates_per_thread && running.load(); ++j) {
// 随机选择一种统计信息更新
switch (j % 4) {
case 0:
server_->notify_message_sent("TestMessage", 100, "ZeroMQ");
break;
case 1:
server_->notify_message_received("TestMessage", 200, "ZeroMQ");
break;
case 2:
server_->get_statistics(); // 读取统计信息
break;
case 3:
if (j % 100 == 0) { // 偶尔重置统计信息
server_->reset_statistics();
}
break;
}
// 随机延迟
if (j % 10 == 0) {
std::this_thread::yield();
}
}
});
}
// 运行一段时间
std::this_thread::sleep_for(1s);
running = false;
// 等待所有线程完成
for (auto& t : threads) {
t.join();
}
// 检查最终统计信息
const auto& stats = server_->get_statistics();
// 由于重置操作,我们不能确切知道数字,但至少应该有些消息
EXPECT_GE(stats.total_messages_sent.load() + stats.total_messages_received.load(), 0);
// 检查统计信息的一致性
EXPECT_GE(stats.total_bytes_sent.load(), stats.total_messages_sent.load() * 100);
EXPECT_GE(stats.total_bytes_received.load(), stats.total_messages_received.load() * 200);
}
// 测试消息路由表的线程安全性
TEST_F(ThreadSafetyTest, RouteTableThreadSafety) {
// 创建服务器
create_server();
// 创建并发修改路由表的线程
const int num_threads = 5;
const int operations_per_thread = 500;
std::atomic<bool> running{true};
// 启动多个线程同时修改路由表
std::vector<std::thread> threads;
for (int i = 0; i < num_threads; ++i) {
threads.emplace_back([&, i]() {
std::random_device rd;
std::mt19937 gen(rd());
std::uniform_int_distribution<> op_dist(0, 2); // 添加、移除、获取操作
for (int j = 0; j < operations_per_thread && running.load(); ++j) {
// 构造唯一的消息类型
std::string message_type = "TestMessage" + std::to_string(i) + "_" + std::to_string(j);
// 随机选择操作
int operation = op_dist(gen);
switch (operation) {
case 0: { // 添加路由
MessageRoute route;
route.message_type = message_type;
route.destination = "tcp://localhost:5555";
route.strategy = RoutingStrategy::ZeroMQOnly;
server_->add_route(route);
break;
}
case 1: // 移除路由
server_->remove_route(message_type);
break;
case 2: // 获取所有路由
server_->get_routes();
break;
}
// 偶尔延迟
if (j % 10 == 0) {
std::this_thread::yield();
}
}
});
}
// 运行一段时间
std::this_thread::sleep_for(1s);
running = false;
// 等待所有线程完成
for (auto& t : threads) {
t.join();
}
// 获取最终路由表
auto routes = server_->get_routes();
// 检查路由表是否包含有效的路由
for (const auto& route : routes) {
EXPECT_FALSE(route.message_type.empty());
EXPECT_FALSE(route.destination.empty());
}
}
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}