Files
mirage/tools/spirv_parser.py
daiqingshuang a635b0464f 功能:实现着色器编译与反射信息提取
- 新增ShaderCompiler类,通过glslc工具实现GLSL到SPIR-V的编译。
- 实现构建glslc命令、运行编译器及从SPIR-V二进制文件提取反射数据的功能。
- 为SPIR-V指令集、装饰符、执行模型及存储类创建常量。
- 开发SPIR-V解析器以提取类型信息、变量细节及反射数据
- 引入类型映射函数实现SPIR-V类型到C++类型的转换,并计算std430内存布局
- 定义着色器元数据、编译结果及SPIR-V反射信息的数据类
- 添加着色器发现、分组及元数据加载的实用函数
2025-11-22 11:51:31 +08:00

446 lines
16 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
SPIR-V解析模块
负责解析SPIR-V二进制文件提取类型系统、变量、装饰和反射信息。
"""
from __future__ import annotations
import struct
from typing import Any, Dict, List, Tuple
from .constants import *
from .types import (
ArrayTypeInfo,
BaseType,
BindingInfo,
BufferInfo,
MatrixTypeInfo,
MemberInfo,
PointerTypeInfo,
ScalarTypeInfo,
SPIRVReflection,
StructTypeInfo,
ToolError,
TypeInfo,
VariableInfo,
VectorTypeInfo,
)
# ============ Binary Parsing Functions ============
def parse_spirv_words(spirv_data: bytes) -> List[int]:
"""将SPIR-V二进制解析为32位字列表"""
if len(spirv_data) < 20:
raise ToolError("SPIR-V data too small")
if len(spirv_data) % 4 != 0:
raise ToolError(f"SPIR-V data size {len(spirv_data)} is not a multiple of 4")
words = struct.unpack(f"<{len(spirv_data)//4}I", spirv_data)
if words[0] != SPIRV_MAGIC:
raise ToolError(f"Invalid SPIR-V magic number: 0x{words[0]:08x}")
return list(words)
def parse_spirv_string(words: List[int], start_index: int) -> Tuple[str, int]:
"""从字数组解析以null结尾的UTF-8字符串返回(字符串, 消耗的字数)"""
chars = []
word_count = 0
for i in range(start_index, len(words)):
word = words[i]
word_count += 1
for byte_idx in range(4):
byte = (word >> (byte_idx * 8)) & 0xFF
if byte == 0:
return ''.join(chars), word_count
chars.append(chr(byte))
return ''.join(chars), word_count
def parse_spirv_instructions(words: List[int]) -> List[Tuple[int, List[int], int]]:
"""解析SPIR-V指令流返回(opcode, operands, index)列表"""
instructions = []
index = 5 # 跳过头部5个字
while index < len(words):
word = words[index]
word_count = (word >> 16) & 0xFFFF
opcode = word & 0xFFFF
if word_count == 0:
break
operands = words[index + 1 : index + word_count]
instructions.append((opcode, operands, index))
index += word_count
return instructions
# ============ Execution Model Conversion ============
def execution_model_to_stage(exec_model: int) -> str:
"""将SPIR-V执行模型转换为着色器阶段名称"""
model_map = {
EXECUTION_MODEL_VERTEX: "vertex",
EXECUTION_MODEL_TESS_CONTROL: "tess_control",
EXECUTION_MODEL_TESS_EVAL: "tess_eval",
EXECUTION_MODEL_GEOMETRY: "geometry",
EXECUTION_MODEL_FRAGMENT: "fragment",
EXECUTION_MODEL_COMPUTE: "compute",
}
return model_map.get(exec_model, "unknown")
def stage_to_execution_model(stage: str) -> int:
"""将着色器阶段名称转换为SPIR-V执行模型"""
stage_map = {
"vertex": EXECUTION_MODEL_VERTEX,
"vert": EXECUTION_MODEL_VERTEX,
"tess_control": EXECUTION_MODEL_TESS_CONTROL,
"tesc": EXECUTION_MODEL_TESS_CONTROL,
"tess_eval": EXECUTION_MODEL_TESS_EVAL,
"tese": EXECUTION_MODEL_TESS_EVAL,
"geometry": EXECUTION_MODEL_GEOMETRY,
"geom": EXECUTION_MODEL_GEOMETRY,
"fragment": EXECUTION_MODEL_FRAGMENT,
"frag": EXECUTION_MODEL_FRAGMENT,
"compute": EXECUTION_MODEL_COMPUTE,
"comp": EXECUTION_MODEL_COMPUTE,
}
return stage_map.get(stage, EXECUTION_MODEL_COMPUTE)
# ============ Type System Parsing ============
def parse_spirv_type_system(spirv_data: bytes, shader_stage: str) -> SPIRVReflection:
"""从SPIR-V提取完整的类型系统和反射信息"""
words = parse_spirv_words(spirv_data)
instructions = parse_spirv_instructions(words)
reflection = SPIRVReflection(shader_stage=shader_stage)
# 第一遍:收集常量
for opcode, operands, index in instructions:
if opcode == OP_CONSTANT and len(operands) >= 3:
result_id = operands[1]
value = operands[2]
reflection.constants[result_id] = value
# 第二遍:收集类型定义
for opcode, operands, index in instructions:
if opcode == OP_TYPE_INT and len(operands) >= 3:
result_id = operands[0]
bit_width = operands[1]
is_signed = operands[2]
base_type = BaseType.INT if is_signed else BaseType.UINT
reflection.types[result_id] = ScalarTypeInfo(
id=result_id,
base_type=base_type,
bit_width=bit_width
)
elif opcode == OP_TYPE_FLOAT and len(operands) >= 2:
result_id = operands[0]
bit_width = operands[1]
reflection.types[result_id] = ScalarTypeInfo(
id=result_id,
base_type=BaseType.FLOAT,
bit_width=bit_width
)
elif opcode == OP_TYPE_VECTOR and len(operands) >= 3:
result_id = operands[0]
component_type_id = operands[1]
component_count = operands[2]
reflection.types[result_id] = VectorTypeInfo(
id=result_id,
component_type_id=component_type_id,
component_count=component_count
)
elif opcode == OP_TYPE_MATRIX and len(operands) >= 3:
result_id = operands[0]
column_type_id = operands[1]
column_count = operands[2]
reflection.types[result_id] = MatrixTypeInfo(
id=result_id,
column_type_id=column_type_id,
column_count=column_count
)
elif opcode == OP_TYPE_ARRAY and len(operands) >= 3:
result_id = operands[0]
element_type_id = operands[1]
length_id = operands[2]
length = reflection.constants.get(length_id, 0)
reflection.types[result_id] = ArrayTypeInfo(
id=result_id,
element_type_id=element_type_id,
length=length,
is_runtime=False
)
elif opcode == OP_TYPE_RUNTIME_ARRAY and len(operands) >= 2:
result_id = operands[0]
element_type_id = operands[1]
reflection.types[result_id] = ArrayTypeInfo(
id=result_id,
element_type_id=element_type_id,
length=None,
is_runtime=True
)
elif opcode == OP_TYPE_STRUCT and len(operands) >= 1:
result_id = operands[0]
reflection.types[result_id] = StructTypeInfo(
id=result_id,
members=[]
)
elif opcode == OP_TYPE_POINTER and len(operands) >= 3:
result_id = operands[0]
storage_class = operands[1]
pointee_type_id = operands[2]
reflection.types[result_id] = PointerTypeInfo(
id=result_id,
storage_class=storage_class,
pointee_type_id=pointee_type_id
)
# 第三遍:收集名称
for opcode, operands, index in instructions:
if opcode == OP_NAME and len(operands) >= 1:
target_id = operands[0]
name, _ = parse_spirv_string(words, index + 2)
reflection.names[target_id] = name
elif opcode == OP_MEMBER_NAME and len(operands) >= 2:
struct_id = operands[0]
member_index = operands[1]
name, _ = parse_spirv_string(words, index + 3)
if struct_id not in reflection.member_names:
reflection.member_names[struct_id] = {}
reflection.member_names[struct_id][member_index] = name
# 第四遍:收集装饰
for opcode, operands, index in instructions:
if opcode == OP_DECORATE and len(operands) >= 2:
target_id = operands[0]
decoration = operands[1]
if target_id not in reflection.decorations:
reflection.decorations[target_id] = {}
if len(operands) >= 3:
reflection.decorations[target_id][decoration] = operands[2]
else:
reflection.decorations[target_id][decoration] = True
elif opcode == OP_MEMBER_DECORATE and len(operands) >= 3:
struct_id = operands[0]
member_index = operands[1]
decoration = operands[2]
if struct_id not in reflection.member_decorations:
reflection.member_decorations[struct_id] = {}
if member_index not in reflection.member_decorations[struct_id]:
reflection.member_decorations[struct_id][member_index] = {}
if len(operands) >= 4:
reflection.member_decorations[struct_id][member_index][decoration] = operands[3]
else:
reflection.member_decorations[struct_id][member_index][decoration] = True
# 填充结构体成员信息
for type_id, type_info in reflection.types.items():
if isinstance(type_info, StructTypeInfo):
for opcode, operands, _ in instructions:
if opcode == OP_TYPE_STRUCT and operands[0] == type_id:
member_type_ids = operands[1:]
type_info.members = []
for idx, member_type_id in enumerate(member_type_ids):
member_name = reflection.member_names.get(type_id, {}).get(idx, f"member_{idx}")
member_decs = reflection.member_decorations.get(type_id, {}).get(idx, {})
offset = member_decs.get(DECORATION_OFFSET, 0)
array_stride = member_decs.get(DECORATION_ARRAY_STRIDE)
matrix_stride = member_decs.get(DECORATION_MATRIX_STRIDE)
is_row_major = DECORATION_ROW_MAJOR in member_decs
type_info.members.append(MemberInfo(
index=idx,
name=member_name,
type_id=member_type_id,
offset=offset,
array_stride=array_stride,
matrix_stride=matrix_stride,
is_row_major=is_row_major
))
break
# 检查Block装饰
decs = reflection.decorations.get(type_id, {})
type_info.is_block = DECORATION_BLOCK in decs
type_info.is_buffer_block = DECORATION_BUFFER_BLOCK in decs
# 设置结构体名称
type_info.name = reflection.names.get(type_id)
# 第五遍:收集变量
for opcode, operands, _ in instructions:
if opcode == OP_VARIABLE and len(operands) >= 3:
result_type_id = operands[0]
result_id = operands[1]
storage_class = operands[2]
var_name = reflection.names.get(result_id)
var_decs = reflection.decorations.get(result_id, {})
binding = var_decs.get(DECORATION_BINDING)
descriptor_set = var_decs.get(DECORATION_DESCRIPTOR_SET, 0)
reflection.variables[result_id] = VariableInfo(
id=result_id,
name=var_name,
type_id=result_type_id,
storage_class=storage_class,
binding=binding,
descriptor_set=descriptor_set
)
# 解析类型引用
resolve_all_types(reflection)
# 提取buffer信息
extract_buffers(reflection)
# 提取入口点
for opcode, operands, index in instructions:
if opcode == OP_ENTRY_POINT and len(operands) >= 2:
exec_model = operands[0]
name, _ = parse_spirv_string(words, index + 3)
target_model = stage_to_execution_model(shader_stage)
if exec_model == target_model:
reflection.entry_point = name
break
return reflection
# ============ Type Resolution ============
def resolve_all_types(reflection: SPIRVReflection):
"""递归解析所有类型引用"""
for type_info in reflection.types.values():
resolve_type(type_info, reflection.types)
for var_info in reflection.variables.values():
if var_info.type_id in reflection.types:
var_info.resolved_type = reflection.types[var_info.type_id]
resolve_type(var_info.resolved_type, reflection.types)
def resolve_type(type_info: TypeInfo, type_map: Dict[int, TypeInfo]) -> TypeInfo:
"""递归解析类型"""
if isinstance(type_info, VectorTypeInfo):
if type_info.component_type_id in type_map:
type_info.resolved_component_type = type_map[type_info.component_type_id]
elif isinstance(type_info, MatrixTypeInfo):
if type_info.column_type_id in type_map:
col_type = type_map[type_info.column_type_id]
type_info.resolved_column_type = col_type
if isinstance(col_type, VectorTypeInfo):
resolve_type(col_type, type_map)
elif isinstance(type_info, ArrayTypeInfo):
if type_info.element_type_id in type_map:
type_info.resolved_element_type = type_map[type_info.element_type_id]
resolve_type(type_info.resolved_element_type, type_map)
elif isinstance(type_info, StructTypeInfo):
for member in type_info.members:
if member.type_id in type_map:
member.resolved_type = type_map[member.type_id]
resolve_type(member.resolved_type, type_map)
elif isinstance(type_info, PointerTypeInfo):
if type_info.pointee_type_id in type_map:
type_info.resolved_pointee_type = type_map[type_info.pointee_type_id]
resolve_type(type_info.resolved_pointee_type, type_map)
return type_info
# ============ Buffer Extraction ============
def extract_buffers(reflection: SPIRVReflection):
"""从变量中提取buffer信息"""
for var_id, var_info in reflection.variables.items():
if var_info.binding is None:
continue
var_type = reflection.types.get(var_info.type_id)
if not isinstance(var_type, PointerTypeInfo):
continue
struct_type = var_type.resolved_pointee_type
if not isinstance(struct_type, StructTypeInfo):
continue
storage_class = var_type.storage_class
if storage_class == STORAGE_CLASS_UNIFORM:
descriptor_type = "uniform_buffer"
elif storage_class == STORAGE_CLASS_STORAGE_BUFFER:
descriptor_type = "storage_buffer"
else:
continue
struct_name = struct_type.name
if not struct_name:
struct_name = f"Buffer{var_info.binding}"
reflection.buffers.append(BufferInfo(
name=struct_name,
binding=var_info.binding,
descriptor_set=var_info.descriptor_set or 0,
descriptor_type=descriptor_type,
struct_type=struct_type,
variable_name=var_info.name or f"buffer_{var_info.binding}"
))
# 按binding排序
reflection.buffers.sort(key=lambda b: b.binding)
# ============ Legacy Interface ============
def extract_spirv_reflection(spirv_data: bytes, shader_type: str) -> Dict[str, Any]:
"""从SPIR-V提取反射信息兼容旧接口"""
reflection = parse_spirv_type_system(spirv_data, shader_type)
# 转换为BindingInfo列表
bindings = []
for buffer in reflection.buffers:
stage = shader_type
if stage == "vert":
stage = "vertex"
elif stage == "frag":
stage = "fragment"
elif stage == "comp":
stage = "compute"
bindings.append(BindingInfo(
binding=buffer.binding,
descriptor_type=buffer.descriptor_type,
stages=[stage],
name=buffer.variable_name,
count=1,
))
return {
'entry_point': reflection.entry_point,
'bindings': bindings,
'reflection': reflection,
}