210 lines
5.1 KiB
Python
210 lines
5.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
数据类型定义
|
||
|
||
包含SPIR-V反射所需的所有数据类和枚举类型。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
|
||
class ToolError(RuntimeError):
|
||
"""工具遇到致命错误时抛出"""
|
||
pass
|
||
|
||
|
||
# ============ Enums ============
|
||
|
||
class StorageClass(Enum):
|
||
"""SPIR-V存储类"""
|
||
UNIFORM_CONSTANT = 0
|
||
INPUT = 1
|
||
UNIFORM = 2
|
||
OUTPUT = 3
|
||
WORKGROUP = 4
|
||
CROSS_WORKGROUP = 5
|
||
PRIVATE = 6
|
||
FUNCTION = 7
|
||
GENERIC = 8
|
||
PUSH_CONSTANT = 9
|
||
ATOMIC_COUNTER = 10
|
||
IMAGE = 11
|
||
STORAGE_BUFFER = 12
|
||
|
||
|
||
class BaseType(Enum):
|
||
"""基础数据类型"""
|
||
VOID = "void"
|
||
BOOL = "bool"
|
||
INT = "int"
|
||
UINT = "uint"
|
||
FLOAT = "float"
|
||
DOUBLE = "double"
|
||
|
||
|
||
# ============ Type Info Classes ============
|
||
|
||
@dataclass
|
||
class TypeInfo:
|
||
"""类型信息基类"""
|
||
id: int
|
||
name: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class ScalarTypeInfo(TypeInfo):
|
||
"""标量类型"""
|
||
base_type: BaseType = BaseType.FLOAT
|
||
bit_width: int = 32
|
||
|
||
|
||
@dataclass
|
||
class VectorTypeInfo(TypeInfo):
|
||
"""向量类型"""
|
||
component_type_id: int = 0
|
||
component_count: int = 4
|
||
resolved_component_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class MatrixTypeInfo(TypeInfo):
|
||
"""矩阵类型"""
|
||
column_type_id: int = 0
|
||
column_count: int = 4
|
||
resolved_column_type: Optional[VectorTypeInfo] = None
|
||
|
||
@property
|
||
def row_count(self) -> int:
|
||
"""行数 = 列向量的分量数"""
|
||
if self.resolved_column_type:
|
||
return self.resolved_column_type.component_count
|
||
return 4
|
||
|
||
|
||
@dataclass
|
||
class ArrayTypeInfo(TypeInfo):
|
||
"""数组类型"""
|
||
element_type_id: int = 0
|
||
length: Optional[int] = None
|
||
is_runtime: bool = False
|
||
stride: Optional[int] = None
|
||
resolved_element_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class MemberInfo:
|
||
"""结构体成员信息"""
|
||
index: int
|
||
name: str
|
||
type_id: int
|
||
offset: int = 0
|
||
resolved_type: Optional[TypeInfo] = None
|
||
matrix_stride: Optional[int] = None
|
||
array_stride: Optional[int] = None
|
||
is_row_major: bool = False
|
||
|
||
|
||
@dataclass
|
||
class StructTypeInfo(TypeInfo):
|
||
"""结构体类型"""
|
||
members: List[MemberInfo] = field(default_factory=list)
|
||
is_block: bool = False
|
||
is_buffer_block: bool = False
|
||
|
||
|
||
@dataclass
|
||
class PointerTypeInfo(TypeInfo):
|
||
"""指针类型"""
|
||
storage_class: int = 0
|
||
pointee_type_id: int = 0
|
||
resolved_pointee_type: Optional[TypeInfo] = None
|
||
|
||
|
||
# ============ Reflection Data Classes ============
|
||
|
||
@dataclass
|
||
class VariableInfo:
|
||
"""变量信息"""
|
||
id: int
|
||
name: Optional[str]
|
||
type_id: int
|
||
storage_class: int
|
||
binding: Optional[int] = None
|
||
descriptor_set: Optional[int] = None
|
||
resolved_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class BufferInfo:
|
||
"""Buffer/Uniform完整信息"""
|
||
name: str
|
||
binding: int
|
||
descriptor_set: int
|
||
descriptor_type: str
|
||
struct_type: StructTypeInfo
|
||
variable_name: str
|
||
|
||
element_count: Optional[int] = None # 数组元素数量(对于运行时数组为None)
|
||
is_dynamic: bool = False # 是否支持动态大小
|
||
access_mode: str = "read_write" # 访问模式: read_only, write_only, read_write
|
||
|
||
|
||
@dataclass
|
||
class BindingInfo:
|
||
"""描述符绑定信息"""
|
||
binding: int
|
||
descriptor_type: str
|
||
stages: List[str]
|
||
name: Optional[str] = None
|
||
count: int = 1
|
||
|
||
|
||
@dataclass
|
||
class SPIRVReflection:
|
||
"""完整的SPIR-V反射信息"""
|
||
types: Dict[int, TypeInfo] = field(default_factory=dict)
|
||
variables: Dict[int, VariableInfo] = field(default_factory=dict)
|
||
names: Dict[int, str] = field(default_factory=dict)
|
||
member_names: Dict[int, Dict[int, str]] = field(default_factory=dict)
|
||
decorations: Dict[int, Dict[int, Any]] = field(default_factory=dict)
|
||
member_decorations: Dict[int, Dict[int, Dict[int, Any]]] = field(default_factory=dict)
|
||
constants: Dict[int, Any] = field(default_factory=dict)
|
||
buffers: List[BufferInfo] = field(default_factory=list)
|
||
entry_point: str = "main"
|
||
shader_stage: str = "compute"
|
||
|
||
|
||
# ============ Shader Metadata Classes ============
|
||
|
||
@dataclass
|
||
class ShaderMetadata:
|
||
"""着色器配置元数据"""
|
||
name: str
|
||
shader_type: str
|
||
entry_points: Dict[str, str]
|
||
output_header: str
|
||
namespace: Optional[str] = None
|
||
bindings: List[BindingInfo] = field(default_factory=list)
|
||
include_paths: List[Path] = field(default_factory=list)
|
||
defines: Dict[str, str] = field(default_factory=dict)
|
||
reflection: Optional[SPIRVReflection] = None
|
||
|
||
generate_buffer_helpers: bool = True # 生成基础辅助函数
|
||
generate_typed_buffers: bool = True # 生成类型安全的 buffer 包装器
|
||
generate_buffer_manager: bool = True # 生成 buffer 管理器类
|
||
|
||
|
||
@dataclass
|
||
class CompilationResult:
|
||
"""编译结果"""
|
||
source_file: Path
|
||
shader_type: str
|
||
spirv_data: bytes
|
||
entry_point: str
|
||
bindings: List[BindingInfo] = field(default_factory=list)
|
||
reflection: Optional[SPIRVReflection] = None |