Files
Alicho/tests/network/test_network_zmq_rpc.cpp

716 lines
19 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#include <gtest/gtest.h>
#include <thread>
#include <chrono>
#include <atomic>
#include <condition_variable>
#include <mutex>
#include "transport/zmq_server.h"
#include "transport/zmq_client.h"
#include "transport/zmq_server_processor.h"
#include "transport/zmq_client_processor.h"
#include "transport/zmq_util.h"
#include "rpc/engine_rpc.h"
#include "rpc/host_rpc.h"
#include "logger.h"
#include "thread_tool.h"
// ============================================================================
// 测试辅助工具
// ============================================================================
// 用于测试的简单计数器,跟踪处理器调用
struct test_counter {
std::atomic<int> count{0};
std::mutex mutex;
std::condition_variable cv;
void increment() {
count++;
cv.notify_all();
}
bool wait_for(int expected, std::chrono::milliseconds timeout = std::chrono::milliseconds(5000)) {
std::unique_lock<std::mutex> lock(mutex);
return cv.wait_for(lock, timeout, [this, expected]() {
return count >= expected;
});
}
void reset() {
count = 0;
}
};
// 测试用的数据结构
struct test_message_t {
uint32_t id;
std::string content;
};
struct test_response_t {
uint32_t request_id;
bool success;
std::string result;
};
// 全局计数器用于验证处理器调用
inline test_counter g_server_counter;
inline test_counter g_client_counter;
inline std::string g_last_received_content;
inline std::mutex g_content_mutex;
// 注册测试消息处理器
ZMQ_SERVER_REGISTER_PROCESSOR(test_message_t) {
g_server_counter.increment();
{
std::lock_guard<std::mutex> lock(g_content_mutex);
g_last_received_content = data.content;
}
}
ZMQ_CLIENT_REGISTER_PROCESSOR(test_response_t) {
g_client_counter.increment();
{
std::lock_guard<std::mutex> lock(g_content_mutex);
g_last_received_content = data.result;
}
}
// ============================================================================
// 基础功能测试
// ============================================================================
class ZmqRpcTest : public ::testing::Test {
protected:
void SetUp() override {
// 重置计数器
g_server_counter.reset();
g_client_counter.reset();
g_last_received_content.clear();
// 清理客户端状态,确保每个测试独立
auto& client = zmq_client::instance();
if (client.is_connected()) {
client.disconnect();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
}
}
void TearDown() override {
// 清理可能未关闭的连接
auto& client = zmq_client::instance();
if (client.is_connected()) {
client.disconnect();
}
}
};
/**
* @brief 测试服务器初始化成功
*/
TEST_F(ZmqRpcTest, ServerInitialization) {
auto& server = zmq_server::instance();
ASSERT_NO_THROW({
server.init();
});
EXPECT_TRUE(server.is_running());
EXPECT_EQ(server.get_state(), zmq_server::state::RUNNING);
EXPECT_EQ(server.client_count(), 0);
}
/**
* @brief 测试客户端连接成功
*/
TEST_F(ZmqRpcTest, ClientConnection) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 12345;
ASSERT_NO_THROW({
client.init(client_id);
});
EXPECT_TRUE(client.is_connected());
EXPECT_EQ(client.get_state(), zmq_client::state::CONNECTED);
// 给一点时间让连接建立
std::this_thread::sleep_for(std::chrono::milliseconds(100));
}
/**
* @brief 测试客户端向服务器发送消息
*/
TEST_F(ZmqRpcTest, ClientToServerMessage) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 10001;
client.init(client_id);
// 等待连接建立
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 发送测试消息
test_message_t msg;
msg.id = 1;
msg.content = "Hello from client";
client.send(msg);
// 在服务器端接收消息(需要在单独的线程中运行)
std::thread server_thread([&server]() {
server.recv();
});
// 等待处理器被调用
ASSERT_TRUE(g_server_counter.wait_for(1));
// 验证消息内容
{
std::lock_guard<std::mutex> lock(g_content_mutex);
EXPECT_EQ(g_last_received_content, "Hello from client");
}
server_thread.join();
// 验证客户端已在服务器注册
EXPECT_TRUE(server.has_client(client_id));
}
/**
* @brief 测试服务器向客户端发送消息
*/
TEST_F(ZmqRpcTest, ServerToClientMessage) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 10002;
client.init(client_id);
// 先让客户端向服务器发送消息以建立连接
test_message_t hello_msg;
hello_msg.id = 1;
hello_msg.content = "Hello";
client.send(hello_msg);
std::thread server_thread1([&server]() {
server.recv();
});
server_thread1.join();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 现在服务器向客户端发送消息
test_response_t response;
response.request_id = 1;
response.success = true;
response.result = "Response from server";
server.send(client_id, response);
// 客户端接收消息
std::thread client_thread([&client]() {
client.recv();
});
// 等待客户端处理器被调用
ASSERT_TRUE(g_client_counter.wait_for(1));
// 验证消息内容
{
std::lock_guard<std::mutex> lock(g_content_mutex);
EXPECT_EQ(g_last_received_content, "Response from server");
}
client_thread.join();
}
/**
* @brief 测试多个客户端连接
*/
TEST_F(ZmqRpcTest, MultipleClients) {
auto& server = zmq_server::instance();
server.init();
// 注意:由于 zmq_client 是单例,这里我们模拟多客户端的场景
// 实际项目中,每个客户端应该是独立的实例
const uint32_t client_id_1 = 20001;
const uint32_t client_id_2 = 20002;
// 第一个客户端
{
auto& client1 = zmq_client::instance();
client1.init(client_id_1);
test_message_t msg1;
msg1.id = 1;
msg1.content = "Client 1";
client1.send(msg1);
std::thread([&server]() { server.recv(); }).join();
}
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 第二个客户端(需要断开并重新连接,因为是单例)
{
auto& client2 = zmq_client::instance();
client2.disconnect();
std::this_thread::sleep_for(std::chrono::milliseconds(50));
client2.init(client_id_2);
test_message_t msg2;
msg2.id = 2;
msg2.content = "Client 2";
client2.send(msg2);
std::thread([&server]() { server.recv(); }).join();
}
// 验证两个客户端都已注册
EXPECT_TRUE(server.has_client(client_id_1));
EXPECT_TRUE(server.has_client(client_id_2));
EXPECT_GE(server.client_count(), 1); // 至少有一个客户端
}
/**
* @brief 测试客户端移除
*/
TEST_F(ZmqRpcTest, ClientRemoval) {
auto& server = zmq_server::instance();
server.init();
const uint32_t client_id = 30001;
auto& client = zmq_client::instance();
client.init(client_id);
// 建立连接
test_message_t msg;
msg.id = 1;
msg.content = "Test";
client.send(msg);
std::thread([&server]() { server.recv(); }).join();
EXPECT_TRUE(server.has_client(client_id));
// 移除客户端
server.remove_client(client_id);
EXPECT_FALSE(server.has_client(client_id));
}
// ============================================================================
// RPC 处理器测试
// ============================================================================
/**
* @brief 测试 engine_rpc::log_message_t 处理
*/
TEST_F(ZmqRpcTest, EngineRpcLogMessage) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 40001;
client.init(client_id);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 发送日志消息
engine_rpc::log_message_t log_msg;
log_msg.level = 2; // INFO 级别
log_msg.message = "Test log message from sandbox_host";
client.send(log_msg);
// 服务器接收并处理
std::thread([&server]() {
server.recv();
}).join();
// 验证消息被处理(通过日志输出,这里我们只验证不抛出异常)
SUCCEED();
}
/**
* @brief 测试 host_rpc::setup_t 处理
*/
TEST_F(ZmqRpcTest, HostRpcSetup) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 40002;
client.init(client_id);
// 建立连接
test_message_t hello;
hello.id = 1;
hello.content = "Hello";
client.send(hello);
std::thread([&server]() { server.recv(); }).join();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 服务器向客户端发送 setup 消息
host_rpc::setup_t setup;
setup.shm_name = "test_shared_memory_segment";
server.send(client_id, setup);
// 客户端接收并处理注意setup_t 处理器会尝试初始化共享内存)
// 这可能会失败,因为共享内存可能不存在,但我们主要测试消息传输
std::thread([&client]() {
try {
client.recv();
}
catch (const std::exception&) {
// 预期可能失败,因为共享内存不存在
}
}).join();
SUCCEED();
}
// ============================================================================
// 异常处理测试
// ============================================================================
/**
* @brief 测试向未连接的客户端发送消息
*/
TEST_F(ZmqRpcTest, SendToNonExistentClient) {
auto& server = zmq_server::instance();
server.init();
const uint32_t fake_client_id = 99999;
// 尝试向不存在的客户端发送消息(应该记录错误但不崩溃)
test_response_t response;
response.request_id = 1;
response.success = false;
response.result = "Error";
ASSERT_NO_THROW({
server.send(fake_client_id, response);
});
EXPECT_FALSE(server.has_client(fake_client_id));
}
/**
* @brief 测试未注册的 func_id 处理
*/
TEST_F(ZmqRpcTest, UnregisteredFuncId) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 50001;
client.init(client_id);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
// 手动构造一个具有未注册 func_id 的消息
struct unknown_message_t {
int data;
};
unknown_message_t unknown_msg;
unknown_msg.data = 12345;
// 发送消息
client.send(unknown_msg);
// 服务器接收(应该记录错误但不崩溃)
ASSERT_NO_THROW({
std::thread([&server]() {
server.recv();
}).join();
});
}
/**
* @brief 测试客户端断开连接
*/
TEST_F(ZmqRpcTest, ClientDisconnect) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 60001;
client.init(client_id);
EXPECT_TRUE(client.is_connected());
// 断开连接
client.disconnect();
EXPECT_FALSE(client.is_connected());
EXPECT_EQ(client.get_state(), zmq_client::state::DISCONNECTED);
}
/**
* @brief 测试客户端重连
*/
TEST_F(ZmqRpcTest, ClientReconnect) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 60002;
// 初始连接
client.init(client_id);
EXPECT_TRUE(client.is_connected());
// 断开
client.disconnect();
EXPECT_FALSE(client.is_connected());
// 重连
bool reconnect_success = client.reconnect();
if (reconnect_success) {
EXPECT_TRUE(client.is_connected());
EXPECT_EQ(client.get_state(), zmq_client::state::CONNECTED);
}
// 如果重连失败也可以接受,取决于具体实现
}
// ============================================================================
// 消息序列化测试
// ============================================================================
/**
* @brief 测试消息序列化和反序列化
*/
TEST_F(ZmqRpcTest, MessageSerialization) {
// 测试 zmq_message_pack 的创建
test_message_t original_msg;
original_msg.id = 12345;
original_msg.content = "Test serialization";
// 序列化
auto serialized = zmq_message_pack::create(original_msg);
EXPECT_GT(serialized.size(), 0);
// 反序列化外层包装
zmq::message_t zmq_msg(serialized.data(), serialized.size());
auto pack = zmq_message_pack::deserialize(zmq_msg);
EXPECT_GT(pack.func_id, 0);
EXPECT_GT(pack.payload.size(), 0);
// 反序列化内部数据
auto result = struct_pack::deserialize<test_message_t>(
pack.payload.data(), pack.payload.size());
ASSERT_TRUE(result.has_value());
EXPECT_EQ(result.value().id, original_msg.id);
EXPECT_EQ(result.value().content, original_msg.content);
}
/**
* @brief 测试不同类型消息的 type_id
*/
TEST_F(ZmqRpcTest, TypeIdUniqueness) {
// 验证不同类型有不同的 type_id
uint32_t log_msg_id = alicho_type_id_v<engine_rpc::log_message_t>;
uint32_t setup_id = alicho_type_id_v<host_rpc::setup_t>;
uint32_t test_msg_id = alicho_type_id_v<test_message_t>;
EXPECT_NE(log_msg_id, setup_id);
EXPECT_NE(log_msg_id, test_msg_id);
EXPECT_NE(setup_id, test_msg_id);
// 同一类型应该有相同的 type_id
uint32_t log_msg_id_2 = alicho_type_id_v<engine_rpc::log_message_t>;
EXPECT_EQ(log_msg_id, log_msg_id_2);
}
// ============================================================================
// 性能测试
// ============================================================================
/**
* @brief 测试消息吞吐量
*/
TEST_F(ZmqRpcTest, MessageThroughput) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 70001;
client.init(client_id);
std::this_thread::sleep_for(std::chrono::milliseconds(100));
const int num_messages = 100;
// [诊断日志] 记录开始时间戳(纳秒精度)
auto start_time = std::chrono::high_resolution_clock::now();
auto start_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(start_time).time_since_epoch().count();
std::cout << "[诊断] 测试开始时间戳(纳秒): " << start_ns << std::endl;
// 发送多条消息
for (int i = 0; i < num_messages; ++i) {
test_message_t msg;
msg.id = i;
msg.content = "Message " + std::to_string(i);
client.send(msg);
}
// [诊断日志] 发送完成时间
auto send_complete_time = std::chrono::high_resolution_clock::now();
auto send_duration_us = std::chrono::duration_cast<std::chrono::microseconds>(
send_complete_time - start_time).count();
std::cout << "[诊断] 发送 " << num_messages << " 条消息耗时: "
<< send_duration_us << " 微秒" << std::endl;
// 在服务器端接收所有消息
std::thread server_thread([&server]() {
auto thread_start = std::chrono::high_resolution_clock::now();
for (int i = 0; i < num_messages; ++i) {
server.recv();
}
auto thread_end = std::chrono::high_resolution_clock::now();
auto thread_duration_us = std::chrono::duration_cast<std::chrono::microseconds>(
thread_end - thread_start).count();
std::cout << "[诊断] 服务器线程接收耗时: " << thread_duration_us << " 微秒" << std::endl;
});
// 等待所有消息被处理
auto wait_start = std::chrono::high_resolution_clock::now();
ASSERT_TRUE(g_server_counter.wait_for(num_messages,
std::chrono::seconds(10)));
auto wait_end = std::chrono::high_resolution_clock::now();
auto wait_duration_us = std::chrono::duration_cast<std::chrono::microseconds>(
wait_end - wait_start).count();
std::cout << "[诊断] 等待处理完成耗时: " << wait_duration_us << " 微秒" << std::endl;
auto end_time = std::chrono::high_resolution_clock::now();
auto end_ns = std::chrono::time_point_cast<std::chrono::nanoseconds>(end_time).time_since_epoch().count();
std::cout << "[诊断] 测试结束时间戳(纳秒): " << end_ns << std::endl;
std::cout << "[诊断] 总时间差(纳秒): " << (end_ns - start_ns) << std::endl;
// 使用微秒精度计算
auto duration_us = std::chrono::duration_cast<std::chrono::microseconds>(
end_time - start_time).count();
auto duration_ms = std::chrono::duration_cast<std::chrono::milliseconds>(
end_time - start_time).count();
std::cout << "[诊断] 微秒精度耗时: " << duration_us << " 微秒" << std::endl;
std::cout << "[诊断] 毫秒精度耗时: " << duration_ms << " 毫秒(截断)" << std::endl;
server_thread.join();
std::cout << "处理 " << num_messages << " 条消息耗时: " << duration_ms << " 毫秒" << std::endl;
// 使用微秒精度计算吞吐量
if (duration_us > 0) {
std::cout << "平均吞吐量: "
<< (num_messages * 1000000.0 / duration_us) << " 消息/秒" << std::endl;
}
else {
std::cout << "平均吞吐量: 耗时太短无法测量" << std::endl;
}
// 修改断言使用微秒精度期望至少大于0微秒
EXPECT_GT(duration_us, 0);
}
/**
* @brief 测试往返延迟 (Round-Trip Time)
*/
TEST_F(ZmqRpcTest, RoundTripLatency) {
auto& server = zmq_server::instance();
server.init();
auto& client = zmq_client::instance();
const uint32_t client_id = 70002;
client.init(client_id);
// 建立连接
test_message_t hello;
hello.id = 0;
hello.content = "Init";
client.send(hello);
std::thread([&server]() { server.recv(); }).join();
std::this_thread::sleep_for(std::chrono::milliseconds(100));
const int num_iterations = 10;
std::vector<double> latencies;
for (int i = 0; i < num_iterations; ++i) {
g_client_counter.reset();
auto start = std::chrono::high_resolution_clock::now();
// 客户端发送消息
test_message_t request;
request.id = i;
request.content = "Ping";
client.send(request);
// 服务器接收并回复
std::thread server_thread([&server, i]() {
server.recv();
test_response_t response;
response.request_id = i;
response.success = true;
response.result = "Pong";
server.send(client_id, response);
});
// 客户端接收回复
std::thread client_thread([&client]() {
client.recv();
});
// 等待响应
ASSERT_TRUE(g_client_counter.wait_for(1,
std::chrono::seconds(5)));
auto end = std::chrono::high_resolution_clock::now();
double latency = std::chrono::duration<double, std::milli>(end - start).count();
latencies.push_back(latency);
server_thread.join();
client_thread.join();
}
// 计算平均延迟
double total_latency = 0;
for (double lat : latencies) {
total_latency += lat;
}
double avg_latency = total_latency / num_iterations;
std::cout << "平均往返延迟: " << avg_latency << " 毫秒" << std::endl;
EXPECT_GT(avg_latency, 0);
EXPECT_LT(avg_latency, 1000); // 应该小于1秒
}
// ============================================================================
// 测试主入口
// ============================================================================
int main(int argc, char** argv) {
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}