222 lines
7.9 KiB
C++
222 lines
7.9 KiB
C++
// ================================================================================================
|
|
// Audio Backend - 消息序列化器实现
|
|
// ================================================================================================
|
|
|
|
#include "serializer.h"
|
|
#include <cstring>
|
|
|
|
// 跨平台网络字节序转换
|
|
#ifdef _WIN32
|
|
#include <winsock2.h>
|
|
#pragma comment(lib, "ws2_32.lib")
|
|
#else
|
|
#include <arpa/inet.h>
|
|
#endif
|
|
|
|
namespace audio_backend::communication {
|
|
|
|
// 使用全局日志函数
|
|
using audio_backend::common::log_info;
|
|
using audio_backend::common::log_warn;
|
|
using audio_backend::common::log_err;
|
|
using audio_backend::common::log_debug;
|
|
|
|
namespace audio_backend::communication {
|
|
|
|
// ================================================================================================
|
|
// ProtobufSerializer 实现
|
|
// ================================================================================================
|
|
|
|
ProtobufSerializer::ProtobufSerializer() {
|
|
log_info("初始化Protobuf序列化器");
|
|
}
|
|
|
|
SerializationError ProtobufSerializer::serialize(const IMessage& message, std::vector<uint8_t>& output) {
|
|
const std::string& type_name = message.message_type();
|
|
|
|
// 查找序列化函数
|
|
auto it = serialize_funcs_.find(type_name);
|
|
if (it == serialize_funcs_.end()) {
|
|
log_err("未找到消息类型 '%s' 的序列化函数", type_name.c_str());
|
|
return SerializationError::UnsupportedType;
|
|
}
|
|
|
|
// 序列化消息内容
|
|
std::vector<uint8_t> message_data;
|
|
SerializationError result = it->second(message, message_data);
|
|
if (result != SerializationError::Success) {
|
|
log_err("序列化消息类型 '%s' 失败: %d", type_name.c_str(), static_cast<int>(result));
|
|
return result;
|
|
}
|
|
|
|
// 构建完整的消息,包含消息头
|
|
MessageHeader header;
|
|
header.magic = htonl(HEADER_MAGIC);
|
|
header.version = htons(CURRENT_VERSION);
|
|
header.type_length = htons(static_cast<uint16_t>(type_name.length()));
|
|
|
|
// 计算总大小
|
|
size_t total_size = HEADER_SIZE + type_name.length() + message_data.size();
|
|
output.reserve(total_size);
|
|
output.clear();
|
|
|
|
// 写入消息头
|
|
const uint8_t* header_bytes = reinterpret_cast<const uint8_t*>(&header);
|
|
output.insert(output.end(), header_bytes, header_bytes + HEADER_SIZE);
|
|
|
|
// 写入消息类型
|
|
output.insert(output.end(), type_name.begin(), type_name.end());
|
|
|
|
// 写入消息数据
|
|
output.insert(output.end(), message_data.begin(), message_data.end());
|
|
|
|
log_debug("成功序列化消息类型 '%s', 大小: %zu 字节", type_name.c_str(), output.size());
|
|
return SerializationError::Success;
|
|
}
|
|
|
|
SerializationError ProtobufSerializer::deserialize(const std::vector<uint8_t>& data, std::unique_ptr<IMessage>& output) {
|
|
if (data.size() < HEADER_SIZE) {
|
|
log_err("数据大小不足以包含消息头: %zu < %zu", data.size(), HEADER_SIZE);
|
|
return SerializationError::InvalidInput;
|
|
}
|
|
|
|
// 解析消息头
|
|
MessageHeader header;
|
|
std::memcpy(&header, data.data(), HEADER_SIZE);
|
|
|
|
// 验证魔术数字
|
|
if (ntohl(header.magic) != HEADER_MAGIC) {
|
|
log_err("无效的消息头魔术数字: 0x%08X", ntohl(header.magic));
|
|
return SerializationError::InvalidFormat;
|
|
}
|
|
|
|
// 检查版本
|
|
uint16_t version = ntohs(header.version);
|
|
if (version != CURRENT_VERSION) {
|
|
log_warn("消息版本不匹配: %d != %d", version, CURRENT_VERSION);
|
|
// 暂时继续处理,但可能需要版本兼容性处理
|
|
}
|
|
|
|
// 提取消息类型
|
|
uint16_t type_length = ntohs(header.type_length);
|
|
if (data.size() < HEADER_SIZE + type_length) {
|
|
log_err("数据大小不足以包含消息类型: %zu < %zu", data.size(), HEADER_SIZE + type_length);
|
|
return SerializationError::InvalidInput;
|
|
}
|
|
|
|
std::string message_type(data.begin() + HEADER_SIZE, data.begin() + HEADER_SIZE + type_length);
|
|
|
|
// 查找反序列化函数
|
|
auto it = deserialize_funcs_.find(message_type);
|
|
if (it == deserialize_funcs_.end()) {
|
|
log_err("未找到消类型 '%s' 的反序列化函数", message_type.c_str());
|
|
return SerializationError::UnsupportedType;
|
|
}
|
|
|
|
// 提取消息数据
|
|
size_t message_data_offset = HEADER_SIZE + type_length;
|
|
std::vector<uint8_t> message_data(data.begin() + message_data_offset, data.end());
|
|
|
|
// 反序列化消息
|
|
SerializationError result = it->second(message_data, output);
|
|
if (result != SerializationError::Success) {
|
|
log_err("反序列化消息类型 '%s' 失败: %d", message_type.c_str(), static_cast<int>(result));
|
|
return result;
|
|
}
|
|
|
|
log_debug("成功反序列化消息类型 '%s', 数据大小: %zu 字节", message_type.c_str(), message_data.size());
|
|
return SerializationError::Success;
|
|
}
|
|
|
|
bool ProtobufSerializer::supports_message_type(const std::string& message_type) const {
|
|
return serialize_funcs_.find(message_type) != serialize_funcs_.end() &&
|
|
deserialize_funcs_.find(message_type) != deserialize_funcs_.end();
|
|
}
|
|
|
|
void ProtobufSerializer::register_message_type(const std::string& type_name,
|
|
const SerializeFunction& serialize_func,
|
|
const DeserializeFunction& deserialize_func) {
|
|
serialize_funcs_[type_name] = serialize_func;
|
|
deserialize_funcs_[type_name] = deserialize_func;
|
|
log_debug("注册消息类型: '%s'", type_name.c_str());
|
|
}
|
|
|
|
std::string ProtobufSerializer::extract_message_type(const std::vector<uint8_t>& data) {
|
|
if (data.size() < HEADER_SIZE) {
|
|
return "";
|
|
}
|
|
|
|
MessageHeader header;
|
|
std::memcpy(&header, data.data(), HEADER_SIZE);
|
|
|
|
if (ntohl(header.magic) != HEADER_MAGIC) {
|
|
return "";
|
|
}
|
|
|
|
uint16_t type_length = ntohs(header.type_length);
|
|
if (data.size() < HEADER_SIZE + type_length) {
|
|
return "";
|
|
}
|
|
|
|
return std::string(data.begin() + HEADER_SIZE, data.begin() + HEADER_SIZE + type_length);
|
|
}
|
|
|
|
// ================================================================================================
|
|
// SerializerFactory 实现
|
|
// ================================================================================================
|
|
|
|
SerializerFactory& SerializerFactory::instance() {
|
|
static SerializerFactory instance;
|
|
return instance;
|
|
}
|
|
|
|
void SerializerFactory::register_serializer(const std::string& name, std::unique_ptr<ISerializer> serializer) {
|
|
if (!serializer) {
|
|
log_err("尝试注册空的序列化器: '%s'", name.c_str());
|
|
return;
|
|
}
|
|
|
|
serializers_[name] = std::move(serializer);
|
|
log_info("注册序列化器: '%s'", name.c_str());
|
|
}
|
|
|
|
ISerializer* SerializerFactory::get_serializer(const std::string& name) {
|
|
auto it = serializers_.find(name);
|
|
if (it != serializers_.end()) {
|
|
return it->second.get();
|
|
}
|
|
|
|
log_warn("未找到序列化器: '%s'", name.c_str());
|
|
return nullptr;
|
|
}
|
|
|
|
ISerializer* SerializerFactory::get_serializer_for_message_type(const std::string& message_type) {
|
|
for (const auto& pair : serializers_) {
|
|
if (pair.second->supports_message_type(message_type)) {
|
|
return pair.second.get();
|
|
}
|
|
}
|
|
|
|
log_warn("未找到支持消息类型 '%s' 的序列化器", message_type.c_str());
|
|
return nullptr;
|
|
}
|
|
|
|
ISerializer* SerializerFactory::detect_serializer(const std::vector<uint8_t>& data) {
|
|
if (data.size() < sizeof(ProtobufSerializer::MessageHeader)) {
|
|
return nullptr;
|
|
}
|
|
|
|
// 检查是否为Protobuf格式
|
|
ProtobufSerializer::MessageHeader header;
|
|
std::memcpy(&header, data.data(), sizeof(header));
|
|
|
|
if (ntohl(header.magic) == ProtobufSerializer::HEADER_MAGIC) {
|
|
return get_serializer("protobuf");
|
|
}
|
|
|
|
// 可以在这里添加其他格式的检测逻辑
|
|
log_warn("无法识别数据格式");
|
|
return nullptr;
|
|
}
|
|
|
|
} // namespace audio_backend::communication
|