Files
mirage/tools/code_generator.py

644 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
代码生成模块
负责生成C++代码包括结构体定义、SPIR-V数组、头文件和绑定信息。
"""
from __future__ import annotations
import struct
from typing import Dict, List
from .type_mapping import (
calculate_std430_alignment,
calculate_std430_size,
spirv_type_to_cpp,
)
from .types import (
ArrayTypeInfo,
BufferInfo,
CompilationResult,
MemberInfo,
ShaderMetadata,
SPIRVReflection,
ToolError,
TypeInfo,
)
# ============ Structure Generation ============
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
"""生成所有buffer/uniform结构体定义"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Structures ============")
lines.append("// Auto-generated from SPIR-V reflection")
lines.append("")
for buffer in reflection.buffers:
struct_lines = generate_single_struct(buffer, reflection.types)
lines.extend(struct_lines)
lines.append("")
return "\n".join(lines)
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成单个结构体定义"""
lines = []
struct = buffer.struct_type
struct_name = buffer.name
# 注释
lines.append(f"// Generated from SPIR-V reflection")
lines.append(f"// Binding: {buffer.binding}, Descriptor Set: {buffer.descriptor_set}")
lines.append(f"// Type: {buffer.descriptor_type}")
# 计算对齐
alignment = calculate_std430_alignment(struct, type_map)
# 结构体声明
lines.append(f"struct alignas({alignment}) {struct_name} {{")
# 成员定义
for member in struct.members:
member_lines = generate_member_definition(member, type_map)
for line in member_lines:
lines.append(f" {line}")
lines.append("};")
return lines
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成结构体成员定义"""
lines = []
member_type = member.resolved_type
if member_type is None:
member_type = type_map.get(member.type_id)
cpp_type = spirv_type_to_cpp(member_type, type_map)
alignment = calculate_std430_alignment(member_type, type_map) if member_type else 4
size = calculate_std430_size(member_type, type_map) if member_type else 4
# 检查是否是运行时数组
is_runtime_array = isinstance(member_type, ArrayTypeInfo) and member_type.is_runtime
# 生成注释
if is_runtime_array:
lines.append(f"// Offset: {member.offset}, Size: runtime, Alignment: {alignment}")
else:
lines.append(f"// Offset: {member.offset}, Size: {size}, Alignment: {alignment}")
# 生成成员声明
if is_runtime_array:
# 运行时数组使用灵活数组成员
lines.append(f"{cpp_type} {member.name}[]; // Runtime-sized array")
elif alignment > 4:
# 需要显式对齐
lines.append(f"alignas({alignment}) {cpp_type} {member.name};")
else:
lines.append(f"{cpp_type} {member.name};")
return lines
# ============ SPIR-V Array Generation ============
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
"""将SPIR-V二进制转换为C++ uint32_t数组"""
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
lines = []
lines.append(f"static constexpr uint32_t {symbol_name}[] = {{")
for i in range(0, len(uint32_array), 12):
chunk = uint32_array[i : i + 12]
hex_values = ", ".join(f"0x{val:08x}" for val in chunk)
lines.append(f" {hex_values},")
if lines[-1].endswith(","):
lines[-1] = lines[-1][:-1]
lines.append("};")
return "\n".join(lines)
# ============ Buffer Helper Functions Generation ============
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 数据操作的辅助函数"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Helper Functions ============")
lines.append("// Auto-generated convenience functions for buffer data management")
lines.append("")
for buffer in reflection.buffers:
# 生成创建函数
lines.extend(_generate_create_buffer_function(buffer, shader_name))
lines.append("")
# 生成上传函数
lines.extend(_generate_upload_function(buffer))
lines.append("")
# 生成下载函数
lines.extend(_generate_download_function(buffer))
lines.append("")
return "\n".join(lines)
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
"""生成创建 buffer 的便捷函数"""
lines = []
func_name = f"create_{buffer.variable_name}_buffer"
struct_name = buffer.name
lines.append(f"// Create buffer for {buffer.variable_name} (binding {buffer.binding})")
lines.append(f"inline auto {func_name}(")
lines.append(f" mirage::render::vulkan::resource_manager& rm,")
lines.append(f" size_t element_count")
lines.append(f") -> mirage::render::vulkan::expected<mirage::render::vulkan::resource_manager::buffer_resource>")
lines.append(f"{{")
# 确定 buffer 用途标志
usage_flags = "vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferDst | vk::BufferUsageFlagBits::eTransferSrc"
# 确定内存属性标志
memory_flags = "vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent"
lines.append(f" return rm.create_buffer(")
lines.append(f" element_count * sizeof({struct_name}),")
lines.append(f" {usage_flags},")
lines.append(f" {memory_flags}")
lines.append(f" );")
lines.append(f"}}")
return lines
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
"""生成上传数据到 buffer 的函数"""
lines = []
func_name = f"upload_{buffer.variable_name}"
struct_name = buffer.name
lines.append(f"// Upload data to {buffer.variable_name} buffer")
lines.append(f"inline void {func_name}(")
lines.append(f" vk::Device device,")
lines.append(f" const mirage::render::vulkan::resource_manager::buffer_resource& buffer,")
lines.append(f" std::span<const {struct_name}> data")
lines.append(f")")
lines.append(f"{{")
lines.append(f" if (data.size_bytes() > buffer.size) {{")
lines.append(f" throw std::runtime_error(\"Data size exceeds buffer capacity\");")
lines.append(f" }}")
lines.append(f" ")
lines.append(f" void* mapped = device.mapMemory(buffer.memory, 0, buffer.size);")
lines.append(f" std::memcpy(mapped, data.data(), data.size_bytes());")
lines.append(f" device.unmapMemory(buffer.memory);")
lines.append(f"}}")
return lines
def _generate_download_function(buffer: BufferInfo) -> List[str]:
"""生成从 buffer 下载数据的函数"""
lines = []
func_name = f"download_{buffer.variable_name}"
struct_name = buffer.name
lines.append(f"// Download data from {buffer.variable_name} buffer")
lines.append(f"inline void {func_name}(")
lines.append(f" vk::Device device,")
lines.append(f" const mirage::render::vulkan::resource_manager::buffer_resource& buffer,")
lines.append(f" std::span<{struct_name}> out_data")
lines.append(f")")
lines.append(f"{{")
lines.append(f" if (out_data.size_bytes() > buffer.size) {{")
lines.append(f" throw std::runtime_error(\"Output buffer size exceeds buffer capacity\");")
lines.append(f" }}")
lines.append(f" ")
lines.append(f" void* mapped = device.mapMemory(buffer.memory, 0, buffer.size);")
lines.append(f" std::memcpy(out_data.data(), mapped, out_data.size_bytes());")
lines.append(f" device.unmapMemory(buffer.memory);")
lines.append(f"}}")
return lines
# ============ Typed Buffer Aliases Generation ============
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成类型安全的 buffer 类型别名和工厂函数"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Typed Buffer Aliases ============")
lines.append("// Type aliases and factory functions for type-safe buffer wrappers")
lines.append("")
for buffer in reflection.buffers:
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
lines.append("")
return "\n".join(lines)
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
"""为特定 buffer 生成类型别名和工厂函数"""
lines = []
struct_name = buffer.name
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
factory_name = f"create_{var_name}_typed"
# 类型别名
lines.append(f"// Type alias for {var_name} typed buffer")
lines.append(f"using {alias_name} = mirage::render::vulkan::typed_buffer<{struct_name}>;")
lines.append("")
# 工厂函数
lines.append(f"// Create typed buffer for {var_name} (binding {buffer.binding})")
lines.append(f"inline auto {factory_name}(")
lines.append(f" vk::Device device,")
lines.append(f" mirage::render::vulkan::resource_manager& rm,")
lines.append(f" size_t count")
lines.append(f") -> mirage::render::vulkan::expected<{alias_name}>")
lines.append(f"{{")
lines.append(f" auto buffer_result = create_{var_name}_buffer(rm, count);")
lines.append(f" if (!buffer_result) {{")
lines.append(f" return std::unexpected(buffer_result.error());")
lines.append(f" }}")
lines.append(f" ")
lines.append(f" return {alias_name}(device, std::move(*buffer_result), count);")
lines.append(f"}}")
return lines
# ============ Buffer Manager Class Generation ============
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 管理器类"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Manager Class ============")
lines.append("// Automatic lifecycle management for all shader buffers")
lines.append("")
class_name = f"{shader_name}_buffer_manager"
# 类定义开始
lines.append(f"class {class_name} {{")
lines.append("public:")
# buffers 结构体
lines.extend(_generate_buffers_struct(reflection.buffers))
lines.append("")
# 静态工厂方法
lines.extend(_generate_manager_factory(reflection.buffers, class_name))
lines.append("")
# 初始化方法
lines.extend(_generate_initialize_method(reflection.buffers))
lines.append("")
# 绑定到 descriptor set 的方法
lines.extend(_generate_bind_method(reflection.buffers))
lines.append("")
# 访问器
lines.append(" // Access buffers")
lines.append(" auto& get_buffers() { return buffers_; }")
lines.append(" const auto& get_buffers() const { return buffers_; }")
lines.append("")
# 私有部分
lines.append("private:")
lines.extend(_generate_manager_constructor(reflection.buffers, class_name))
lines.append("")
# 成员变量
lines.append(" vk::Device device_;")
lines.append(" buffers buffers_;")
lines.append("};")
return "\n".join(lines)
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
"""生成 buffers 结构体"""
lines = []
lines.append(" // Container for all buffers")
lines.append(" struct buffers {")
for buffer in buffers:
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
lines.append(f" {alias_name} {var_name};")
lines.append(" };")
return lines
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的工厂方法"""
lines = []
lines.append(" // Factory method to create buffer manager")
lines.append(" static auto create(")
lines.append(" vk::Device device,")
lines.append(" mirage::render::vulkan::resource_manager& rm,")
lines.append(" size_t element_count")
lines.append(f" ) -> mirage::render::vulkan::expected<{class_name}>")
lines.append(" {")
# 创建所有 buffers
for i, buffer in enumerate(buffers):
var_name = buffer.variable_name
factory_name = f"create_{var_name}_typed"
result_var = f"{var_name}_result"
lines.append(f" auto {result_var} = {factory_name}(device, rm, element_count);")
lines.append(f" if (!{result_var}) {{")
lines.append(f" return std::unexpected({result_var}.error());")
lines.append(f" }}")
if i < len(buffers) - 1:
lines.append("")
lines.append("")
lines.append(f" return {class_name}(")
lines.append(" device,")
for i, buffer in enumerate(buffers):
var_name = buffer.variable_name
result_var = f"{var_name}_result"
comma = "," if i < len(buffers) - 1 else ""
lines.append(f" std::move(*{result_var}){comma}")
lines.append(" );")
lines.append(" }")
return lines
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
"""生成初始化方法"""
lines = []
lines.append(" // Initialize all buffers with data")
lines.append(" void initialize(")
for i, buffer in enumerate(buffers):
var_name = buffer.variable_name
struct_name = buffer.name
comma = "," if i < len(buffers) - 1 else ""
lines.append(f" std::span<const {struct_name}> {var_name}_data{comma}")
lines.append(" ) {")
for buffer in buffers:
var_name = buffer.variable_name
lines.append(f" buffers_.{var_name}.upload({var_name}_data);")
lines.append(" }")
return lines
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
"""生成绑定到 descriptor set 的方法"""
lines = []
lines.append(" // Bind all buffers to a descriptor set")
lines.append(" void bind_to_descriptor_set(")
lines.append(" const mirage::render::vulkan::compute_pipeline& pipeline,")
lines.append(" vk::DescriptorSet descriptor_set")
lines.append(" ) const {")
for buffer in buffers:
var_name = buffer.variable_name
binding = buffer.binding
descriptor_type = "vk::DescriptorType::eStorageBuffer" if buffer.descriptor_type == "storage_buffer" else "vk::DescriptorType::eUniformBuffer"
lines.append(f" pipeline.bind_buffer(")
lines.append(f" descriptor_set,")
lines.append(f" {binding},")
lines.append(f" buffers_.{var_name}.get_buffer(),")
lines.append(f" {descriptor_type}")
lines.append(f" );")
lines.append(" }")
return lines
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的私有构造函数"""
lines = []
lines.append(f" {class_name}(")
lines.append(" vk::Device device,")
for i, buffer in enumerate(buffers):
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
comma = "," if i < len(buffers) - 1 else ""
lines.append(f" {alias_name} {var_name}{comma}")
lines.append(" ) : device_(device),")
lines.append(" buffers_{")
for i, buffer in enumerate(buffers):
var_name = buffer.variable_name
comma = "," if i < len(buffers) - 1 else ""
lines.append(f" .{var_name} = std::move({var_name}){comma}")
lines.append(" } {}")
return lines
# ============ Vulkan Type Conversion ============
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
"""将描述符类型转换为Vulkan C++ enum"""
type_map = {
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
"sampled_image": "vk::DescriptorType::eSampledImage",
"storage_image": "vk::DescriptorType::eStorageImage",
"sampler": "vk::DescriptorType::eSampler",
}
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
"""将着色器阶段列表转换为Vulkan C++ flags"""
stage_map = {
"vertex": "vk::ShaderStageFlagBits::eVertex",
"fragment": "vk::ShaderStageFlagBits::eFragment",
"compute": "vk::ShaderStageFlagBits::eCompute",
"geometry": "vk::ShaderStageFlagBits::eGeometry",
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
}
if not stages:
stages = ["compute"]
flags = [stage_map[stage] for stage in stages if stage in stage_map]
if not flags:
flags = ["vk::ShaderStageFlagBits::eCompute"]
return " | ".join(flags)
# ============ Header File Generation ============
def generate_header(
metadata: ShaderMetadata,
compilation_results: List[CompilationResult],
) -> str:
"""生成完整的C++头文件"""
lines = []
lines.append("#pragma once")
lines.append("")
lines.append("#include <cstdint>")
lines.append("#include <span>")
lines.append("#include <array>")
lines.append("#include <vulkan/vulkan.hpp>")
lines.append("#include <Eigen/Eigen>")
# 如果启用了 typed buffer 功能,添加相关头文件
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
lines.append("#include \"vulkan/typed_buffer.h\"")
lines.append("")
if metadata.namespace:
lines.append(f"namespace {metadata.namespace} {{")
lines.append("")
# 生成结构体定义(如果有反射信息)
if metadata.reflection and metadata.reflection.buffers:
struct_code = generate_buffer_structures(metadata.reflection)
lines.append(struct_code)
lines.append("")
# 生成 buffer 辅助函数
if metadata.generate_buffer_helpers:
helper_code = generate_buffer_helper_functions(metadata.reflection, metadata.name)
if helper_code:
lines.append(helper_code)
lines.append("")
# 生成 typed buffer 类型别名和工厂函数
if metadata.generate_typed_buffers:
typed_buffer_code = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
if typed_buffer_code:
lines.append(typed_buffer_code)
lines.append("")
# 生成 buffer 管理器类
if metadata.generate_buffer_manager:
manager_code = generate_buffer_manager_class(metadata.reflection, metadata.name)
if manager_code:
lines.append(manager_code)
lines.append("")
# SPIR-V二进制数组
for result in compilation_results:
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
lines.append(f"// {result.shader_type.capitalize()} shader: {result.source_file.name}")
lines.append("// Compiled with: glslc ... --target-env=vulkan1.2")
lines.append(spirv_to_cpp_array(result.spirv_data, symbol_name))
lines.append("")
# 元数据
metadata_struct_name = f"{metadata.name}_metadata"
lines.append(f"struct {metadata_struct_name} {{")
lines.append(f' static constexpr std::string_view name = "{metadata.name}";')
lines.append(f' static constexpr std::string_view entry_point = "{compilation_results[0].entry_point}";')
lines.append("};")
lines.append("")
# 绑定信息
if metadata.bindings:
bindings_name = f"{metadata.name}_bindings"
lines.append(
f"inline constexpr std::array<mirage::render::vulkan::compute_pipeline::binding_info, {len(metadata.bindings)}>"
)
lines.append(f"{bindings_name} = {{{{")
for binding in metadata.bindings:
descriptor_type = vk_descriptor_type_to_cpp(binding.descriptor_type)
stage_flags = vk_shader_stage_to_cpp(binding.stages)
lines.append(" mirage::render::vulkan::compute_pipeline::binding_info{")
lines.append(f" .binding = {binding.binding},")
lines.append(f" .type = {descriptor_type},")
lines.append(f" .count = {binding.count},")
lines.append(f" .stage = {stage_flags},")
lines.append(" },")
lines.append("}};")
lines.append("")
# 工厂函数
if metadata.shader_type == "compute":
result = compilation_results[0]
spirv_array = f"{metadata.name}_{result.shader_type}_spirv"
lines.append("// ============ Factory Function ============")
lines.append(
f"inline auto create_{metadata.name}_pipeline(const mirage::render::vulkan::logical_device& device)"
)
lines.append(
" -> mirage::render::vulkan::expected<mirage::render::vulkan::compute_pipeline>"
)
lines.append("{")
lines.append(f" return mirage::render::vulkan::compute_pipeline::create_from_spv(")
lines.append(f" device,")
lines.append(f" std::span<const uint32_t>({spirv_array}),")
if metadata.bindings:
lines.append(f" {metadata.name}_bindings")
else:
lines.append(" std::span<const mirage::render::vulkan::compute_pipeline::binding_info>{{}}")
lines.append(f" );")
lines.append("}")
lines.append("")
if metadata.namespace:
lines.append(f"}} // namespace {metadata.namespace}")
lines.append("")
return "\n".join(lines)