Files
mirage/tools/code_generator.py

516 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
代码生成模块
负责生成C++代码包括结构体定义、SPIR-V数组、头文件和绑定信息。
使用Jinja2模板引擎代替字符串拼接。
"""
from __future__ import annotations
import struct
from typing import Dict, List
from .template_loader import get_renderer
from .type_mapping import (
calculate_std430_alignment,
calculate_std430_size,
spirv_type_to_cpp,
)
from .types import (
ArrayTypeInfo,
BufferInfo,
CompilationResult,
MemberInfo,
PushConstantInfo,
ShaderMetadata,
SPIRVReflection,
ToolError,
TypeInfo,
)
# ============ Structure Generation ============
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
"""生成所有buffer/uniform结构体定义"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Structures ============")
lines.append("// Auto-generated from SPIR-V reflection")
lines.append("")
for buffer in reflection.buffers:
struct_lines = generate_single_struct(buffer, reflection.types)
lines.extend(struct_lines)
lines.append("")
return "\n".join(lines)
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成单个结构体定义"""
renderer = get_renderer()
context = {
'buffer': buffer,
'type_map': type_map,
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
}
result = renderer.render('base/struct.jinja2', context)
return result.splitlines()
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成结构体成员定义"""
renderer = get_renderer()
context = {
'member': member,
'type_map': type_map,
}
result = renderer.render('base/struct_member.jinja2', context)
return result.splitlines()
# ============ Push Constants Effect Structure Generation ============
def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants 效果部分结构体定义
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的C++结构体代码如果没有Push Constants效果部分则返回空字符串
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.effect_members:
return ""
# 生成结构体名称
struct_name = f"{shader_name.capitalize()}PushConstantEffect"
# 准备成员信息
members = []
type_map = reflection.types
for member in push_constant.effect_members:
# 计算对齐和大小
alignment = calculate_std430_alignment(member.resolved_type, type_map)
size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
members.append({
'name': member.name,
'cpp_type': cpp_type,
'offset': member.offset,
'size': size,
'alignment': alignment,
})
renderer = get_renderer()
context = {
'struct_name': struct_name,
'base_offset': push_constant.base_offset,
'members': members,
}
return renderer.render('push_constant/effect_struct.jinja2', context)
def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants 基础部分布局信息结构体和静态检查代码
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的C++代码,如果没有基础部分则返回空字符串
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.base_info:
return ""
base_info = push_constant.base_info
# 准备模板数据
base_members = [
{
'name': m.name,
'offset': m.offset,
'size': m.size,
'expected_offset': m.expected_offset,
'expected_size': m.expected_size,
'is_valid': m.is_valid,
}
for m in base_info.members
]
expected_names = ['projection', 'viewport_size', 'time', '_pad0']
present_names = [m.name for m in base_info.members]
renderer = get_renderer()
# 渲染布局结构体
layout_context = {
'struct_name': shader_name,
'base_members': base_members,
'expected_names': expected_names,
'present_names': present_names,
'total_size': base_info.total_size,
'is_standard': base_info.is_standard_layout,
}
layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context)
# 渲染静态检查代码
check_context = {
'struct_name': shader_name,
'base_members': base_members,
'is_standard': base_info.is_standard_layout,
}
check_code = renderer.render('push_constant/static_check.jinja2', check_context)
return layout_code + "\n\n" + check_code
# ============ SPIR-V Array Generation ============
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
"""将SPIR-V二进制转换为C++ uint32_t数组"""
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
# 将数组分成每行12个元素的块
chunks = []
for i in range(0, len(uint32_array), 12):
chunk = uint32_array[i : i + 12]
hex_values = [f"0x{val:08x}" for val in chunk]
chunks.append(hex_values)
renderer = get_renderer()
context = {
'symbol_name': symbol_name,
'chunks': chunks,
}
return renderer.render('base/spirv_array.jinja2', context)
# ============ 缓冲区辅助函数生成 ============
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 数据操作的辅助函数"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ 缓冲区辅助函数 ============")
lines.append("// 自动生成缓冲区数据管理的便利函数")
lines.append("")
for buffer in reflection.buffers:
# 生成创建函数
lines.extend(_generate_create_buffer_function(buffer, shader_name))
lines.append("")
# 生成上传函数
lines.extend(_generate_upload_function(buffer))
lines.append("")
# 生成下载函数
lines.extend(_generate_download_function(buffer))
lines.append("")
return "\n".join(lines)
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
"""生成创建 buffer 的便捷函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
return result.splitlines()
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
"""生成上传数据到 buffer 的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/upload.jinja2', context)
return result.splitlines()
def _generate_download_function(buffer: BufferInfo) -> List[str]:
"""生成从 buffer 下载数据的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/download.jinja2', context)
return result.splitlines()
# ============ Typed Buffer Aliases Generation ============
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成类型安全的 buffer 类型别名和工厂函数"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Typed Buffer Aliases ============")
lines.append("// Type aliases and factory functions for type-safe buffer wrappers")
lines.append("")
for buffer in reflection.buffers:
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
lines.append("")
return "\n".join(lines)
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
"""为特定 buffer 生成类型别名和工厂函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
return result.splitlines()
# ============ Buffer Manager Class Generation ============
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 管理器类"""
if not reflection.buffers:
return ""
lines = []
lines.append("// ============ Buffer Manager Class ============")
lines.append("// Automatic lifecycle management for all shader buffers")
lines.append("")
class_name = f"{shader_name}_buffer_manager"
# 判断是否是计算着色器
is_compute = reflection.shader_stage == "compute"
renderer = get_renderer()
context = {
'class_name': class_name,
'buffers': reflection.buffers,
'is_compute': is_compute,
}
result = renderer.render('buffer_manager/manager_class.jinja2', context)
lines.append(result)
return "\n".join(lines)
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
lines = []
lines.append(" // Container for all buffers")
lines.append(" struct buffers {")
for buffer in buffers:
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
lines.append(f" {alias_name} {var_name};")
lines.append(" };")
return lines
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
# ============ Vulkan Type Conversion ============
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
"""将描述符类型转换为Vulkan C++ enum"""
type_map = {
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
"sampled_image": "vk::DescriptorType::eSampledImage",
"storage_image": "vk::DescriptorType::eStorageImage",
"sampler": "vk::DescriptorType::eSampler",
}
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
"""将着色器阶段列表转换为Vulkan C++ flags"""
stage_map = {
"vertex": "vk::ShaderStageFlagBits::eVertex",
"fragment": "vk::ShaderStageFlagBits::eFragment",
"compute": "vk::ShaderStageFlagBits::eCompute",
"geometry": "vk::ShaderStageFlagBits::eGeometry",
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
}
if not stages:
stages = ["compute"]
flags = [stage_map[stage] for stage in stages if stage in stage_map]
if not flags:
flags = ["vk::ShaderStageFlagBits::eCompute"]
return " | ".join(flags)
# ============ Header File Generation ============
def _to_pascal_case(name: str) -> str:
"""将 snake_case 转换为 PascalCase"""
return ''.join(word.capitalize() for word in name.split('_'))
def generate_header(
metadata: ShaderMetadata,
compilation_results: List[CompilationResult],
) -> str:
"""生成完整的C++头文件"""
# 准备buffer结构体代码
buffer_structures = ""
if metadata.reflection and metadata.reflection.buffers:
buffer_structures = generate_buffer_structures(metadata.reflection)
# 准备buffer辅助函数代码
buffer_helpers = ""
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
# 准备typed buffer代码
typed_buffers = ""
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
# 准备buffer管理器代码
buffer_manager = ""
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
# 准备Push Constants效果结构体代码
push_constant_effect_struct = ""
push_constant_effect_type_name = "void"
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
push_constant_effect_struct = generate_push_constant_effect_struct(metadata.reflection, metadata.name)
push_constant_effect_type_name = f"{metadata.name.capitalize()}PushConstantEffect"
# 准备Push Constants基础部分布局代码
push_constant_base_layout = ""
has_push_constant_base_layout = False
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
push_constant_base_layout = generate_push_constant_base_layout(metadata.reflection, metadata.name)
has_push_constant_base_layout = bool(push_constant_base_layout)
# 准备SPIR-V数组和着色器类型检测
compilation_results_with_arrays = []
has_vert_shader = False
has_frag_shader = False
vert_spirv_symbol = ""
frag_spirv_symbol = ""
for result in compilation_results:
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
compilation_results_with_arrays.append({
'shader_type': result.shader_type,
'source_file': result.source_file,
'spirv_array': spirv_array,
})
# 检测着色器类型
if result.shader_type == "vert":
has_vert_shader = True
vert_spirv_symbol = symbol_name
elif result.shader_type == "frag":
has_frag_shader = True
frag_spirv_symbol = symbol_name
# 准备绑定信息
bindings_data = []
if metadata.bindings:
for binding in metadata.bindings:
bindings_data.append({
'binding': binding.binding,
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
'count': binding.count,
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
})
# 确定参数类型名称(从 buffer 反射信息中获取)
# 优先级buffer > push_constant效果部分 > void
params_type_name = "void"
if metadata.reflection and metadata.reflection.buffers:
# 使用第一个 buffer 的结构体类型名称
params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name
elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
# 没有 buffer但有 push_constant 效果部分,使用效果结构体名称
params_type_name = f"{metadata.name.capitalize()}PushConstantEffect"
# 使用模板渲染
renderer = get_renderer()
context = {
'namespace': metadata.namespace,
'shader_name': metadata.name,
'metadata_struct_name': f"{metadata.name}_metadata",
'bindings_name': f"{metadata.name}_bindings",
'shader_spec_name': f"{metadata.name}_shader_spec",
'params_type_name': params_type_name,
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
'is_compute': metadata.shader_type == "compute",
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
'has_buffers': metadata.reflection and metadata.reflection.buffers,
'generate_typed_buffers': metadata.generate_typed_buffers,
'buffer_structures': buffer_structures,
'buffer_helpers': buffer_helpers,
'typed_buffers': typed_buffers,
'buffer_manager': buffer_manager,
'compilation_results': compilation_results_with_arrays,
'bindings': bindings_data,
# shader_spec 相关变量
'has_vert_shader': has_vert_shader,
'has_frag_shader': has_frag_shader,
'vert_spirv_symbol': vert_spirv_symbol,
'frag_spirv_symbol': frag_spirv_symbol,
# Push Constants 效果结构体相关变量
'push_constant_effect_struct': push_constant_effect_struct,
'push_constant_effect_type_name': push_constant_effect_type_name,
'has_push_constant_effect': bool(push_constant_effect_struct),
# Push Constants 基础部分布局相关变量
'push_constant_base_layout': push_constant_base_layout,
'has_push_constant_base_layout': has_push_constant_base_layout,
}
return renderer.render('base/header.jinja2', context)