615 lines
18 KiB
Python
615 lines
18 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
MIRAI 着色器绑定代码生成器
|
||
|
||
读取 SPIR-V 和反射 JSON,生成 C++ 绑定代码。
|
||
|
||
Usage:
|
||
python generate_shader_bindings.py \
|
||
--dir ./shader_intermediate \
|
||
--output generated/shader_name_bindings.hpp \
|
||
--name ShaderName
|
||
"""
|
||
|
||
import argparse
|
||
import json
|
||
import struct
|
||
import re
|
||
from pathlib import Path
|
||
from dataclasses import dataclass, field
|
||
from typing import List, Dict, Optional, Any
|
||
from datetime import datetime
|
||
|
||
from jinja2 import Environment, FileSystemLoader, select_autoescape
|
||
|
||
|
||
# ============================================================================
|
||
# 数据结构
|
||
# ============================================================================
|
||
|
||
@dataclass
|
||
class UniformMember:
|
||
"""Uniform Buffer 成员"""
|
||
name: str
|
||
type: str
|
||
offset: int
|
||
size: int
|
||
array_size: int = 1
|
||
|
||
|
||
@dataclass
|
||
class UniformBuffer:
|
||
"""Uniform Buffer 描述"""
|
||
name: str
|
||
set: int
|
||
binding: int
|
||
size: int
|
||
members: List[UniformMember] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class PushConstant:
|
||
"""Push Constant 描述"""
|
||
name: str
|
||
size: int
|
||
stage_flags: int = 0 # Vulkan 着色器阶段标志位掩码
|
||
offset: int = 0 # Push constant 偏移量(通常为 0)
|
||
members: List[UniformMember] = field(default_factory=list)
|
||
|
||
|
||
@dataclass
|
||
class SamplerBinding:
|
||
"""采样器绑定描述"""
|
||
name: str
|
||
set: int
|
||
binding: int
|
||
dimension: str = "2D"
|
||
|
||
|
||
@dataclass
|
||
class ShaderStage:
|
||
"""着色器阶段"""
|
||
stage: str # vertex, fragment, compute
|
||
entry_point: str
|
||
spirv: bytes = b''
|
||
|
||
|
||
@dataclass
|
||
class ShaderReflection:
|
||
"""着色器反射数据"""
|
||
name: str
|
||
stages: List[ShaderStage] = field(default_factory=list)
|
||
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
||
push_constants: List[PushConstant] = field(default_factory=list)
|
||
samplers: List[SamplerBinding] = field(default_factory=list)
|
||
|
||
|
||
# ============================================================================
|
||
# 类型映射
|
||
# ============================================================================
|
||
|
||
# Slang/HLSL 类型到 C++ 类型的映射
|
||
SLANG_TO_CPP_TYPE = {
|
||
"float": "float",
|
||
"float2": "Eigen::Vector2f",
|
||
"float3": "Eigen::Vector3f",
|
||
"float4": "Eigen::Vector4f",
|
||
"int": "int32_t",
|
||
"int2": "Eigen::Vector2i",
|
||
"int3": "Eigen::Vector3i",
|
||
"int4": "Eigen::Vector4i",
|
||
"uint": "uint32_t",
|
||
"uint2": "Eigen::Matrix<uint32_t, 2, 1>",
|
||
"uint3": "Eigen::Matrix<uint32_t, 3, 1>",
|
||
"uint4": "Eigen::Matrix<uint32_t, 4, 1>",
|
||
"float2x2": "Eigen::Matrix2f",
|
||
"float3x3": "Eigen::Matrix3f",
|
||
"float4x4": "Eigen::Matrix4f",
|
||
"bool": "uint32_t", # GLSL bool 是 4 字节
|
||
"matrix": "Eigen::Matrix4f",
|
||
}
|
||
|
||
# STD140 对齐规则
|
||
STD140_ALIGNMENT = {
|
||
"float": 4,
|
||
"float2": 8,
|
||
"float3": 16,
|
||
"float4": 16,
|
||
"int": 4,
|
||
"int2": 8,
|
||
"int3": 16,
|
||
"int4": 16,
|
||
"uint": 4,
|
||
"uint2": 8,
|
||
"uint3": 16,
|
||
"uint4": 16,
|
||
"float2x2": 16,
|
||
"float3x3": 16,
|
||
"float4x4": 16,
|
||
"bool": 4,
|
||
"matrix": 16,
|
||
}
|
||
|
||
|
||
# ============================================================================
|
||
# 解析器
|
||
# ============================================================================
|
||
|
||
def parse_reflection_json(json_path: Path) -> ShaderReflection:
|
||
"""解析 Slang 生成的反射 JSON"""
|
||
with open(json_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
|
||
name = json_path.stem.replace('.reflect', '')
|
||
|
||
# 解析 Uniform Buffers
|
||
uniform_buffers = []
|
||
for ub in data.get('uniformBuffers', []):
|
||
members = [
|
||
UniformMember(
|
||
name=m['name'],
|
||
type=m['type'],
|
||
offset=m['offset'],
|
||
size=m['size'],
|
||
array_size=m.get('arraySize', 1)
|
||
)
|
||
for m in ub.get('members', [])
|
||
]
|
||
uniform_buffers.append(UniformBuffer(
|
||
name=ub['name'],
|
||
set=ub['set'],
|
||
binding=ub['binding'],
|
||
size=ub['size'],
|
||
members=members
|
||
))
|
||
|
||
# 解析 Push Constants
|
||
push_constants = []
|
||
for pc in data.get('pushConstants', []):
|
||
members = [
|
||
UniformMember(
|
||
name=m['name'],
|
||
type=m['type'],
|
||
offset=m['offset'],
|
||
size=m['size']
|
||
)
|
||
for m in pc.get('members', [])
|
||
]
|
||
push_constants.append(PushConstant(
|
||
name=pc['name'],
|
||
size=pc['size'],
|
||
stage_flags=pc.get('stageFlags', 0),
|
||
offset=pc.get('offset', 0),
|
||
members=members
|
||
))
|
||
|
||
# 解析 Samplers
|
||
samplers = []
|
||
for s in data.get('samplers', []):
|
||
samplers.append(SamplerBinding(
|
||
name=s['name'],
|
||
set=s['set'],
|
||
binding=s['binding'],
|
||
dimension=s.get('dimension', '2D')
|
||
))
|
||
|
||
return ShaderReflection(
|
||
name=name,
|
||
stages=[],
|
||
uniform_buffers=uniform_buffers,
|
||
push_constants=push_constants,
|
||
samplers=samplers
|
||
)
|
||
|
||
|
||
def merge_reflections(reflections: List[ShaderReflection]) -> ShaderReflection:
|
||
"""合并多个反射数据(取并集)"""
|
||
if not reflections:
|
||
return ShaderReflection(name="merged")
|
||
|
||
if len(reflections) == 1:
|
||
return reflections[0]
|
||
|
||
# 使用第一个反射的名称作为合并后的名称
|
||
merged_name = reflections[0].name
|
||
|
||
# 合并 uniform buffers(按 set/binding 去重)
|
||
uniform_buffers_dict: Dict[tuple, UniformBuffer] = {}
|
||
for ref in reflections:
|
||
for ub in ref.uniform_buffers:
|
||
key = (ub.set, ub.binding)
|
||
if key not in uniform_buffers_dict:
|
||
uniform_buffers_dict[key] = ub
|
||
|
||
# 合并 push constants(按名称去重)
|
||
push_constants_dict: Dict[str, PushConstant] = {}
|
||
for ref in reflections:
|
||
for pc in ref.push_constants:
|
||
if pc.name not in push_constants_dict:
|
||
push_constants_dict[pc.name] = pc
|
||
|
||
# 合并 samplers(按 set/binding 去重)
|
||
samplers_dict: Dict[tuple, SamplerBinding] = {}
|
||
for ref in reflections:
|
||
for s in ref.samplers:
|
||
key = (s.set, s.binding)
|
||
if key not in samplers_dict:
|
||
samplers_dict[key] = s
|
||
|
||
return ShaderReflection(
|
||
name=merged_name,
|
||
stages=[],
|
||
uniform_buffers=list(uniform_buffers_dict.values()),
|
||
push_constants=list(push_constants_dict.values()),
|
||
samplers=list(samplers_dict.values())
|
||
)
|
||
|
||
|
||
def load_spirv(spirv_path: Path) -> bytes:
|
||
"""加载 SPIR-V 二进制文件"""
|
||
with open(spirv_path, 'rb') as f:
|
||
return f.read()
|
||
|
||
|
||
def detect_stage_from_filename(filename: str) -> Optional[str]:
|
||
"""从文件名检测着色器阶段"""
|
||
filename_lower = filename.lower()
|
||
if '.vert.' in filename_lower or filename_lower.endswith('.vert'):
|
||
return 'vertex'
|
||
elif '.frag.' in filename_lower or filename_lower.endswith('.frag'):
|
||
return 'fragment'
|
||
elif '.comp.' in filename_lower or filename_lower.endswith('.comp'):
|
||
return 'compute'
|
||
elif '.geom.' in filename_lower or filename_lower.endswith('.geom'):
|
||
return 'geometry'
|
||
elif '.tesc.' in filename_lower or filename_lower.endswith('.tesc'):
|
||
return 'tessellation_control'
|
||
elif '.tese.' in filename_lower or filename_lower.endswith('.tese'):
|
||
return 'tessellation_evaluation'
|
||
return None
|
||
|
||
|
||
# ============================================================================
|
||
# 辅助函数
|
||
# ============================================================================
|
||
|
||
def to_snake_case(name: str) -> str:
|
||
"""转换为 snake_case"""
|
||
# 如果已经是全小写下划线连接,不做转换
|
||
if name.islower() and '_' in name:
|
||
return name
|
||
|
||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||
snake = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
||
# 处理连续大写字母的情况(如 "ABCDef" -> "abc_def")
|
||
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', snake)
|
||
snake = re.sub(r'([a-z])([A-Z])', r'\1_\2', snake).lower()
|
||
return snake
|
||
|
||
|
||
def to_camel_case(name: str) -> str:
|
||
"""转换为 CamelCase(首字母大写的驼峰式)"""
|
||
parts = name.split('_')
|
||
return ''.join(part.title() for part in parts if part)
|
||
|
||
|
||
def to_upper_snake_case(name: str) -> str:
|
||
"""转换为 UPPER_SNAKE_CASE"""
|
||
return to_snake_case(name).upper()
|
||
|
||
|
||
def get_stage_suffix(stage: str) -> str:
|
||
"""获取着色器阶段后缀"""
|
||
return {
|
||
'vertex': 'VERT',
|
||
'fragment': 'FRAG',
|
||
'compute': 'COMP',
|
||
'geometry': 'GEOM',
|
||
'tessellation_control': 'TESC',
|
||
'tessellation_evaluation': 'TESE',
|
||
}.get(stage, stage.upper())
|
||
|
||
|
||
def spirv_to_hex_lines(spirv: bytes) -> List[str]:
|
||
"""将 SPIR-V 转换为十六进制行"""
|
||
if len(spirv) == 0:
|
||
return []
|
||
|
||
word_count = len(spirv) // 4
|
||
words = struct.unpack(f'<{word_count}I', spirv[:word_count * 4])
|
||
|
||
lines = []
|
||
for i in range(0, len(words), 8):
|
||
chunk = words[i:i+8]
|
||
hex_values = ', '.join(f'0x{w:08x}' for w in chunk)
|
||
lines.append(f"{hex_values},")
|
||
|
||
return lines
|
||
|
||
|
||
def prepare_template_context(
|
||
shader_name: str,
|
||
reflection: ShaderReflection,
|
||
stages: Dict[str, bytes]
|
||
) -> Dict[str, Any]:
|
||
"""准备 Jinja2 模板上下文"""
|
||
snake_name = to_snake_case(shader_name)
|
||
upper_snake_name = to_upper_snake_case(shader_name)
|
||
guard = f"MIRAI_GENERATED_SHADER_{upper_snake_name}_HPP"
|
||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||
|
||
# 准备 SPIR-V 数据
|
||
stages_data = {}
|
||
for stage, spirv in stages.items():
|
||
stage_suffix = get_stage_suffix(stage)
|
||
word_count = len(spirv) // 4 if spirv else 0
|
||
stages_data[stage_suffix] = {
|
||
'word_count': word_count,
|
||
'hex_lines': spirv_to_hex_lines(spirv),
|
||
}
|
||
|
||
# 准备 Uniform Buffer 数据
|
||
uniform_buffers_data = []
|
||
for ub in reflection.uniform_buffers:
|
||
struct_name = to_snake_case(ub.name) # 使用 snake_case 作为结构体名称
|
||
members_data = []
|
||
calculated_size = 0
|
||
for member in ub.members:
|
||
cpp_type = SLANG_TO_CPP_TYPE.get(member.type, member.type)
|
||
alignment = STD140_ALIGNMENT.get(member.type, 16)
|
||
if member.array_size > 1:
|
||
cpp_type = f"std::array<{cpp_type}, {member.array_size}>"
|
||
members_data.append({
|
||
'name': to_snake_case(member.name), # 成员名使用 snake_case
|
||
'type_name': member.type,
|
||
'cpp_type': cpp_type,
|
||
'alignment': alignment,
|
||
'offset': member.offset,
|
||
'size': member.size,
|
||
})
|
||
# 计算最大偏移量+大小作为 fallback size
|
||
member_end = member.offset + member.size
|
||
if member_end > calculated_size:
|
||
calculated_size = member_end
|
||
|
||
# 使用反射数据中的 size,如果为 0 则使用计算的值
|
||
final_size = ub.size if ub.size > 0 else calculated_size
|
||
# 还需要考虑 16 字节对齐
|
||
final_size = (final_size + 15) & ~15
|
||
|
||
uniform_buffers_data.append({
|
||
'name': ub.name,
|
||
'struct_name': struct_name,
|
||
'snake_name': to_snake_case(ub.name),
|
||
'set': ub.set,
|
||
'binding': ub.binding,
|
||
'size': final_size,
|
||
'members': members_data,
|
||
})
|
||
|
||
# 准备 Push Constant 数据
|
||
push_constants_data = []
|
||
for pc in reflection.push_constants:
|
||
struct_name = to_snake_case(pc.name) # 使用 snake_case 作为结构体名称
|
||
members_data = []
|
||
for member in pc.members:
|
||
cpp_type = SLANG_TO_CPP_TYPE.get(member.type, member.type)
|
||
members_data.append({
|
||
'name': to_snake_case(member.name), # 成员名使用 snake_case
|
||
'type_name': member.type,
|
||
'cpp_type': cpp_type,
|
||
'offset': member.offset,
|
||
'size': member.size,
|
||
})
|
||
push_constants_data.append({
|
||
'name': pc.name,
|
||
'struct_name': struct_name,
|
||
'snake_name': to_snake_case(pc.name),
|
||
'size': pc.size,
|
||
'stage_flags': pc.stage_flags,
|
||
'offset': pc.offset,
|
||
'members': members_data,
|
||
})
|
||
|
||
# 准备采样器数据
|
||
samplers_data = []
|
||
for s in reflection.samplers:
|
||
samplers_data.append({
|
||
'name': s.name,
|
||
'snake_name': to_snake_case(s.name), # 使用 snake_case
|
||
'set': s.set,
|
||
'binding': s.binding,
|
||
'dimension': s.dimension,
|
||
})
|
||
|
||
# 准备绑定常量数据
|
||
bindings_data = []
|
||
for ub in reflection.uniform_buffers:
|
||
bindings_data.append({
|
||
'const_name': to_upper_snake_case(ub.name),
|
||
'set': ub.set,
|
||
'binding': ub.binding,
|
||
})
|
||
for s in reflection.samplers:
|
||
bindings_data.append({
|
||
'const_name': to_upper_snake_case(s.name),
|
||
'set': s.set,
|
||
'binding': s.binding,
|
||
})
|
||
|
||
return {
|
||
'shader_name': shader_name,
|
||
'snake_name': snake_name,
|
||
'upper_snake_name': upper_snake_name,
|
||
'guard': guard,
|
||
'timestamp': timestamp,
|
||
'stages': stages_data,
|
||
'uniform_buffers': uniform_buffers_data,
|
||
'push_constants': push_constants_data,
|
||
'samplers': samplers_data,
|
||
'bindings': bindings_data,
|
||
'uniform_buffer_count': len(reflection.uniform_buffers),
|
||
'sampler_count': len(reflection.samplers),
|
||
}
|
||
|
||
|
||
# ============================================================================
|
||
# 代码生成器
|
||
# ============================================================================
|
||
|
||
def create_jinja_env(template_dir: Path) -> Environment:
|
||
"""创建 Jinja2 环境"""
|
||
return Environment(
|
||
loader=FileSystemLoader(template_dir),
|
||
autoescape=select_autoescape(['html', 'xml']),
|
||
trim_blocks=True,
|
||
lstrip_blocks=True,
|
||
keep_trailing_newline=True,
|
||
)
|
||
|
||
|
||
def generate_header(
|
||
shader_name: str,
|
||
reflection: ShaderReflection,
|
||
stages: Dict[str, bytes],
|
||
template_dir: Path
|
||
) -> str:
|
||
"""使用 Jinja2 生成 C++ 头文件"""
|
||
env = create_jinja_env(template_dir)
|
||
template = env.get_template('shader_bindings.hpp.j2')
|
||
context = prepare_template_context(shader_name, reflection, stages)
|
||
return template.render(context)
|
||
|
||
|
||
# ============================================================================
|
||
# 主函数
|
||
# ============================================================================
|
||
|
||
def main():
|
||
parser = argparse.ArgumentParser(
|
||
description='MIRAI 着色器绑定代码生成器',
|
||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||
epilog="""
|
||
示例:
|
||
python generate_shader_bindings.py \\
|
||
--dir ./shader_intermediate/basic \\
|
||
--output generated/basic_bindings.hpp \\
|
||
--name basic
|
||
"""
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--dir', '-d',
|
||
required=True,
|
||
help='包含 SPIR-V 和反射文件的目录'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--output', '-o',
|
||
required=True,
|
||
help='输出的 C++ 头文件路径'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--name', '-n',
|
||
required=True,
|
||
help='着色器名称 (用于生成类名和变量名)'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--template-dir', '-t',
|
||
default=None,
|
||
help='Jinja2 模板目录 (默认: 脚本所在目录的 templates 子目录)'
|
||
)
|
||
|
||
parser.add_argument(
|
||
'--verbose', '-v',
|
||
action='store_true',
|
||
help='输出详细信息'
|
||
)
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 确定模板目录
|
||
if args.template_dir:
|
||
template_dir = Path(args.template_dir)
|
||
else:
|
||
template_dir = Path(__file__).parent / 'templates'
|
||
|
||
if not template_dir.exists():
|
||
print(f"Error: Template directory not found: {template_dir}")
|
||
return 1
|
||
|
||
# 从目录中查找文件
|
||
shader_dir = Path(args.dir)
|
||
if not shader_dir.exists():
|
||
print(f"Info: Shader directory not found: {shader_dir}. No entry points, skipping.")
|
||
return 0
|
||
|
||
# 查找所有 spv 和 reflect.json 文件
|
||
spirv_files = sorted(shader_dir.glob("*.spv"))
|
||
reflect_files = sorted(shader_dir.glob("*.reflect.json"))
|
||
|
||
if args.verbose:
|
||
print(f"Found {len(spirv_files)} SPIR-V files")
|
||
print(f"Found {len(reflect_files)} reflection files")
|
||
|
||
# 如果没有反射文件,跳过生成
|
||
if not reflect_files:
|
||
print(f"Info: No reflection files found in {shader_dir}. Skipping binding generation.")
|
||
return 0
|
||
|
||
# 解析反射数据
|
||
all_reflections = []
|
||
for reflection_path in reflect_files:
|
||
if args.verbose:
|
||
print(f"Parsing reflection: {reflection_path}")
|
||
reflection = parse_reflection_json(reflection_path)
|
||
all_reflections.append(reflection)
|
||
|
||
# 合并多个反射数据(取并集)
|
||
reflection = merge_reflections(all_reflections)
|
||
|
||
# 加载 SPIR-V 文件
|
||
stages: Dict[str, bytes] = {}
|
||
for spirv_path in spirv_files:
|
||
stage = detect_stage_from_filename(spirv_path.name)
|
||
if stage is None:
|
||
if args.verbose:
|
||
print(f"Warning: Cannot detect shader stage from filename: {spirv_path.name}")
|
||
stage = spirv_path.stem.split('.')[-1] if '.' in spirv_path.stem else 'unknown'
|
||
|
||
if args.verbose:
|
||
print(f"Loading SPIR-V ({stage}): {spirv_path}")
|
||
|
||
stages[stage] = load_spirv(spirv_path)
|
||
|
||
# 生成头文件
|
||
if args.verbose:
|
||
print(f"Generating header: {args.output}")
|
||
print(f"Using templates from: {template_dir}")
|
||
|
||
header_content = generate_header(args.name, reflection, stages, template_dir)
|
||
|
||
# 确保输出目录存在
|
||
output_path = Path(args.output)
|
||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
||
# 写入文件
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
f.write(header_content)
|
||
|
||
print(f"Generated: {output_path}")
|
||
|
||
if args.verbose:
|
||
print(f" - Stages: {list(stages.keys())}")
|
||
print(f" - Uniform Buffers: {len(reflection.uniform_buffers)}")
|
||
print(f" - Samplers: {len(reflection.samplers)}")
|
||
print(f" - Push Constants: {len(reflection.push_constants)}")
|
||
|
||
return 0
|
||
|
||
|
||
if __name__ == '__main__':
|
||
exit(main())
|