- 新增ShaderCompiler类,通过glslc工具实现GLSL到SPIR-V的编译。 - 实现构建glslc命令、运行编译器及从SPIR-V二进制文件提取反射数据的功能。 - 为SPIR-V指令集、装饰符、执行模型及存储类创建常量。 - 开发SPIR-V解析器以提取类型信息、变量细节及反射数据 - 引入类型映射函数实现SPIR-V类型到C++类型的转换,并计算std430内存布局 - 定义着色器元数据、编译结果及SPIR-V反射信息的数据类 - 添加着色器发现、分组及元数据加载的实用函数
446 lines
16 KiB
Python
446 lines
16 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
SPIR-V解析模块
|
||
|
||
负责解析SPIR-V二进制文件,提取类型系统、变量、装饰和反射信息。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import struct
|
||
from typing import Any, Dict, List, Tuple
|
||
|
||
from .constants import *
|
||
from .types import (
|
||
ArrayTypeInfo,
|
||
BaseType,
|
||
BindingInfo,
|
||
BufferInfo,
|
||
MatrixTypeInfo,
|
||
MemberInfo,
|
||
PointerTypeInfo,
|
||
ScalarTypeInfo,
|
||
SPIRVReflection,
|
||
StructTypeInfo,
|
||
ToolError,
|
||
TypeInfo,
|
||
VariableInfo,
|
||
VectorTypeInfo,
|
||
)
|
||
|
||
|
||
# ============ Binary Parsing Functions ============
|
||
|
||
def parse_spirv_words(spirv_data: bytes) -> List[int]:
|
||
"""将SPIR-V二进制解析为32位字列表"""
|
||
if len(spirv_data) < 20:
|
||
raise ToolError("SPIR-V data too small")
|
||
if len(spirv_data) % 4 != 0:
|
||
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
|
||
|
||
words = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
|
||
|
||
if words[0] != SPIRV_MAGIC:
|
||
raise ToolError(f"Invalid SPIR-V magic number: 0x{words[0]:08x}")
|
||
|
||
return list(words)
|
||
|
||
|
||
def parse_spirv_string(words: List[int], start_index: int) -> Tuple[str, int]:
|
||
"""从字数组解析以null结尾的UTF-8字符串,返回(字符串, 消耗的字数)"""
|
||
chars = []
|
||
word_count = 0
|
||
|
||
for i in range(start_index, len(words)):
|
||
word = words[i]
|
||
word_count += 1
|
||
for byte_idx in range(4):
|
||
byte = (word >> (byte_idx * 8)) & 0xFF
|
||
if byte == 0:
|
||
return ''.join(chars), word_count
|
||
chars.append(chr(byte))
|
||
|
||
return ''.join(chars), word_count
|
||
|
||
|
||
def parse_spirv_instructions(words: List[int]) -> List[Tuple[int, List[int], int]]:
|
||
"""解析SPIR-V指令流,返回(opcode, operands, index)列表"""
|
||
instructions = []
|
||
index = 5 # 跳过头部(5个字)
|
||
|
||
while index < len(words):
|
||
word = words[index]
|
||
word_count = (word >> 16) & 0xFFFF
|
||
opcode = word & 0xFFFF
|
||
|
||
if word_count == 0:
|
||
break
|
||
|
||
operands = words[index + 1 : index + word_count]
|
||
instructions.append((opcode, operands, index))
|
||
index += word_count
|
||
|
||
return instructions
|
||
|
||
|
||
# ============ Execution Model Conversion ============
|
||
|
||
def execution_model_to_stage(exec_model: int) -> str:
|
||
"""将SPIR-V执行模型转换为着色器阶段名称"""
|
||
model_map = {
|
||
EXECUTION_MODEL_VERTEX: "vertex",
|
||
EXECUTION_MODEL_TESS_CONTROL: "tess_control",
|
||
EXECUTION_MODEL_TESS_EVAL: "tess_eval",
|
||
EXECUTION_MODEL_GEOMETRY: "geometry",
|
||
EXECUTION_MODEL_FRAGMENT: "fragment",
|
||
EXECUTION_MODEL_COMPUTE: "compute",
|
||
}
|
||
return model_map.get(exec_model, "unknown")
|
||
|
||
|
||
def stage_to_execution_model(stage: str) -> int:
|
||
"""将着色器阶段名称转换为SPIR-V执行模型"""
|
||
stage_map = {
|
||
"vertex": EXECUTION_MODEL_VERTEX,
|
||
"vert": EXECUTION_MODEL_VERTEX,
|
||
"tess_control": EXECUTION_MODEL_TESS_CONTROL,
|
||
"tesc": EXECUTION_MODEL_TESS_CONTROL,
|
||
"tess_eval": EXECUTION_MODEL_TESS_EVAL,
|
||
"tese": EXECUTION_MODEL_TESS_EVAL,
|
||
"geometry": EXECUTION_MODEL_GEOMETRY,
|
||
"geom": EXECUTION_MODEL_GEOMETRY,
|
||
"fragment": EXECUTION_MODEL_FRAGMENT,
|
||
"frag": EXECUTION_MODEL_FRAGMENT,
|
||
"compute": EXECUTION_MODEL_COMPUTE,
|
||
"comp": EXECUTION_MODEL_COMPUTE,
|
||
}
|
||
return stage_map.get(stage, EXECUTION_MODEL_COMPUTE)
|
||
|
||
|
||
# ============ Type System Parsing ============
|
||
|
||
def parse_spirv_type_system(spirv_data: bytes, shader_stage: str) -> SPIRVReflection:
|
||
"""从SPIR-V提取完整的类型系统和反射信息"""
|
||
words = parse_spirv_words(spirv_data)
|
||
instructions = parse_spirv_instructions(words)
|
||
|
||
reflection = SPIRVReflection(shader_stage=shader_stage)
|
||
|
||
# 第一遍:收集常量
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_CONSTANT and len(operands) >= 3:
|
||
result_id = operands[1]
|
||
value = operands[2]
|
||
reflection.constants[result_id] = value
|
||
|
||
# 第二遍:收集类型定义
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_TYPE_INT and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
bit_width = operands[1]
|
||
is_signed = operands[2]
|
||
base_type = BaseType.INT if is_signed else BaseType.UINT
|
||
reflection.types[result_id] = ScalarTypeInfo(
|
||
id=result_id,
|
||
base_type=base_type,
|
||
bit_width=bit_width
|
||
)
|
||
|
||
elif opcode == OP_TYPE_FLOAT and len(operands) >= 2:
|
||
result_id = operands[0]
|
||
bit_width = operands[1]
|
||
reflection.types[result_id] = ScalarTypeInfo(
|
||
id=result_id,
|
||
base_type=BaseType.FLOAT,
|
||
bit_width=bit_width
|
||
)
|
||
|
||
elif opcode == OP_TYPE_VECTOR and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
component_type_id = operands[1]
|
||
component_count = operands[2]
|
||
reflection.types[result_id] = VectorTypeInfo(
|
||
id=result_id,
|
||
component_type_id=component_type_id,
|
||
component_count=component_count
|
||
)
|
||
|
||
elif opcode == OP_TYPE_MATRIX and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
column_type_id = operands[1]
|
||
column_count = operands[2]
|
||
reflection.types[result_id] = MatrixTypeInfo(
|
||
id=result_id,
|
||
column_type_id=column_type_id,
|
||
column_count=column_count
|
||
)
|
||
|
||
elif opcode == OP_TYPE_ARRAY and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
element_type_id = operands[1]
|
||
length_id = operands[2]
|
||
length = reflection.constants.get(length_id, 0)
|
||
reflection.types[result_id] = ArrayTypeInfo(
|
||
id=result_id,
|
||
element_type_id=element_type_id,
|
||
length=length,
|
||
is_runtime=False
|
||
)
|
||
|
||
elif opcode == OP_TYPE_RUNTIME_ARRAY and len(operands) >= 2:
|
||
result_id = operands[0]
|
||
element_type_id = operands[1]
|
||
reflection.types[result_id] = ArrayTypeInfo(
|
||
id=result_id,
|
||
element_type_id=element_type_id,
|
||
length=None,
|
||
is_runtime=True
|
||
)
|
||
|
||
elif opcode == OP_TYPE_STRUCT and len(operands) >= 1:
|
||
result_id = operands[0]
|
||
reflection.types[result_id] = StructTypeInfo(
|
||
id=result_id,
|
||
members=[]
|
||
)
|
||
|
||
elif opcode == OP_TYPE_POINTER and len(operands) >= 3:
|
||
result_id = operands[0]
|
||
storage_class = operands[1]
|
||
pointee_type_id = operands[2]
|
||
reflection.types[result_id] = PointerTypeInfo(
|
||
id=result_id,
|
||
storage_class=storage_class,
|
||
pointee_type_id=pointee_type_id
|
||
)
|
||
|
||
# 第三遍:收集名称
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_NAME and len(operands) >= 1:
|
||
target_id = operands[0]
|
||
name, _ = parse_spirv_string(words, index + 2)
|
||
reflection.names[target_id] = name
|
||
|
||
elif opcode == OP_MEMBER_NAME and len(operands) >= 2:
|
||
struct_id = operands[0]
|
||
member_index = operands[1]
|
||
name, _ = parse_spirv_string(words, index + 3)
|
||
if struct_id not in reflection.member_names:
|
||
reflection.member_names[struct_id] = {}
|
||
reflection.member_names[struct_id][member_index] = name
|
||
|
||
# 第四遍:收集装饰
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_DECORATE and len(operands) >= 2:
|
||
target_id = operands[0]
|
||
decoration = operands[1]
|
||
if target_id not in reflection.decorations:
|
||
reflection.decorations[target_id] = {}
|
||
if len(operands) >= 3:
|
||
reflection.decorations[target_id][decoration] = operands[2]
|
||
else:
|
||
reflection.decorations[target_id][decoration] = True
|
||
|
||
elif opcode == OP_MEMBER_DECORATE and len(operands) >= 3:
|
||
struct_id = operands[0]
|
||
member_index = operands[1]
|
||
decoration = operands[2]
|
||
if struct_id not in reflection.member_decorations:
|
||
reflection.member_decorations[struct_id] = {}
|
||
if member_index not in reflection.member_decorations[struct_id]:
|
||
reflection.member_decorations[struct_id][member_index] = {}
|
||
if len(operands) >= 4:
|
||
reflection.member_decorations[struct_id][member_index][decoration] = operands[3]
|
||
else:
|
||
reflection.member_decorations[struct_id][member_index][decoration] = True
|
||
|
||
# 填充结构体成员信息
|
||
for type_id, type_info in reflection.types.items():
|
||
if isinstance(type_info, StructTypeInfo):
|
||
for opcode, operands, _ in instructions:
|
||
if opcode == OP_TYPE_STRUCT and operands[0] == type_id:
|
||
member_type_ids = operands[1:]
|
||
type_info.members = []
|
||
for idx, member_type_id in enumerate(member_type_ids):
|
||
member_name = reflection.member_names.get(type_id, {}).get(idx, f"member_{idx}")
|
||
member_decs = reflection.member_decorations.get(type_id, {}).get(idx, {})
|
||
offset = member_decs.get(DECORATION_OFFSET, 0)
|
||
array_stride = member_decs.get(DECORATION_ARRAY_STRIDE)
|
||
matrix_stride = member_decs.get(DECORATION_MATRIX_STRIDE)
|
||
is_row_major = DECORATION_ROW_MAJOR in member_decs
|
||
|
||
type_info.members.append(MemberInfo(
|
||
index=idx,
|
||
name=member_name,
|
||
type_id=member_type_id,
|
||
offset=offset,
|
||
array_stride=array_stride,
|
||
matrix_stride=matrix_stride,
|
||
is_row_major=is_row_major
|
||
))
|
||
break
|
||
|
||
# 检查Block装饰
|
||
decs = reflection.decorations.get(type_id, {})
|
||
type_info.is_block = DECORATION_BLOCK in decs
|
||
type_info.is_buffer_block = DECORATION_BUFFER_BLOCK in decs
|
||
|
||
# 设置结构体名称
|
||
type_info.name = reflection.names.get(type_id)
|
||
|
||
# 第五遍:收集变量
|
||
for opcode, operands, _ in instructions:
|
||
if opcode == OP_VARIABLE and len(operands) >= 3:
|
||
result_type_id = operands[0]
|
||
result_id = operands[1]
|
||
storage_class = operands[2]
|
||
|
||
var_name = reflection.names.get(result_id)
|
||
var_decs = reflection.decorations.get(result_id, {})
|
||
binding = var_decs.get(DECORATION_BINDING)
|
||
descriptor_set = var_decs.get(DECORATION_DESCRIPTOR_SET, 0)
|
||
|
||
reflection.variables[result_id] = VariableInfo(
|
||
id=result_id,
|
||
name=var_name,
|
||
type_id=result_type_id,
|
||
storage_class=storage_class,
|
||
binding=binding,
|
||
descriptor_set=descriptor_set
|
||
)
|
||
|
||
# 解析类型引用
|
||
resolve_all_types(reflection)
|
||
|
||
# 提取buffer信息
|
||
extract_buffers(reflection)
|
||
|
||
# 提取入口点
|
||
for opcode, operands, index in instructions:
|
||
if opcode == OP_ENTRY_POINT and len(operands) >= 2:
|
||
exec_model = operands[0]
|
||
name, _ = parse_spirv_string(words, index + 3)
|
||
target_model = stage_to_execution_model(shader_stage)
|
||
if exec_model == target_model:
|
||
reflection.entry_point = name
|
||
break
|
||
|
||
return reflection
|
||
|
||
|
||
# ============ Type Resolution ============
|
||
|
||
def resolve_all_types(reflection: SPIRVReflection):
|
||
"""递归解析所有类型引用"""
|
||
for type_info in reflection.types.values():
|
||
resolve_type(type_info, reflection.types)
|
||
|
||
for var_info in reflection.variables.values():
|
||
if var_info.type_id in reflection.types:
|
||
var_info.resolved_type = reflection.types[var_info.type_id]
|
||
resolve_type(var_info.resolved_type, reflection.types)
|
||
|
||
|
||
def resolve_type(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> TypeInfo:
|
||
"""递归解析类型"""
|
||
if isinstance(type_info, VectorTypeInfo):
|
||
if type_info.component_type_id in type_map:
|
||
type_info.resolved_component_type = type_map[type_info.component_type_id]
|
||
|
||
elif isinstance(type_info, MatrixTypeInfo):
|
||
if type_info.column_type_id in type_map:
|
||
col_type = type_map[type_info.column_type_id]
|
||
type_info.resolved_column_type = col_type
|
||
if isinstance(col_type, VectorTypeInfo):
|
||
resolve_type(col_type, type_map)
|
||
|
||
elif isinstance(type_info, ArrayTypeInfo):
|
||
if type_info.element_type_id in type_map:
|
||
type_info.resolved_element_type = type_map[type_info.element_type_id]
|
||
resolve_type(type_info.resolved_element_type, type_map)
|
||
|
||
elif isinstance(type_info, StructTypeInfo):
|
||
for member in type_info.members:
|
||
if member.type_id in type_map:
|
||
member.resolved_type = type_map[member.type_id]
|
||
resolve_type(member.resolved_type, type_map)
|
||
|
||
elif isinstance(type_info, PointerTypeInfo):
|
||
if type_info.pointee_type_id in type_map:
|
||
type_info.resolved_pointee_type = type_map[type_info.pointee_type_id]
|
||
resolve_type(type_info.resolved_pointee_type, type_map)
|
||
|
||
return type_info
|
||
|
||
|
||
# ============ Buffer Extraction ============
|
||
|
||
def extract_buffers(reflection: SPIRVReflection):
|
||
"""从变量中提取buffer信息"""
|
||
for var_id, var_info in reflection.variables.items():
|
||
if var_info.binding is None:
|
||
continue
|
||
|
||
var_type = reflection.types.get(var_info.type_id)
|
||
if not isinstance(var_type, PointerTypeInfo):
|
||
continue
|
||
|
||
struct_type = var_type.resolved_pointee_type
|
||
if not isinstance(struct_type, StructTypeInfo):
|
||
continue
|
||
|
||
storage_class = var_type.storage_class
|
||
if storage_class == STORAGE_CLASS_UNIFORM:
|
||
descriptor_type = "uniform_buffer"
|
||
elif storage_class == STORAGE_CLASS_STORAGE_BUFFER:
|
||
descriptor_type = "storage_buffer"
|
||
else:
|
||
continue
|
||
|
||
struct_name = struct_type.name
|
||
if not struct_name:
|
||
struct_name = f"Buffer{var_info.binding}"
|
||
|
||
reflection.buffers.append(BufferInfo(
|
||
name=struct_name,
|
||
binding=var_info.binding,
|
||
descriptor_set=var_info.descriptor_set or 0,
|
||
descriptor_type=descriptor_type,
|
||
struct_type=struct_type,
|
||
variable_name=var_info.name or f"buffer_{var_info.binding}"
|
||
))
|
||
|
||
# 按binding排序
|
||
reflection.buffers.sort(key=lambda b: b.binding)
|
||
|
||
|
||
# ============ Legacy Interface ============
|
||
|
||
def extract_spirv_reflection(spirv_data: bytes, shader_type: str) -> Dict[str, Any]:
|
||
"""从SPIR-V提取反射信息(兼容旧接口)"""
|
||
reflection = parse_spirv_type_system(spirv_data, shader_type)
|
||
|
||
# 转换为BindingInfo列表
|
||
bindings = []
|
||
for buffer in reflection.buffers:
|
||
stage = shader_type
|
||
if stage == "vert":
|
||
stage = "vertex"
|
||
elif stage == "frag":
|
||
stage = "fragment"
|
||
elif stage == "comp":
|
||
stage = "compute"
|
||
|
||
bindings.append(BindingInfo(
|
||
binding=buffer.binding,
|
||
descriptor_type=buffer.descriptor_type,
|
||
stages=[stage],
|
||
name=buffer.variable_name,
|
||
count=1,
|
||
))
|
||
|
||
return {
|
||
'entry_point': reflection.entry_point,
|
||
'bindings': bindings,
|
||
'reflection': reflection,
|
||
} |