- Removed legacy push constant structures and functions for better clarity and maintainability. - Introduced new `text_push_constants_t` structure for text rendering with optimized layout. - Implemented dual stage push constant analysis to support separate layouts for vertex and fragment shaders. - Added functions to generate push constant structures and fill functions based on shader reflection. - Enhanced static checks for push constant layouts to ensure compatibility and correctness. - Updated templates to accommodate new dual stage push constant generation. - Added support detection for procedural vertex shaders based on push constant layout.
1338 lines
47 KiB
Python
1338 lines
47 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
SPIR-V解析模块
|
||
|
||
负责解析SPIR-V二进制文件,提取类型系统、变量、装饰和反射信息。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
from typing import Any, Dict, List, Optional, Tuple
|
||
|
||
from .constants import (
|
||
DECORATION_ARRAY_STRIDE,
|
||
DECORATION_BINDING,
|
||
DECORATION_BLOCK,
|
||
DECORATION_BUFFER_BLOCK,
|
||
DECORATION_COL_MAJOR,
|
||
DECORATION_DESCRIPTOR_SET,
|
||
DECORATION_LOCATION,
|
||
DECORATION_MATRIX_STRIDE,
|
||
DECORATION_OFFSET,
|
||
DECORATION_ROW_MAJOR,
|
||
EXECUTION_MODEL_COMPUTE,
|
||
EXECUTION_MODEL_FRAGMENT,
|
||
EXECUTION_MODEL_GEOMETRY,
|
||
EXECUTION_MODEL_TESS_CONTROL,
|
||
EXECUTION_MODEL_TESS_EVAL,
|
||
EXECUTION_MODEL_VERTEX,
|
||
OP_CONSTANT,
|
||
OP_DECORATE,
|
||
OP_ENTRY_POINT,
|
||
OP_MEMBER_DECORATE,
|
||
OP_MEMBER_NAME,
|
||
OP_NAME,
|
||
OP_TYPE_ARRAY,
|
||
OP_TYPE_FLOAT,
|
||
OP_TYPE_IMAGE,
|
||
OP_TYPE_INT,
|
||
OP_TYPE_MATRIX,
|
||
OP_TYPE_POINTER,
|
||
OP_TYPE_RUNTIME_ARRAY,
|
||
OP_TYPE_SAMPLED_IMAGE,
|
||
OP_TYPE_SAMPLER,
|
||
OP_TYPE_STRUCT,
|
||
OP_TYPE_VECTOR,
|
||
OP_VARIABLE,
|
||
PUSH_CONSTANT_HEADER_SIZE,
|
||
PUSH_CONSTANT_CUSTOM_OFFSET,
|
||
SPIRV_MAGIC,
|
||
STORAGE_CLASS_INPUT,
|
||
STORAGE_CLASS_PUSH_CONSTANT,
|
||
STORAGE_CLASS_STORAGE_BUFFER,
|
||
STORAGE_CLASS_UNIFORM,
|
||
STORAGE_CLASS_UNIFORM_CONSTANT,
|
||
)
|
||
from .type_mapping import calculate_std430_alignment, calculate_std430_size, spirv_type_to_cpp, spirv_type_to_compact_cpp
|
||
from .types import (
|
||
ArrayTypeInfo,
|
||
BaseType,
|
||
BindingInfo,
|
||
BufferInfo,
|
||
CombinedPushConstantInfo,
|
||
InstanceBufferInfo,
|
||
MatrixTypeInfo,
|
||
MemberInfo,
|
||
PointerTypeInfo,
|
||
PushConstantBaseInfo,
|
||
PushConstantBaseMember,
|
||
PushConstantBaseMemberInfo,
|
||
PushConstantInfo,
|
||
ScalarTypeInfo,
|
||
SPIRVReflection,
|
||
StagePushConstantInfo,
|
||
StructTypeInfo,
|
||
ToolError,
|
||
TypeInfo,
|
||
VariableInfo,
|
||
VectorTypeInfo,
|
||
VertexAttribute,
|
||
VertexLayout,
|
||
)
|
||
|
||
|
||
# ============ Binary Parsing Functions ============
|
||
|
||
def parse_spirv_words(spirv_data: bytes) -> List[int]:
|
||
"""将SPIR-V二进制解析为32位字列表"""
|
||
if len(spirv_data) < 20:
|
||
raise ToolError("SPIR-V data too small")
|
||
if len(spirv_data) % 4 != 0:
|
||
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
|
||
|
||
words = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
|
||
|
||
if words[0] != SPIRV_MAGIC:
|
||
raise ToolError(f"Invalid SPIR-V magic number: 0x{words[0]:08x}")
|
||
|
||
return list(words)
|
||
|
||
|
||
def parse_spirv_string(words: List[int], start_index: int) -> Tuple[str, int]:
|
||
"""从字数组解析以null结尾的UTF-8字符串,返回(字符串, 消耗的字数)"""
|
||
chars = []
|
||
word_count = 0
|
||
|
||
for i in range(start_index, len(words)):
|
||
word = words[i]
|
||
word_count += 1
|
||
for byte_idx in range(4):
|
||
byte = (word >> (byte_idx * 8)) & 0xFF
|
||
if byte == 0:
|
||
return ''.join(chars), word_count
|
||
chars.append(chr(byte))
|
||
|
||
return ''.join(chars), word_count
|
||
|
||
|
||
def parse_spirv_instructions(words: List[int]) -> List[Tuple[int, List[int], int]]:
|
||
"""解析SPIR-V指令流,返回(opcode, operands, index)列表"""
|
||
instructions = []
|
||
index = 5 # 跳过头部(5个字)
|
||
|
||
while index < len(words):
|
||
word = words[index]
|
||
word_count = (word >> 16) & 0xFFFF
|
||
opcode = word & 0xFFFF
|
||
|
||
if word_count == 0:
|
||
break
|
||
|
||
operands = words[index + 1 : index + word_count]
|
||
instructions.append((opcode, operands, index))
|
||
index += word_count
|
||
|
||
return instructions
|
||
|
||
|
||
# ============ Execution Model Conversion ============
|
||
|
||
def execution_model_to_stage(exec_model: int) -> str:
|
||
"""将SPIR-V执行模型转换为着色器阶段名称"""
|
||
model_map = {
|
||
EXECUTION_MODEL_VERTEX: "vertex",
|
||
EXECUTION_MODEL_TESS_CONTROL: "tess_control",
|
||
EXECUTION_MODEL_TESS_EVAL: "tess_eval",
|
||
EXECUTION_MODEL_GEOMETRY: "geometry",
|
||
EXECUTION_MODEL_FRAGMENT: "fragment",
|
||
EXECUTION_MODEL_COMPUTE: "compute",
|
||
}
|
||
return model_map.get(exec_model, "unknown")
|
||
|
||
|
||
def stage_to_execution_model(stage: str) -> int:
|
||
"""将着色器阶段名称转换为SPIR-V执行模型"""
|
||
stage_map = {
|
||
"vertex": EXECUTION_MODEL_VERTEX,
|
||
"vert": EXECUTION_MODEL_VERTEX,
|
||
"tess_control": EXECUTION_MODEL_TESS_CONTROL,
|
||
"tesc": EXECUTION_MODEL_TESS_CONTROL,
|
||
"tess_eval": EXECUTION_MODEL_TESS_EVAL,
|
||
"tese": EXECUTION_MODEL_TESS_EVAL,
|
||
"geometry": EXECUTION_MODEL_GEOMETRY,
|
||
"geom": EXECUTION_MODEL_GEOMETRY,
|
||
"fragment": EXECUTION_MODEL_FRAGMENT,
|
||
"frag": EXECUTION_MODEL_FRAGMENT,
|
||
"compute": EXECUTION_MODEL_COMPUTE,
|
||
"comp": EXECUTION_MODEL_COMPUTE,
|
||
}
|
||
return stage_map.get(stage, EXECUTION_MODEL_COMPUTE)
|
||
|
||
|
||
# ============ Type System Parsing ============
|
||
|
||
def parse_spirv_type_system(spirv_data: bytes, shader_stage: str) -> SPIRVReflection:
|
||
"""从SPIR-V提取完整的类型系统和反射信息"""
|
||
words = parse_spirv_words(spirv_data)
|
||
instructions = parse_spirv_instructions(words)
|
||
|
||
reflection = SPIRVReflection(shader_stage=shader_stage)
|
||
|
||
# 第一遍:收集常量
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_CONSTANT and len(operands) >= 3:
|
||
result_id = operands[1]
|
||
value = operands[2]
|
||
reflection.constants[result_id] = value
|
||
|
||
# 第二遍:收集类型定义
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_TYPE_INT and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
bit_width = operands[1]
|
||
is_signed = operands[2]
|
||
base_type = BaseType.INT if is_signed else BaseType.UINT
|
||
reflection.types[result_id] = ScalarTypeInfo(
|
||
id=result_id,
|
||
base_type=base_type,
|
||
bit_width=bit_width
|
||
)
|
||
|
||
elif opcode == OP_TYPE_FLOAT and len(operands) >= 2:
|
||
result_id = operands[0]
|
||
bit_width = operands[1]
|
||
reflection.types[result_id] = ScalarTypeInfo(
|
||
id=result_id,
|
||
base_type=BaseType.FLOAT,
|
||
bit_width=bit_width
|
||
)
|
||
|
||
elif opcode == OP_TYPE_VECTOR and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
component_type_id = operands[1]
|
||
component_count = operands[2]
|
||
reflection.types[result_id] = VectorTypeInfo(
|
||
id=result_id,
|
||
component_type_id=component_type_id,
|
||
component_count=component_count
|
||
)
|
||
|
||
elif opcode == OP_TYPE_MATRIX and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
column_type_id = operands[1]
|
||
column_count = operands[2]
|
||
reflection.types[result_id] = MatrixTypeInfo(
|
||
id=result_id,
|
||
column_type_id=column_type_id,
|
||
column_count=column_count
|
||
)
|
||
|
||
elif opcode == OP_TYPE_IMAGE and len(operands) >= 8:
|
||
result_id = operands[0]
|
||
# OpTypeImage: result_type, sampled_type, dim, depth, arrayed, MS, sampled, format
|
||
sampled_type_id = operands[1]
|
||
from .types import ImageTypeInfo
|
||
reflection.types[result_id] = ImageTypeInfo(
|
||
id=result_id,
|
||
sampled_type_id=sampled_type_id,
|
||
dimension=operands[2],
|
||
depth=operands[3],
|
||
arrayed=operands[4],
|
||
ms=operands[5],
|
||
sampled=operands[6],
|
||
format=operands[7],
|
||
)
|
||
|
||
elif opcode == OP_TYPE_SAMPLED_IMAGE and len(operands) >= 2:
|
||
result_id = operands[0]
|
||
image_type_id = operands[1]
|
||
from .types import SampledImageTypeInfo
|
||
reflection.types[result_id] = SampledImageTypeInfo(
|
||
id=result_id,
|
||
image_type_id=image_type_id,
|
||
)
|
||
|
||
elif opcode == OP_TYPE_SAMPLER and len(operands) >= 1:
|
||
result_id = operands[0]
|
||
from .types import SamplerTypeInfo
|
||
reflection.types[result_id] = SamplerTypeInfo(id=result_id)
|
||
|
||
elif opcode == OP_TYPE_ARRAY and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
element_type_id = operands[1]
|
||
length_id = operands[2]
|
||
length = reflection.constants.get(length_id, 0)
|
||
reflection.types[result_id] = ArrayTypeInfo(
|
||
id=result_id,
|
||
element_type_id=element_type_id,
|
||
length=length,
|
||
is_runtime=False
|
||
)
|
||
|
||
elif opcode == OP_TYPE_RUNTIME_ARRAY and len(operands) >= 2:
|
||
result_id = operands[0]
|
||
element_type_id = operands[1]
|
||
reflection.types[result_id] = ArrayTypeInfo(
|
||
id=result_id,
|
||
element_type_id=element_type_id,
|
||
length=None,
|
||
is_runtime=True
|
||
)
|
||
|
||
elif opcode == OP_TYPE_STRUCT and len(operands) >= 1:
|
||
result_id = operands[0]
|
||
reflection.types[result_id] = StructTypeInfo(
|
||
id=result_id,
|
||
members=[]
|
||
)
|
||
|
||
elif opcode == OP_TYPE_POINTER and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
storage_class = operands[1]
|
||
pointee_type_id = operands[2]
|
||
reflection.types[result_id] = PointerTypeInfo(
|
||
id=result_id,
|
||
storage_class=storage_class,
|
||
pointee_type_id=pointee_type_id
|
||
)
|
||
|
||
# 第三遍:收集名称
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_NAME and len(operands) >= 1:
|
||
target_id = operands[0]
|
||
name, _ = parse_spirv_string(words, index + 2)
|
||
reflection.names[target_id] = name
|
||
|
||
elif opcode == OP_MEMBER_NAME and len(operands) >= 2:
|
||
struct_id = operands[0]
|
||
member_index = operands[1]
|
||
name, _ = parse_spirv_string(words, index + 3)
|
||
if struct_id not in reflection.member_names:
|
||
reflection.member_names[struct_id] = {}
|
||
reflection.member_names[struct_id][member_index] = name
|
||
|
||
# 第四遍:收集装饰
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_DECORATE and len(operands) >= 2:
|
||
target_id = operands[0]
|
||
decoration = operands[1]
|
||
if target_id not in reflection.decorations:
|
||
reflection.decorations[target_id] = {}
|
||
if len(operands) >= 3:
|
||
reflection.decorations[target_id][decoration] = operands[2]
|
||
else:
|
||
reflection.decorations[target_id][decoration] = True
|
||
|
||
elif opcode == OP_MEMBER_DECORATE and len(operands) >= 3:
|
||
struct_id = operands[0]
|
||
member_index = operands[1]
|
||
decoration = operands[2]
|
||
if struct_id not in reflection.member_decorations:
|
||
reflection.member_decorations[struct_id] = {}
|
||
if member_index not in reflection.member_decorations[struct_id]:
|
||
reflection.member_decorations[struct_id][member_index] = {}
|
||
if len(operands) >= 4:
|
||
reflection.member_decorations[struct_id][member_index][decoration] = operands[3]
|
||
else:
|
||
reflection.member_decorations[struct_id][member_index][decoration] = True
|
||
|
||
# 填充结构体成员信息
|
||
for type_id, type_info in reflection.types.items():
|
||
if isinstance(type_info, StructTypeInfo):
|
||
for opcode, operands, _ in instructions:
|
||
if opcode == OP_TYPE_STRUCT and operands[0] == type_id:
|
||
member_type_ids = operands[1:]
|
||
type_info.members = []
|
||
for idx, member_type_id in enumerate(member_type_ids):
|
||
member_name = reflection.member_names.get(type_id, {}).get(idx, f"member_{idx}")
|
||
member_decs = reflection.member_decorations.get(type_id, {}).get(idx, {})
|
||
offset = member_decs.get(DECORATION_OFFSET, 0)
|
||
array_stride = member_decs.get(DECORATION_ARRAY_STRIDE)
|
||
matrix_stride = member_decs.get(DECORATION_MATRIX_STRIDE)
|
||
is_row_major = DECORATION_ROW_MAJOR in member_decs
|
||
|
||
type_info.members.append(MemberInfo(
|
||
index=idx,
|
||
name=member_name,
|
||
type_id=member_type_id,
|
||
offset=offset,
|
||
array_stride=array_stride,
|
||
matrix_stride=matrix_stride,
|
||
is_row_major=is_row_major
|
||
))
|
||
break
|
||
|
||
# 检查Block装饰
|
||
decs = reflection.decorations.get(type_id, {})
|
||
type_info.is_block = DECORATION_BLOCK in decs
|
||
type_info.is_buffer_block = DECORATION_BUFFER_BLOCK in decs
|
||
|
||
# 设置结构体名称
|
||
type_info.name = reflection.names.get(type_id)
|
||
|
||
# 第五遍:收集变量
|
||
for opcode, operands, _ in instructions:
|
||
if opcode == OP_VARIABLE and len(operands) >= 3:
|
||
result_type_id = operands[0]
|
||
result_id = operands[1]
|
||
storage_class = operands[2]
|
||
|
||
var_name = reflection.names.get(result_id)
|
||
var_decs = reflection.decorations.get(result_id, {})
|
||
binding = var_decs.get(DECORATION_BINDING)
|
||
descriptor_set = var_decs.get(DECORATION_DESCRIPTOR_SET, 0)
|
||
|
||
reflection.variables[result_id] = VariableInfo(
|
||
id=result_id,
|
||
name=var_name,
|
||
type_id=result_type_id,
|
||
storage_class=storage_class,
|
||
binding=binding,
|
||
descriptor_set=descriptor_set
|
||
)
|
||
|
||
# 解析类型引用
|
||
resolve_all_types(reflection)
|
||
|
||
# 提取buffer信息(传递instructions以检测OpTypeSampledImage)
|
||
extract_buffers(reflection, instructions)
|
||
|
||
# 提取Push Constants信息
|
||
reflection.push_constant = extract_push_constants(reflection)
|
||
|
||
# 提取顶点输入布局(仅顶点着色器)
|
||
if shader_stage in ("vertex", "vert"):
|
||
reflection.vertex_layout = extract_vertex_inputs(reflection)
|
||
|
||
# 提取入口点
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_ENTRY_POINT and len(operands) >= 2:
|
||
exec_model = operands[0]
|
||
name, _ = parse_spirv_string(words, index + 3)
|
||
target_model = stage_to_execution_model(shader_stage)
|
||
if exec_model == target_model:
|
||
reflection.entry_point = name
|
||
break
|
||
|
||
return reflection
|
||
|
||
|
||
# ============ Type Resolution ============
|
||
|
||
def resolve_all_types(reflection: SPIRVReflection):
|
||
"""递归解析所有类型引用"""
|
||
for type_info in reflection.types.values():
|
||
resolve_type(type_info, reflection.types)
|
||
|
||
for var_info in reflection.variables.values():
|
||
if var_info.type_id in reflection.types:
|
||
var_info.resolved_type = reflection.types[var_info.type_id]
|
||
resolve_type(var_info.resolved_type, reflection.types)
|
||
|
||
|
||
def resolve_type(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> TypeInfo:
|
||
"""递归解析类型"""
|
||
if isinstance(type_info, VectorTypeInfo):
|
||
if type_info.component_type_id in type_map:
|
||
type_info.resolved_component_type = type_map[type_info.component_type_id]
|
||
|
||
elif isinstance(type_info, MatrixTypeInfo):
|
||
if type_info.column_type_id in type_map:
|
||
col_type = type_map[type_info.column_type_id]
|
||
type_info.resolved_column_type = col_type
|
||
if isinstance(col_type, VectorTypeInfo):
|
||
resolve_type(col_type, type_map)
|
||
|
||
elif isinstance(type_info, ArrayTypeInfo):
|
||
if type_info.element_type_id in type_map:
|
||
type_info.resolved_element_type = type_map[type_info.element_type_id]
|
||
resolve_type(type_info.resolved_element_type, type_map)
|
||
|
||
elif isinstance(type_info, StructTypeInfo):
|
||
for member in type_info.members:
|
||
if member.type_id in type_map:
|
||
member.resolved_type = type_map[member.type_id]
|
||
resolve_type(member.resolved_type, type_map)
|
||
|
||
elif isinstance(type_info, PointerTypeInfo):
|
||
if type_info.pointee_type_id in type_map:
|
||
type_info.resolved_pointee_type = type_map[type_info.pointee_type_id]
|
||
resolve_type(type_info.resolved_pointee_type, type_map)
|
||
|
||
return type_info
|
||
|
||
|
||
# ============ Buffer Extraction ============
|
||
|
||
def extract_buffers(reflection: SPIRVReflection, instructions: List[Tuple[int, List[int], int]] = None):
|
||
"""从变量中提取buffer信息(包括uniform_buffer, storage_buffer, sampled_image)
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
instructions: SPIR-V指令列表(可选,如果提供则用于检测OpTypeSampledImage)
|
||
"""
|
||
# 首先收集所有变量信息
|
||
extra_bindings: List[BindingInfo] = []
|
||
|
||
# 如果提供了instructions,收集所有OpTypeSampledImage的类型ID
|
||
sampled_image_type_ids: set = set()
|
||
if instructions:
|
||
for opcode, operands, _ in instructions:
|
||
if opcode == OP_TYPE_SAMPLED_IMAGE and len(operands) >= 2:
|
||
result_type_id = operands[0]
|
||
sampled_image_type_ids.add(result_type_id)
|
||
|
||
for var_id, var_info in reflection.variables.items():
|
||
if var_info.binding is None:
|
||
continue
|
||
|
||
var_type = reflection.types.get(var_info.type_id)
|
||
|
||
if not isinstance(var_type, PointerTypeInfo):
|
||
continue
|
||
|
||
storage_class = var_type.storage_class
|
||
|
||
# 处理非结构体类型的uniform变量(如sampler2D)
|
||
# sampler2D 使用 UniformConstant (0) 存储类,而不是 Uniform (2)
|
||
if storage_class == STORAGE_CLASS_UNIFORM or storage_class == STORAGE_CLASS_UNIFORM_CONSTANT:
|
||
pointee_type = var_type.resolved_pointee_type
|
||
|
||
if pointee_type is None:
|
||
# pointee_type 没有被解析,尝试直接获取类型
|
||
pointee_type = reflection.types.get(var_type.pointee_type_id)
|
||
|
||
if pointee_type is None:
|
||
continue
|
||
|
||
if not isinstance(pointee_type, StructTypeInfo):
|
||
# 这是一个非结构体的uniform变量,检查是否是sampled_image
|
||
is_sampled_image = False
|
||
# 默认 count 为 1(单个采样器/纹理)
|
||
sampler_count = 1
|
||
|
||
# 方法1:通过类型名称判断
|
||
if hasattr(pointee_type, 'name') and pointee_type.name:
|
||
if 'sampled_image' in pointee_type.name.lower() or 'sampler2d' in pointee_type.name.lower():
|
||
is_sampled_image = True
|
||
|
||
# 方法2:通过是否在OpTypeSampledImage结果类型中
|
||
if not is_sampled_image and pointee_type.id in sampled_image_type_ids:
|
||
is_sampled_image = True
|
||
|
||
# 方法3:如果类型是ArrayTypeInfo,检查元素类型是否是sampled_image
|
||
if isinstance(pointee_type, ArrayTypeInfo):
|
||
elem_type = reflection.types.get(pointee_type.element_type_id)
|
||
if elem_type and elem_type.id in sampled_image_type_ids:
|
||
is_sampled_image = True
|
||
# 获取数组长度作为 descriptor count
|
||
if pointee_type.length is not None and pointee_type.length > 0:
|
||
sampler_count = pointee_type.length
|
||
|
||
if is_sampled_image:
|
||
extra_bindings.append(BindingInfo(
|
||
binding=var_info.binding,
|
||
descriptor_set=var_info.descriptor_set or 0,
|
||
descriptor_type="combined_image_sampler",
|
||
stages=[], # stages稍后设置
|
||
name=var_info.name or f"sampler_{var_info.binding}",
|
||
count=sampler_count, # 使用正确的数组大小
|
||
))
|
||
continue
|
||
|
||
# 只处理结构体类型的buffers
|
||
struct_type = var_type.resolved_pointee_type
|
||
if not isinstance(struct_type, StructTypeInfo):
|
||
continue
|
||
|
||
if storage_class == STORAGE_CLASS_UNIFORM:
|
||
descriptor_type = "uniform_buffer"
|
||
elif storage_class == STORAGE_CLASS_STORAGE_BUFFER:
|
||
descriptor_type = "storage_buffer"
|
||
else:
|
||
continue
|
||
|
||
struct_name = struct_type.name
|
||
if not struct_name:
|
||
struct_name = f"Buffer{var_info.binding}"
|
||
|
||
reflection.buffers.append(BufferInfo(
|
||
name=struct_name,
|
||
binding=var_info.binding,
|
||
descriptor_set=var_info.descriptor_set or 0,
|
||
descriptor_type=descriptor_type,
|
||
struct_type=struct_type,
|
||
variable_name=var_info.name or f"buffer_{var_info.binding}"
|
||
))
|
||
|
||
# 存储额外的bindings供后续使用
|
||
if extra_bindings:
|
||
reflection._extra_bindings = extra_bindings
|
||
|
||
# 按descriptor_set和binding排序
|
||
reflection.buffers.sort(key=lambda b: (b.descriptor_set, b.binding))
|
||
|
||
|
||
# ============ Push Constants Extraction ============
|
||
|
||
# Header 成员名称到标志的映射(简化布局)
|
||
# Header [0-15]: scale (vec2, offset 0) + translate (vec2, offset 8)
|
||
HEADER_MEMBER_MAP = {
|
||
'uScale': (PushConstantBaseMember.SCALE, 0, 8), # vec2, 8 字节
|
||
'scale': (PushConstantBaseMember.SCALE, 0, 8), # 兼容小写命名
|
||
'uTranslate': (PushConstantBaseMember.TRANSLATE, 8, 8), # vec2, 8 字节
|
||
'translate': (PushConstantBaseMember.TRANSLATE, 8, 8), # 兼容小写命名
|
||
}
|
||
|
||
|
||
def _identify_header_member(member: MemberInfo, type_map: Dict[int, TypeInfo]) -> Optional[PushConstantBaseMemberInfo]:
|
||
"""识别 Header 部分成员
|
||
|
||
Args:
|
||
member: 结构体成员信息
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
PushConstantBaseMemberInfo 或 None(如果不是 Header 部分成员)
|
||
"""
|
||
if member.offset >= PUSH_CONSTANT_HEADER_SIZE:
|
||
return None
|
||
|
||
# 计算成员大小
|
||
member_size = calculate_std430_size(member.resolved_type, type_map)
|
||
cpp_type = spirv_type_to_cpp(member.resolved_type, type_map)
|
||
|
||
# 通过名称或偏移量匹配
|
||
for name, (flag, expected_offset, expected_size) in HEADER_MEMBER_MAP.items():
|
||
if member.name == name or member.offset == expected_offset:
|
||
is_valid = (member.offset == expected_offset and member_size == expected_size)
|
||
return PushConstantBaseMemberInfo(
|
||
name=member.name,
|
||
member_flag=flag,
|
||
offset=member.offset,
|
||
size=member_size,
|
||
expected_offset=expected_offset,
|
||
expected_size=expected_size,
|
||
type_name=cpp_type,
|
||
is_valid=is_valid,
|
||
)
|
||
|
||
return None
|
||
|
||
|
||
def extract_push_constants(reflection: SPIRVReflection) -> Optional[PushConstantInfo]:
|
||
"""从变量中提取Push Constants信息
|
||
|
||
查找storage class为PushConstant (9)的变量,解析其结构体类型,
|
||
并分离为 Header 部分和 Custom 部分。
|
||
|
||
128 字节简化布局:
|
||
- Header [0-15]: scale (vec2) + translate (vec2) - 16 字节(由系统填充)
|
||
- Custom [16-127]: 用户自定义数据 - 112 字节
|
||
|
||
Args:
|
||
reflection: SPIR-V反射信息
|
||
|
||
Returns:
|
||
PushConstantInfo或None(如果没有Push Constants)
|
||
"""
|
||
push_constant_var = None
|
||
|
||
for var_id, var_info in reflection.variables.items():
|
||
if var_info.storage_class == STORAGE_CLASS_PUSH_CONSTANT:
|
||
push_constant_var = var_info
|
||
break
|
||
|
||
if push_constant_var is None:
|
||
return None
|
||
|
||
var_type = reflection.types.get(push_constant_var.type_id)
|
||
if not isinstance(var_type, PointerTypeInfo):
|
||
return None
|
||
|
||
struct_type = var_type.resolved_pointee_type
|
||
if not isinstance(struct_type, StructTypeInfo):
|
||
return None
|
||
|
||
# 获取结构体名称
|
||
struct_name = struct_type.name or f"PushConstant"
|
||
|
||
# 分离 Header 部分和 Custom 部分
|
||
# Header 部分:offset < PUSH_CONSTANT_HEADER_SIZE (16)
|
||
# Custom 部分:offset >= PUSH_CONSTANT_CUSTOM_OFFSET (16)
|
||
header_members_info = []
|
||
present_flags = PushConstantBaseMember.NONE
|
||
custom_members = []
|
||
|
||
for member in struct_type.members:
|
||
header_member_info = _identify_header_member(member, reflection.types)
|
||
if header_member_info:
|
||
# Header 部分成员(scale + translate)
|
||
header_members_info.append(header_member_info)
|
||
present_flags |= header_member_info.member_flag
|
||
elif member.offset >= PUSH_CONSTANT_CUSTOM_OFFSET:
|
||
# Custom 部分成员 - 计算相对于 Custom 部分的偏移量
|
||
relative_offset = member.offset - PUSH_CONSTANT_CUSTOM_OFFSET
|
||
member_info = MemberInfo(
|
||
index=member.index,
|
||
name=member.name,
|
||
type_id=member.type_id,
|
||
offset=relative_offset, # 使用相对偏移量
|
||
resolved_type=member.resolved_type,
|
||
matrix_stride=member.matrix_stride,
|
||
array_stride=member.array_stride,
|
||
is_row_major=member.is_row_major,
|
||
)
|
||
custom_members.append(member_info)
|
||
|
||
# 计算 Header 部分总大小
|
||
header_total_size = max(
|
||
(m.offset + m.size for m in header_members_info),
|
||
default=0
|
||
)
|
||
|
||
# 检查是否为标准布局(包含 scale 和 translate 且都有效)
|
||
is_standard = (
|
||
present_flags == PushConstantBaseMember.STANDARD_HEADER and
|
||
all(m.is_valid for m in header_members_info)
|
||
)
|
||
|
||
# 只有存在 Header 部分成员时才创建 base_info
|
||
base_info = None
|
||
if header_members_info:
|
||
base_info = PushConstantBaseInfo(
|
||
members=header_members_info,
|
||
present_flags=present_flags,
|
||
total_size=header_total_size,
|
||
is_standard_layout=is_standard,
|
||
)
|
||
|
||
return PushConstantInfo(
|
||
name=struct_name,
|
||
struct_type=struct_type,
|
||
base_offset=PUSH_CONSTANT_CUSTOM_OFFSET,
|
||
base_info=base_info,
|
||
effect_members=custom_members, # 重命名为 custom_members,但保持字段名兼容
|
||
)
|
||
|
||
|
||
# ============ Instance Buffer Detection and Validation ============
|
||
|
||
def _is_vec4_float(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> bool:
|
||
"""检查类型是否为 vec4 (float)
|
||
|
||
Args:
|
||
type_info: 要检查的类型信息
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
True 如果类型是 vec4 float
|
||
"""
|
||
if not isinstance(type_info, VectorTypeInfo):
|
||
return False
|
||
|
||
if type_info.component_count != 4:
|
||
return False
|
||
|
||
# 检查组件类型是否为 float
|
||
component_type = type_info.resolved_component_type
|
||
if component_type is None:
|
||
component_type = type_map.get(type_info.component_type_id)
|
||
|
||
if not isinstance(component_type, ScalarTypeInfo):
|
||
return False
|
||
|
||
return component_type.base_type == BaseType.FLOAT and component_type.bit_width == 32
|
||
|
||
|
||
def _validate_instance_struct(
|
||
struct_type: StructTypeInfo,
|
||
struct_name: str,
|
||
type_map: Dict[int, TypeInfo],
|
||
source_file: Optional[str] = None
|
||
) -> None:
|
||
"""验证实例结构体布局
|
||
|
||
验证规则:
|
||
1. 结构体必须有成员
|
||
2. 第一个成员必须命名为 'rect'
|
||
3. rect 必须是 vec4 类型
|
||
4. rect 偏移必须为 0
|
||
5. 结构体大小必须是 16 字节的倍数
|
||
|
||
Args:
|
||
struct_type: 要验证的结构体类型
|
||
struct_name: 结构体名称(用于错误消息)
|
||
type_map: 类型映射
|
||
source_file: 源文件路径(用于错误消息,可选)
|
||
|
||
Raises:
|
||
ToolError: 如果验证失败
|
||
"""
|
||
location_hint = f"\n In: {source_file}" if source_file else ""
|
||
|
||
# 检查 1:必须有成员
|
||
if not struct_type.members:
|
||
raise ToolError(
|
||
f"Instance buffer struct '{struct_name}' has no members. "
|
||
f"Expected at least 'rect' field.{location_hint}"
|
||
)
|
||
|
||
# 检查 2:第一个成员必须是 rect
|
||
first_member = struct_type.members[0]
|
||
if first_member.name != "rect":
|
||
raise ToolError(
|
||
f"Instance buffer struct '{struct_name}': first member must be named 'rect', "
|
||
f"got '{first_member.name}'. The rect field is required for positioning.{location_hint}"
|
||
)
|
||
|
||
# 检查 3:rect 必须是 vec4
|
||
rect_type = first_member.resolved_type
|
||
if rect_type is None:
|
||
rect_type = type_map.get(first_member.type_id)
|
||
|
||
if not _is_vec4_float(rect_type, type_map):
|
||
actual_type = spirv_type_to_cpp(rect_type, type_map) if rect_type else "<unknown>"
|
||
raise ToolError(
|
||
f"Instance buffer struct '{struct_name}': 'rect' must be vec4 (16 bytes), "
|
||
f"got {actual_type}.{location_hint}"
|
||
)
|
||
|
||
# 检查 4:rect 偏移必须为 0
|
||
if first_member.offset != 0:
|
||
raise ToolError(
|
||
f"Instance buffer struct '{struct_name}': 'rect' offset must be 0, "
|
||
f"got {first_member.offset}.{location_hint}"
|
||
)
|
||
|
||
# 检查 5:结构体大小必须是 16 的倍数
|
||
total_size = calculate_std430_size(struct_type, type_map)
|
||
if total_size % 16 != 0:
|
||
raise ToolError(
|
||
f"Instance buffer struct '{struct_name}': struct size must be multiple of 16, "
|
||
f"got {total_size} bytes. Add padding to align.{location_hint}"
|
||
)
|
||
|
||
|
||
def detect_instance_buffer(
|
||
reflection: SPIRVReflection,
|
||
annotations: Dict[str, Any],
|
||
source_file: Optional[str] = None
|
||
) -> Optional[InstanceBufferInfo]:
|
||
"""检测并验证实例缓冲
|
||
|
||
根据 @instance_buffer 注释标记,从 SPIR-V 反射中找到对应的 SSBO,
|
||
并验证其布局符合实例数据要求。
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息
|
||
annotations: 着色器注释信息(从 parse_shader_annotations 获取)
|
||
source_file: 源文件路径(用于错误消息,可选)
|
||
|
||
Returns:
|
||
InstanceBufferInfo 如果找到并验证通过,否则返回 None
|
||
|
||
Raises:
|
||
ToolError: 如果标记的 SSBO 不存在或验证失败
|
||
"""
|
||
if annotations is None:
|
||
return None
|
||
|
||
target_binding = annotations.get('instance_buffer_binding')
|
||
if target_binding is None:
|
||
return None
|
||
|
||
location_hint = f"\n In: {source_file}" if source_file else ""
|
||
|
||
# 从反射中找到对应的 storage buffer
|
||
target_buffer: Optional[BufferInfo] = None
|
||
for buffer in reflection.buffers:
|
||
if buffer.binding == target_binding and buffer.descriptor_type == "storage_buffer":
|
||
target_buffer = buffer
|
||
break
|
||
|
||
if target_buffer is None:
|
||
raise ToolError(
|
||
f"Cannot find @instance_buffer marked SSBO (binding={target_binding}). "
|
||
f"Make sure the buffer is declared as 'layout(std430, ...) buffer' "
|
||
f"and the binding number matches the annotation.{location_hint}"
|
||
)
|
||
|
||
struct_type = target_buffer.struct_type
|
||
struct_name = struct_type.name or target_buffer.name
|
||
|
||
# 获取实际的元素结构体类型(如果是 runtime array)
|
||
# SSBO 通常是包含 runtime array 的结构体,需要找到数组元素类型
|
||
element_struct_type = struct_type
|
||
element_struct_name = struct_name
|
||
|
||
# 检查是否有 runtime array 成员
|
||
for member in struct_type.members:
|
||
member_type = member.resolved_type
|
||
if member_type is None:
|
||
member_type = reflection.types.get(member.type_id)
|
||
|
||
if isinstance(member_type, ArrayTypeInfo) and member_type.is_runtime:
|
||
# 找到了 runtime array,获取其元素类型
|
||
element_type = member_type.resolved_element_type
|
||
if element_type is None:
|
||
element_type = reflection.types.get(member_type.element_type_id)
|
||
|
||
if isinstance(element_type, StructTypeInfo):
|
||
element_struct_type = element_type
|
||
element_struct_name = element_type.name or f"{struct_name}Element"
|
||
break
|
||
|
||
# 验证实例结构体布局
|
||
_validate_instance_struct(
|
||
element_struct_type,
|
||
element_struct_name,
|
||
reflection.types,
|
||
source_file
|
||
)
|
||
|
||
# 计算结构体总大小
|
||
total_size = calculate_std430_size(element_struct_type, reflection.types)
|
||
|
||
# 创建 InstanceBufferInfo
|
||
return InstanceBufferInfo(
|
||
name=target_buffer.variable_name,
|
||
struct_type_name=element_struct_name,
|
||
binding=target_buffer.binding,
|
||
set_number=target_buffer.descriptor_set,
|
||
members=element_struct_type.members,
|
||
struct_type=element_struct_type,
|
||
has_rect=True, # 验证通过,必定有 rect
|
||
rect_offset=0, # 验证通过,必定为 0
|
||
total_size=total_size,
|
||
)
|
||
|
||
|
||
# ============ Vertex Input Extraction ============
|
||
|
||
def _spirv_type_to_vk_format(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> str:
|
||
"""将 SPIR-V 类型转换为 Vulkan 格式枚举
|
||
|
||
Args:
|
||
type_info: 类型信息
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
Vulkan 格式枚举字符串(如 VK_FORMAT_R32G32B32A32_SFLOAT)
|
||
"""
|
||
if type_info is None:
|
||
return "VK_FORMAT_R32_SFLOAT"
|
||
|
||
if isinstance(type_info, ScalarTypeInfo):
|
||
if type_info.base_type == BaseType.FLOAT:
|
||
if type_info.bit_width == 32:
|
||
return "VK_FORMAT_R32_SFLOAT"
|
||
elif type_info.bit_width == 64:
|
||
return "VK_FORMAT_R64_SFLOAT"
|
||
elif type_info.base_type == BaseType.INT:
|
||
if type_info.bit_width == 32:
|
||
return "VK_FORMAT_R32_SINT"
|
||
elif type_info.base_type == BaseType.UINT:
|
||
if type_info.bit_width == 32:
|
||
return "VK_FORMAT_R32_UINT"
|
||
return "VK_FORMAT_R32_SFLOAT"
|
||
|
||
elif isinstance(type_info, VectorTypeInfo):
|
||
component_type = type_info.resolved_component_type
|
||
if component_type is None:
|
||
component_type = type_map.get(type_info.component_type_id)
|
||
|
||
count = type_info.component_count
|
||
|
||
# 确定组件类型
|
||
if isinstance(component_type, ScalarTypeInfo):
|
||
if component_type.base_type == BaseType.FLOAT and component_type.bit_width == 32:
|
||
format_map = {
|
||
1: "VK_FORMAT_R32_SFLOAT",
|
||
2: "VK_FORMAT_R32G32_SFLOAT",
|
||
3: "VK_FORMAT_R32G32B32_SFLOAT",
|
||
4: "VK_FORMAT_R32G32B32A32_SFLOAT",
|
||
}
|
||
return format_map.get(count, "VK_FORMAT_R32G32B32A32_SFLOAT")
|
||
elif component_type.base_type == BaseType.INT and component_type.bit_width == 32:
|
||
format_map = {
|
||
1: "VK_FORMAT_R32_SINT",
|
||
2: "VK_FORMAT_R32G32_SINT",
|
||
3: "VK_FORMAT_R32G32B32_SINT",
|
||
4: "VK_FORMAT_R32G32B32A32_SINT",
|
||
}
|
||
return format_map.get(count, "VK_FORMAT_R32G32B32A32_SINT")
|
||
elif component_type.base_type == BaseType.UINT and component_type.bit_width == 32:
|
||
format_map = {
|
||
1: "VK_FORMAT_R32_UINT",
|
||
2: "VK_FORMAT_R32G32_UINT",
|
||
3: "VK_FORMAT_R32G32B32_UINT",
|
||
4: "VK_FORMAT_R32G32B32A32_UINT",
|
||
}
|
||
return format_map.get(count, "VK_FORMAT_R32G32B32A32_UINT")
|
||
|
||
# 默认浮点向量格式
|
||
format_map = {
|
||
1: "VK_FORMAT_R32_SFLOAT",
|
||
2: "VK_FORMAT_R32G32_SFLOAT",
|
||
3: "VK_FORMAT_R32G32B32_SFLOAT",
|
||
4: "VK_FORMAT_R32G32B32A32_SFLOAT",
|
||
}
|
||
return format_map.get(count, "VK_FORMAT_R32G32B32A32_SFLOAT")
|
||
|
||
return "VK_FORMAT_R32_SFLOAT"
|
||
|
||
|
||
def _clean_vertex_attribute_name(name: str) -> str:
|
||
"""清理顶点属性名称,去除前缀
|
||
|
||
Args:
|
||
name: 原始变量名称(如 in_position)
|
||
|
||
Returns:
|
||
清理后的名称(如 position)
|
||
"""
|
||
# 去除常见前缀
|
||
prefixes = ['in_', 'a_', 'v_', 'attr_', 'input_']
|
||
for prefix in prefixes:
|
||
if name.startswith(prefix):
|
||
return name[len(prefix):]
|
||
return name
|
||
|
||
|
||
def extract_vertex_inputs(reflection: SPIRVReflection) -> Optional[VertexLayout]:
|
||
"""从 SPIR-V 反射信息中提取顶点输入布局
|
||
|
||
遍历所有存储类为 Input 的变量,提取其 location 装饰和类型信息,
|
||
生成顶点属性列表。
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息
|
||
|
||
Returns:
|
||
VertexLayout 如果找到顶点输入,否则返回 None
|
||
"""
|
||
attributes = []
|
||
|
||
for var_id, var_info in reflection.variables.items():
|
||
# 只处理 Input 存储类的变量
|
||
if var_info.storage_class != STORAGE_CLASS_INPUT:
|
||
continue
|
||
|
||
# 获取 location 装饰
|
||
var_decs = reflection.decorations.get(var_id, {})
|
||
location = var_decs.get(DECORATION_LOCATION)
|
||
|
||
if location is None:
|
||
continue
|
||
|
||
# 获取类型信息
|
||
var_type = reflection.types.get(var_info.type_id)
|
||
|
||
# 如果是指针类型,获取 pointee 类型
|
||
actual_type = var_type
|
||
if isinstance(var_type, PointerTypeInfo):
|
||
actual_type = var_type.resolved_pointee_type
|
||
if actual_type is None:
|
||
actual_type = reflection.types.get(var_type.pointee_type_id)
|
||
|
||
if actual_type is None:
|
||
continue
|
||
|
||
# 获取变量名称
|
||
var_name = var_info.name or f"attr_{location}"
|
||
clean_name = _clean_vertex_attribute_name(var_name)
|
||
|
||
# 计算 C++ 类型和 VkFormat
|
||
# 使用紧凑类型映射,避免 Eigen 对齐问题
|
||
cpp_type = spirv_type_to_compact_cpp(actual_type, reflection.types)
|
||
vk_format = _spirv_type_to_vk_format(actual_type, reflection.types)
|
||
|
||
# 计算大小
|
||
size = calculate_std430_size(actual_type, reflection.types)
|
||
|
||
attributes.append(VertexAttribute(
|
||
name=clean_name,
|
||
location=location,
|
||
type_id=actual_type.id if hasattr(actual_type, 'id') else 0,
|
||
cpp_type=cpp_type,
|
||
vk_format=vk_format,
|
||
offset=0, # 稍后计算
|
||
size=size,
|
||
resolved_type=actual_type,
|
||
))
|
||
|
||
if not attributes:
|
||
return None
|
||
|
||
# 按 location 排序
|
||
attributes.sort(key=lambda a: a.location)
|
||
|
||
# 计算偏移量和总步长
|
||
current_offset = 0
|
||
for attr in attributes:
|
||
attr.offset = current_offset
|
||
current_offset += attr.size
|
||
|
||
stride = current_offset
|
||
|
||
return VertexLayout(
|
||
attributes=attributes,
|
||
stride=stride,
|
||
struct_name="Vertex", # 默认名称,后续可以根据着色器名称设置
|
||
)
|
||
|
||
|
||
# ============ Legacy Interface ============
|
||
|
||
def extract_spirv_reflection(
|
||
spirv_data: bytes,
|
||
shader_type: str,
|
||
annotations: Optional[Dict[str, Any]] = None,
|
||
source_file: Optional[Any] = None
|
||
) -> Dict[str, Any]:
|
||
"""从SPIR-V提取反射信息(兼容旧接口)
|
||
|
||
Args:
|
||
spirv_data: SPIR-V 二进制数据
|
||
shader_type: 着色器类型(vert, frag, comp 等)
|
||
annotations: 着色器注释信息(可选,用于实例缓冲检测)
|
||
source_file: 源文件路径(可选,用于错误消息)
|
||
|
||
Returns:
|
||
包含反射信息的字典
|
||
"""
|
||
reflection = parse_spirv_type_system(spirv_data, shader_type)
|
||
|
||
# 映射shader类型名称
|
||
stage_map = {
|
||
"vert": "vertex",
|
||
"frag": "fragment",
|
||
"comp": "compute",
|
||
"geom": "geometry",
|
||
"tesc": "tess_control",
|
||
"tese": "tess_eval",
|
||
}
|
||
stage = stage_map.get(shader_type, shader_type)
|
||
|
||
# 检测并验证实例缓冲
|
||
if annotations is not None:
|
||
source_file_str = str(source_file) if source_file else None
|
||
reflection.instance_buffer = detect_instance_buffer(
|
||
reflection,
|
||
annotations,
|
||
source_file_str
|
||
)
|
||
|
||
# 转换为BindingInfo列表(包含buffers)
|
||
bindings = []
|
||
for buffer in reflection.buffers:
|
||
bindings.append(BindingInfo(
|
||
binding=buffer.binding,
|
||
descriptor_set=buffer.descriptor_set,
|
||
descriptor_type=buffer.descriptor_type,
|
||
stages=[stage],
|
||
name=buffer.variable_name,
|
||
count=1,
|
||
))
|
||
|
||
# 添加额外的bindings(如sampled_image)
|
||
if hasattr(reflection, '_extra_bindings'):
|
||
for extra in reflection._extra_bindings:
|
||
extra.stages = [stage]
|
||
bindings.append(extra)
|
||
delattr(reflection, '_extra_bindings')
|
||
|
||
return {
|
||
'entry_point': reflection.entry_point,
|
||
'bindings': bindings,
|
||
'reflection': reflection,
|
||
}
|
||
|
||
|
||
# ============ Dual Stage Push Constants Analysis ============
|
||
|
||
def _extract_stage_push_constant_info(
|
||
reflection: SPIRVReflection,
|
||
stage: str
|
||
) -> Optional[StagePushConstantInfo]:
|
||
"""从反射信息中提取单个阶段的 Push Constant 信息
|
||
|
||
提取 Custom 区域(偏移 >= 16)的成员信息。
|
||
|
||
Args:
|
||
reflection: SPIR-V 反射信息
|
||
stage: 着色器阶段名称("vertex" 或 "fragment")
|
||
|
||
Returns:
|
||
StagePushConstantInfo 或 None
|
||
"""
|
||
push_constant = reflection.push_constant
|
||
if not push_constant or not push_constant.effect_members:
|
||
return None
|
||
|
||
# effect_members 已经是 Custom 区域成员(偏移量已调整为相对偏移)
|
||
members = push_constant.effect_members
|
||
|
||
if not members:
|
||
return None
|
||
|
||
# 计算总大小和对齐
|
||
total_size = 0
|
||
max_alignment = 4
|
||
|
||
for member in members:
|
||
member_size = calculate_std430_size(member.resolved_type, reflection.types)
|
||
member_alignment = calculate_std430_alignment(member.resolved_type, reflection.types)
|
||
|
||
# 更新结束位置
|
||
end_offset = member.offset + member_size
|
||
if end_offset > total_size:
|
||
total_size = end_offset
|
||
|
||
# 更新最大对齐
|
||
if member_alignment > max_alignment:
|
||
max_alignment = member_alignment
|
||
|
||
return StagePushConstantInfo(
|
||
stage=stage,
|
||
members=members,
|
||
relative_offset=PUSH_CONSTANT_CUSTOM_OFFSET,
|
||
total_size=total_size,
|
||
alignment=max_alignment,
|
||
)
|
||
|
||
|
||
def _compare_member_layouts(
|
||
vert_members: List[MemberInfo],
|
||
frag_members: List[MemberInfo],
|
||
type_map: Dict[int, TypeInfo]
|
||
) -> Tuple[bool, List[str]]:
|
||
"""比较两个阶段的成员布局
|
||
|
||
检查成员是否具有相同的名称、偏移量和类型。
|
||
|
||
Args:
|
||
vert_members: 顶点着色器成员列表
|
||
frag_members: 片元着色器成员列表
|
||
type_map: 类型映射
|
||
|
||
Returns:
|
||
(is_identical, shared_member_names) 元组:
|
||
- is_identical: 布局是否完全相同
|
||
- shared_member_names: 共享的成员名称列表
|
||
"""
|
||
if not vert_members and not frag_members:
|
||
return True, []
|
||
|
||
if not vert_members or not frag_members:
|
||
return False, []
|
||
|
||
# 构建成员映射(按名称)
|
||
vert_by_name = {m.name: m for m in vert_members}
|
||
frag_by_name = {m.name: m for m in frag_members}
|
||
|
||
# 找到共享的成员名称
|
||
shared_names = set(vert_by_name.keys()) & set(frag_by_name.keys())
|
||
|
||
if not shared_names:
|
||
return False, []
|
||
|
||
# 检查共享成员的布局是否相同
|
||
shared_members = []
|
||
for name in shared_names:
|
||
vert_member = vert_by_name[name]
|
||
frag_member = frag_by_name[name]
|
||
|
||
# 检查偏移量是否相同
|
||
if vert_member.offset != frag_member.offset:
|
||
continue
|
||
|
||
# 检查类型大小是否相同
|
||
vert_size = calculate_std430_size(vert_member.resolved_type, type_map)
|
||
frag_size = calculate_std430_size(frag_member.resolved_type, type_map)
|
||
if vert_size != frag_size:
|
||
continue
|
||
|
||
shared_members.append(name)
|
||
|
||
# 判断是否完全相同(所有成员都匹配且数量相同)
|
||
is_identical = (
|
||
len(shared_members) == len(vert_members) == len(frag_members) and
|
||
len(shared_members) > 0
|
||
)
|
||
|
||
return is_identical, sorted(shared_members)
|
||
|
||
|
||
def analyze_dual_stage_push_constants(
|
||
vert_reflection: SPIRVReflection,
|
||
frag_reflection: SPIRVReflection
|
||
) -> CombinedPushConstantInfo:
|
||
"""分析顶点和片元着色器的 Push Constant 布局
|
||
|
||
分析两个着色器阶段的 Push Constant Custom 区域布局:
|
||
1. 提取各阶段的 Custom 成员(偏移量 >= 16 的成员)
|
||
2. 检测布局是否相同(共享模式)或不同(分离模式)
|
||
3. 返回组合信息
|
||
|
||
布局分析规则:
|
||
- 如果两个阶段的 Custom 区域成员完全相同(名称、偏移、类型大小都匹配),
|
||
则使用共享模式(overlapping=True)
|
||
- 如果不同,则使用分离模式(overlapping=False)
|
||
|
||
Args:
|
||
vert_reflection: 顶点着色器的 SPIR-V 反射信息
|
||
frag_reflection: 片元着色器的 SPIR-V 反射信息
|
||
|
||
Returns:
|
||
CombinedPushConstantInfo 包含双阶段分析结果
|
||
|
||
Example:
|
||
>>> vert_ref = parse_spirv_type_system(vert_spirv, "vertex")
|
||
>>> frag_ref = parse_spirv_type_system(frag_spirv, "fragment")
|
||
>>> combined = analyze_dual_stage_push_constants(vert_ref, frag_ref)
|
||
>>> if combined.is_shared_layout:
|
||
>>> print("使用共享 Push Constant 布局")
|
||
>>> else:
|
||
>>> print("使用分离 Push Constant 布局")
|
||
"""
|
||
# 提取各阶段的 Push Constant 信息
|
||
vert_info = _extract_stage_push_constant_info(vert_reflection, "vertex")
|
||
frag_info = _extract_stage_push_constant_info(frag_reflection, "fragment")
|
||
|
||
# 如果两个阶段都没有 Custom 成员,返回空结果
|
||
if not vert_info and not frag_info:
|
||
return CombinedPushConstantInfo(
|
||
vertex_info=None,
|
||
fragment_info=None,
|
||
overlapping=False,
|
||
shared_members=[],
|
||
)
|
||
|
||
# 如果只有一个阶段有 Custom 成员,不存在重叠
|
||
if not vert_info or not frag_info:
|
||
return CombinedPushConstantInfo(
|
||
vertex_info=vert_info,
|
||
fragment_info=frag_info,
|
||
overlapping=False,
|
||
shared_members=[],
|
||
)
|
||
|
||
# 使用顶点着色器的类型映射(两者应该是兼容的)
|
||
type_map = vert_reflection.types
|
||
|
||
# 比较两个阶段的成员布局
|
||
is_identical, shared_members = _compare_member_layouts(
|
||
vert_info.members,
|
||
frag_info.members,
|
||
type_map,
|
||
)
|
||
|
||
return CombinedPushConstantInfo(
|
||
vertex_info=vert_info,
|
||
fragment_info=frag_info,
|
||
overlapping=is_identical,
|
||
shared_members=shared_members,
|
||
) |