Files
mirai/tools/generate_shader_bindings.py

615 lines
18 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
MIRAI 着色器绑定代码生成器
读取 SPIR-V 和反射 JSON生成 C++ 绑定代码。
Usage:
python generate_shader_bindings.py \
--dir ./shader_intermediate \
--output generated/shader_name_bindings.hpp \
--name ShaderName
"""
import argparse
import json
import struct
import re
from pathlib import Path
from dataclasses import dataclass, field
from typing import List, Dict, Optional, Any
from datetime import datetime
from jinja2 import Environment, FileSystemLoader, select_autoescape
# ============================================================================
# 数据结构
# ============================================================================
@dataclass
class UniformMember:
"""Uniform Buffer 成员"""
name: str
type: str
offset: int
size: int
array_size: int = 1
@dataclass
class UniformBuffer:
"""Uniform Buffer 描述"""
name: str
set: int
binding: int
size: int
members: List[UniformMember] = field(default_factory=list)
@dataclass
class PushConstant:
"""Push Constant 描述"""
name: str
size: int
stage_flags: int = 0 # Vulkan 着色器阶段标志位掩码
offset: int = 0 # Push constant 偏移量(通常为 0
members: List[UniformMember] = field(default_factory=list)
@dataclass
class SamplerBinding:
"""采样器绑定描述"""
name: str
set: int
binding: int
dimension: str = "2D"
@dataclass
class ShaderStage:
"""着色器阶段"""
stage: str # vertex, fragment, compute
entry_point: str
spirv: bytes = b''
@dataclass
class ShaderReflection:
"""着色器反射数据"""
name: str
stages: List[ShaderStage] = field(default_factory=list)
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
push_constants: List[PushConstant] = field(default_factory=list)
samplers: List[SamplerBinding] = field(default_factory=list)
# ============================================================================
# 类型映射
# ============================================================================
# Slang/HLSL 类型到 C++ 类型的映射
SLANG_TO_CPP_TYPE = {
"float": "float",
"float2": "Eigen::Vector2f",
"float3": "Eigen::Vector3f",
"float4": "Eigen::Vector4f",
"int": "int32_t",
"int2": "Eigen::Vector2i",
"int3": "Eigen::Vector3i",
"int4": "Eigen::Vector4i",
"uint": "uint32_t",
"uint2": "Eigen::Matrix<uint32_t, 2, 1>",
"uint3": "Eigen::Matrix<uint32_t, 3, 1>",
"uint4": "Eigen::Matrix<uint32_t, 4, 1>",
"float2x2": "Eigen::Matrix2f",
"float3x3": "Eigen::Matrix3f",
"float4x4": "Eigen::Matrix4f",
"bool": "uint32_t", # GLSL bool 是 4 字节
"matrix": "Eigen::Matrix4f",
}
# STD140 对齐规则
STD140_ALIGNMENT = {
"float": 4,
"float2": 8,
"float3": 16,
"float4": 16,
"int": 4,
"int2": 8,
"int3": 16,
"int4": 16,
"uint": 4,
"uint2": 8,
"uint3": 16,
"uint4": 16,
"float2x2": 16,
"float3x3": 16,
"float4x4": 16,
"bool": 4,
"matrix": 16,
}
# ============================================================================
# 解析器
# ============================================================================
def parse_reflection_json(json_path: Path) -> ShaderReflection:
"""解析 Slang 生成的反射 JSON"""
with open(json_path, 'r', encoding='utf-8') as f:
data = json.load(f)
name = json_path.stem.replace('.reflect', '')
# 解析 Uniform Buffers
uniform_buffers = []
for ub in data.get('uniformBuffers', []):
members = [
UniformMember(
name=m['name'],
type=m['type'],
offset=m['offset'],
size=m['size'],
array_size=m.get('arraySize', 1)
)
for m in ub.get('members', [])
]
uniform_buffers.append(UniformBuffer(
name=ub['name'],
set=ub['set'],
binding=ub['binding'],
size=ub['size'],
members=members
))
# 解析 Push Constants
push_constants = []
for pc in data.get('pushConstants', []):
members = [
UniformMember(
name=m['name'],
type=m['type'],
offset=m['offset'],
size=m['size']
)
for m in pc.get('members', [])
]
push_constants.append(PushConstant(
name=pc['name'],
size=pc['size'],
stage_flags=pc.get('stageFlags', 0),
offset=pc.get('offset', 0),
members=members
))
# 解析 Samplers
samplers = []
for s in data.get('samplers', []):
samplers.append(SamplerBinding(
name=s['name'],
set=s['set'],
binding=s['binding'],
dimension=s.get('dimension', '2D')
))
return ShaderReflection(
name=name,
stages=[],
uniform_buffers=uniform_buffers,
push_constants=push_constants,
samplers=samplers
)
def merge_reflections(reflections: List[ShaderReflection]) -> ShaderReflection:
"""合并多个反射数据(取并集)"""
if not reflections:
return ShaderReflection(name="merged")
if len(reflections) == 1:
return reflections[0]
# 使用第一个反射的名称作为合并后的名称
merged_name = reflections[0].name
# 合并 uniform buffers按 set/binding 去重)
uniform_buffers_dict: Dict[tuple, UniformBuffer] = {}
for ref in reflections:
for ub in ref.uniform_buffers:
key = (ub.set, ub.binding)
if key not in uniform_buffers_dict:
uniform_buffers_dict[key] = ub
# 合并 push constants按名称去重
push_constants_dict: Dict[str, PushConstant] = {}
for ref in reflections:
for pc in ref.push_constants:
if pc.name not in push_constants_dict:
push_constants_dict[pc.name] = pc
# 合并 samplers按 set/binding 去重)
samplers_dict: Dict[tuple, SamplerBinding] = {}
for ref in reflections:
for s in ref.samplers:
key = (s.set, s.binding)
if key not in samplers_dict:
samplers_dict[key] = s
return ShaderReflection(
name=merged_name,
stages=[],
uniform_buffers=list(uniform_buffers_dict.values()),
push_constants=list(push_constants_dict.values()),
samplers=list(samplers_dict.values())
)
def load_spirv(spirv_path: Path) -> bytes:
"""加载 SPIR-V 二进制文件"""
with open(spirv_path, 'rb') as f:
return f.read()
def detect_stage_from_filename(filename: str) -> Optional[str]:
"""从文件名检测着色器阶段"""
filename_lower = filename.lower()
if '.vert.' in filename_lower or filename_lower.endswith('.vert'):
return 'vertex'
elif '.frag.' in filename_lower or filename_lower.endswith('.frag'):
return 'fragment'
elif '.comp.' in filename_lower or filename_lower.endswith('.comp'):
return 'compute'
elif '.geom.' in filename_lower or filename_lower.endswith('.geom'):
return 'geometry'
elif '.tesc.' in filename_lower or filename_lower.endswith('.tesc'):
return 'tessellation_control'
elif '.tese.' in filename_lower or filename_lower.endswith('.tese'):
return 'tessellation_evaluation'
return None
# ============================================================================
# 辅助函数
# ============================================================================
def to_snake_case(name: str) -> str:
"""转换为 snake_case"""
# 如果已经是全小写下划线连接,不做转换
if name.islower() and '_' in name:
return name
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
snake = re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
# 处理连续大写字母的情况(如 "ABCDef" -> "abc_def"
snake = re.sub(r'([A-Z]+)([A-Z][a-z])', r'\1_\2', snake)
snake = re.sub(r'([a-z])([A-Z])', r'\1_\2', snake).lower()
return snake
def to_camel_case(name: str) -> str:
"""转换为 CamelCase首字母大写的驼峰式"""
parts = name.split('_')
return ''.join(part.title() for part in parts if part)
def to_upper_snake_case(name: str) -> str:
"""转换为 UPPER_SNAKE_CASE"""
return to_snake_case(name).upper()
def get_stage_suffix(stage: str) -> str:
"""获取着色器阶段后缀"""
return {
'vertex': 'VERT',
'fragment': 'FRAG',
'compute': 'COMP',
'geometry': 'GEOM',
'tessellation_control': 'TESC',
'tessellation_evaluation': 'TESE',
}.get(stage, stage.upper())
def spirv_to_hex_lines(spirv: bytes) -> List[str]:
"""将 SPIR-V 转换为十六进制行"""
if len(spirv) == 0:
return []
word_count = len(spirv) // 4
words = struct.unpack(f'<{word_count}I', spirv[:word_count * 4])
lines = []
for i in range(0, len(words), 8):
chunk = words[i:i+8]
hex_values = ', '.join(f'0x{w:08x}' for w in chunk)
lines.append(f"{hex_values},")
return lines
def prepare_template_context(
shader_name: str,
reflection: ShaderReflection,
stages: Dict[str, bytes]
) -> Dict[str, Any]:
"""准备 Jinja2 模板上下文"""
snake_name = to_snake_case(shader_name)
upper_snake_name = to_upper_snake_case(shader_name)
guard = f"MIRAI_GENERATED_SHADER_{upper_snake_name}_HPP"
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
# 准备 SPIR-V 数据
stages_data = {}
for stage, spirv in stages.items():
stage_suffix = get_stage_suffix(stage)
word_count = len(spirv) // 4 if spirv else 0
stages_data[stage_suffix] = {
'word_count': word_count,
'hex_lines': spirv_to_hex_lines(spirv),
}
# 准备 Uniform Buffer 数据
uniform_buffers_data = []
for ub in reflection.uniform_buffers:
struct_name = to_snake_case(ub.name) # 使用 snake_case 作为结构体名称
members_data = []
calculated_size = 0
for member in ub.members:
cpp_type = SLANG_TO_CPP_TYPE.get(member.type, member.type)
alignment = STD140_ALIGNMENT.get(member.type, 16)
if member.array_size > 1:
cpp_type = f"std::array<{cpp_type}, {member.array_size}>"
members_data.append({
'name': to_snake_case(member.name), # 成员名使用 snake_case
'type_name': member.type,
'cpp_type': cpp_type,
'alignment': alignment,
'offset': member.offset,
'size': member.size,
})
# 计算最大偏移量+大小作为 fallback size
member_end = member.offset + member.size
if member_end > calculated_size:
calculated_size = member_end
# 使用反射数据中的 size如果为 0 则使用计算的值
final_size = ub.size if ub.size > 0 else calculated_size
# 还需要考虑 16 字节对齐
final_size = (final_size + 15) & ~15
uniform_buffers_data.append({
'name': ub.name,
'struct_name': struct_name,
'snake_name': to_snake_case(ub.name),
'set': ub.set,
'binding': ub.binding,
'size': final_size,
'members': members_data,
})
# 准备 Push Constant 数据
push_constants_data = []
for pc in reflection.push_constants:
struct_name = to_snake_case(pc.name) # 使用 snake_case 作为结构体名称
members_data = []
for member in pc.members:
cpp_type = SLANG_TO_CPP_TYPE.get(member.type, member.type)
members_data.append({
'name': to_snake_case(member.name), # 成员名使用 snake_case
'type_name': member.type,
'cpp_type': cpp_type,
'offset': member.offset,
'size': member.size,
})
push_constants_data.append({
'name': pc.name,
'struct_name': struct_name,
'snake_name': to_snake_case(pc.name),
'size': pc.size,
'stage_flags': pc.stage_flags,
'offset': pc.offset,
'members': members_data,
})
# 准备采样器数据
samplers_data = []
for s in reflection.samplers:
samplers_data.append({
'name': s.name,
'snake_name': to_snake_case(s.name), # 使用 snake_case
'set': s.set,
'binding': s.binding,
'dimension': s.dimension,
})
# 准备绑定常量数据
bindings_data = []
for ub in reflection.uniform_buffers:
bindings_data.append({
'const_name': to_upper_snake_case(ub.name),
'set': ub.set,
'binding': ub.binding,
})
for s in reflection.samplers:
bindings_data.append({
'const_name': to_upper_snake_case(s.name),
'set': s.set,
'binding': s.binding,
})
return {
'shader_name': shader_name,
'snake_name': snake_name,
'upper_snake_name': upper_snake_name,
'guard': guard,
'timestamp': timestamp,
'stages': stages_data,
'uniform_buffers': uniform_buffers_data,
'push_constants': push_constants_data,
'samplers': samplers_data,
'bindings': bindings_data,
'uniform_buffer_count': len(reflection.uniform_buffers),
'sampler_count': len(reflection.samplers),
}
# ============================================================================
# 代码生成器
# ============================================================================
def create_jinja_env(template_dir: Path) -> Environment:
"""创建 Jinja2 环境"""
return Environment(
loader=FileSystemLoader(template_dir),
autoescape=select_autoescape(['html', 'xml']),
trim_blocks=True,
lstrip_blocks=True,
keep_trailing_newline=True,
)
def generate_header(
shader_name: str,
reflection: ShaderReflection,
stages: Dict[str, bytes],
template_dir: Path
) -> str:
"""使用 Jinja2 生成 C++ 头文件"""
env = create_jinja_env(template_dir)
template = env.get_template('shader_bindings.hpp.j2')
context = prepare_template_context(shader_name, reflection, stages)
return template.render(context)
# ============================================================================
# 主函数
# ============================================================================
def main():
parser = argparse.ArgumentParser(
description='MIRAI 着色器绑定代码生成器',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
示例:
python generate_shader_bindings.py \\
--dir ./shader_intermediate/basic \\
--output generated/basic_bindings.hpp \\
--name basic
"""
)
parser.add_argument(
'--dir', '-d',
required=True,
help='包含 SPIR-V 和反射文件的目录'
)
parser.add_argument(
'--output', '-o',
required=True,
help='输出的 C++ 头文件路径'
)
parser.add_argument(
'--name', '-n',
required=True,
help='着色器名称 (用于生成类名和变量名)'
)
parser.add_argument(
'--template-dir', '-t',
default=None,
help='Jinja2 模板目录 (默认: 脚本所在目录的 templates 子目录)'
)
parser.add_argument(
'--verbose', '-v',
action='store_true',
help='输出详细信息'
)
args = parser.parse_args()
# 确定模板目录
if args.template_dir:
template_dir = Path(args.template_dir)
else:
template_dir = Path(__file__).parent / 'templates'
if not template_dir.exists():
print(f"Error: Template directory not found: {template_dir}")
return 1
# 从目录中查找文件
shader_dir = Path(args.dir)
if not shader_dir.exists():
print(f"Info: Shader directory not found: {shader_dir}. No entry points, skipping.")
return 0
# 查找所有 spv 和 reflect.json 文件
spirv_files = sorted(shader_dir.glob("*.spv"))
reflect_files = sorted(shader_dir.glob("*.reflect.json"))
if args.verbose:
print(f"Found {len(spirv_files)} SPIR-V files")
print(f"Found {len(reflect_files)} reflection files")
# 如果没有反射文件,跳过生成
if not reflect_files:
print(f"Info: No reflection files found in {shader_dir}. Skipping binding generation.")
return 0
# 解析反射数据
all_reflections = []
for reflection_path in reflect_files:
if args.verbose:
print(f"Parsing reflection: {reflection_path}")
reflection = parse_reflection_json(reflection_path)
all_reflections.append(reflection)
# 合并多个反射数据(取并集)
reflection = merge_reflections(all_reflections)
# 加载 SPIR-V 文件
stages: Dict[str, bytes] = {}
for spirv_path in spirv_files:
stage = detect_stage_from_filename(spirv_path.name)
if stage is None:
if args.verbose:
print(f"Warning: Cannot detect shader stage from filename: {spirv_path.name}")
stage = spirv_path.stem.split('.')[-1] if '.' in spirv_path.stem else 'unknown'
if args.verbose:
print(f"Loading SPIR-V ({stage}): {spirv_path}")
stages[stage] = load_spirv(spirv_path)
# 生成头文件
if args.verbose:
print(f"Generating header: {args.output}")
print(f"Using templates from: {template_dir}")
header_content = generate_header(args.name, reflection, stages, template_dir)
# 确保输出目录存在
output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)
# 写入文件
with open(output_path, 'w', encoding='utf-8') as f:
f.write(header_content)
print(f"Generated: {output_path}")
if args.verbose:
print(f" - Stages: {list(stages.keys())}")
print(f" - Uniform Buffers: {len(reflection.uniform_buffers)}")
print(f" - Samplers: {len(reflection.samplers)}")
print(f" - Push Constants: {len(reflection.push_constants)}")
return 0
if __name__ == '__main__':
exit(main())