516 lines
18 KiB
Python
516 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
代码生成模块
|
||
|
||
负责生成C++代码,包括结构体定义、SPIR-V数组、头文件和绑定信息。
|
||
使用Jinja2模板引擎代替字符串拼接。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
from typing import Dict, List
|
||
|
||
from .template_loader import get_renderer
|
||
from .type_mapping import (
|
||
calculate_std430_alignment,
|
||
calculate_std430_size,
|
||
spirv_type_to_cpp,
|
||
)
|
||
from .types import (
|
||
ArrayTypeInfo,
|
||
BufferInfo,
|
||
CompilationResult,
|
||
MemberInfo,
|
||
PushConstantInfo,
|
||
ShaderMetadata,
|
||
SPIRVReflection,
|
||
ToolError,
|
||
TypeInfo,
|
||
)
|
||
|
||
|
||
# ============ Structure Generation ============
|
||
|
||
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
|
||
"""生成所有buffer/uniform结构体定义"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Buffer Structures ============")
|
||
lines.append("// Auto-generated from SPIR-V reflection")
|
||
lines.append("")
|
||
|
||
for buffer in reflection.buffers:
|
||
struct_lines = generate_single_struct(buffer, reflection.types)
|
||
lines.extend(struct_lines)
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成单个结构体定义"""
|
||
renderer = get_renderer()
|
||
|
||
context = {
|
||
'buffer': buffer,
|
||
'type_map': type_map,
|
||
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
|
||
}
|
||
|
||
result = renderer.render('base/struct.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成结构体成员定义"""
|
||
renderer = get_renderer()
|
||
|
||
context = {
|
||
'member': member,
|
||
'type_map': type_map,
|
||
}
|
||
|
||
result = renderer.render('base/struct_member.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Push Constants Effect Structure Generation ============
|
||
|
||
def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants 效果部分结构体定义
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的C++结构体代码,如果没有Push Constants效果部分则返回空字符串
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.effect_members:
|
||
return ""
|
||
|
||
# 生成结构体名称
|
||
struct_name = f"{shader_name.capitalize()}PushConstantEffect"
|
||
|
||
# 准备成员信息
|
||
members = []
|
||
type_map = reflection.types
|
||
|
||
for member in push_constant.effect_members:
|
||
# 计算对齐和大小
|
||
alignment = calculate_std430_alignment(member.resolved_type, type_map)
|
||
size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
members.append({
|
||
'name': member.name,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': size,
|
||
'alignment': alignment,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': struct_name,
|
||
'base_offset': push_constant.base_offset,
|
||
'members': members,
|
||
}
|
||
|
||
return renderer.render('push_constant/effect_struct.jinja2', context)
|
||
|
||
|
||
def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants 基础部分布局信息结构体和静态检查代码
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的C++代码,如果没有基础部分则返回空字符串
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.base_info:
|
||
return ""
|
||
|
||
base_info = push_constant.base_info
|
||
|
||
# 准备模板数据
|
||
base_members = [
|
||
{
|
||
'name': m.name,
|
||
'offset': m.offset,
|
||
'size': m.size,
|
||
'expected_offset': m.expected_offset,
|
||
'expected_size': m.expected_size,
|
||
'is_valid': m.is_valid,
|
||
}
|
||
for m in base_info.members
|
||
]
|
||
|
||
expected_names = ['projection', 'viewport_size', 'time', '_pad0']
|
||
present_names = [m.name for m in base_info.members]
|
||
|
||
renderer = get_renderer()
|
||
|
||
# 渲染布局结构体
|
||
layout_context = {
|
||
'struct_name': shader_name,
|
||
'base_members': base_members,
|
||
'expected_names': expected_names,
|
||
'present_names': present_names,
|
||
'total_size': base_info.total_size,
|
||
'is_standard': base_info.is_standard_layout,
|
||
}
|
||
|
||
layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context)
|
||
|
||
# 渲染静态检查代码
|
||
check_context = {
|
||
'struct_name': shader_name,
|
||
'base_members': base_members,
|
||
'is_standard': base_info.is_standard_layout,
|
||
}
|
||
|
||
check_code = renderer.render('push_constant/static_check.jinja2', check_context)
|
||
|
||
return layout_code + "\n\n" + check_code
|
||
|
||
|
||
# ============ SPIR-V Array Generation ============
|
||
|
||
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
|
||
"""将SPIR-V二进制转换为C++ uint32_t数组"""
|
||
if len(spirv_data) % 4 != 0:
|
||
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
|
||
|
||
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
|
||
|
||
# 将数组分成每行12个元素的块
|
||
chunks = []
|
||
for i in range(0, len(uint32_array), 12):
|
||
chunk = uint32_array[i : i + 12]
|
||
hex_values = [f"0x{val:08x}" for val in chunk]
|
||
chunks.append(hex_values)
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'symbol_name': symbol_name,
|
||
'chunks': chunks,
|
||
}
|
||
|
||
return renderer.render('base/spirv_array.jinja2', context)
|
||
|
||
|
||
# ============ 缓冲区辅助函数生成 ============
|
||
|
||
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 数据操作的辅助函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ 缓冲区辅助函数 ============")
|
||
lines.append("// 自动生成缓冲区数据管理的便利函数")
|
||
lines.append("")
|
||
|
||
for buffer in reflection.buffers:
|
||
# 生成创建函数
|
||
lines.extend(_generate_create_buffer_function(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
# 生成上传函数
|
||
lines.extend(_generate_upload_function(buffer))
|
||
lines.append("")
|
||
|
||
# 生成下载函数
|
||
lines.extend(_generate_download_function(buffer))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""生成创建 buffer 的便捷函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer, 'shader_name': shader_name}
|
||
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成上传数据到 buffer 的函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer}
|
||
result = renderer.render('buffer_helpers/upload.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def _generate_download_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成从 buffer 下载数据的函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer}
|
||
result = renderer.render('buffer_helpers/download.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Typed Buffer Aliases Generation ============
|
||
|
||
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成类型安全的 buffer 类型别名和工厂函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Typed Buffer Aliases ============")
|
||
lines.append("// Type aliases and factory functions for type-safe buffer wrappers")
|
||
lines.append("")
|
||
|
||
for buffer in reflection.buffers:
|
||
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""为特定 buffer 生成类型别名和工厂函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer, 'shader_name': shader_name}
|
||
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Buffer Manager Class Generation ============
|
||
|
||
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 管理器类"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
lines.append("// ============ Buffer Manager Class ============")
|
||
lines.append("// Automatic lifecycle management for all shader buffers")
|
||
lines.append("")
|
||
|
||
class_name = f"{shader_name}_buffer_manager"
|
||
|
||
# 判断是否是计算着色器
|
||
is_compute = reflection.shader_stage == "compute"
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'class_name': class_name,
|
||
'buffers': reflection.buffers,
|
||
'is_compute': is_compute,
|
||
}
|
||
|
||
result = renderer.render('buffer_manager/manager_class.jinja2', context)
|
||
lines.append(result)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
|
||
lines = []
|
||
lines.append(" // Container for all buffers")
|
||
lines.append(" struct buffers {")
|
||
for buffer in buffers:
|
||
var_name = buffer.variable_name
|
||
alias_name = f"{var_name}_buffer"
|
||
lines.append(f" {alias_name} {var_name};")
|
||
lines.append(" };")
|
||
return lines
|
||
|
||
|
||
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
# ============ Vulkan Type Conversion ============
|
||
|
||
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
|
||
"""将描述符类型转换为Vulkan C++ enum"""
|
||
type_map = {
|
||
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
|
||
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
|
||
"sampled_image": "vk::DescriptorType::eSampledImage",
|
||
"storage_image": "vk::DescriptorType::eStorageImage",
|
||
"sampler": "vk::DescriptorType::eSampler",
|
||
}
|
||
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
|
||
|
||
|
||
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
|
||
"""将着色器阶段列表转换为Vulkan C++ flags"""
|
||
stage_map = {
|
||
"vertex": "vk::ShaderStageFlagBits::eVertex",
|
||
"fragment": "vk::ShaderStageFlagBits::eFragment",
|
||
"compute": "vk::ShaderStageFlagBits::eCompute",
|
||
"geometry": "vk::ShaderStageFlagBits::eGeometry",
|
||
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
|
||
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
|
||
}
|
||
|
||
if not stages:
|
||
stages = ["compute"]
|
||
|
||
flags = [stage_map[stage] for stage in stages if stage in stage_map]
|
||
if not flags:
|
||
flags = ["vk::ShaderStageFlagBits::eCompute"]
|
||
|
||
return " | ".join(flags)
|
||
|
||
|
||
# ============ Header File Generation ============
|
||
|
||
def _to_pascal_case(name: str) -> str:
|
||
"""将 snake_case 转换为 PascalCase"""
|
||
return ''.join(word.capitalize() for word in name.split('_'))
|
||
|
||
|
||
def generate_header(
|
||
metadata: ShaderMetadata,
|
||
compilation_results: List[CompilationResult],
|
||
) -> str:
|
||
"""生成完整的C++头文件"""
|
||
# 准备buffer结构体代码
|
||
buffer_structures = ""
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
buffer_structures = generate_buffer_structures(metadata.reflection)
|
||
|
||
# 准备buffer辅助函数代码
|
||
buffer_helpers = ""
|
||
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
|
||
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
|
||
|
||
# 准备typed buffer代码
|
||
typed_buffers = ""
|
||
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
|
||
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
|
||
|
||
# 准备buffer管理器代码
|
||
buffer_manager = ""
|
||
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
|
||
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
|
||
|
||
# 准备Push Constants效果结构体代码
|
||
push_constant_effect_struct = ""
|
||
push_constant_effect_type_name = "void"
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
|
||
push_constant_effect_struct = generate_push_constant_effect_struct(metadata.reflection, metadata.name)
|
||
push_constant_effect_type_name = f"{metadata.name.capitalize()}PushConstantEffect"
|
||
|
||
# 准备Push Constants基础部分布局代码
|
||
push_constant_base_layout = ""
|
||
has_push_constant_base_layout = False
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
|
||
push_constant_base_layout = generate_push_constant_base_layout(metadata.reflection, metadata.name)
|
||
has_push_constant_base_layout = bool(push_constant_base_layout)
|
||
|
||
# 准备SPIR-V数组和着色器类型检测
|
||
compilation_results_with_arrays = []
|
||
has_vert_shader = False
|
||
has_frag_shader = False
|
||
vert_spirv_symbol = ""
|
||
frag_spirv_symbol = ""
|
||
|
||
for result in compilation_results:
|
||
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
|
||
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
|
||
compilation_results_with_arrays.append({
|
||
'shader_type': result.shader_type,
|
||
'source_file': result.source_file,
|
||
'spirv_array': spirv_array,
|
||
})
|
||
|
||
# 检测着色器类型
|
||
if result.shader_type == "vert":
|
||
has_vert_shader = True
|
||
vert_spirv_symbol = symbol_name
|
||
elif result.shader_type == "frag":
|
||
has_frag_shader = True
|
||
frag_spirv_symbol = symbol_name
|
||
|
||
# 准备绑定信息
|
||
bindings_data = []
|
||
if metadata.bindings:
|
||
for binding in metadata.bindings:
|
||
bindings_data.append({
|
||
'binding': binding.binding,
|
||
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
|
||
'count': binding.count,
|
||
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
|
||
})
|
||
|
||
# 确定参数类型名称(从 buffer 反射信息中获取)
|
||
# 优先级:buffer > push_constant效果部分 > void
|
||
params_type_name = "void"
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
# 使用第一个 buffer 的结构体类型名称
|
||
params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name
|
||
elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
|
||
# 没有 buffer,但有 push_constant 效果部分,使用效果结构体名称
|
||
params_type_name = f"{metadata.name.capitalize()}PushConstantEffect"
|
||
|
||
# 使用模板渲染
|
||
renderer = get_renderer()
|
||
context = {
|
||
'namespace': metadata.namespace,
|
||
'shader_name': metadata.name,
|
||
'metadata_struct_name': f"{metadata.name}_metadata",
|
||
'bindings_name': f"{metadata.name}_bindings",
|
||
'shader_spec_name': f"{metadata.name}_shader_spec",
|
||
'params_type_name': params_type_name,
|
||
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
|
||
'is_compute': metadata.shader_type == "compute",
|
||
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
|
||
'has_buffers': metadata.reflection and metadata.reflection.buffers,
|
||
'generate_typed_buffers': metadata.generate_typed_buffers,
|
||
'buffer_structures': buffer_structures,
|
||
'buffer_helpers': buffer_helpers,
|
||
'typed_buffers': typed_buffers,
|
||
'buffer_manager': buffer_manager,
|
||
'compilation_results': compilation_results_with_arrays,
|
||
'bindings': bindings_data,
|
||
# shader_spec 相关变量
|
||
'has_vert_shader': has_vert_shader,
|
||
'has_frag_shader': has_frag_shader,
|
||
'vert_spirv_symbol': vert_spirv_symbol,
|
||
'frag_spirv_symbol': frag_spirv_symbol,
|
||
# Push Constants 效果结构体相关变量
|
||
'push_constant_effect_struct': push_constant_effect_struct,
|
||
'push_constant_effect_type_name': push_constant_effect_type_name,
|
||
'has_push_constant_effect': bool(push_constant_effect_struct),
|
||
# Push Constants 基础部分布局相关变量
|
||
'push_constant_base_layout': push_constant_base_layout,
|
||
'has_push_constant_base_layout': has_push_constant_base_layout,
|
||
}
|
||
|
||
return renderer.render('base/header.jinja2', context) |