Files
mirage/tools/types.py

210 lines
5.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
数据类型定义
包含SPIR-V反射所需的所有数据类和枚举类型。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum
from pathlib import Path
from typing import Any, Dict, List, Optional
class ToolError(RuntimeError):
"""工具遇到致命错误时抛出"""
pass
# ============ Enums ============
class StorageClass(Enum):
"""SPIR-V存储类"""
UNIFORM_CONSTANT = 0
INPUT = 1
UNIFORM = 2
OUTPUT = 3
WORKGROUP = 4
CROSS_WORKGROUP = 5
PRIVATE = 6
FUNCTION = 7
GENERIC = 8
PUSH_CONSTANT = 9
ATOMIC_COUNTER = 10
IMAGE = 11
STORAGE_BUFFER = 12
class BaseType(Enum):
"""基础数据类型"""
VOID = "void"
BOOL = "bool"
INT = "int"
UINT = "uint"
FLOAT = "float"
DOUBLE = "double"
# ============ Type Info Classes ============
@dataclass
class TypeInfo:
"""类型信息基类"""
id: int
name: Optional[str] = None
@dataclass
class ScalarTypeInfo(TypeInfo):
"""标量类型"""
base_type: BaseType = BaseType.FLOAT
bit_width: int = 32
@dataclass
class VectorTypeInfo(TypeInfo):
"""向量类型"""
component_type_id: int = 0
component_count: int = 4
resolved_component_type: Optional[TypeInfo] = None
@dataclass
class MatrixTypeInfo(TypeInfo):
"""矩阵类型"""
column_type_id: int = 0
column_count: int = 4
resolved_column_type: Optional[VectorTypeInfo] = None
@property
def row_count(self) -> int:
"""行数 = 列向量的分量数"""
if self.resolved_column_type:
return self.resolved_column_type.component_count
return 4
@dataclass
class ArrayTypeInfo(TypeInfo):
"""数组类型"""
element_type_id: int = 0
length: Optional[int] = None
is_runtime: bool = False
stride: Optional[int] = None
resolved_element_type: Optional[TypeInfo] = None
@dataclass
class MemberInfo:
"""结构体成员信息"""
index: int
name: str
type_id: int
offset: int = 0
resolved_type: Optional[TypeInfo] = None
matrix_stride: Optional[int] = None
array_stride: Optional[int] = None
is_row_major: bool = False
@dataclass
class StructTypeInfo(TypeInfo):
"""结构体类型"""
members: List[MemberInfo] = field(default_factory=list)
is_block: bool = False
is_buffer_block: bool = False
@dataclass
class PointerTypeInfo(TypeInfo):
"""指针类型"""
storage_class: int = 0
pointee_type_id: int = 0
resolved_pointee_type: Optional[TypeInfo] = None
# ============ Reflection Data Classes ============
@dataclass
class VariableInfo:
"""变量信息"""
id: int
name: Optional[str]
type_id: int
storage_class: int
binding: Optional[int] = None
descriptor_set: Optional[int] = None
resolved_type: Optional[TypeInfo] = None
@dataclass
class BufferInfo:
"""Buffer/Uniform完整信息"""
name: str
binding: int
descriptor_set: int
descriptor_type: str
struct_type: StructTypeInfo
variable_name: str
element_count: Optional[int] = None # 数组元素数量对于运行时数组为None
is_dynamic: bool = False # 是否支持动态大小
access_mode: str = "read_write" # 访问模式: read_only, write_only, read_write
@dataclass
class BindingInfo:
"""描述符绑定信息"""
binding: int
descriptor_type: str
stages: List[str]
name: Optional[str] = None
count: int = 1
@dataclass
class SPIRVReflection:
"""完整的SPIR-V反射信息"""
types: Dict[int, TypeInfo] = field(default_factory=dict)
variables: Dict[int, VariableInfo] = field(default_factory=dict)
names: Dict[int, str] = field(default_factory=dict)
member_names: Dict[int, Dict[int, str]] = field(default_factory=dict)
decorations: Dict[int, Dict[int, Any]] = field(default_factory=dict)
member_decorations: Dict[int, Dict[int, Dict[int, Any]]] = field(default_factory=dict)
constants: Dict[int, Any] = field(default_factory=dict)
buffers: List[BufferInfo] = field(default_factory=list)
entry_point: str = "main"
shader_stage: str = "compute"
# ============ Shader Metadata Classes ============
@dataclass
class ShaderMetadata:
"""着色器配置元数据"""
name: str
shader_type: str
entry_points: Dict[str, str]
output_header: str
namespace: Optional[str] = None
bindings: List[BindingInfo] = field(default_factory=list)
include_paths: List[Path] = field(default_factory=list)
defines: Dict[str, str] = field(default_factory=dict)
reflection: Optional[SPIRVReflection] = None
generate_buffer_helpers: bool = True # 生成基础辅助函数
generate_typed_buffers: bool = True # 生成类型安全的 buffer 包装器
generate_buffer_manager: bool = True # 生成 buffer 管理器类
@dataclass
class CompilationResult:
"""编译结果"""
source_file: Path
shader_type: str
spirv_data: bytes
entry_point: str
bindings: List[BindingInfo] = field(default_factory=list)
reflection: Optional[SPIRVReflection] = None