Files
mirage/tools/spirv_parser.py
daiqingshuang 5a8d62f841 Refactor Push Constants and Add Dual Stage Support
- Removed legacy push constant structures and functions for better clarity and maintainability.
- Introduced new `text_push_constants_t` structure for text rendering with optimized layout.
- Implemented dual stage push constant analysis to support separate layouts for vertex and fragment shaders.
- Added functions to generate push constant structures and fill functions based on shader reflection.
- Enhanced static checks for push constant layouts to ensure compatibility and correctness.
- Updated templates to accommodate new dual stage push constant generation.
- Added support detection for procedural vertex shaders based on push constant layout.
2025-12-25 21:04:39 +08:00

1338 lines
47 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
SPIR-V解析模块
负责解析SPIR-V二进制文件提取类型系统、变量、装饰和反射信息。
"""
from __future__ import annotations
import struct
from typing import Any, Dict, List, Optional, Tuple
from .constants import (
DECORATION_ARRAY_STRIDE,
DECORATION_BINDING,
DECORATION_BLOCK,
DECORATION_BUFFER_BLOCK,
DECORATION_COL_MAJOR,
DECORATION_DESCRIPTOR_SET,
DECORATION_LOCATION,
DECORATION_MATRIX_STRIDE,
DECORATION_OFFSET,
DECORATION_ROW_MAJOR,
EXECUTION_MODEL_COMPUTE,
EXECUTION_MODEL_FRAGMENT,
EXECUTION_MODEL_GEOMETRY,
EXECUTION_MODEL_TESS_CONTROL,
EXECUTION_MODEL_TESS_EVAL,
EXECUTION_MODEL_VERTEX,
OP_CONSTANT,
OP_DECORATE,
OP_ENTRY_POINT,
OP_MEMBER_DECORATE,
OP_MEMBER_NAME,
OP_NAME,
OP_TYPE_ARRAY,
OP_TYPE_FLOAT,
OP_TYPE_IMAGE,
OP_TYPE_INT,
OP_TYPE_MATRIX,
OP_TYPE_POINTER,
OP_TYPE_RUNTIME_ARRAY,
OP_TYPE_SAMPLED_IMAGE,
OP_TYPE_SAMPLER,
OP_TYPE_STRUCT,
OP_TYPE_VECTOR,
OP_VARIABLE,
PUSH_CONSTANT_HEADER_SIZE,
PUSH_CONSTANT_CUSTOM_OFFSET,
SPIRV_MAGIC,
STORAGE_CLASS_INPUT,
STORAGE_CLASS_PUSH_CONSTANT,
STORAGE_CLASS_STORAGE_BUFFER,
STORAGE_CLASS_UNIFORM,
STORAGE_CLASS_UNIFORM_CONSTANT,
)
from .type_mapping import calculate_std430_alignment, calculate_std430_size, spirv_type_to_cpp, spirv_type_to_compact_cpp
from .types import (
ArrayTypeInfo,
BaseType,
BindingInfo,
BufferInfo,
CombinedPushConstantInfo,
InstanceBufferInfo,
MatrixTypeInfo,
MemberInfo,
PointerTypeInfo,
PushConstantBaseInfo,
PushConstantBaseMember,
PushConstantBaseMemberInfo,
PushConstantInfo,
ScalarTypeInfo,
SPIRVReflection,
StagePushConstantInfo,
StructTypeInfo,
ToolError,
TypeInfo,
VariableInfo,
VectorTypeInfo,
VertexAttribute,
VertexLayout,
)
# ============ Binary Parsing Functions ============
def parse_spirv_words(spirv_data: bytes) -> List[int]:
"""将SPIR-V二进制解析为32位字列表"""
if len(spirv_data) < 20:
raise ToolError("SPIR-V data too small")
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
words = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
if words[0] != SPIRV_MAGIC:
raise ToolError(f"Invalid SPIR-V magic number: 0x{words[0]:08x}")
return list(words)
def parse_spirv_string(words: List[int], start_index: int) -> Tuple[str, int]:
"""从字数组解析以null结尾的UTF-8字符串返回(字符串, 消耗的字数)"""
chars = []
word_count = 0
for i in range(start_index, len(words)):
word = words[i]
word_count += 1
for byte_idx in range(4):
byte = (word >> (byte_idx * 8)) & 0xFF
if byte == 0:
return ''.join(chars), word_count
chars.append(chr(byte))
return ''.join(chars), word_count
def parse_spirv_instructions(words: List[int]) -> List[Tuple[int, List[int], int]]:
"""解析SPIR-V指令流返回(opcode, operands, index)列表"""
instructions = []
index = 5 # 跳过头部5个字
while index < len(words):
word = words[index]
word_count = (word >> 16) & 0xFFFF
opcode = word & 0xFFFF
if word_count == 0:
break
operands = words[index + 1 : index + word_count]
instructions.append((opcode, operands, index))
index += word_count
return instructions
# ============ Execution Model Conversion ============
def execution_model_to_stage(exec_model: int) -> str:
"""将SPIR-V执行模型转换为着色器阶段名称"""
model_map = {
EXECUTION_MODEL_VERTEX: "vertex",
EXECUTION_MODEL_TESS_CONTROL: "tess_control",
EXECUTION_MODEL_TESS_EVAL: "tess_eval",
EXECUTION_MODEL_GEOMETRY: "geometry",
EXECUTION_MODEL_FRAGMENT: "fragment",
EXECUTION_MODEL_COMPUTE: "compute",
}
return model_map.get(exec_model, "unknown")
def stage_to_execution_model(stage: str) -> int:
"""将着色器阶段名称转换为SPIR-V执行模型"""
stage_map = {
"vertex": EXECUTION_MODEL_VERTEX,
"vert": EXECUTION_MODEL_VERTEX,
"tess_control": EXECUTION_MODEL_TESS_CONTROL,
"tesc": EXECUTION_MODEL_TESS_CONTROL,
"tess_eval": EXECUTION_MODEL_TESS_EVAL,
"tese": EXECUTION_MODEL_TESS_EVAL,
"geometry": EXECUTION_MODEL_GEOMETRY,
"geom": EXECUTION_MODEL_GEOMETRY,
"fragment": EXECUTION_MODEL_FRAGMENT,
"frag": EXECUTION_MODEL_FRAGMENT,
"compute": EXECUTION_MODEL_COMPUTE,
"comp": EXECUTION_MODEL_COMPUTE,
}
return stage_map.get(stage, EXECUTION_MODEL_COMPUTE)
# ============ Type System Parsing ============
def parse_spirv_type_system(spirv_data: bytes, shader_stage: str) -> SPIRVReflection:
"""从SPIR-V提取完整的类型系统和反射信息"""
words = parse_spirv_words(spirv_data)
instructions = parse_spirv_instructions(words)
reflection = SPIRVReflection(shader_stage=shader_stage)
# 第一遍:收集常量
for opcode, operands, index in instructions:
if opcode == OP_CONSTANT and len(operands) >= 3:
result_id = operands[1]
value = operands[2]
reflection.constants[result_id] = value
# 第二遍:收集类型定义
for opcode, operands, index in instructions:
if opcode == OP_TYPE_INT and len(operands) >= 3:
result_id = operands[0]
bit_width = operands[1]
is_signed = operands[2]
base_type = BaseType.INT if is_signed else BaseType.UINT
reflection.types[result_id] = ScalarTypeInfo(
id=result_id,
base_type=base_type,
bit_width=bit_width
)
elif opcode == OP_TYPE_FLOAT and len(operands) >= 2:
result_id = operands[0]
bit_width = operands[1]
reflection.types[result_id] = ScalarTypeInfo(
id=result_id,
base_type=BaseType.FLOAT,
bit_width=bit_width
)
elif opcode == OP_TYPE_VECTOR and len(operands) >= 3:
result_id = operands[0]
component_type_id = operands[1]
component_count = operands[2]
reflection.types[result_id] = VectorTypeInfo(
id=result_id,
component_type_id=component_type_id,
component_count=component_count
)
elif opcode == OP_TYPE_MATRIX and len(operands) >= 3:
result_id = operands[0]
column_type_id = operands[1]
column_count = operands[2]
reflection.types[result_id] = MatrixTypeInfo(
id=result_id,
column_type_id=column_type_id,
column_count=column_count
)
elif opcode == OP_TYPE_IMAGE and len(operands) >= 8:
result_id = operands[0]
# OpTypeImage: result_type, sampled_type, dim, depth, arrayed, MS, sampled, format
sampled_type_id = operands[1]
from .types import ImageTypeInfo
reflection.types[result_id] = ImageTypeInfo(
id=result_id,
sampled_type_id=sampled_type_id,
dimension=operands[2],
depth=operands[3],
arrayed=operands[4],
ms=operands[5],
sampled=operands[6],
format=operands[7],
)
elif opcode == OP_TYPE_SAMPLED_IMAGE and len(operands) >= 2:
result_id = operands[0]
image_type_id = operands[1]
from .types import SampledImageTypeInfo
reflection.types[result_id] = SampledImageTypeInfo(
id=result_id,
image_type_id=image_type_id,
)
elif opcode == OP_TYPE_SAMPLER and len(operands) >= 1:
result_id = operands[0]
from .types import SamplerTypeInfo
reflection.types[result_id] = SamplerTypeInfo(id=result_id)
elif opcode == OP_TYPE_ARRAY and len(operands) >= 3:
result_id = operands[0]
element_type_id = operands[1]
length_id = operands[2]
length = reflection.constants.get(length_id, 0)
reflection.types[result_id] = ArrayTypeInfo(
id=result_id,
element_type_id=element_type_id,
length=length,
is_runtime=False
)
elif opcode == OP_TYPE_RUNTIME_ARRAY and len(operands) >= 2:
result_id = operands[0]
element_type_id = operands[1]
reflection.types[result_id] = ArrayTypeInfo(
id=result_id,
element_type_id=element_type_id,
length=None,
is_runtime=True
)
elif opcode == OP_TYPE_STRUCT and len(operands) >= 1:
result_id = operands[0]
reflection.types[result_id] = StructTypeInfo(
id=result_id,
members=[]
)
elif opcode == OP_TYPE_POINTER and len(operands) >= 3:
result_id = operands[0]
storage_class = operands[1]
pointee_type_id = operands[2]
reflection.types[result_id] = PointerTypeInfo(
id=result_id,
storage_class=storage_class,
pointee_type_id=pointee_type_id
)
# 第三遍:收集名称
for opcode, operands, index in instructions:
if opcode == OP_NAME and len(operands) >= 1:
target_id = operands[0]
name, _ = parse_spirv_string(words, index + 2)
reflection.names[target_id] = name
elif opcode == OP_MEMBER_NAME and len(operands) >= 2:
struct_id = operands[0]
member_index = operands[1]
name, _ = parse_spirv_string(words, index + 3)
if struct_id not in reflection.member_names:
reflection.member_names[struct_id] = {}
reflection.member_names[struct_id][member_index] = name
# 第四遍:收集装饰
for opcode, operands, index in instructions:
if opcode == OP_DECORATE and len(operands) >= 2:
target_id = operands[0]
decoration = operands[1]
if target_id not in reflection.decorations:
reflection.decorations[target_id] = {}
if len(operands) >= 3:
reflection.decorations[target_id][decoration] = operands[2]
else:
reflection.decorations[target_id][decoration] = True
elif opcode == OP_MEMBER_DECORATE and len(operands) >= 3:
struct_id = operands[0]
member_index = operands[1]
decoration = operands[2]
if struct_id not in reflection.member_decorations:
reflection.member_decorations[struct_id] = {}
if member_index not in reflection.member_decorations[struct_id]:
reflection.member_decorations[struct_id][member_index] = {}
if len(operands) >= 4:
reflection.member_decorations[struct_id][member_index][decoration] = operands[3]
else:
reflection.member_decorations[struct_id][member_index][decoration] = True
# 填充结构体成员信息
for type_id, type_info in reflection.types.items():
if isinstance(type_info, StructTypeInfo):
for opcode, operands, _ in instructions:
if opcode == OP_TYPE_STRUCT and operands[0] == type_id:
member_type_ids = operands[1:]
type_info.members = []
for idx, member_type_id in enumerate(member_type_ids):
member_name = reflection.member_names.get(type_id, {}).get(idx, f"member_{idx}")
member_decs = reflection.member_decorations.get(type_id, {}).get(idx, {})
offset = member_decs.get(DECORATION_OFFSET, 0)
array_stride = member_decs.get(DECORATION_ARRAY_STRIDE)
matrix_stride = member_decs.get(DECORATION_MATRIX_STRIDE)
is_row_major = DECORATION_ROW_MAJOR in member_decs
type_info.members.append(MemberInfo(
index=idx,
name=member_name,
type_id=member_type_id,
offset=offset,
array_stride=array_stride,
matrix_stride=matrix_stride,
is_row_major=is_row_major
))
break
# 检查Block装饰
decs = reflection.decorations.get(type_id, {})
type_info.is_block = DECORATION_BLOCK in decs
type_info.is_buffer_block = DECORATION_BUFFER_BLOCK in decs
# 设置结构体名称
type_info.name = reflection.names.get(type_id)
# 第五遍:收集变量
for opcode, operands, _ in instructions:
if opcode == OP_VARIABLE and len(operands) >= 3:
result_type_id = operands[0]
result_id = operands[1]
storage_class = operands[2]
var_name = reflection.names.get(result_id)
var_decs = reflection.decorations.get(result_id, {})
binding = var_decs.get(DECORATION_BINDING)
descriptor_set = var_decs.get(DECORATION_DESCRIPTOR_SET, 0)
reflection.variables[result_id] = VariableInfo(
id=result_id,
name=var_name,
type_id=result_type_id,
storage_class=storage_class,
binding=binding,
descriptor_set=descriptor_set
)
# 解析类型引用
resolve_all_types(reflection)
# 提取buffer信息传递instructions以检测OpTypeSampledImage
extract_buffers(reflection, instructions)
# 提取Push Constants信息
reflection.push_constant = extract_push_constants(reflection)
# 提取顶点输入布局(仅顶点着色器)
if shader_stage in ("vertex", "vert"):
reflection.vertex_layout = extract_vertex_inputs(reflection)
# 提取入口点
for opcode, operands, index in instructions:
if opcode == OP_ENTRY_POINT and len(operands) >= 2:
exec_model = operands[0]
name, _ = parse_spirv_string(words, index + 3)
target_model = stage_to_execution_model(shader_stage)
if exec_model == target_model:
reflection.entry_point = name
break
return reflection
# ============ Type Resolution ============
def resolve_all_types(reflection: SPIRVReflection):
"""递归解析所有类型引用"""
for type_info in reflection.types.values():
resolve_type(type_info, reflection.types)
for var_info in reflection.variables.values():
if var_info.type_id in reflection.types:
var_info.resolved_type = reflection.types[var_info.type_id]
resolve_type(var_info.resolved_type, reflection.types)
def resolve_type(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> TypeInfo:
"""递归解析类型"""
if isinstance(type_info, VectorTypeInfo):
if type_info.component_type_id in type_map:
type_info.resolved_component_type = type_map[type_info.component_type_id]
elif isinstance(type_info, MatrixTypeInfo):
if type_info.column_type_id in type_map:
col_type = type_map[type_info.column_type_id]
type_info.resolved_column_type = col_type
if isinstance(col_type, VectorTypeInfo):
resolve_type(col_type, type_map)
elif isinstance(type_info, ArrayTypeInfo):
if type_info.element_type_id in type_map:
type_info.resolved_element_type = type_map[type_info.element_type_id]
resolve_type(type_info.resolved_element_type, type_map)
elif isinstance(type_info, StructTypeInfo):
for member in type_info.members:
if member.type_id in type_map:
member.resolved_type = type_map[member.type_id]
resolve_type(member.resolved_type, type_map)
elif isinstance(type_info, PointerTypeInfo):
if type_info.pointee_type_id in type_map:
type_info.resolved_pointee_type = type_map[type_info.pointee_type_id]
resolve_type(type_info.resolved_pointee_type, type_map)
return type_info
# ============ Buffer Extraction ============
def extract_buffers(reflection: SPIRVReflection, instructions: List[Tuple[int, List[int], int]] = None):
"""从变量中提取buffer信息包括uniform_buffer, storage_buffer, sampled_image
Args:
reflection: SPIR-V反射信息
instructions: SPIR-V指令列表可选如果提供则用于检测OpTypeSampledImage
"""
# 首先收集所有变量信息
extra_bindings: List[BindingInfo] = []
# 如果提供了instructions收集所有OpTypeSampledImage的类型ID
sampled_image_type_ids: set = set()
if instructions:
for opcode, operands, _ in instructions:
if opcode == OP_TYPE_SAMPLED_IMAGE and len(operands) >= 2:
result_type_id = operands[0]
sampled_image_type_ids.add(result_type_id)
for var_id, var_info in reflection.variables.items():
if var_info.binding is None:
continue
var_type = reflection.types.get(var_info.type_id)
if not isinstance(var_type, PointerTypeInfo):
continue
storage_class = var_type.storage_class
# 处理非结构体类型的uniform变量如sampler2D
# sampler2D 使用 UniformConstant (0) 存储类,而不是 Uniform (2)
if storage_class == STORAGE_CLASS_UNIFORM or storage_class == STORAGE_CLASS_UNIFORM_CONSTANT:
pointee_type = var_type.resolved_pointee_type
if pointee_type is None:
# pointee_type 没有被解析,尝试直接获取类型
pointee_type = reflection.types.get(var_type.pointee_type_id)
if pointee_type is None:
continue
if not isinstance(pointee_type, StructTypeInfo):
# 这是一个非结构体的uniform变量检查是否是sampled_image
is_sampled_image = False
# 默认 count 为 1单个采样器/纹理)
sampler_count = 1
# 方法1通过类型名称判断
if hasattr(pointee_type, 'name') and pointee_type.name:
if 'sampled_image' in pointee_type.name.lower() or 'sampler2d' in pointee_type.name.lower():
is_sampled_image = True
# 方法2通过是否在OpTypeSampledImage结果类型中
if not is_sampled_image and pointee_type.id in sampled_image_type_ids:
is_sampled_image = True
# 方法3如果类型是ArrayTypeInfo检查元素类型是否是sampled_image
if isinstance(pointee_type, ArrayTypeInfo):
elem_type = reflection.types.get(pointee_type.element_type_id)
if elem_type and elem_type.id in sampled_image_type_ids:
is_sampled_image = True
# 获取数组长度作为 descriptor count
if pointee_type.length is not None and pointee_type.length > 0:
sampler_count = pointee_type.length
if is_sampled_image:
extra_bindings.append(BindingInfo(
binding=var_info.binding,
descriptor_set=var_info.descriptor_set or 0,
descriptor_type="combined_image_sampler",
stages=[], # stages稍后设置
name=var_info.name or f"sampler_{var_info.binding}",
count=sampler_count, # 使用正确的数组大小
))
continue
# 只处理结构体类型的buffers
struct_type = var_type.resolved_pointee_type
if not isinstance(struct_type, StructTypeInfo):
continue
if storage_class == STORAGE_CLASS_UNIFORM:
descriptor_type = "uniform_buffer"
elif storage_class == STORAGE_CLASS_STORAGE_BUFFER:
descriptor_type = "storage_buffer"
else:
continue
struct_name = struct_type.name
if not struct_name:
struct_name = f"Buffer{var_info.binding}"
reflection.buffers.append(BufferInfo(
name=struct_name,
binding=var_info.binding,
descriptor_set=var_info.descriptor_set or 0,
descriptor_type=descriptor_type,
struct_type=struct_type,
variable_name=var_info.name or f"buffer_{var_info.binding}"
))
# 存储额外的bindings供后续使用
if extra_bindings:
reflection._extra_bindings = extra_bindings
# 按descriptor_set和binding排序
reflection.buffers.sort(key=lambda b: (b.descriptor_set, b.binding))
# ============ Push Constants Extraction ============
# Header 成员名称到标志的映射(简化布局)
# Header [0-15]: scale (vec2, offset 0) + translate (vec2, offset 8)
HEADER_MEMBER_MAP = {
'uScale': (PushConstantBaseMember.SCALE, 0, 8), # vec2, 8 字节
'scale': (PushConstantBaseMember.SCALE, 0, 8), # 兼容小写命名
'uTranslate': (PushConstantBaseMember.TRANSLATE, 8, 8), # vec2, 8 字节
'translate': (PushConstantBaseMember.TRANSLATE, 8, 8), # 兼容小写命名
}
def _identify_header_member(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> Optional[PushConstantBaseMemberInfo]:
"""识别 Header 部分成员
Args:
member: 结构体成员信息
type_map: 类型映射
Returns:
PushConstantBaseMemberInfo 或 None如果不是 Header 部分成员)
"""
if member.offset >= PUSH_CONSTANT_HEADER_SIZE:
return None
# 计算成员大小
member_size = calculate_std430_size(member.resolved_type, type_map)
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
# 通过名称或偏移量匹配
for name, (flag, expected_offset, expected_size) in HEADER_MEMBER_MAP.items():
if member.name == name or member.offset == expected_offset:
is_valid = (member.offset == expected_offset and member_size == expected_size)
return PushConstantBaseMemberInfo(
name=member.name,
member_flag=flag,
offset=member.offset,
size=member_size,
expected_offset=expected_offset,
expected_size=expected_size,
type_name=cpp_type,
is_valid=is_valid,
)
return None
def extract_push_constants(reflection: SPIRVReflection) -> Optional[PushConstantInfo]:
"""从变量中提取Push Constants信息
查找storage class为PushConstant (9)的变量,解析其结构体类型,
并分离为 Header 部分和 Custom 部分。
128 字节简化布局:
- Header [0-15]: scale (vec2) + translate (vec2) - 16 字节(由系统填充)
- Custom [16-127]: 用户自定义数据 - 112 字节
Args:
reflection: SPIR-V反射信息
Returns:
PushConstantInfo或None如果没有Push Constants
"""
push_constant_var = None
for var_id, var_info in reflection.variables.items():
if var_info.storage_class == STORAGE_CLASS_PUSH_CONSTANT:
push_constant_var = var_info
break
if push_constant_var is None:
return None
var_type = reflection.types.get(push_constant_var.type_id)
if not isinstance(var_type, PointerTypeInfo):
return None
struct_type = var_type.resolved_pointee_type
if not isinstance(struct_type, StructTypeInfo):
return None
# 获取结构体名称
struct_name = struct_type.name or f"PushConstant"
# 分离 Header 部分和 Custom 部分
# Header 部分offset < PUSH_CONSTANT_HEADER_SIZE (16)
# Custom 部分offset >= PUSH_CONSTANT_CUSTOM_OFFSET (16)
header_members_info = []
present_flags = PushConstantBaseMember.NONE
custom_members = []
for member in struct_type.members:
header_member_info = _identify_header_member(member, reflection.types)
if header_member_info:
# Header 部分成员scale + translate
header_members_info.append(header_member_info)
present_flags |= header_member_info.member_flag
elif member.offset >= PUSH_CONSTANT_CUSTOM_OFFSET:
# Custom 部分成员 - 计算相对于 Custom 部分的偏移量
relative_offset = member.offset - PUSH_CONSTANT_CUSTOM_OFFSET
member_info = MemberInfo(
index=member.index,
name=member.name,
type_id=member.type_id,
offset=relative_offset, # 使用相对偏移量
resolved_type=member.resolved_type,
matrix_stride=member.matrix_stride,
array_stride=member.array_stride,
is_row_major=member.is_row_major,
)
custom_members.append(member_info)
# 计算 Header 部分总大小
header_total_size = max(
(m.offset + m.size for m in header_members_info),
default=0
)
# 检查是否为标准布局(包含 scale 和 translate 且都有效)
is_standard = (
present_flags == PushConstantBaseMember.STANDARD_HEADER and
all(m.is_valid for m in header_members_info)
)
# 只有存在 Header 部分成员时才创建 base_info
base_info = None
if header_members_info:
base_info = PushConstantBaseInfo(
members=header_members_info,
present_flags=present_flags,
total_size=header_total_size,
is_standard_layout=is_standard,
)
return PushConstantInfo(
name=struct_name,
struct_type=struct_type,
base_offset=PUSH_CONSTANT_CUSTOM_OFFSET,
base_info=base_info,
effect_members=custom_members, # 重命名为 custom_members但保持字段名兼容
)
# ============ Instance Buffer Detection and Validation ============
def _is_vec4_float(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> bool:
"""检查类型是否为 vec4 (float)
Args:
type_info: 要检查的类型信息
type_map: 类型映射
Returns:
True 如果类型是 vec4 float
"""
if not isinstance(type_info, VectorTypeInfo):
return False
if type_info.component_count != 4:
return False
# 检查组件类型是否为 float
component_type = type_info.resolved_component_type
if component_type is None:
component_type = type_map.get(type_info.component_type_id)
if not isinstance(component_type, ScalarTypeInfo):
return False
return component_type.base_type == BaseType.FLOAT and component_type.bit_width == 32
def _validate_instance_struct(
struct_type: StructTypeInfo,
struct_name: str,
type_map: Dict[int, TypeInfo],
source_file: Optional[str] = None
) -> None:
"""验证实例结构体布局
验证规则:
1. 结构体必须有成员
2. 第一个成员必须命名为 'rect'
3. rect 必须是 vec4 类型
4. rect 偏移必须为 0
5. 结构体大小必须是 16 字节的倍数
Args:
struct_type: 要验证的结构体类型
struct_name: 结构体名称(用于错误消息)
type_map: 类型映射
source_file: 源文件路径(用于错误消息,可选)
Raises:
ToolError: 如果验证失败
"""
location_hint = f"\n In: {source_file}" if source_file else ""
# 检查 1必须有成员
if not struct_type.members:
raise ToolError(
f"Instance buffer struct '{struct_name}' has no members. "
f"Expected at least 'rect' field.{location_hint}"
)
# 检查 2第一个成员必须是 rect
first_member = struct_type.members[0]
if first_member.name != "rect":
raise ToolError(
f"Instance buffer struct '{struct_name}': first member must be named 'rect', "
f"got '{first_member.name}'. The rect field is required for positioning.{location_hint}"
)
# 检查 3rect 必须是 vec4
rect_type = first_member.resolved_type
if rect_type is None:
rect_type = type_map.get(first_member.type_id)
if not _is_vec4_float(rect_type, type_map):
actual_type = spirv_type_to_cpp(rect_type, type_map) if rect_type else "<unknown>"
raise ToolError(
f"Instance buffer struct '{struct_name}': 'rect' must be vec4 (16 bytes), "
f"got {actual_type}.{location_hint}"
)
# 检查 4rect 偏移必须为 0
if first_member.offset != 0:
raise ToolError(
f"Instance buffer struct '{struct_name}': 'rect' offset must be 0, "
f"got {first_member.offset}.{location_hint}"
)
# 检查 5结构体大小必须是 16 的倍数
total_size = calculate_std430_size(struct_type, type_map)
if total_size % 16 != 0:
raise ToolError(
f"Instance buffer struct '{struct_name}': struct size must be multiple of 16, "
f"got {total_size} bytes. Add padding to align.{location_hint}"
)
def detect_instance_buffer(
reflection: SPIRVReflection,
annotations: Dict[str, Any],
source_file: Optional[str] = None
) -> Optional[InstanceBufferInfo]:
"""检测并验证实例缓冲
根据 @instance_buffer 注释标记,从 SPIR-V 反射中找到对应的 SSBO
并验证其布局符合实例数据要求。
Args:
reflection: SPIR-V 反射信息
annotations: 着色器注释信息(从 parse_shader_annotations 获取)
source_file: 源文件路径(用于错误消息,可选)
Returns:
InstanceBufferInfo 如果找到并验证通过,否则返回 None
Raises:
ToolError: 如果标记的 SSBO 不存在或验证失败
"""
if annotations is None:
return None
target_binding = annotations.get('instance_buffer_binding')
if target_binding is None:
return None
location_hint = f"\n In: {source_file}" if source_file else ""
# 从反射中找到对应的 storage buffer
target_buffer: Optional[BufferInfo] = None
for buffer in reflection.buffers:
if buffer.binding == target_binding and buffer.descriptor_type == "storage_buffer":
target_buffer = buffer
break
if target_buffer is None:
raise ToolError(
f"Cannot find @instance_buffer marked SSBO (binding={target_binding}). "
f"Make sure the buffer is declared as 'layout(std430, ...) buffer' "
f"and the binding number matches the annotation.{location_hint}"
)
struct_type = target_buffer.struct_type
struct_name = struct_type.name or target_buffer.name
# 获取实际的元素结构体类型(如果是 runtime array
# SSBO 通常是包含 runtime array 的结构体,需要找到数组元素类型
element_struct_type = struct_type
element_struct_name = struct_name
# 检查是否有 runtime array 成员
for member in struct_type.members:
member_type = member.resolved_type
if member_type is None:
member_type = reflection.types.get(member.type_id)
if isinstance(member_type, ArrayTypeInfo) and member_type.is_runtime:
# 找到了 runtime array获取其元素类型
element_type = member_type.resolved_element_type
if element_type is None:
element_type = reflection.types.get(member_type.element_type_id)
if isinstance(element_type, StructTypeInfo):
element_struct_type = element_type
element_struct_name = element_type.name or f"{struct_name}Element"
break
# 验证实例结构体布局
_validate_instance_struct(
element_struct_type,
element_struct_name,
reflection.types,
source_file
)
# 计算结构体总大小
total_size = calculate_std430_size(element_struct_type, reflection.types)
# 创建 InstanceBufferInfo
return InstanceBufferInfo(
name=target_buffer.variable_name,
struct_type_name=element_struct_name,
binding=target_buffer.binding,
set_number=target_buffer.descriptor_set,
members=element_struct_type.members,
struct_type=element_struct_type,
has_rect=True, # 验证通过,必定有 rect
rect_offset=0, # 验证通过,必定为 0
total_size=total_size,
)
# ============ Vertex Input Extraction ============
def _spirv_type_to_vk_format(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> str:
"""将 SPIR-V 类型转换为 Vulkan 格式枚举
Args:
type_info: 类型信息
type_map: 类型映射
Returns:
Vulkan 格式枚举字符串(如 VK_FORMAT_R32G32B32A32_SFLOAT
"""
if type_info is None:
return "VK_FORMAT_R32_SFLOAT"
if isinstance(type_info, ScalarTypeInfo):
if type_info.base_type == BaseType.FLOAT:
if type_info.bit_width == 32:
return "VK_FORMAT_R32_SFLOAT"
elif type_info.bit_width == 64:
return "VK_FORMAT_R64_SFLOAT"
elif type_info.base_type == BaseType.INT:
if type_info.bit_width == 32:
return "VK_FORMAT_R32_SINT"
elif type_info.base_type == BaseType.UINT:
if type_info.bit_width == 32:
return "VK_FORMAT_R32_UINT"
return "VK_FORMAT_R32_SFLOAT"
elif isinstance(type_info, VectorTypeInfo):
component_type = type_info.resolved_component_type
if component_type is None:
component_type = type_map.get(type_info.component_type_id)
count = type_info.component_count
# 确定组件类型
if isinstance(component_type, ScalarTypeInfo):
if component_type.base_type == BaseType.FLOAT and component_type.bit_width == 32:
format_map = {
1: "VK_FORMAT_R32_SFLOAT",
2: "VK_FORMAT_R32G32_SFLOAT",
3: "VK_FORMAT_R32G32B32_SFLOAT",
4: "VK_FORMAT_R32G32B32A32_SFLOAT",
}
return format_map.get(count, "VK_FORMAT_R32G32B32A32_SFLOAT")
elif component_type.base_type == BaseType.INT and component_type.bit_width == 32:
format_map = {
1: "VK_FORMAT_R32_SINT",
2: "VK_FORMAT_R32G32_SINT",
3: "VK_FORMAT_R32G32B32_SINT",
4: "VK_FORMAT_R32G32B32A32_SINT",
}
return format_map.get(count, "VK_FORMAT_R32G32B32A32_SINT")
elif component_type.base_type == BaseType.UINT and component_type.bit_width == 32:
format_map = {
1: "VK_FORMAT_R32_UINT",
2: "VK_FORMAT_R32G32_UINT",
3: "VK_FORMAT_R32G32B32_UINT",
4: "VK_FORMAT_R32G32B32A32_UINT",
}
return format_map.get(count, "VK_FORMAT_R32G32B32A32_UINT")
# 默认浮点向量格式
format_map = {
1: "VK_FORMAT_R32_SFLOAT",
2: "VK_FORMAT_R32G32_SFLOAT",
3: "VK_FORMAT_R32G32B32_SFLOAT",
4: "VK_FORMAT_R32G32B32A32_SFLOAT",
}
return format_map.get(count, "VK_FORMAT_R32G32B32A32_SFLOAT")
return "VK_FORMAT_R32_SFLOAT"
def _clean_vertex_attribute_name(name: str) -> str:
"""清理顶点属性名称,去除前缀
Args:
name: 原始变量名称(如 in_position
Returns:
清理后的名称(如 position
"""
# 去除常见前缀
prefixes = ['in_', 'a_', 'v_', 'attr_', 'input_']
for prefix in prefixes:
if name.startswith(prefix):
return name[len(prefix):]
return name
def extract_vertex_inputs(reflection: SPIRVReflection) -> Optional[VertexLayout]:
"""从 SPIR-V 反射信息中提取顶点输入布局
遍历所有存储类为 Input 的变量,提取其 location 装饰和类型信息,
生成顶点属性列表。
Args:
reflection: SPIR-V 反射信息
Returns:
VertexLayout 如果找到顶点输入,否则返回 None
"""
attributes = []
for var_id, var_info in reflection.variables.items():
# 只处理 Input 存储类的变量
if var_info.storage_class != STORAGE_CLASS_INPUT:
continue
# 获取 location 装饰
var_decs = reflection.decorations.get(var_id, {})
location = var_decs.get(DECORATION_LOCATION)
if location is None:
continue
# 获取类型信息
var_type = reflection.types.get(var_info.type_id)
# 如果是指针类型,获取 pointee 类型
actual_type = var_type
if isinstance(var_type, PointerTypeInfo):
actual_type = var_type.resolved_pointee_type
if actual_type is None:
actual_type = reflection.types.get(var_type.pointee_type_id)
if actual_type is None:
continue
# 获取变量名称
var_name = var_info.name or f"attr_{location}"
clean_name = _clean_vertex_attribute_name(var_name)
# 计算 C++ 类型和 VkFormat
# 使用紧凑类型映射,避免 Eigen 对齐问题
cpp_type = spirv_type_to_compact_cpp(actual_type, reflection.types)
vk_format = _spirv_type_to_vk_format(actual_type, reflection.types)
# 计算大小
size = calculate_std430_size(actual_type, reflection.types)
attributes.append(VertexAttribute(
name=clean_name,
location=location,
type_id=actual_type.id if hasattr(actual_type, 'id') else 0,
cpp_type=cpp_type,
vk_format=vk_format,
offset=0, # 稍后计算
size=size,
resolved_type=actual_type,
))
if not attributes:
return None
# 按 location 排序
attributes.sort(key=lambda a: a.location)
# 计算偏移量和总步长
current_offset = 0
for attr in attributes:
attr.offset = current_offset
current_offset += attr.size
stride = current_offset
return VertexLayout(
attributes=attributes,
stride=stride,
struct_name="Vertex", # 默认名称,后续可以根据着色器名称设置
)
# ============ Legacy Interface ============
def extract_spirv_reflection(
spirv_data: bytes,
shader_type: str,
annotations: Optional[Dict[str, Any]] = None,
source_file: Optional[Any] = None
) -> Dict[str, Any]:
"""从SPIR-V提取反射信息兼容旧接口
Args:
spirv_data: SPIR-V 二进制数据
shader_type: 着色器类型vert, frag, comp 等)
annotations: 着色器注释信息(可选,用于实例缓冲检测)
source_file: 源文件路径(可选,用于错误消息)
Returns:
包含反射信息的字典
"""
reflection = parse_spirv_type_system(spirv_data, shader_type)
# 映射shader类型名称
stage_map = {
"vert": "vertex",
"frag": "fragment",
"comp": "compute",
"geom": "geometry",
"tesc": "tess_control",
"tese": "tess_eval",
}
stage = stage_map.get(shader_type, shader_type)
# 检测并验证实例缓冲
if annotations is not None:
source_file_str = str(source_file) if source_file else None
reflection.instance_buffer = detect_instance_buffer(
reflection,
annotations,
source_file_str
)
# 转换为BindingInfo列表包含buffers
bindings = []
for buffer in reflection.buffers:
bindings.append(BindingInfo(
binding=buffer.binding,
descriptor_set=buffer.descriptor_set,
descriptor_type=buffer.descriptor_type,
stages=[stage],
name=buffer.variable_name,
count=1,
))
# 添加额外的bindings如sampled_image
if hasattr(reflection, '_extra_bindings'):
for extra in reflection._extra_bindings:
extra.stages = [stage]
bindings.append(extra)
delattr(reflection, '_extra_bindings')
return {
'entry_point': reflection.entry_point,
'bindings': bindings,
'reflection': reflection,
}
# ============ Dual Stage Push Constants Analysis ============
def _extract_stage_push_constant_info(
reflection: SPIRVReflection,
stage: str
) -> Optional[StagePushConstantInfo]:
"""从反射信息中提取单个阶段的 Push Constant 信息
提取 Custom 区域(偏移 >= 16的成员信息。
Args:
reflection: SPIR-V 反射信息
stage: 着色器阶段名称("vertex""fragment"
Returns:
StagePushConstantInfo 或 None
"""
push_constant = reflection.push_constant
if not push_constant or not push_constant.effect_members:
return None
# effect_members 已经是 Custom 区域成员(偏移量已调整为相对偏移)
members = push_constant.effect_members
if not members:
return None
# 计算总大小和对齐
total_size = 0
max_alignment = 4
for member in members:
member_size = calculate_std430_size(member.resolved_type, reflection.types)
member_alignment = calculate_std430_alignment(member.resolved_type, reflection.types)
# 更新结束位置
end_offset = member.offset + member_size
if end_offset > total_size:
total_size = end_offset
# 更新最大对齐
if member_alignment > max_alignment:
max_alignment = member_alignment
return StagePushConstantInfo(
stage=stage,
members=members,
relative_offset=PUSH_CONSTANT_CUSTOM_OFFSET,
total_size=total_size,
alignment=max_alignment,
)
def _compare_member_layouts(
vert_members: List[MemberInfo],
frag_members: List[MemberInfo],
type_map: Dict[int, TypeInfo]
) -> Tuple[bool, List[str]]:
"""比较两个阶段的成员布局
检查成员是否具有相同的名称、偏移量和类型。
Args:
vert_members: 顶点着色器成员列表
frag_members: 片元着色器成员列表
type_map: 类型映射
Returns:
(is_identical, shared_member_names) 元组:
- is_identical: 布局是否完全相同
- shared_member_names: 共享的成员名称列表
"""
if not vert_members and not frag_members:
return True, []
if not vert_members or not frag_members:
return False, []
# 构建成员映射(按名称)
vert_by_name = {m.name: m for m in vert_members}
frag_by_name = {m.name: m for m in frag_members}
# 找到共享的成员名称
shared_names = set(vert_by_name.keys()) & set(frag_by_name.keys())
if not shared_names:
return False, []
# 检查共享成员的布局是否相同
shared_members = []
for name in shared_names:
vert_member = vert_by_name[name]
frag_member = frag_by_name[name]
# 检查偏移量是否相同
if vert_member.offset != frag_member.offset:
continue
# 检查类型大小是否相同
vert_size = calculate_std430_size(vert_member.resolved_type, type_map)
frag_size = calculate_std430_size(frag_member.resolved_type, type_map)
if vert_size != frag_size:
continue
shared_members.append(name)
# 判断是否完全相同(所有成员都匹配且数量相同)
is_identical = (
len(shared_members) == len(vert_members) == len(frag_members) and
len(shared_members) > 0
)
return is_identical, sorted(shared_members)
def analyze_dual_stage_push_constants(
vert_reflection: SPIRVReflection,
frag_reflection: SPIRVReflection
) -> CombinedPushConstantInfo:
"""分析顶点和片元着色器的 Push Constant 布局
分析两个着色器阶段的 Push Constant Custom 区域布局:
1. 提取各阶段的 Custom 成员(偏移量 >= 16 的成员)
2. 检测布局是否相同(共享模式)或不同(分离模式)
3. 返回组合信息
布局分析规则:
- 如果两个阶段的 Custom 区域成员完全相同(名称、偏移、类型大小都匹配),
则使用共享模式overlapping=True
- 如果不同则使用分离模式overlapping=False
Args:
vert_reflection: 顶点着色器的 SPIR-V 反射信息
frag_reflection: 片元着色器的 SPIR-V 反射信息
Returns:
CombinedPushConstantInfo 包含双阶段分析结果
Example:
>>> vert_ref = parse_spirv_type_system(vert_spirv, "vertex")
>>> frag_ref = parse_spirv_type_system(frag_spirv, "fragment")
>>> combined = analyze_dual_stage_push_constants(vert_ref, frag_ref)
>>> if combined.is_shared_layout:
>>> print("使用共享 Push Constant 布局")
>>> else:
>>> print("使用分离 Push Constant 布局")
"""
# 提取各阶段的 Push Constant 信息
vert_info = _extract_stage_push_constant_info(vert_reflection, "vertex")
frag_info = _extract_stage_push_constant_info(frag_reflection, "fragment")
# 如果两个阶段都没有 Custom 成员,返回空结果
if not vert_info and not frag_info:
return CombinedPushConstantInfo(
vertex_info=None,
fragment_info=None,
overlapping=False,
shared_members=[],
)
# 如果只有一个阶段有 Custom 成员,不存在重叠
if not vert_info or not frag_info:
return CombinedPushConstantInfo(
vertex_info=vert_info,
fragment_info=frag_info,
overlapping=False,
shared_members=[],
)
# 使用顶点着色器的类型映射(两者应该是兼容的)
type_map = vert_reflection.types
# 比较两个阶段的成员布局
is_identical, shared_members = _compare_member_layouts(
vert_info.members,
frag_info.members,
type_map,
)
return CombinedPushConstantInfo(
vertex_info=vert_info,
fragment_info=frag_info,
overlapping=is_identical,
shared_members=shared_members,
)