- Removed legacy push constant structures and functions for better clarity and maintainability. - Introduced new `text_push_constants_t` structure for text rendering with optimized layout. - Implemented dual stage push constant analysis to support separate layouts for vertex and fragment shaders. - Added functions to generate push constant structures and fill functions based on shader reflection. - Enhanced static checks for push constant layouts to ensure compatibility and correctness. - Updated templates to accommodate new dual stage push constant generation. - Added support detection for procedural vertex shaders based on push constant layout.
1149 lines
43 KiB
Python
1149 lines
43 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
代码生成模块
|
||
|
||
负责生成C++代码,包括结构体定义、SPIR-V数组、头文件和绑定信息。
|
||
使用Jinja2模板引擎代替字符串拼接。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
from typing import Dict, List
|
||
|
||
from .template_loader import get_renderer
|
||
from .type_mapping import (
|
||
calculate_std430_alignment,
|
||
calculate_std430_size,
|
||
spirv_type_to_cpp,
|
||
)
|
||
from .types import (
|
||
ArrayTypeInfo,
|
||
BufferInfo,
|
||
CombinedPushConstantInfo,
|
||
CompilationResult,
|
||
InstanceBufferInfo,
|
||
MemberInfo,
|
||
PushConstantInfo,
|
||
ShaderMetadata,
|
||
SPIRVReflection,
|
||
StagePushConstantInfo,
|
||
ToolError,
|
||
TypeInfo,
|
||
VertexAttribute,
|
||
VertexLayout,
|
||
)
|
||
from .spirv_parser import analyze_dual_stage_push_constants
|
||
from .constants import (
|
||
PUSH_CONSTANT_HEADER_SIZE,
|
||
PUSH_CONSTANT_CUSTOM_OFFSET,
|
||
PUSH_CONSTANT_CUSTOM_MAX_SIZE,
|
||
)
|
||
|
||
|
||
def check_supports_procedural_vertex_shader(reflection: SPIRVReflection) -> bool:
|
||
"""检测片元着色器是否支持 custom_shader_quad_procedural 顶点着色器
|
||
|
||
检测条件:片元着色器的 push_constant 中第一个用户自定义成员的偏移正好是 16 字节。
|
||
|
||
这意味着片元着色器的 push_constant 布局:
|
||
- Header [0-15]: scale + translate (由顶点着色器定义并填充)
|
||
- Custom [16-127]: 用户自定义数据 (第一个成员从 offset 16 开始)
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息(通常来自片元着色器)
|
||
|
||
Returns:
|
||
True 如果支持 procedural 顶点着色器,否则 False
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant:
|
||
return False
|
||
|
||
# 检查原始结构体成员,找到 offset >= 16 的第一个成员
|
||
struct_type = push_constant.struct_type
|
||
if not struct_type or not struct_type.members:
|
||
return False
|
||
|
||
# 找到所有 offset >= PUSH_CONSTANT_CUSTOM_OFFSET 的成员
|
||
custom_members = [m for m in struct_type.members if m.offset >= PUSH_CONSTANT_CUSTOM_OFFSET]
|
||
|
||
if not custom_members:
|
||
return False
|
||
|
||
# 检查第一个 custom 成员的偏移是否正好是 16
|
||
first_custom_offset = min(m.offset for m in custom_members)
|
||
return first_custom_offset == PUSH_CONSTANT_CUSTOM_OFFSET
|
||
|
||
|
||
# ============ Structure Generation ============
|
||
|
||
def generate_buffer_structures(reflection: SPIRVReflection) -> str:
|
||
"""生成所有buffer/uniform结构体定义"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
for buffer in reflection.buffers:
|
||
struct_lines = generate_single_struct(buffer, reflection.types)
|
||
lines.extend(struct_lines)
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def generate_single_struct(buffer: BufferInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成单个结构体定义"""
|
||
renderer = get_renderer()
|
||
|
||
context = {
|
||
'buffer': buffer,
|
||
'type_map': type_map,
|
||
'alignment': calculate_std430_alignment(buffer.struct_type, type_map),
|
||
}
|
||
|
||
result = renderer.render('base/struct.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def generate_member_definition(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> List[str]:
|
||
"""生成结构体成员定义"""
|
||
renderer = get_renderer()
|
||
|
||
context = {
|
||
'member': member,
|
||
'type_map': type_map,
|
||
}
|
||
|
||
result = renderer.render('base/struct_member.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Push Constants Custom Structure Generation ============
|
||
|
||
|
||
def generate_push_constant_custom_struct(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants 用户自定义参数结构体定义 (Custom 部分)
|
||
|
||
布局规范:
|
||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||
│ Push Constants 内存布局 (128 bytes) │
|
||
├─────────────────────────────────────────────────────────────────────────────┤
|
||
│ Header [0-15]: scale + translate - 16 bytes │ ← 系统自动填充
|
||
│ Custom [16-127]: 用户自定义数据 - 112 bytes │ ← 本函数生成
|
||
└─────────────────────────────────────────────────────────────────────────────┘
|
||
|
||
识别规则:
|
||
- 仅收集 offset >= 16 的成员(跳过 scale 和 translate)
|
||
- 计算相对偏移量(绝对偏移 - 16)
|
||
- 生成的 C++ 结构体只包含 Custom 区域的变量
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的C++结构体代码,如果没有用户自定义部分则返回空字符串
|
||
|
||
Note:
|
||
与 push_constant_traits.h 中的 PUSH_CONSTANT_CUSTOM_OFFSET (16)
|
||
和 PUSH_CONSTANT_CUSTOM_MAX_SIZE (112) 常量保持一致。
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.effect_members:
|
||
return ""
|
||
|
||
# 生成结构体名称(保持兼容性,内部使用 custom_members)
|
||
struct_name = f"{_to_pascal_case(shader_name)}PushConstantCustom"
|
||
|
||
# 准备成员信息
|
||
members = []
|
||
type_map = reflection.types
|
||
|
||
for member in push_constant.effect_members: # effect_members 现在存储的是 custom_members
|
||
# 计算对齐和大小
|
||
alignment = calculate_std430_alignment(member.resolved_type, type_map)
|
||
size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
members.append({
|
||
'name': member.name,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': size,
|
||
'alignment': alignment,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': struct_name,
|
||
'base_offset': PUSH_CONSTANT_CUSTOM_OFFSET, # 使用 Custom 部分偏移量
|
||
'members': members,
|
||
}
|
||
|
||
return renderer.render('push_constant/custom_struct.jinja2', context)
|
||
|
||
|
||
# 兼容性别名
|
||
def generate_push_constant_effect_struct(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""已弃用:请使用 generate_push_constant_custom_struct
|
||
|
||
保留此函数以保持向后兼容性。
|
||
"""
|
||
return generate_push_constant_custom_struct(reflection, shader_name)
|
||
|
||
|
||
def generate_push_constant_header_layout(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants Header 部分布局信息结构体
|
||
|
||
Header 部分包含 scale 和 translate 变量(偏移 0-15),用于 NDC 变换:
|
||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||
│ Header [0-15]: scale (vec2) + translate (vec2) - 16 bytes │
|
||
│ scale: 偏移 0, 大小 8 │
|
||
│ translate: 偏移 8, 大小 8 │
|
||
└─────────────────────────────────────────────────────────────────────────────┘
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的C++代码,如果没有 Header 部分则返回空字符串
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.base_info:
|
||
return ""
|
||
|
||
base_info = push_constant.base_info
|
||
|
||
# 准备模板数据
|
||
header_members = [
|
||
{
|
||
'name': m.name,
|
||
'offset': m.offset,
|
||
'size': m.size,
|
||
'expected_offset': m.expected_offset,
|
||
'expected_size': m.expected_size,
|
||
'is_valid': m.is_valid,
|
||
}
|
||
for m in base_info.members
|
||
]
|
||
|
||
expected_names = ['scale', 'translate']
|
||
present_names = [m.name for m in base_info.members]
|
||
|
||
renderer = get_renderer()
|
||
|
||
# 渲染布局结构体
|
||
layout_context = {
|
||
'struct_name': shader_name,
|
||
'header_members': header_members,
|
||
'expected_names': expected_names,
|
||
'present_names': present_names,
|
||
'header_size': base_info.total_size,
|
||
'is_standard': base_info.is_standard_layout,
|
||
}
|
||
|
||
layout_code = renderer.render('push_constant/base_layout.jinja2', layout_context)
|
||
|
||
return layout_code
|
||
|
||
|
||
def generate_push_constant_static_check(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 Push Constants 静态检查代码(不含 custom 结构体检查)
|
||
|
||
注意:custom 结构体的 sizeof 检查需要在结构体定义之后进行,
|
||
因此此函数只生成 Header 部分的静态检查。
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的静态检查C++代码
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.base_info:
|
||
return ""
|
||
|
||
base_info = push_constant.base_info
|
||
|
||
# 准备模板数据
|
||
header_members = [
|
||
{
|
||
'name': m.name,
|
||
'offset': m.offset,
|
||
'size': m.size,
|
||
'expected_offset': m.expected_offset,
|
||
'expected_size': m.expected_size,
|
||
'is_valid': m.is_valid,
|
||
}
|
||
for m in base_info.members
|
||
]
|
||
|
||
expected_names = ['scale', 'translate']
|
||
present_names = [m.name for m in base_info.members]
|
||
|
||
renderer = get_renderer()
|
||
|
||
# 渲染静态检查代码(不含 custom 部分)
|
||
check_context = {
|
||
'struct_name': shader_name,
|
||
'header_members': header_members,
|
||
'is_standard': base_info.is_standard_layout,
|
||
'has_custom_members': False, # 不在这里检查 custom
|
||
}
|
||
|
||
return renderer.render('push_constant/static_check.jinja2', check_context)
|
||
|
||
|
||
# 兼容性别名
|
||
def generate_push_constant_base_layout(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""已弃用:请使用 generate_push_constant_header_layout
|
||
|
||
保留此函数以保持向后兼容性。
|
||
"""
|
||
return generate_push_constant_header_layout(reflection, shader_name)
|
||
|
||
|
||
# ============ Dual Stage Push Constants Structure Generation ============
|
||
|
||
def _generate_stage_push_constant_struct(
|
||
stage_info: StagePushConstantInfo,
|
||
shader_name: str,
|
||
type_map: Dict[int, TypeInfo]
|
||
) -> str:
|
||
"""生成单个阶段的 Push Constant 结构体
|
||
|
||
Args:
|
||
stage_info: 阶段 Push Constant 信息
|
||
shader_name: 着色器名称
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
生成的 C++ 结构体代码
|
||
"""
|
||
if not stage_info or not stage_info.has_members:
|
||
return ""
|
||
|
||
# 生成结构体名称
|
||
stage_suffix = "Vertex" if stage_info.stage == "vertex" else "Fragment"
|
||
struct_name = f"{_to_pascal_case(shader_name)}{stage_suffix}PushConstant"
|
||
|
||
# 准备成员信息
|
||
members = []
|
||
for member in stage_info.members:
|
||
alignment = calculate_std430_alignment(member.resolved_type, type_map)
|
||
size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
members.append({
|
||
'name': member.name,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': size,
|
||
'alignment': alignment,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': struct_name,
|
||
'base_offset': stage_info.relative_offset,
|
||
'members': members,
|
||
}
|
||
|
||
return renderer.render('push_constant/custom_struct.jinja2', context)
|
||
|
||
|
||
def _generate_dual_stage_fill_function(
|
||
combined_info: CombinedPushConstantInfo,
|
||
shader_name: str,
|
||
type_map: Dict[int, TypeInfo]
|
||
) -> str:
|
||
"""生成双阶段 Push Constant 填充函数
|
||
|
||
根据布局模式生成不同的填充函数:
|
||
- 共享模式:单一填充函数,同时填充两个阶段
|
||
- 分离模式:分别为每个阶段生成填充函数
|
||
|
||
Args:
|
||
combined_info: 合并后的双阶段信息
|
||
shader_name: 着色器名称
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
生成的 C++ 填充函数代码
|
||
"""
|
||
lines = []
|
||
pascal_name = _to_pascal_case(shader_name)
|
||
|
||
if combined_info.is_shared_layout:
|
||
# 共享模式:生成单一填充函数
|
||
struct_name = f"{pascal_name}PushConstantCustom"
|
||
lines.append(f"/**")
|
||
lines.append(f" * @brief 填充 Push Constants Custom 区域(共享布局)")
|
||
lines.append(f" * ")
|
||
lines.append(f" * 顶点和片元着色器共享相同的 Custom 区域布局。")
|
||
lines.append(f" * ")
|
||
lines.append(f" * @param cmd 命令缓冲")
|
||
lines.append(f" * @param layout 管线布局")
|
||
lines.append(f" * @param data Custom 区域数据")
|
||
lines.append(f" */")
|
||
lines.append(f"inline void fill_push_constants(")
|
||
lines.append(f" vk::CommandBuffer cmd,")
|
||
lines.append(f" vk::PipelineLayout layout,")
|
||
lines.append(f" const {struct_name}& data")
|
||
lines.append(f") {{")
|
||
lines.append(f" cmd.pushConstants(")
|
||
lines.append(f" layout,")
|
||
lines.append(f" vk::ShaderStageFlagBits::eVertex | vk::ShaderStageFlagBits::eFragment,")
|
||
lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,")
|
||
lines.append(f" sizeof({struct_name}),")
|
||
lines.append(f" &data")
|
||
lines.append(f" );")
|
||
lines.append(f"}}")
|
||
else:
|
||
# 分离模式:为每个阶段生成填充函数
|
||
if combined_info.has_vertex:
|
||
vert_struct = f"{pascal_name}VertexPushConstant"
|
||
lines.append(f"/**")
|
||
lines.append(f" * @brief 填充顶点着色器 Push Constants Custom 区域")
|
||
lines.append(f" * ")
|
||
lines.append(f" * @param cmd 命令缓冲")
|
||
lines.append(f" * @param layout 管线布局")
|
||
lines.append(f" * @param data 顶点着色器 Custom 区域数据")
|
||
lines.append(f" */")
|
||
lines.append(f"inline void fill_vertex_push_constants(")
|
||
lines.append(f" vk::CommandBuffer cmd,")
|
||
lines.append(f" vk::PipelineLayout layout,")
|
||
lines.append(f" const {vert_struct}& data")
|
||
lines.append(f") {{")
|
||
lines.append(f" cmd.pushConstants(")
|
||
lines.append(f" layout,")
|
||
lines.append(f" vk::ShaderStageFlagBits::eVertex,")
|
||
lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,")
|
||
lines.append(f" sizeof({vert_struct}),")
|
||
lines.append(f" &data")
|
||
lines.append(f" );")
|
||
lines.append(f"}}")
|
||
lines.append(f"")
|
||
|
||
if combined_info.has_fragment:
|
||
frag_struct = f"{pascal_name}FragmentPushConstant"
|
||
lines.append(f"/**")
|
||
lines.append(f" * @brief 填充片元着色器 Push Constants Custom 区域")
|
||
lines.append(f" * ")
|
||
lines.append(f" * @param cmd 命令缓冲")
|
||
lines.append(f" * @param layout 管线布局")
|
||
lines.append(f" * @param data 片元着色器 Custom 区域数据")
|
||
lines.append(f" */")
|
||
lines.append(f"inline void fill_fragment_push_constants(")
|
||
lines.append(f" vk::CommandBuffer cmd,")
|
||
lines.append(f" vk::PipelineLayout layout,")
|
||
lines.append(f" const {frag_struct}& data")
|
||
lines.append(f") {{")
|
||
lines.append(f" cmd.pushConstants(")
|
||
lines.append(f" layout,")
|
||
lines.append(f" vk::ShaderStageFlagBits::eFragment,")
|
||
lines.append(f" PUSH_CONSTANT_CUSTOM_OFFSET,")
|
||
lines.append(f" sizeof({frag_struct}),")
|
||
lines.append(f" &data")
|
||
lines.append(f" );")
|
||
lines.append(f"}}")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_dual_stage_layout_constants(
|
||
combined_info: CombinedPushConstantInfo,
|
||
shader_name: str
|
||
) -> str:
|
||
"""生成双阶段 Push Constant 布局常量
|
||
|
||
Args:
|
||
combined_info: 合并后的双阶段信息
|
||
shader_name: 着色器名称
|
||
|
||
Returns:
|
||
生成的 C++ 布局常量代码
|
||
"""
|
||
lines = []
|
||
pascal_name = _to_pascal_case(shader_name)
|
||
|
||
lines.append(f"/**")
|
||
lines.append(f" * @brief {pascal_name} Push Constant 布局信息")
|
||
lines.append(f" */")
|
||
lines.append(f"struct {pascal_name}PushConstantLayout {{")
|
||
lines.append(f" /// 是否使用共享布局模式")
|
||
lines.append(f" static constexpr bool is_shared_layout = {'true' if combined_info.is_shared_layout else 'false'};")
|
||
lines.append(f" ")
|
||
|
||
if combined_info.has_vertex:
|
||
lines.append(f" /// 顶点着色器 Custom 区域大小")
|
||
lines.append(f" static constexpr std::uint32_t vertex_custom_size = {combined_info.vertex_info.total_size};")
|
||
|
||
if combined_info.has_fragment:
|
||
lines.append(f" /// 片元着色器 Custom 区域大小")
|
||
lines.append(f" static constexpr std::uint32_t fragment_custom_size = {combined_info.fragment_info.total_size};")
|
||
|
||
if combined_info.shared_members:
|
||
lines.append(f" ")
|
||
lines.append(f" /// 共享的成员数量")
|
||
lines.append(f" static constexpr std::size_t shared_member_count = {len(combined_info.shared_members)};")
|
||
|
||
lines.append(f"}};")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def generate_dual_stage_push_constant_structs(
|
||
vert_reflection: SPIRVReflection,
|
||
frag_reflection: SPIRVReflection,
|
||
shader_name: str
|
||
) -> str:
|
||
"""生成双阶段 Push Constant 结构体和填充函数
|
||
|
||
分析顶点和片元着色器的 Push Constant 布局,生成:
|
||
- {ShaderName}VertexPushConstant 结构体(如果有顶点专用数据)
|
||
- {ShaderName}FragmentPushConstant 结构体(如果有片元专用数据)
|
||
- {ShaderName}PushConstantCustom 结构体(共享模式时)
|
||
- fill_push_constants() 或 fill_vertex/fragment_push_constants() 填充函数
|
||
- {ShaderName}PushConstantLayout 布局信息结构体
|
||
|
||
Args:
|
||
vert_reflection: 顶点着色器的 SPIR-V 反射信息
|
||
frag_reflection: 片元着色器的 SPIR-V 反射信息
|
||
shader_name: 着色器名称
|
||
|
||
Returns:
|
||
生成的 C++ 代码,如果没有 Custom 数据则返回空字符串
|
||
|
||
Example:
|
||
>>> vert_ref = parse_spirv_type_system(vert_spirv, "vertex")
|
||
>>> frag_ref = parse_spirv_type_system(frag_spirv, "fragment")
|
||
>>> code = generate_dual_stage_push_constant_structs(vert_ref, frag_ref, "my_shader")
|
||
>>> print(code)
|
||
// 生成的结构体定义和填充函数...
|
||
"""
|
||
# 分析双阶段布局
|
||
combined_info = analyze_dual_stage_push_constants(vert_reflection, frag_reflection)
|
||
|
||
# 如果两个阶段都没有 Custom 成员,返回空字符串
|
||
if not combined_info.has_vertex and not combined_info.has_fragment:
|
||
return ""
|
||
|
||
sections = []
|
||
pascal_name = _to_pascal_case(shader_name)
|
||
|
||
# 添加注释头
|
||
sections.append(f"// ============ {pascal_name} Push Constants (Dual Stage) ============")
|
||
sections.append("")
|
||
|
||
if combined_info.is_shared_layout:
|
||
# 共享模式:使用现有的 generate_push_constant_custom_struct
|
||
# 优先使用片元着色器的反射信息(通常更完整)
|
||
ref = frag_reflection if frag_reflection.push_constant else vert_reflection
|
||
custom_struct = generate_push_constant_custom_struct(ref, shader_name)
|
||
if custom_struct:
|
||
sections.append(custom_struct)
|
||
sections.append("")
|
||
else:
|
||
# 分离模式:为每个阶段生成独立结构体
|
||
type_map = vert_reflection.types if vert_reflection else frag_reflection.types
|
||
|
||
if combined_info.has_vertex:
|
||
vert_struct = _generate_stage_push_constant_struct(
|
||
combined_info.vertex_info,
|
||
shader_name,
|
||
type_map
|
||
)
|
||
if vert_struct:
|
||
sections.append(vert_struct)
|
||
sections.append("")
|
||
|
||
if combined_info.has_fragment:
|
||
frag_struct = _generate_stage_push_constant_struct(
|
||
combined_info.fragment_info,
|
||
shader_name,
|
||
frag_reflection.types if frag_reflection else type_map
|
||
)
|
||
if frag_struct:
|
||
sections.append(frag_struct)
|
||
sections.append("")
|
||
|
||
# 生成填充函数
|
||
type_map = vert_reflection.types if vert_reflection else frag_reflection.types
|
||
fill_function = _generate_dual_stage_fill_function(combined_info, shader_name, type_map)
|
||
if fill_function:
|
||
sections.append(fill_function)
|
||
sections.append("")
|
||
|
||
# 生成布局常量
|
||
layout_constants = _generate_dual_stage_layout_constants(combined_info, shader_name)
|
||
if layout_constants:
|
||
sections.append(layout_constants)
|
||
sections.append("")
|
||
|
||
return "\n".join(sections)
|
||
|
||
|
||
# ============ Instance Data Structure Generation ============
|
||
|
||
def generate_instance_data_struct(instance_buffer: InstanceBufferInfo, type_map: Dict[int, TypeInfo]) -> str:
|
||
"""生成实例数据结构体定义
|
||
|
||
从 @instance_buffer 标记的 SSBO 中的结构体类型生成 C++ 结构体,
|
||
包含编译期静态检查以验证布局正确性。
|
||
|
||
Args:
|
||
instance_buffer: 实例缓冲信息
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
生成的 C++ 结构体代码
|
||
|
||
Example output:
|
||
struct alignas(16) gradient_instance_data {
|
||
Eigen::Vector4f rect; // 必需:x, y, width, height
|
||
Eigen::Vector4f start_color; // 起始颜色
|
||
Eigen::Vector4f transform; // rotation, anchor_x, anchor_y, scale
|
||
Eigen::Vector4f custom_params;
|
||
};
|
||
|
||
static_assert(sizeof(gradient_instance_data) == 64);
|
||
static_assert(offsetof(gradient_instance_data, rect) == 0);
|
||
"""
|
||
# 准备成员信息
|
||
members = []
|
||
for member in instance_buffer.members:
|
||
alignment = calculate_std430_alignment(member.resolved_type, type_map)
|
||
size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
members.append({
|
||
'name': member.name,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': size,
|
||
'alignment': alignment,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': instance_buffer.struct_type_name,
|
||
'members': members,
|
||
'total_size': instance_buffer.total_size,
|
||
'alignment': 16, # 实例数据结构必须 16 字节对齐
|
||
}
|
||
|
||
return renderer.render('instance/instance_data_struct.jinja2', context)
|
||
|
||
|
||
def generate_instance_data_info(reflection: SPIRVReflection, shader_name: str) -> dict:
|
||
"""生成实例数据相关的模板上下文信息
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息
|
||
shader_name: 着色器名称
|
||
|
||
Returns:
|
||
包含实例化相关模板变量的字典
|
||
"""
|
||
if not reflection or not reflection.instance_buffer:
|
||
return {
|
||
'has_instance_data': False,
|
||
'instance_data_struct': '',
|
||
'instance_type_name': 'void',
|
||
'instance_buffer_binding': 0,
|
||
'instance_buffer_set': 0,
|
||
'instance_data_size': 0,
|
||
}
|
||
|
||
instance_buffer = reflection.instance_buffer
|
||
|
||
# 生成实例数据结构体代码
|
||
instance_data_struct = generate_instance_data_struct(instance_buffer, reflection.types)
|
||
|
||
return {
|
||
'has_instance_data': True,
|
||
'instance_data_struct': instance_data_struct,
|
||
'instance_type_name': instance_buffer.struct_type_name,
|
||
'instance_buffer_binding': instance_buffer.binding,
|
||
'instance_buffer_set': instance_buffer.set_number,
|
||
'instance_data_size': instance_buffer.total_size,
|
||
}
|
||
|
||
|
||
# ============ Vertex Structure Generation ============
|
||
|
||
def generate_vertex_struct(vertex_layout: VertexLayout, shader_name: str) -> str:
|
||
"""生成顶点输入结构体定义
|
||
|
||
从 SPIR-V 反射中提取的顶点输入布局生成 C++ 结构体,
|
||
包含静态方法用于获取 Vulkan 顶点输入描述。
|
||
|
||
Args:
|
||
vertex_layout: 顶点布局信息
|
||
shader_name: 着色器名称(用于生成结构体名称)
|
||
|
||
Returns:
|
||
生成的 C++ 结构体代码
|
||
|
||
Example output:
|
||
struct ImageShaderVertex {
|
||
Eigen::Vector2f position; // location = 0
|
||
Eigen::Vector4f color; // location = 1
|
||
Eigen::Vector2f uv; // location = 2
|
||
|
||
static VkVertexInputBindingDescription get_binding_description() {
|
||
return {
|
||
.binding = 0,
|
||
.stride = sizeof(ImageShaderVertex),
|
||
.inputRate = VK_VERTEX_INPUT_RATE_VERTEX
|
||
};
|
||
}
|
||
|
||
static std::array<VkVertexInputAttributeDescription, 3> get_attribute_descriptions() {
|
||
return {{
|
||
{0, 0, VK_FORMAT_R32G32_SFLOAT, offsetof(ImageShaderVertex, position)},
|
||
{1, 0, VK_FORMAT_R32G32B32A32_SFLOAT, offsetof(ImageShaderVertex, color)},
|
||
{2, 0, VK_FORMAT_R32G32_SFLOAT, offsetof(ImageShaderVertex, uv)}
|
||
}};
|
||
}
|
||
};
|
||
"""
|
||
if not vertex_layout or not vertex_layout.attributes:
|
||
return ""
|
||
|
||
# 生成结构体名称 (shader_name -> ShaderNameVertex)
|
||
struct_name = f"{_to_pascal_case(shader_name)}Vertex"
|
||
vertex_layout.struct_name = struct_name
|
||
|
||
# 准备属性信息
|
||
attributes = []
|
||
for attr in vertex_layout.attributes:
|
||
attributes.append({
|
||
'name': attr.name,
|
||
'cpp_type': attr.cpp_type,
|
||
'location': attr.location,
|
||
'vk_format': attr.vk_format,
|
||
'offset': attr.offset,
|
||
'size': attr.size,
|
||
})
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'struct_name': struct_name,
|
||
'attributes': attributes,
|
||
'stride': vertex_layout.stride,
|
||
'attribute_count': len(attributes),
|
||
}
|
||
|
||
return renderer.render('vertex/vertex_struct.jinja2', context)
|
||
|
||
|
||
def generate_vertex_layout_info(reflection: SPIRVReflection, shader_name: str) -> dict:
|
||
"""生成顶点布局相关的模板上下文信息
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息
|
||
shader_name: 着色器名称
|
||
|
||
Returns:
|
||
包含顶点布局相关模板变量的字典
|
||
"""
|
||
if not reflection or not reflection.vertex_layout:
|
||
return {
|
||
'has_vertex_layout': False,
|
||
'vertex_struct': '',
|
||
'vertex_type_name': 'void',
|
||
'vertex_stride': 0,
|
||
'vertex_attribute_count': 0,
|
||
}
|
||
|
||
vertex_layout = reflection.vertex_layout
|
||
|
||
# 生成顶点结构体代码
|
||
vertex_struct = generate_vertex_struct(vertex_layout, shader_name)
|
||
|
||
return {
|
||
'has_vertex_layout': True,
|
||
'vertex_struct': vertex_struct,
|
||
'vertex_type_name': vertex_layout.struct_name,
|
||
'vertex_stride': vertex_layout.stride,
|
||
'vertex_attribute_count': len(vertex_layout.attributes),
|
||
}
|
||
|
||
|
||
# ============ SPIR-V Array Generation ============
|
||
|
||
def spirv_to_cpp_array(spirv_data: bytes, symbol_name: str) -> str:
|
||
"""将SPIR-V二进制转换为C++ uint32_t数组"""
|
||
if len(spirv_data) % 4 != 0:
|
||
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
|
||
|
||
uint32_array = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
|
||
|
||
# 将数组分成每行12个元素的块
|
||
chunks = []
|
||
for i in range(0, len(uint32_array), 12):
|
||
chunk = uint32_array[i : i + 12]
|
||
hex_values = [f"0x{val:08x}" for val in chunk]
|
||
chunks.append(hex_values)
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'symbol_name': symbol_name,
|
||
'chunks': chunks,
|
||
}
|
||
|
||
return renderer.render('base/spirv_array.jinja2', context)
|
||
|
||
|
||
# ============ 缓冲区辅助函数生成 ============
|
||
|
||
def generate_buffer_helper_functions(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 数据操作的辅助函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
for buffer in reflection.buffers:
|
||
# 生成创建函数
|
||
lines.extend(_generate_create_buffer_function(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
# 生成上传函数
|
||
lines.extend(_generate_upload_function(buffer))
|
||
lines.append("")
|
||
|
||
# 生成下载函数
|
||
lines.extend(_generate_download_function(buffer))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_create_buffer_function(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""生成创建 buffer 的便捷函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer, 'shader_name': shader_name}
|
||
result = renderer.render('buffer_helpers/create_buffer.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def _generate_upload_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成上传数据到 buffer 的函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer}
|
||
result = renderer.render('buffer_helpers/upload.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
def _generate_download_function(buffer: BufferInfo) -> List[str]:
|
||
"""生成从 buffer 下载数据的函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer}
|
||
result = renderer.render('buffer_helpers/download.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Typed Buffer Aliases Generation ============
|
||
|
||
def generate_typed_buffer_aliases(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成类型安全的 buffer 类型别名和工厂函数"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
for buffer in reflection.buffers:
|
||
lines.extend(_generate_buffer_type_alias_and_factory(buffer, shader_name))
|
||
lines.append("")
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffer_type_alias_and_factory(buffer: BufferInfo, shader_name: str) -> List[str]:
|
||
"""为特定 buffer 生成类型别名和工厂函数"""
|
||
renderer = get_renderer()
|
||
context = {'buffer': buffer, 'shader_name': shader_name}
|
||
result = renderer.render('typed_buffer/alias_and_factory.jinja2', context)
|
||
return result.splitlines()
|
||
|
||
|
||
# ============ Buffer Manager Class Generation ============
|
||
|
||
def generate_buffer_manager_class(reflection: SPIRVReflection, shader_name: str) -> str:
|
||
"""生成 buffer 管理器类"""
|
||
if not reflection.buffers:
|
||
return ""
|
||
|
||
lines = []
|
||
|
||
class_name = f"{shader_name}_buffer_manager"
|
||
|
||
# 判断是否是计算着色器
|
||
is_compute = reflection.shader_stage == "compute"
|
||
|
||
renderer = get_renderer()
|
||
context = {
|
||
'class_name': class_name,
|
||
'buffers': reflection.buffers,
|
||
'is_compute': is_compute,
|
||
}
|
||
|
||
result = renderer.render('buffer_manager/manager_class.jinja2', context)
|
||
lines.append(result)
|
||
|
||
return "\n".join(lines)
|
||
|
||
|
||
def _generate_buffers_struct(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成 buffers 结构体(已被模板替代,保留用于兼容性)"""
|
||
lines = []
|
||
lines.append(" // Container for all buffers")
|
||
lines.append(" struct buffers {")
|
||
for buffer in buffers:
|
||
var_name = buffer.variable_name
|
||
alias_name = f"{var_name}_buffer"
|
||
lines.append(f" {alias_name} {var_name};")
|
||
lines.append(" };")
|
||
return lines
|
||
|
||
|
||
def _generate_manager_factory(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的工厂方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_initialize_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成初始化方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_bind_method(buffers: List[BufferInfo]) -> List[str]:
|
||
"""生成绑定到 descriptor set 的方法(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
def _generate_manager_constructor(buffers: List[BufferInfo], class_name: str) -> List[str]:
|
||
"""生成管理器的私有构造函数(已被模板替代,保留用于兼容性)"""
|
||
# 此函数已被模板替代,保留签名以防外部调用
|
||
return []
|
||
|
||
|
||
# ============ Vulkan Type Conversion ============
|
||
|
||
def vk_descriptor_type_to_cpp(descriptor_type: str) -> str:
|
||
"""将描述符类型转换为Vulkan C++ enum"""
|
||
type_map = {
|
||
"storage_buffer": "vk::DescriptorType::eStorageBuffer",
|
||
"uniform_buffer": "vk::DescriptorType::eUniformBuffer",
|
||
"sampled_image": "vk::DescriptorType::eSampledImage",
|
||
"combined_image_sampler": "vk::DescriptorType::eCombinedImageSampler",
|
||
"storage_image": "vk::DescriptorType::eStorageImage",
|
||
"sampler": "vk::DescriptorType::eSampler",
|
||
}
|
||
return type_map.get(descriptor_type, "vk::DescriptorType::eStorageBuffer")
|
||
|
||
|
||
def vk_shader_stage_to_cpp(stages: List[str]) -> str:
|
||
"""将着色器阶段列表转换为Vulkan C++ flags"""
|
||
stage_map = {
|
||
"vertex": "vk::ShaderStageFlagBits::eVertex",
|
||
"fragment": "vk::ShaderStageFlagBits::eFragment",
|
||
"compute": "vk::ShaderStageFlagBits::eCompute",
|
||
"geometry": "vk::ShaderStageFlagBits::eGeometry",
|
||
"tess_control": "vk::ShaderStageFlagBits::eTessellationControl",
|
||
"tess_eval": "vk::ShaderStageFlagBits::eTessellationEvaluation",
|
||
}
|
||
|
||
if not stages:
|
||
stages = ["compute"]
|
||
|
||
flags = [stage_map[stage] for stage in stages if stage in stage_map]
|
||
if not flags:
|
||
flags = ["vk::ShaderStageFlagBits::eCompute"]
|
||
|
||
return " | ".join(flags)
|
||
|
||
|
||
# ============ Header File Generation ============
|
||
|
||
def _to_pascal_case(name: str) -> str:
|
||
"""将 snake_case 转换为 PascalCase"""
|
||
return ''.join(word.capitalize() for word in name.split('_'))
|
||
|
||
|
||
def generate_header(
|
||
metadata: ShaderMetadata,
|
||
compilation_results: List[CompilationResult],
|
||
) -> str:
|
||
"""生成完整的C++头文件"""
|
||
# 准备buffer结构体代码
|
||
buffer_structures = ""
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
buffer_structures = generate_buffer_structures(metadata.reflection)
|
||
|
||
# 准备buffer辅助函数代码
|
||
buffer_helpers = ""
|
||
if metadata.generate_buffer_helpers and metadata.reflection and metadata.reflection.buffers:
|
||
buffer_helpers = generate_buffer_helper_functions(metadata.reflection, metadata.name)
|
||
|
||
# 准备typed buffer代码
|
||
typed_buffers = ""
|
||
if metadata.generate_typed_buffers and metadata.reflection and metadata.reflection.buffers:
|
||
typed_buffers = generate_typed_buffer_aliases(metadata.reflection, metadata.name)
|
||
|
||
# 准备buffer管理器代码
|
||
buffer_manager = ""
|
||
if metadata.generate_buffer_manager and metadata.reflection and metadata.reflection.buffers:
|
||
buffer_manager = generate_buffer_manager_class(metadata.reflection, metadata.name)
|
||
|
||
# 准备Push Constants用户自定义部分结构体代码
|
||
push_constant_custom_struct = ""
|
||
push_constant_custom_type_name = "void"
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
|
||
push_constant_custom_struct = generate_push_constant_custom_struct(metadata.reflection, metadata.name)
|
||
push_constant_custom_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
|
||
|
||
# 兼容性:保持旧变量名
|
||
push_constant_effect_struct = push_constant_custom_struct
|
||
push_constant_effect_type_name = push_constant_custom_type_name
|
||
|
||
# 准备Push Constants Header部分布局代码
|
||
push_constant_header_layout = ""
|
||
has_push_constant_header_layout = False
|
||
push_constant_header_member_names = []
|
||
push_constant_header_member_offsets = []
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
|
||
push_constant_header_layout = generate_push_constant_header_layout(metadata.reflection, metadata.name)
|
||
has_push_constant_header_layout = bool(push_constant_header_layout)
|
||
# 提取 Header 成员名称和偏移量供 shader_spec 使用
|
||
base_info = metadata.reflection.push_constant.base_info
|
||
for member in base_info.members:
|
||
push_constant_header_member_names.append(member.name)
|
||
push_constant_header_member_offsets.append(member.offset)
|
||
|
||
# 兼容性:保持旧变量名
|
||
push_constant_base_layout = push_constant_header_layout
|
||
has_push_constant_base_layout = has_push_constant_header_layout
|
||
push_constant_base_member_names = push_constant_header_member_names
|
||
push_constant_base_member_offsets = push_constant_header_member_offsets
|
||
|
||
# 生成 Header 静态检查代码
|
||
push_constant_static_check = ""
|
||
if metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.base_info:
|
||
push_constant_static_check = generate_push_constant_static_check(metadata.reflection, metadata.name)
|
||
|
||
# 准备SPIR-V数组和着色器类型检测
|
||
compilation_results_with_arrays = []
|
||
has_vert_shader = False
|
||
has_frag_shader = False
|
||
vert_spirv_symbol = ""
|
||
frag_spirv_symbol = ""
|
||
|
||
for result in compilation_results:
|
||
symbol_name = f"{metadata.name}_{result.shader_type}_spirv"
|
||
spirv_array = spirv_to_cpp_array(result.spirv_data, symbol_name)
|
||
compilation_results_with_arrays.append({
|
||
'shader_type': result.shader_type,
|
||
'source_file': result.source_file,
|
||
'spirv_array': spirv_array,
|
||
})
|
||
|
||
# 检测着色器类型
|
||
if result.shader_type == "vert":
|
||
has_vert_shader = True
|
||
vert_spirv_symbol = symbol_name
|
||
elif result.shader_type == "frag":
|
||
has_frag_shader = True
|
||
frag_spirv_symbol = symbol_name
|
||
|
||
# 准备绑定信息
|
||
# 使用 set * 1000 + binding 作为最终的 binding 值
|
||
bindings_data = []
|
||
if metadata.bindings:
|
||
for binding in metadata.bindings:
|
||
# 计算包含set偏移的binding值
|
||
final_binding = binding.descriptor_set * 1000 + binding.binding
|
||
bindings_data.append({
|
||
'binding': final_binding,
|
||
'descriptor_type_cpp': vk_descriptor_type_to_cpp(binding.descriptor_type),
|
||
'count': binding.count,
|
||
'stage_flags': vk_shader_stage_to_cpp(binding.stages),
|
||
})
|
||
|
||
# 确定参数类型名称(从 buffer 反射信息中获取)
|
||
# 优先级:buffer > push_constant效果部分 > void
|
||
params_type_name = "void"
|
||
if metadata.reflection and metadata.reflection.buffers:
|
||
# 使用第一个 buffer 的结构体类型名称
|
||
params_type_name = metadata.reflection.buffers[0].struct_type.name or params_type_name
|
||
elif metadata.reflection and metadata.reflection.push_constant and metadata.reflection.push_constant.effect_members:
|
||
# 没有 buffer,但有 push_constant 效果部分,使用效果结构体名称
|
||
params_type_name = f"{_to_pascal_case(metadata.name)}PushConstantCustom"
|
||
|
||
# 准备实例数据相关变量
|
||
instance_info = generate_instance_data_info(metadata.reflection, metadata.name)
|
||
|
||
# 准备顶点布局相关变量
|
||
vertex_info = generate_vertex_layout_info(metadata.reflection, metadata.name)
|
||
|
||
# 检测是否支持 procedural 顶点着色器
|
||
# 条件:片元着色器的 push_constant 第一个用户自定义成员偏移为 16 字节
|
||
supports_procedural_vertex_shader = False
|
||
if metadata.reflection:
|
||
supports_procedural_vertex_shader = check_supports_procedural_vertex_shader(metadata.reflection)
|
||
|
||
# 使用模板渲染
|
||
renderer = get_renderer()
|
||
context = {
|
||
'namespace': metadata.namespace,
|
||
'shader_name': metadata.name,
|
||
'metadata_struct_name': f"{metadata.name}_metadata",
|
||
'bindings_name': f"{metadata.name}_bindings",
|
||
'shader_spec_name': f"{metadata.name}_shader_spec",
|
||
'params_type_name': params_type_name,
|
||
'entry_point': compilation_results[0].entry_point if compilation_results else "main",
|
||
'is_compute': metadata.shader_type == "compute",
|
||
'spirv_symbol_name': f"{metadata.name}_{compilation_results[0].shader_type}_spirv" if compilation_results else "",
|
||
'has_buffers': metadata.reflection and metadata.reflection.buffers,
|
||
'generate_typed_buffers': metadata.generate_typed_buffers,
|
||
'buffer_structures': buffer_structures,
|
||
'buffer_helpers': buffer_helpers,
|
||
'typed_buffers': typed_buffers,
|
||
'buffer_manager': buffer_manager,
|
||
'compilation_results': compilation_results_with_arrays,
|
||
'bindings': bindings_data,
|
||
# shader_spec 相关变量
|
||
'has_vert_shader': has_vert_shader,
|
||
'has_frag_shader': has_frag_shader,
|
||
'vert_spirv_symbol': vert_spirv_symbol,
|
||
'frag_spirv_symbol': frag_spirv_symbol,
|
||
# Push Constants 效果结构体相关变量
|
||
'push_constant_effect_struct': push_constant_effect_struct,
|
||
'push_constant_effect_type_name': push_constant_effect_type_name,
|
||
'has_push_constant_effect': bool(push_constant_effect_struct),
|
||
# Push Constants 基础部分布局相关变量
|
||
'push_constant_base_layout': push_constant_base_layout,
|
||
'has_push_constant_base_layout': has_push_constant_base_layout,
|
||
'push_constant_base_member_names': push_constant_base_member_names,
|
||
'push_constant_base_member_offsets': push_constant_base_member_offsets,
|
||
# Push Constants 静态检查代码
|
||
'push_constant_static_check': push_constant_static_check,
|
||
# 是否支持 procedural 顶点着色器
|
||
'supports_procedural_vertex_shader': supports_procedural_vertex_shader,
|
||
# 实例化渲染相关变量
|
||
**instance_info,
|
||
# 顶点布局相关变量
|
||
**vertex_info,
|
||
}
|
||
|
||
return renderer.render('base/header.jinja2', context) |