Files
mirage/tools/code_generator.py
daiqingshuang 5a8d62f841 Refactor Push Constants and Add Dual Stage Support
- Removed legacy push constant structures and functions for better clarity and maintainability.
- Introduced new `text_push_constants_t` structure for text rendering with optimized layout.
- Implemented dual stage push constant analysis to support separate layouts for vertex and fragment shaders.
- Added functions to generate push constant structures and fill functions based on shader reflection.
- Enhanced static checks for push constant layouts to ensure compatibility and correctness.
- Updated templates to accommodate new dual stage push constant generation.
- Added support detection for procedural vertex shaders based on push constant layout.
2025-12-25 21:04:39 +08:00

1149 lines
43 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
代码生成模块
负责生成C++代码包括结构体定义、SPIR-V数组、头文件和绑定信息。
使用Jinja2模板引擎代替字符串拼接。
"""
from __future__ import annotations
import struct
from typing import Dict, List
from .template_loader import get_renderer
from .type_mapping import (
calculate_std430_alignment,
calculate_std430_size,
spirv_type_to_cpp,
)
from .types import (
ArrayTypeInfo,
BufferInfo,
CombinedPushConstantInfo,
CompilationResult,
InstanceBufferInfo,
MemberInfo,
PushConstantInfo,
ShaderMetadata,
SPIRVReflection,
StagePushConstantInfo,
ToolError,
TypeInfo,
VertexAttribute,
VertexLayout,
)
from .spirv_parser import analyze_dual_stage_push_constants
from .constants import (
PUSH_CONSTANT_HEADER_SIZE,
PUSH_CONSTANT_CUSTOM_OFFSET,
PUSH_CONSTANT_CUSTOM_MAX_SIZE,
)
def check_supports_procedural_vertex_shader(reflection: SPIRVReflection) -> bool:
"""检测片元着色器是否支持 custom_shader_quad_procedural 顶点着色器
检测条件:片元着色器的 push_constant 中第一个用户自定义成员的偏移正好是 16 字节。
这意味着片元着色器的 push_constant 布局:
- Header [0-15]: scale + translate (由顶点着色器定义并填充)
- Custom [16-127]: 用户自定义数据 (第一个成员从 offset 16 开始)
Args:
reflection: SPIR-V 反射信息(通常来自片元着色器)
Returns:
True 如果支持 procedural 顶点着色器,否则 False
"""
push_constant = reflection.push_constant
if not push_constant:
return False
# 检查原始结构体成员,找到 offset >= 16 的第一个成员
struct_type = push_constant.struct_type
if not struct_type or not struct_type.members:
return False
# 找到所有 offset >= PUSH_CONSTANT_CUSTOM_OFFSET 的成员
custom_members = [m for m in struct_type.members if m.offset >= PUSH_CONSTANT_CUSTOM_OFFSET]
if not custom_members:
return False
# 检查第一个 custom 成员的偏移是否正好是 16
first_custom_offset = min(m.offset for m in custom_members)
return first_custom_offset == PUSH_CONSTANT_CUSTOM_OFFSET
# ============ Structure Generation ============
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
"""生成所有buffer/uniform结构体定义"""
if not reflection.buffers:
return ""
lines = []
for buffer in reflection.buffers:
struct_lines = generate_single_struct(buffer, reflection.types)
lines.extend(struct_lines)
lines.append("")
return "\n".join(lines)
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成单个结构体定义"""
renderer = get_renderer()
context = {
'buffer': buffer,
'type_map': type_map,
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
}
result = renderer.render('base/struct.jinja2', context)
return result.splitlines()
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
"""生成结构体成员定义"""
renderer = get_renderer()
context = {
'member': member,
'type_map': type_map,
}
result = renderer.render('base/struct_member.jinja2', context)
return result.splitlines()
# ============ Push Constants Custom Structure Generation ============
def generate_push_constant_custom_struct(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants 用户自定义参数结构体定义 (Custom 部分)
布局规范:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Push Constants 内存布局 (128 bytes) │
├─────────────────────────────────────────────────────────────────────────────┤
│ Header [0-15]: scale + translate - 16 bytes │ ← 系统自动填充
│ Custom [16-127]: 用户自定义数据 - 112 bytes │ ← 本函数生成
└─────────────────────────────────────────────────────────────────────────────┘
识别规则:
- 仅收集 offset >= 16 的成员(跳过 scale 和 translate
- 计算相对偏移量(绝对偏移 - 16
- 生成的 C++ 结构体只包含 Custom 区域的变量
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的C++结构体代码,如果没有用户自定义部分则返回空字符串
Note:
与 push_constant_traits.h 中的 PUSH_CONSTANT_CUSTOM_OFFSET (16)
和 PUSH_CONSTANT_CUSTOM_MAX_SIZE (112) 常量保持一致。
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.effect_members:
return ""
# 生成结构体名称(保持兼容性,内部使用 custom_members
struct_name = f"{_to_pascal_case(shader_name)}PushConstantCustom"
# 准备成员信息
members = []
type_map = reflection.types
for member in push_constant.effect_members: # effect_members 现在存储的是 custom_members
# 计算对齐和大小
alignment = calculate_std430_alignment(member.resolved_type, type_map)
size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
members.append({
'name': member.name,
'cpp_type': cpp_type,
'offset': member.offset,
'size': size,
'alignment': alignment,
})
renderer = get_renderer()
context = {
'struct_name': struct_name,
'base_offset': PUSH_CONSTANT_CUSTOM_OFFSET, # 使用 Custom 部分偏移量
'members': members,
}
return renderer.render('push_constant/custom_struct.jinja2', context)
# 兼容性别名
def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str:
"""已弃用:请使用 generate_push_constant_custom_struct
保留此函数以保持向后兼容性。
"""
return generate_push_constant_custom_struct(reflection, shader_name)
def generate_push_constant_header_layout(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants Header 部分布局信息结构体
Header 部分包含 scale 和 translate 变量(偏移 0-15用于 NDC 变换:
┌─────────────────────────────────────────────────────────────────────────────┐
│ Header [0-15]: scale (vec2) + translate (vec2) - 16 bytes │
│ scale: 偏移 0, 大小 8 │
│ translate: 偏移 8, 大小 8 │
└─────────────────────────────────────────────────────────────────────────────┘
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的C++代码,如果没有 Header 部分则返回空字符串
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.base_info:
return ""
base_info = push_constant.base_info
# 准备模板数据
header_members = [
{
'name': m.name,
'offset': m.offset,
'size': m.size,
'expected_offset': m.expected_offset,
'expected_size': m.expected_size,
'is_valid': m.is_valid,
}
for m in base_info.members
]
expected_names = ['scale', 'translate']
present_names = [m.name for m in base_info.members]
renderer = get_renderer()
# 渲染布局结构体
layout_context = {
'struct_name': shader_name,
'header_members': header_members,
'expected_names': expected_names,
'present_names': present_names,
'header_size': base_info.total_size,
'is_standard': base_info.is_standard_layout,
}
layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context)
return layout_code
def generate_push_constant_static_check(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 Push Constants 静态检查代码(不含 custom 结构体检查)
注意custom 结构体的 sizeof 检查需要在结构体定义之后进行,
因此此函数只生成 Header 部分的静态检查。
Args:
reflection: SPIR-V反射信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的静态检查C++代码
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.base_info:
return ""
base_info = push_constant.base_info
# 准备模板数据
header_members = [
{
'name': m.name,
'offset': m.offset,
'size': m.size,
'expected_offset': m.expected_offset,
'expected_size': m.expected_size,
'is_valid': m.is_valid,
}
for m in base_info.members
]
expected_names = ['scale', 'translate']
present_names = [m.name for m in base_info.members]
renderer = get_renderer()
# 渲染静态检查代码(不含 custom 部分)
check_context = {
'struct_name': shader_name,
'header_members': header_members,
'is_standard': base_info.is_standard_layout,
'has_custom_members': False, # 不在这里检查 custom
}
return renderer.render('push_constant/static_check.jinja2', check_context)
# 兼容性别名
def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str:
"""已弃用:请使用 generate_push_constant_header_layout
保留此函数以保持向后兼容性。
"""
return generate_push_constant_header_layout(reflection, shader_name)
# ============ Dual Stage Push Constants Structure Generation ============
def _generate_stage_push_constant_struct(
stage_info: StagePushConstantInfo,
shader_name: str,
type_map: Dict[int, TypeInfo]
) -> str:
"""生成单个阶段的 Push Constant 结构体
Args:
stage_info: 阶段 Push Constant 信息
shader_name: 着色器名称
type_map: 类型映射
Returns:
生成的 C++ 结构体代码
"""
if not stage_info or not stage_info.has_members:
return ""
# 生成结构体名称
stage_suffix = "Vertex" if stage_info.stage == "vertex" else "Fragment"
struct_name = f"{_to_pascal_case(shader_name)}{stage_suffix}PushConstant"
# 准备成员信息
members = []
for member in stage_info.members:
alignment = calculate_std430_alignment(member.resolved_type, type_map)
size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
members.append({
'name': member.name,
'cpp_type': cpp_type,
'offset': member.offset,
'size': size,
'alignment': alignment,
})
renderer = get_renderer()
context = {
'struct_name': struct_name,
'base_offset': stage_info.relative_offset,
'members': members,
}
return renderer.render('push_constant/custom_struct.jinja2', context)
def _generate_dual_stage_fill_function(
combined_info: CombinedPushConstantInfo,
shader_name: str,
type_map: Dict[int, TypeInfo]
) -> str:
"""生成双阶段 Push Constant 填充函数
根据布局模式生成不同的填充函数:
- 共享模式:单一填充函数,同时填充两个阶段
- 分离模式:分别为每个阶段生成填充函数
Args:
combined_info: 合并后的双阶段信息
shader_name: 着色器名称
type_map: 类型映射
Returns:
生成的 C++ 填充函数代码
"""
lines = []
pascal_name = _to_pascal_case(shader_name)
if combined_info.is_shared_layout:
# 共享模式:生成单一填充函数
struct_name = f"{pascal_name}PushConstantCustom"
lines.append(f"/**")
lines.append(f" * @brief 填充 Push Constants Custom 区域(共享布局)")
lines.append(f" * ")
lines.append(f" * 顶点和片元着色器共享相同的 Custom 区域布局。")
lines.append(f" * ")
lines.append(f" * @param cmd 命令缓冲")
lines.append(f" * @param layout 管线布局")
lines.append(f" * @param data Custom 区域数据")
lines.append(f" */")
lines.append(f"inline void fill_push_constants(")
lines.append(f" vk::CommandBuffer cmd,")
lines.append(f" vk::PipelineLayout layout,")
lines.append(f" const {struct_name}& data")
lines.append(f") {{")
lines.append(f" cmd.pushConstants(")
lines.append(f" layout,")
lines.append(f" vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment,")
lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,")
lines.append(f" sizeof({struct_name}),")
lines.append(f" &data")
lines.append(f" );")
lines.append(f"}}")
else:
# 分离模式:为每个阶段生成填充函数
if combined_info.has_vertex:
vert_struct = f"{pascal_name}VertexPushConstant"
lines.append(f"/**")
lines.append(f" * @brief 填充顶点着色器 Push Constants Custom 区域")
lines.append(f" * ")
lines.append(f" * @param cmd 命令缓冲")
lines.append(f" * @param layout 管线布局")
lines.append(f" * @param data 顶点着色器 Custom 区域数据")
lines.append(f" */")
lines.append(f"inline void fill_vertex_push_constants(")
lines.append(f" vk::CommandBuffer cmd,")
lines.append(f" vk::PipelineLayout layout,")
lines.append(f" const {vert_struct}& data")
lines.append(f") {{")
lines.append(f" cmd.pushConstants(")
lines.append(f" layout,")
lines.append(f" vk::ShaderStageFlagBits::eVertex,")
lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,")
lines.append(f" sizeof({vert_struct}),")
lines.append(f" &data")
lines.append(f" );")
lines.append(f"}}")
lines.append(f"")
if combined_info.has_fragment:
frag_struct = f"{pascal_name}FragmentPushConstant"
lines.append(f"/**")
lines.append(f" * @brief 填充片元着色器 Push Constants Custom 区域")
lines.append(f" * ")
lines.append(f" * @param cmd 命令缓冲")
lines.append(f" * @param layout 管线布局")
lines.append(f" * @param data 片元着色器 Custom 区域数据")
lines.append(f" */")
lines.append(f"inline void fill_fragment_push_constants(")
lines.append(f" vk::CommandBuffer cmd,")
lines.append(f" vk::PipelineLayout layout,")
lines.append(f" const {frag_struct}& data")
lines.append(f") {{")
lines.append(f" cmd.pushConstants(")
lines.append(f" layout,")
lines.append(f" vk::ShaderStageFlagBits::eFragment,")
lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,")
lines.append(f" sizeof({frag_struct}),")
lines.append(f" &data")
lines.append(f" );")
lines.append(f"}}")
return "\n".join(lines)
def _generate_dual_stage_layout_constants(
combined_info: CombinedPushConstantInfo,
shader_name: str
) -> str:
"""生成双阶段 Push Constant 布局常量
Args:
combined_info: 合并后的双阶段信息
shader_name: 着色器名称
Returns:
生成的 C++ 布局常量代码
"""
lines = []
pascal_name = _to_pascal_case(shader_name)
lines.append(f"/**")
lines.append(f" * @brief {pascal_name} Push Constant 布局信息")
lines.append(f" */")
lines.append(f"struct {pascal_name}PushConstantLayout {{")
lines.append(f" /// 是否使用共享布局模式")
lines.append(f" static constexpr bool is_shared_layout = {'true' if combined_info.is_shared_layout else 'false'};")
lines.append(f" ")
if combined_info.has_vertex:
lines.append(f" /// 顶点着色器 Custom 区域大小")
lines.append(f" static constexpr std::uint32_t vertex_custom_size = {combined_info.vertex_info.total_size};")
if combined_info.has_fragment:
lines.append(f" /// 片元着色器 Custom 区域大小")
lines.append(f" static constexpr std::uint32_t fragment_custom_size = {combined_info.fragment_info.total_size};")
if combined_info.shared_members:
lines.append(f" ")
lines.append(f" /// 共享的成员数量")
lines.append(f" static constexpr std::size_t shared_member_count = {len(combined_info.shared_members)};")
lines.append(f"}};")
return "\n".join(lines)
def generate_dual_stage_push_constant_structs(
vert_reflection: SPIRVReflection,
frag_reflection: SPIRVReflection,
shader_name: str
) -> str:
"""生成双阶段 Push Constant 结构体和填充函数
分析顶点和片元着色器的 Push Constant 布局,生成:
- {ShaderName}VertexPushConstant 结构体(如果有顶点专用数据)
- {ShaderName}FragmentPushConstant 结构体(如果有片元专用数据)
- {ShaderName}PushConstantCustom 结构体(共享模式时)
- fill_push_constants() 或 fill_vertex/fragment_push_constants() 填充函数
- {ShaderName}PushConstantLayout 布局信息结构体
Args:
vert_reflection: 顶点着色器的 SPIR-V 反射信息
frag_reflection: 片元着色器的 SPIR-V 反射信息
shader_name: 着色器名称
Returns:
生成的 C++ 代码,如果没有 Custom 数据则返回空字符串
Example:
>>> vert_ref = parse_spirv_type_system(vert_spirv, "vertex")
>>> frag_ref = parse_spirv_type_system(frag_spirv, "fragment")
>>> code = generate_dual_stage_push_constant_structs(vert_ref, frag_ref, "my_shader")
>>> print(code)
// 生成的结构体定义和填充函数...
"""
# 分析双阶段布局
combined_info = analyze_dual_stage_push_constants(vert_reflection, frag_reflection)
# 如果两个阶段都没有 Custom 成员,返回空字符串
if not combined_info.has_vertex and not combined_info.has_fragment:
return ""
sections = []
pascal_name = _to_pascal_case(shader_name)
# 添加注释头
sections.append(f"// ============ {pascal_name} Push Constants (Dual Stage) ============")
sections.append("")
if combined_info.is_shared_layout:
# 共享模式:使用现有的 generate_push_constant_custom_struct
# 优先使用片元着色器的反射信息(通常更完整)
ref = frag_reflection if frag_reflection.push_constant else vert_reflection
custom_struct = generate_push_constant_custom_struct(ref, shader_name)
if custom_struct:
sections.append(custom_struct)
sections.append("")
else:
# 分离模式:为每个阶段生成独立结构体
type_map = vert_reflection.types if vert_reflection else frag_reflection.types
if combined_info.has_vertex:
vert_struct = _generate_stage_push_constant_struct(
combined_info.vertex_info,
shader_name,
type_map
)
if vert_struct:
sections.append(vert_struct)
sections.append("")
if combined_info.has_fragment:
frag_struct = _generate_stage_push_constant_struct(
combined_info.fragment_info,
shader_name,
frag_reflection.types if frag_reflection else type_map
)
if frag_struct:
sections.append(frag_struct)
sections.append("")
# 生成填充函数
type_map = vert_reflection.types if vert_reflection else frag_reflection.types
fill_function = _generate_dual_stage_fill_function(combined_info, shader_name, type_map)
if fill_function:
sections.append(fill_function)
sections.append("")
# 生成布局常量
layout_constants = _generate_dual_stage_layout_constants(combined_info, shader_name)
if layout_constants:
sections.append(layout_constants)
sections.append("")
return "\n".join(sections)
# ============ Instance Data Structure Generation ============
def generate_instance_data_struct(instance_buffer: InstanceBufferInfo, type_map: Dict[int, TypeInfo]) -> str:
"""生成实例数据结构体定义
从 @instance_buffer 标记的 SSBO 中的结构体类型生成 C++ 结构体,
包含编译期静态检查以验证布局正确性。
Args:
instance_buffer: 实例缓冲信息
type_map: 类型映射
Returns:
生成的 C++ 结构体代码
Example output:
struct alignas(16) gradient_instance_data {
Eigen::Vector4f rect; // 必需x, y, width, height
Eigen::Vector4f start_color; // 起始颜色
Eigen::Vector4f transform; // rotation, anchor_x, anchor_y, scale
Eigen::Vector4f custom_params;
};
static_assert(sizeof(gradient_instance_data) == 64);
static_assert(offsetof(gradient_instance_data, rect) == 0);
"""
# 准备成员信息
members = []
for member in instance_buffer.members:
alignment = calculate_std430_alignment(member.resolved_type, type_map)
size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
members.append({
'name': member.name,
'cpp_type': cpp_type,
'offset': member.offset,
'size': size,
'alignment': alignment,
})
renderer = get_renderer()
context = {
'struct_name': instance_buffer.struct_type_name,
'members': members,
'total_size': instance_buffer.total_size,
'alignment': 16, # 实例数据结构必须 16 字节对齐
}
return renderer.render('instance/instance_data_struct.jinja2', context)
def generate_instance_data_info(reflection: SPIRVReflection, shader_name: str) -> dict:
"""生成实例数据相关的模板上下文信息
Args:
reflection: SPIR-V 反射信息
shader_name: 着色器名称
Returns:
包含实例化相关模板变量的字典
"""
if not reflection or not reflection.instance_buffer:
return {
'has_instance_data': False,
'instance_data_struct': '',
'instance_type_name': 'void',
'instance_buffer_binding': 0,
'instance_buffer_set': 0,
'instance_data_size': 0,
}
instance_buffer = reflection.instance_buffer
# 生成实例数据结构体代码
instance_data_struct = generate_instance_data_struct(instance_buffer, reflection.types)
return {
'has_instance_data': True,
'instance_data_struct': instance_data_struct,
'instance_type_name': instance_buffer.struct_type_name,
'instance_buffer_binding': instance_buffer.binding,
'instance_buffer_set': instance_buffer.set_number,
'instance_data_size': instance_buffer.total_size,
}
# ============ Vertex Structure Generation ============
def generate_vertex_struct(vertex_layout: VertexLayout, shader_name: str) -> str:
"""生成顶点输入结构体定义
从 SPIR-V 反射中提取的顶点输入布局生成 C++ 结构体,
包含静态方法用于获取 Vulkan 顶点输入描述。
Args:
vertex_layout: 顶点布局信息
shader_name: 着色器名称(用于生成结构体名称)
Returns:
生成的 C++ 结构体代码
Example output:
struct ImageShaderVertex {
Eigen::Vector2f position; // location = 0
Eigen::Vector4f color; // location = 1
Eigen::Vector2f uv; // location = 2
static VkVertexInputBindingDescription get_binding_description() {
return {
.binding = 0,
.stride = sizeof(ImageShaderVertex),
.inputRate = VK_VERTEX_INPUT_RATE_VERTEX
};
}
static std::array<VkVertexInputAttributeDescription, 3> get_attribute_descriptions() {
return {{
{0, 0, VK_FORMAT_R32G32_SFLOAT, offsetof(ImageShaderVertex, position)},
{1, 0, VK_FORMAT_R32G32B32A32_SFLOAT, offsetof(ImageShaderVertex, color)},
{2, 0, VK_FORMAT_R32G32_SFLOAT, offsetof(ImageShaderVertex, uv)}
}};
}
};
"""
if not vertex_layout or not vertex_layout.attributes:
return ""
# 生成结构体名称 (shader_name -> ShaderNameVertex)
struct_name = f"{_to_pascal_case(shader_name)}Vertex"
vertex_layout.struct_name = struct_name
# 准备属性信息
attributes = []
for attr in vertex_layout.attributes:
attributes.append({
'name': attr.name,
'cpp_type': attr.cpp_type,
'location': attr.location,
'vk_format': attr.vk_format,
'offset': attr.offset,
'size': attr.size,
})
renderer = get_renderer()
context = {
'struct_name': struct_name,
'attributes': attributes,
'stride': vertex_layout.stride,
'attribute_count': len(attributes),
}
return renderer.render('vertex/vertex_struct.jinja2', context)
def generate_vertex_layout_info(reflection: SPIRVReflection, shader_name: str) -> dict:
"""生成顶点布局相关的模板上下文信息
Args:
reflection: SPIR-V 反射信息
shader_name: 着色器名称
Returns:
包含顶点布局相关模板变量的字典
"""
if not reflection or not reflection.vertex_layout:
return {
'has_vertex_layout': False,
'vertex_struct': '',
'vertex_type_name': 'void',
'vertex_stride': 0,
'vertex_attribute_count': 0,
}
vertex_layout = reflection.vertex_layout
# 生成顶点结构体代码
vertex_struct = generate_vertex_struct(vertex_layout, shader_name)
return {
'has_vertex_layout': True,
'vertex_struct': vertex_struct,
'vertex_type_name': vertex_layout.struct_name,
'vertex_stride': vertex_layout.stride,
'vertex_attribute_count': len(vertex_layout.attributes),
}
# ============ SPIR-V Array Generation ============
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
"""将SPIR-V二进制转换为C++ uint32_t数组"""
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
# 将数组分成每行12个元素的块
chunks = []
for i in range(0, len(uint32_array), 12):
chunk = uint32_array[i : i + 12]
hex_values = [f"0x{val:08x}" for val in chunk]
chunks.append(hex_values)
renderer = get_renderer()
context = {
'symbol_name': symbol_name,
'chunks': chunks,
}
return renderer.render('base/spirv_array.jinja2', context)
# ============ 缓冲区辅助函数生成 ============
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 数据操作的辅助函数"""
if not reflection.buffers:
return ""
lines = []
for buffer in reflection.buffers:
# 生成创建函数
lines.extend(_generate_create_buffer_function(buffer, shader_name))
lines.append("")
# 生成上传函数
lines.extend(_generate_upload_function(buffer))
lines.append("")
# 生成下载函数
lines.extend(_generate_download_function(buffer))
lines.append("")
return "\n".join(lines)
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
"""生成创建 buffer 的便捷函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
return result.splitlines()
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
"""生成上传数据到 buffer 的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/upload.jinja2', context)
return result.splitlines()
def _generate_download_function(buffer: BufferInfo) -> List[str]:
"""生成从 buffer 下载数据的函数"""
renderer = get_renderer()
context = {'buffer': buffer}
result = renderer.render('buffer_helpers/download.jinja2', context)
return result.splitlines()
# ============ Typed Buffer Aliases Generation ============
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成类型安全的 buffer 类型别名和工厂函数"""
if not reflection.buffers:
return ""
lines = []
for buffer in reflection.buffers:
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
lines.append("")
return "\n".join(lines)
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
"""为特定 buffer 生成类型别名和工厂函数"""
renderer = get_renderer()
context = {'buffer': buffer, 'shader_name': shader_name}
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
return result.splitlines()
# ============ Buffer Manager Class Generation ============
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
"""生成 buffer 管理器类"""
if not reflection.buffers:
return ""
lines = []
class_name = f"{shader_name}_buffer_manager"
# 判断是否是计算着色器
is_compute = reflection.shader_stage == "compute"
renderer = get_renderer()
context = {
'class_name': class_name,
'buffers': reflection.buffers,
'is_compute': is_compute,
}
result = renderer.render('buffer_manager/manager_class.jinja2', context)
lines.append(result)
return "\n".join(lines)
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
lines = []
lines.append(" // Container for all buffers")
lines.append(" struct buffers {")
for buffer in buffers:
var_name = buffer.variable_name
alias_name = f"{var_name}_buffer"
lines.append(f" {alias_name} {var_name};")
lines.append(" };")
return lines
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
# 此函数已被模板替代,保留签名以防外部调用
return []
# ============ Vulkan Type Conversion ============
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
"""将描述符类型转换为Vulkan C++ enum"""
type_map = {
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
"sampled_image": "vk::DescriptorType::eSampledImage",
"combined_image_sampler": "vk::DescriptorType::eCombinedImageSampler",
"storage_image": "vk::DescriptorType::eStorageImage",
"sampler": "vk::DescriptorType::eSampler",
}
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
"""将着色器阶段列表转换为Vulkan C++ flags"""
stage_map = {
"vertex": "vk::ShaderStageFlagBits::eVertex",
"fragment": "vk::ShaderStageFlagBits::eFragment",
"compute": "vk::ShaderStageFlagBits::eCompute",
"geometry": "vk::ShaderStageFlagBits::eGeometry",
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
}
if not stages:
stages = ["compute"]
flags = [stage_map[stage] for stage in stages if stage in stage_map]
if not flags:
flags = ["vk::ShaderStageFlagBits::eCompute"]
return " | ".join(flags)
# ============ Header File Generation ============
def _to_pascal_case(name: str) -> str:
"""将 snake_case 转换为 PascalCase"""
return ''.join(word.capitalize() for word in name.split('_'))
def generate_header(
metadata: ShaderMetadata,
compilation_results: List[CompilationResult],
) -> str:
"""生成完整的C++头文件"""
# 准备buffer结构体代码
buffer_structures = ""
if metadata.reflection and metadata.reflection.buffers:
buffer_structures = generate_buffer_structures(metadata.reflection)
# 准备buffer辅助函数代码
buffer_helpers = ""
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
# 准备typed buffer代码
typed_buffers = ""
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
# 准备buffer管理器代码
buffer_manager = ""
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
# 准备Push Constants用户自定义部分结构体代码
push_constant_custom_struct = ""
push_constant_custom_type_name = "void"
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
push_constant_custom_struct = generate_push_constant_custom_struct(metadata.reflection, metadata.name)
push_constant_custom_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
# 兼容性:保持旧变量名
push_constant_effect_struct = push_constant_custom_struct
push_constant_effect_type_name = push_constant_custom_type_name
# 准备Push Constants Header部分布局代码
push_constant_header_layout = ""
has_push_constant_header_layout = False
push_constant_header_member_names = []
push_constant_header_member_offsets = []
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
push_constant_header_layout = generate_push_constant_header_layout(metadata.reflection, metadata.name)
has_push_constant_header_layout = bool(push_constant_header_layout)
# 提取 Header 成员名称和偏移量供 shader_spec 使用
base_info = metadata.reflection.push_constant.base_info
for member in base_info.members:
push_constant_header_member_names.append(member.name)
push_constant_header_member_offsets.append(member.offset)
# 兼容性:保持旧变量名
push_constant_base_layout = push_constant_header_layout
has_push_constant_base_layout = has_push_constant_header_layout
push_constant_base_member_names = push_constant_header_member_names
push_constant_base_member_offsets = push_constant_header_member_offsets
# 生成 Header 静态检查代码
push_constant_static_check = ""
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
push_constant_static_check = generate_push_constant_static_check(metadata.reflection, metadata.name)
# 准备SPIR-V数组和着色器类型检测
compilation_results_with_arrays = []
has_vert_shader = False
has_frag_shader = False
vert_spirv_symbol = ""
frag_spirv_symbol = ""
for result in compilation_results:
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
compilation_results_with_arrays.append({
'shader_type': result.shader_type,
'source_file': result.source_file,
'spirv_array': spirv_array,
})
# 检测着色器类型
if result.shader_type == "vert":
has_vert_shader = True
vert_spirv_symbol = symbol_name
elif result.shader_type == "frag":
has_frag_shader = True
frag_spirv_symbol = symbol_name
# 准备绑定信息
# 使用 set * 1000 + binding 作为最终的 binding 值
bindings_data = []
if metadata.bindings:
for binding in metadata.bindings:
# 计算包含set偏移的binding值
final_binding = binding.descriptor_set * 1000 + binding.binding
bindings_data.append({
'binding': final_binding,
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
'count': binding.count,
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
})
# 确定参数类型名称(从 buffer 反射信息中获取)
# 优先级buffer > push_constant效果部分 > void
params_type_name = "void"
if metadata.reflection and metadata.reflection.buffers:
# 使用第一个 buffer 的结构体类型名称
params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name
elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
# 没有 buffer但有 push_constant 效果部分,使用效果结构体名称
params_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
# 准备实例数据相关变量
instance_info = generate_instance_data_info(metadata.reflection, metadata.name)
# 准备顶点布局相关变量
vertex_info = generate_vertex_layout_info(metadata.reflection, metadata.name)
# 检测是否支持 procedural 顶点着色器
# 条件:片元着色器的 push_constant 第一个用户自定义成员偏移为 16 字节
supports_procedural_vertex_shader = False
if metadata.reflection:
supports_procedural_vertex_shader = check_supports_procedural_vertex_shader(metadata.reflection)
# 使用模板渲染
renderer = get_renderer()
context = {
'namespace': metadata.namespace,
'shader_name': metadata.name,
'metadata_struct_name': f"{metadata.name}_metadata",
'bindings_name': f"{metadata.name}_bindings",
'shader_spec_name': f"{metadata.name}_shader_spec",
'params_type_name': params_type_name,
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
'is_compute': metadata.shader_type == "compute",
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
'has_buffers': metadata.reflection and metadata.reflection.buffers,
'generate_typed_buffers': metadata.generate_typed_buffers,
'buffer_structures': buffer_structures,
'buffer_helpers': buffer_helpers,
'typed_buffers': typed_buffers,
'buffer_manager': buffer_manager,
'compilation_results': compilation_results_with_arrays,
'bindings': bindings_data,
# shader_spec 相关变量
'has_vert_shader': has_vert_shader,
'has_frag_shader': has_frag_shader,
'vert_spirv_symbol': vert_spirv_symbol,
'frag_spirv_symbol': frag_spirv_symbol,
# Push Constants 效果结构体相关变量
'push_constant_effect_struct': push_constant_effect_struct,
'push_constant_effect_type_name': push_constant_effect_type_name,
'has_push_constant_effect': bool(push_constant_effect_struct),
# Push Constants 基础部分布局相关变量
'push_constant_base_layout': push_constant_base_layout,
'has_push_constant_base_layout': has_push_constant_base_layout,
'push_constant_base_member_names': push_constant_base_member_names,
'push_constant_base_member_offsets': push_constant_base_member_offsets,
# Push Constants 静态检查代码
'push_constant_static_check': push_constant_static_check,
# 是否支持 procedural 顶点着色器
'supports_procedural_vertex_shader': supports_procedural_vertex_shader,
# 实例化渲染相关变量
**instance_info,
# 顶点布局相关变量
**vertex_info,
}
return renderer.render('base/header.jinja2', context)