Files
Alicho/src/communication/core/serializer.cpp
2025-10-24 18:17:58 +08:00

222 lines
7.9 KiB
C++

// ================================================================================================
// Audio Backend - 消息序列化器实现
// ================================================================================================
#include "serializer.h"
#include <cstring>
// 跨平台网络字节序转换
#ifdef _WIN32
#include <winsock2.h>
#pragma comment(lib, "ws2_32.lib")
#else
#include <arpa/inet.h>
#endif
namespace audio_backend::communication {
// 使用全局日志函数
using audio_backend::common::log_info;
using audio_backend::common::log_warn;
using audio_backend::common::log_err;
using audio_backend::common::log_debug;
namespace audio_backend::communication {
// ================================================================================================
// ProtobufSerializer 实现
// ================================================================================================
ProtobufSerializer::ProtobufSerializer() {
log_info("初始化Protobuf序列化器");
}
SerializationError ProtobufSerializer::serialize(const IMessage& message, std::vector<uint8_t>& output) {
const std::string& type_name = message.message_type();
// 查找序列化函数
auto it = serialize_funcs_.find(type_name);
if (it == serialize_funcs_.end()) {
log_err("未找到消息类型 '%s' 的序列化函数", type_name.c_str());
return SerializationError::UnsupportedType;
}
// 序列化消息内容
std::vector<uint8_t> message_data;
SerializationError result = it->second(message, message_data);
if (result != SerializationError::Success) {
log_err("序列化消息类型 '%s' 失败: %d", type_name.c_str(), static_cast<int>(result));
return result;
}
// 构建完整的消息,包含消息头
MessageHeader header;
header.magic = htonl(HEADER_MAGIC);
header.version = htons(CURRENT_VERSION);
header.type_length = htons(static_cast<uint16_t>(type_name.length()));
// 计算总大小
size_t total_size = HEADER_SIZE + type_name.length() + message_data.size();
output.reserve(total_size);
output.clear();
// 写入消息头
const uint8_t* header_bytes = reinterpret_cast<const uint8_t*>(&header);
output.insert(output.end(), header_bytes, header_bytes + HEADER_SIZE);
// 写入消息类型
output.insert(output.end(), type_name.begin(), type_name.end());
// 写入消息数据
output.insert(output.end(), message_data.begin(), message_data.end());
log_debug("成功序列化消息类型 '%s', 大小: %zu 字节", type_name.c_str(), output.size());
return SerializationError::Success;
}
SerializationError ProtobufSerializer::deserialize(const std::vector<uint8_t>& data, std::unique_ptr<IMessage>& output) {
if (data.size() < HEADER_SIZE) {
log_err("数据大小不足以包含消息头: %zu < %zu", data.size(), HEADER_SIZE);
return SerializationError::InvalidInput;
}
// 解析消息头
MessageHeader header;
std::memcpy(&header, data.data(), HEADER_SIZE);
// 验证魔术数字
if (ntohl(header.magic) != HEADER_MAGIC) {
log_err("无效的消息头魔术数字: 0x%08X", ntohl(header.magic));
return SerializationError::InvalidFormat;
}
// 检查版本
uint16_t version = ntohs(header.version);
if (version != CURRENT_VERSION) {
log_warn("消息版本不匹配: %d != %d", version, CURRENT_VERSION);
// 暂时继续处理,但可能需要版本兼容性处理
}
// 提取消息类型
uint16_t type_length = ntohs(header.type_length);
if (data.size() < HEADER_SIZE + type_length) {
log_err("数据大小不足以包含消息类型: %zu < %zu", data.size(), HEADER_SIZE + type_length);
return SerializationError::InvalidInput;
}
std::string message_type(data.begin() + HEADER_SIZE, data.begin() + HEADER_SIZE + type_length);
// 查找反序列化函数
auto it = deserialize_funcs_.find(message_type);
if (it == deserialize_funcs_.end()) {
log_err("未找到消类型 '%s' 的反序列化函数", message_type.c_str());
return SerializationError::UnsupportedType;
}
// 提取消息数据
size_t message_data_offset = HEADER_SIZE + type_length;
std::vector<uint8_t> message_data(data.begin() + message_data_offset, data.end());
// 反序列化消息
SerializationError result = it->second(message_data, output);
if (result != SerializationError::Success) {
log_err("反序列化消息类型 '%s' 失败: %d", message_type.c_str(), static_cast<int>(result));
return result;
}
log_debug("成功反序列化消息类型 '%s', 数据大小: %zu 字节", message_type.c_str(), message_data.size());
return SerializationError::Success;
}
bool ProtobufSerializer::supports_message_type(const std::string& message_type) const {
return serialize_funcs_.find(message_type) != serialize_funcs_.end() &&
deserialize_funcs_.find(message_type) != deserialize_funcs_.end();
}
void ProtobufSerializer::register_message_type(const std::string& type_name,
const SerializeFunction& serialize_func,
const DeserializeFunction& deserialize_func) {
serialize_funcs_[type_name] = serialize_func;
deserialize_funcs_[type_name] = deserialize_func;
log_debug("注册消息类型: '%s'", type_name.c_str());
}
std::string ProtobufSerializer::extract_message_type(const std::vector<uint8_t>& data) {
if (data.size() < HEADER_SIZE) {
return "";
}
MessageHeader header;
std::memcpy(&header, data.data(), HEADER_SIZE);
if (ntohl(header.magic) != HEADER_MAGIC) {
return "";
}
uint16_t type_length = ntohs(header.type_length);
if (data.size() < HEADER_SIZE + type_length) {
return "";
}
return std::string(data.begin() + HEADER_SIZE, data.begin() + HEADER_SIZE + type_length);
}
// ================================================================================================
// SerializerFactory 实现
// ================================================================================================
SerializerFactory& SerializerFactory::instance() {
static SerializerFactory instance;
return instance;
}
void SerializerFactory::register_serializer(const std::string& name, std::unique_ptr<ISerializer> serializer) {
if (!serializer) {
log_err("尝试注册空的序列化器: '%s'", name.c_str());
return;
}
serializers_[name] = std::move(serializer);
log_info("注册序列化器: '%s'", name.c_str());
}
ISerializer* SerializerFactory::get_serializer(const std::string& name) {
auto it = serializers_.find(name);
if (it != serializers_.end()) {
return it->second.get();
}
log_warn("未找到序列化器: '%s'", name.c_str());
return nullptr;
}
ISerializer* SerializerFactory::get_serializer_for_message_type(const std::string& message_type) {
for (const auto& pair : serializers_) {
if (pair.second->supports_message_type(message_type)) {
return pair.second.get();
}
}
log_warn("未找到支持消息类型 '%s' 的序列化器", message_type.c_str());
return nullptr;
}
ISerializer* SerializerFactory::detect_serializer(const std::vector<uint8_t>& data) {
if (data.size() < sizeof(ProtobufSerializer::MessageHeader)) {
return nullptr;
}
// 检查是否为Protobuf格式
ProtobufSerializer::MessageHeader header;
std::memcpy(&header, data.data(), sizeof(header));
if (ntohl(header.magic) == ProtobufSerializer::HEADER_MAGIC) {
return get_serializer("protobuf");
}
// 可以在这里添加其他格式的检测逻辑
log_warn("无法识别数据格式");
return nullptr;
}
} // namespace audio_backend::communication