540 lines
18 KiB
C++
540 lines
18 KiB
C++
// ================================================================================================
|
|
// Audio Backend - 通信组件线程安全性测试
|
|
// ================================================================================================
|
|
|
|
#include <gtest/gtest.h>
|
|
#include <gmock/gmock.h>
|
|
#include "communication/communication.h"
|
|
#include "tests/common/test_fixtures.h"
|
|
#include <thread>
|
|
#include <mutex>
|
|
#include <condition_variable>
|
|
#include <atomic>
|
|
#include <vector>
|
|
#include <random>
|
|
#include <chrono>
|
|
|
|
using namespace audio_backend;
|
|
using namespace audio_backend::communication;
|
|
using namespace std::chrono_literals;
|
|
|
|
// 线程安全测试消息类
|
|
class ThreadSafetyMessage : public Message {
|
|
public:
|
|
ThreadSafetyMessage(int id = 0)
|
|
: Message("ThreadSafetyMessage"), id_(id) {}
|
|
|
|
int id() const { return id_; }
|
|
void set_id(int id) { id_ = id; }
|
|
|
|
// 实现虚函数
|
|
size_t estimated_size() const override { return sizeof(*this); }
|
|
Priority priority() const override { return Priority::Normal; }
|
|
TransportChannel preferred_channel() const override { return TransportChannel::ZeroMQ; }
|
|
|
|
private:
|
|
int id_ = 0;
|
|
};
|
|
|
|
// 通信线程安全测试固定装置
|
|
class ThreadSafetyTest : public test::CommunicationTest {
|
|
protected:
|
|
void SetUp() override {
|
|
test::CommunicationTest::SetUp();
|
|
|
|
// 创建通信管理器配置
|
|
server_config_.process_name = "thread_safety_server";
|
|
server_config_.routing_strategy = RoutingStrategy::Auto;
|
|
server_config_.enable_zmq = true;
|
|
server_config_.enable_shm = true;
|
|
|
|
// ZeroMQ配置
|
|
ZmqConfig zmq_config;
|
|
zmq_config.endpoint = "tcp://127.0.0.1:5558";
|
|
zmq_config.socket_type = ZMQ_REP;
|
|
zmq_config.bind_instead_of_connect = true;
|
|
server_config_.zmq_configs.push_back(zmq_config);
|
|
|
|
// 共享内存配置
|
|
server_config_.shm_config.segment_name = "thread_safety_test";
|
|
server_config_.shm_config.segment_size = 1024 * 1024;
|
|
server_config_.shm_config.create_if_not_exists = true;
|
|
server_config_.shm_config.remove_on_destroy = true;
|
|
|
|
// 创建客户端配置
|
|
client_config_ = server_config_;
|
|
client_config_.process_name = "thread_safety_client";
|
|
|
|
// 修改客户端ZeroMQ配置
|
|
client_config_.zmq_configs[0].socket_type = ZMQ_REQ;
|
|
client_config_.zmq_configs[0].bind_instead_of_connect = false;
|
|
|
|
// 客户端不应该创建或销毁共享内存
|
|
client_config_.shm_config.create_if_not_exists = false;
|
|
client_config_.shm_config.remove_on_destroy = false;
|
|
|
|
// 注册消息类型
|
|
message_factory_.register_message<ThreadSafetyMessage>("ThreadSafetyMessage");
|
|
}
|
|
|
|
void TearDown() override {
|
|
// 清理资源
|
|
server_.reset();
|
|
clients_.clear();
|
|
|
|
test::CommunicationTest::TearDown();
|
|
}
|
|
|
|
// 创建服务器
|
|
void create_server() {
|
|
server_ = std::make_unique<CommunicationManager>(server_config_, message_factory_);
|
|
ASSERT_EQ(server_->initialize(), CommError::Success);
|
|
}
|
|
|
|
// 创建客户端
|
|
std::shared_ptr<CommunicationManager> create_client() {
|
|
auto client = std::make_shared<CommunicationManager>(client_config_, message_factory_);
|
|
ASSERT_EQ(client->initialize(), CommError::Success);
|
|
clients_.push_back(client);
|
|
return client;
|
|
}
|
|
|
|
protected:
|
|
CommunicationConfig server_config_;
|
|
CommunicationConfig client_config_;
|
|
MessageFactory message_factory_;
|
|
std::unique_ptr<CommunicationManager> server_;
|
|
std::vector<std::shared_ptr<CommunicationManager>> clients_;
|
|
};
|
|
|
|
// 测试多线程消息接收
|
|
TEST_F(ThreadSafetyTest, ConcurrentMessageHandling) {
|
|
// 创建服务器
|
|
create_server();
|
|
|
|
// 设置接收计数器
|
|
std::atomic<int> messages_received{0};
|
|
std::mutex mtx;
|
|
std::condition_variable cv;
|
|
|
|
// 设置服务器消息处理器
|
|
server_->register_message_handler("ThreadSafetyMessage", [&](std::unique_ptr<IMessage> message) {
|
|
auto* typed_msg = dynamic_cast<ThreadSafetyMessage*>(message.get());
|
|
if (typed_msg) {
|
|
// 模拟处理时间(增加并发可能性)
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(typed_msg->id() % 10));
|
|
|
|
// 创建并发送响应
|
|
auto response = std::make_unique<ThreadSafetyMessage>(typed_msg->id());
|
|
server_->send_message(*response);
|
|
|
|
// 更新计数器
|
|
messages_received++;
|
|
|
|
// 通知等待线程
|
|
cv.notify_all();
|
|
}
|
|
});
|
|
|
|
// 创建多个客户端和线程
|
|
const int num_clients = 10;
|
|
const int messages_per_client = 100;
|
|
std::atomic<int> responses_received{0};
|
|
|
|
// 启动多个客户端线程
|
|
std::vector<std::thread> client_threads;
|
|
for (int i = 0; i < num_clients; ++i) {
|
|
client_threads.emplace_back([this, i, messages_per_client, &responses_received]() {
|
|
// 创建客户端
|
|
auto client = create_client();
|
|
|
|
// 注册响应处理器
|
|
client->register_message_handler("ThreadSafetyMessage", [&](std::unique_ptr<IMessage> message) {
|
|
auto* typed_msg = dynamic_cast<ThreadSafetyMessage*>(message.get());
|
|
if (typed_msg) {
|
|
responses_received++;
|
|
}
|
|
});
|
|
|
|
// 随机数生成器,使消息发送更随机
|
|
std::random_device rd;
|
|
std::mt19937 gen(rd());
|
|
std::uniform_int_distribution<> delay_dist(1, 20);
|
|
|
|
// 发送多个消息
|
|
for (int j = 0; j < messages_per_client; ++j) {
|
|
// 创建消息
|
|
auto msg = std::make_unique<ThreadSafetyMessage>(i * 1000 + j);
|
|
|
|
// 发送消息
|
|
EXPECT_EQ(client->send_message(*msg), CommError::Success);
|
|
|
|
// 随机延迟,增加并发性
|
|
std::this_thread::sleep_for(std::chrono::milliseconds(delay_dist(gen)));
|
|
}
|
|
});
|
|
}
|
|
|
|
// 等待所有线程完成
|
|
for (auto& t : client_threads) {
|
|
t.join();
|
|
}
|
|
|
|
// 等待所有消息处理完成
|
|
{
|
|
std::unique_lock<std::mutex> lock(mtx);
|
|
EXPECT_TRUE(cv.wait_for(lock, 5s, [&] {
|
|
return messages_received.load() >= num_clients * messages_per_client;
|
|
}));
|
|
}
|
|
|
|
// 验证所有消息都被正确处理
|
|
EXPECT_EQ(messages_received.load(), num_clients * messages_per_client);
|
|
EXPECT_EQ(responses_received.load(), num_clients * messages_per_client);
|
|
|
|
// 检查服务器统计信息
|
|
const auto& stats = server_->get_statistics();
|
|
EXPECT_EQ(stats.total_messages_received.load(), num_clients * messages_per_client);
|
|
EXPECT_EQ(stats.total_messages_sent.load(), num_clients * messages_per_client);
|
|
}
|
|
|
|
// 测试多线程共享内存访问
|
|
TEST_F(ThreadSafetyTest, ConcurrentSharedMemoryAccess) {
|
|
// 初始化共享内存配置
|
|
ShmConfig config;
|
|
config.segment_name = "concurrent_shm_test";
|
|
config.segment_size = 1024 * 1024; // 1MB
|
|
config.create_if_not_exists = true;
|
|
config.remove_on_destroy = true;
|
|
|
|
// 创建共享内存管理器
|
|
auto shm_manager = std::make_unique<SharedMemoryManager>(config);
|
|
ASSERT_EQ(shm_manager->initialize(), ShmError::Success);
|
|
|
|
// 创建并发访问的环形缓冲区
|
|
const size_t buffer_capacity = 1024;
|
|
RingBuffer<int> ring_buffer(*shm_manager, "concurrent_buffer", buffer_capacity);
|
|
|
|
// 创建并发度
|
|
const int num_producers = 5;
|
|
const int num_consumers = 5;
|
|
const int items_per_producer = 1000;
|
|
|
|
// 生产者线程函数
|
|
auto producer_func = [&](int producer_id) {
|
|
for (int i = 0; i < items_per_producer; ++i) {
|
|
int value = producer_id * items_per_producer + i;
|
|
// 尝试写入直到成功
|
|
while (!ring_buffer.write(&value, 1)) {
|
|
std::this_thread::yield();
|
|
}
|
|
}
|
|
};
|
|
|
|
// 消费者线程函数
|
|
std::atomic<int> total_consumed{0};
|
|
std::vector<std::set<int>> consumed_values(num_consumers);
|
|
std::mutex consumed_mutex;
|
|
|
|
auto consumer_func = [&](int consumer_id) {
|
|
while (total_consumed.load() < num_producers * items_per_producer) {
|
|
int value;
|
|
if (ring_buffer.read(&value, 1)) {
|
|
{
|
|
std::lock_guard<std::mutex> lock(consumed_mutex);
|
|
consumed_values[consumer_id].insert(value);
|
|
}
|
|
total_consumed++;
|
|
} else {
|
|
std::this_thread::yield();
|
|
}
|
|
}
|
|
};
|
|
|
|
// 启动生产者线程
|
|
std::vector<std::thread> producer_threads;
|
|
for (int i = 0; i < num_producers; ++i) {
|
|
producer_threads.emplace_back(producer_func, i);
|
|
}
|
|
|
|
// 启动消费者线程
|
|
std::vector<std::thread> consumer_threads;
|
|
for (int i = 0; i < num_consumers; ++i) {
|
|
consumer_threads.emplace_back(consumer_func, i);
|
|
}
|
|
|
|
// 等待所有生产者完成
|
|
for (auto& t : producer_threads) {
|
|
t.join();
|
|
}
|
|
|
|
// 等待所有消费者完成
|
|
for (auto& t : consumer_threads) {
|
|
t.join();
|
|
}
|
|
|
|
// 验证所有项目都被消费
|
|
EXPECT_EQ(total_consumed.load(), num_producers * items_per_producer);
|
|
|
|
// 合并所有消费者的集合
|
|
std::set<int> all_consumed;
|
|
for (const auto& set : consumed_values) {
|
|
all_consumed.insert(set.begin(), set.end());
|
|
}
|
|
|
|
// 确认每个项目只被消费一次
|
|
EXPECT_EQ(all_consumed.size(), num_producers * items_per_producer);
|
|
|
|
// 验证生产的每个项目都被消费了
|
|
for (int p = 0; p < num_producers; ++p) {
|
|
for (int i = 0; i < items_per_producer; ++i) {
|
|
int value = p * items_per_producer + i;
|
|
EXPECT_TRUE(all_consumed.find(value) != all_consumed.end());
|
|
}
|
|
}
|
|
|
|
// 清理
|
|
shm_manager->shutdown();
|
|
}
|
|
|
|
// 测试三缓冲区线程安全性
|
|
TEST_F(ThreadSafetyTest, TripleBufferThreadSafety) {
|
|
// 初始化共享内存配置
|
|
ShmConfig config;
|
|
config.segment_name = "triple_buffer_test";
|
|
config.segment_size = 1024 * 1024; // 1MB
|
|
config.create_if_not_exists = true;
|
|
config.remove_on_destroy = true;
|
|
|
|
// 创建共享内存管理器
|
|
auto shm_manager = std::make_unique<SharedMemoryManager>(config);
|
|
ASSERT_EQ(shm_manager->initialize(), ShmError::Success);
|
|
|
|
// 创建三缓冲区
|
|
using AudioFrame = std::array<float, 512>;
|
|
TripleBuffer<AudioFrame> triple_buffer(*shm_manager, "audio_frames");
|
|
|
|
// 设置运行时间
|
|
const auto test_duration = 500ms;
|
|
|
|
// 设置线程同步变量
|
|
std::atomic<bool> running{true};
|
|
std::atomic<int> frames_written{0};
|
|
std::atomic<int> frames_read{0};
|
|
std::atomic<bool> data_corrupted{false};
|
|
|
|
// 生产者线程:持续写入递增的样本值
|
|
std::thread producer([&]() {
|
|
int frame_count = 0;
|
|
|
|
while (running.load()) {
|
|
// 获取写入缓冲区
|
|
auto* write_buffer = triple_buffer.get_write_buffer();
|
|
if (!write_buffer) continue;
|
|
|
|
// 写入递增的样本值
|
|
float base_value = static_cast<float>(frame_count) / 1000.0f;
|
|
for (size_t i = 0; i < write_buffer->size(); ++i) {
|
|
(*write_buffer)[i] = base_value + static_cast<float>(i) / 10000.0f;
|
|
}
|
|
|
|
// 提交写入
|
|
triple_buffer.commit_write();
|
|
frames_written++;
|
|
frame_count++;
|
|
|
|
// 模拟实际音频处理速率
|
|
std::this_thread::sleep_for(1ms);
|
|
}
|
|
});
|
|
|
|
// 消费者线程:读取并验证数据
|
|
std::thread consumer([&]() {
|
|
while (running.load()) {
|
|
// 如果有新数据
|
|
if (triple_buffer.has_new_data()) {
|
|
// 获取读取缓冲区
|
|
const auto* read_buffer = triple_buffer.get_read_buffer();
|
|
if (!read_buffer) continue;
|
|
|
|
// 验证数据连续性
|
|
float expected_base = -1.0f;
|
|
bool first_sample = true;
|
|
|
|
for (size_t i = 0; i < read_buffer->size(); ++i) {
|
|
float current = (*read_buffer)[i];
|
|
|
|
// 设置首个样本的基准值
|
|
if (first_sample) {
|
|
expected_base = current - static_cast<float>(i) / 10000.0f;
|
|
first_sample = false;
|
|
} else {
|
|
// 验证当前样本是否符合预期
|
|
float expected = expected_base + static_cast<float>(i) / 10000.0f;
|
|
if (std::abs(current - expected) > 0.0001f) {
|
|
// 数据不一致,可能是多线程访问冲突导致
|
|
data_corrupted = true;
|
|
}
|
|
}
|
|
}
|
|
|
|
// 提交读取
|
|
triple_buffer.commit_read();
|
|
frames_read++;
|
|
} else {
|
|
// 无新数据时让出CPU
|
|
std::this_thread::yield();
|
|
}
|
|
}
|
|
});
|
|
|
|
// 运行一段时间
|
|
std::this_thread::sleep_for(test_duration);
|
|
running = false;
|
|
|
|
// 等待线程结束
|
|
producer.join();
|
|
consumer.join();
|
|
|
|
// 验证结果
|
|
EXPECT_GT(frames_written.load(), 0);
|
|
EXPECT_GT(frames_read.load(), 0);
|
|
EXPECT_FALSE(data_corrupted.load()) << "三缓冲区数据访问冲突,可能存在线程安全问题";
|
|
|
|
// 清理
|
|
shm_manager->shutdown();
|
|
}
|
|
|
|
// 测试通信管理器统计信息的线程安全性
|
|
TEST_F(ThreadSafetyTest, CommunicationStatisticsThreadSafety) {
|
|
// 创建服务器
|
|
create_server();
|
|
|
|
// 创建统计信息压力测试线程
|
|
const int num_threads = 10;
|
|
const int updates_per_thread = 1000;
|
|
std::atomic<bool> running{true};
|
|
|
|
// 启动多个线程同时更新统计信息
|
|
std::vector<std::thread> threads;
|
|
for (int i = 0; i < num_threads; ++i) {
|
|
threads.emplace_back([&, i]() {
|
|
for (int j = 0; j < updates_per_thread && running.load(); ++j) {
|
|
// 随机选择一种统计信息更新
|
|
switch (j % 4) {
|
|
case 0:
|
|
server_->notify_message_sent("TestMessage", 100, "ZeroMQ");
|
|
break;
|
|
case 1:
|
|
server_->notify_message_received("TestMessage", 200, "ZeroMQ");
|
|
break;
|
|
case 2:
|
|
server_->get_statistics(); // 读取统计信息
|
|
break;
|
|
case 3:
|
|
if (j % 100 == 0) { // 偶尔重置统计信息
|
|
server_->reset_statistics();
|
|
}
|
|
break;
|
|
}
|
|
|
|
// 随机延迟
|
|
if (j % 10 == 0) {
|
|
std::this_thread::yield();
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// 运行一段时间
|
|
std::this_thread::sleep_for(1s);
|
|
running = false;
|
|
|
|
// 等待所有线程完成
|
|
for (auto& t : threads) {
|
|
t.join();
|
|
}
|
|
|
|
// 检查最终统计信息
|
|
const auto& stats = server_->get_statistics();
|
|
|
|
// 由于重置操作,我们不能确切知道数字,但至少应该有些消息
|
|
EXPECT_GE(stats.total_messages_sent.load() + stats.total_messages_received.load(), 0);
|
|
|
|
// 检查统计信息的一致性
|
|
EXPECT_GE(stats.total_bytes_sent.load(), stats.total_messages_sent.load() * 100);
|
|
EXPECT_GE(stats.total_bytes_received.load(), stats.total_messages_received.load() * 200);
|
|
}
|
|
|
|
// 测试消息路由表的线程安全性
|
|
TEST_F(ThreadSafetyTest, RouteTableThreadSafety) {
|
|
// 创建服务器
|
|
create_server();
|
|
|
|
// 创建并发修改路由表的线程
|
|
const int num_threads = 5;
|
|
const int operations_per_thread = 500;
|
|
std::atomic<bool> running{true};
|
|
|
|
// 启动多个线程同时修改路由表
|
|
std::vector<std::thread> threads;
|
|
for (int i = 0; i < num_threads; ++i) {
|
|
threads.emplace_back([&, i]() {
|
|
std::random_device rd;
|
|
std::mt19937 gen(rd());
|
|
std::uniform_int_distribution<> op_dist(0, 2); // 添加、移除、获取操作
|
|
|
|
for (int j = 0; j < operations_per_thread && running.load(); ++j) {
|
|
// 构造唯一的消息类型
|
|
std::string message_type = "TestMessage" + std::to_string(i) + "_" + std::to_string(j);
|
|
|
|
// 随机选择操作
|
|
int operation = op_dist(gen);
|
|
switch (operation) {
|
|
case 0: { // 添加路由
|
|
MessageRoute route;
|
|
route.message_type = message_type;
|
|
route.destination = "tcp://localhost:5555";
|
|
route.strategy = RoutingStrategy::ZeroMQOnly;
|
|
server_->add_route(route);
|
|
break;
|
|
}
|
|
case 1: // 移除路由
|
|
server_->remove_route(message_type);
|
|
break;
|
|
case 2: // 获取所有路由
|
|
server_->get_routes();
|
|
break;
|
|
}
|
|
|
|
// 偶尔延迟
|
|
if (j % 10 == 0) {
|
|
std::this_thread::yield();
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
// 运行一段时间
|
|
std::this_thread::sleep_for(1s);
|
|
running = false;
|
|
|
|
// 等待所有线程完成
|
|
for (auto& t : threads) {
|
|
t.join();
|
|
}
|
|
|
|
// 获取最终路由表
|
|
auto routes = server_->get_routes();
|
|
|
|
// 检查路由表是否包含有效的路由
|
|
for (const auto& route : routes) {
|
|
EXPECT_FALSE(route.message_type.empty());
|
|
EXPECT_FALSE(route.destination.empty());
|
|
}
|
|
}
|
|
|
|
int main(int argc, char** argv) {
|
|
::testing::InitGoogleTest(&argc, argv);
|
|
return RUN_ALL_TESTS();
|
|
} |