Files
mirage/tools/types.py
daiqingshuang 5a8d62f841 Refactor Push Constants and Add Dual Stage Support
- Removed legacy push constant structures and functions for better clarity and maintainability.
- Introduced new `text_push_constants_t` structure for text rendering with optimized layout.
- Implemented dual stage push constant analysis to support separate layouts for vertex and fragment shaders.
- Added functions to generate push constant structures and fill functions based on shader reflection.
- Enhanced static checks for push constant layouts to ensure compatibility and correctness.
- Updated templates to accommodate new dual stage push constant generation.
- Added support detection for procedural vertex shaders based on push constant layout.
2025-12-25 21:04:39 +08:00

481 lines
15 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python3
"""
数据类型定义
包含SPIR-V反射所需的所有数据类和枚举类型。
"""
from __future__ import annotations
from dataclasses import dataclass, field
from enum import Enum, Flag, auto
from pathlib import Path
from typing import Any, Dict, List, Optional
class ToolError(RuntimeError):
"""工具遇到致命错误时抛出"""
pass
# ============ Push Constant Base Members ============
class PushConstantBaseMember(Flag):
"""Push Constants Header 部分成员标志(简化布局)"""
NONE = 0
SCALE = auto() # vec2 scale (offset 0)
TRANSLATE = auto() # vec2 translate (offset 8)
# 预定义的组合
STANDARD_HEADER = SCALE | TRANSLATE # 标准 Header 布局
@dataclass
class PushConstantBaseMemberInfo:
"""Push Constants Header 部分单个成员信息"""
name: str # 成员名称(如 "scale"
member_flag: PushConstantBaseMember # 成员标志
offset: int # 实际偏移量
size: int # 实际大小
expected_offset: int # 期望偏移量
expected_size: int # 期望大小
type_name: str # C++ 类型名称
is_valid: bool = True # 是否与期望匹配
@dataclass
class PushConstantBaseInfo:
"""Push Constants Header 部分完整信息"""
members: List[PushConstantBaseMemberInfo] = field(default_factory=list)
present_flags: PushConstantBaseMember = PushConstantBaseMember.NONE
total_size: int = 0
is_standard_layout: bool = False # 是否为标准布局
# 期望布局定义(简化布局)
EXPECTED_LAYOUT = {
'uScale': (0, 8), # vec2 (offset, size)
'scale': (0, 8), # 兼容小写命名
'uTranslate': (8, 8), # vec2
'translate': (8, 8), # 兼容小写命名
}
EXPECTED_TOTAL_SIZE = 16
@property
def has_scale(self) -> bool:
"""检查是否包含 scale 成员"""
return any(m.name in ('uScale', 'scale') for m in self.members)
@property
def has_translate(self) -> bool:
"""检查是否包含 translate 成员"""
return any(m.name in ('uTranslate', 'translate') for m in self.members)
# ============ Enums ============
class StorageClass(Enum):
"""SPIR-V存储类"""
UNIFORM_CONSTANT = 0
INPUT = 1
UNIFORM = 2
OUTPUT = 3
WORKGROUP = 4
CROSS_WORKGROUP = 5
PRIVATE = 6
FUNCTION = 7
GENERIC = 8
PUSH_CONSTANT = 9
ATOMIC_COUNTER = 10
IMAGE = 11
STORAGE_BUFFER = 12
class BaseType(Enum):
"""基础数据类型"""
VOID = "void"
BOOL = "bool"
INT = "int"
UINT = "uint"
FLOAT = "float"
DOUBLE = "double"
# ============ Type Info Classes ============
@dataclass
class TypeInfo:
"""类型信息基类"""
id: int
name: Optional[str] = None
@dataclass
class ScalarTypeInfo(TypeInfo):
"""标量类型"""
base_type: BaseType = BaseType.FLOAT
bit_width: int = 32
@dataclass
class VectorTypeInfo(TypeInfo):
"""向量类型"""
component_type_id: int = 0
component_count: int = 4
resolved_component_type: Optional[TypeInfo] = None
@dataclass
class MatrixTypeInfo(TypeInfo):
"""矩阵类型"""
column_type_id: int = 0
column_count: int = 4
resolved_column_type: Optional[VectorTypeInfo] = None
@property
def row_count(self) -> int:
"""行数 = 列向量的分量数"""
if self.resolved_column_type:
return self.resolved_column_type.component_count
return 4
@dataclass
class ArrayTypeInfo(TypeInfo):
"""数组类型"""
element_type_id: int = 0
length: Optional[int] = None
is_runtime: bool = False
stride: Optional[int] = None
resolved_element_type: Optional[TypeInfo] = None
@dataclass
class MemberInfo:
"""结构体成员信息"""
index: int
name: str
type_id: int
offset: int = 0
resolved_type: Optional[TypeInfo] = None
matrix_stride: Optional[int] = None
array_stride: Optional[int] = None
is_row_major: bool = False
@dataclass
class StructTypeInfo(TypeInfo):
"""结构体类型"""
members: List[MemberInfo] = field(default_factory=list)
is_block: bool = False
is_buffer_block: bool = False
@dataclass
class ImageTypeInfo(TypeInfo):
"""图像类型OpTypeImage"""
sampled_type_id: int = 0
dimension: int = 0
depth: int = 0
arrayed: int = 0
ms: int = 0
sampled: int = 0
format: int = 0
@dataclass
class SampledImageTypeInfo(TypeInfo):
"""采样图像类型OpTypeSampledImage"""
image_type_id: int = 0
@dataclass
class SamplerTypeInfo(TypeInfo):
"""采样器类型OpTypeSampler"""
pass
@dataclass
class PointerTypeInfo(TypeInfo):
"""指针类型"""
storage_class: int = 0
pointee_type_id: int = 0
resolved_pointee_type: Optional[TypeInfo] = None
# ============ Reflection Data Classes ============
@dataclass
class VariableInfo:
"""变量信息"""
id: int
name: Optional[str]
type_id: int
storage_class: int
binding: Optional[int] = None
descriptor_set: Optional[int] = None
resolved_type: Optional[TypeInfo] = None
@dataclass
class BufferInfo:
"""Buffer/Uniform完整信息"""
name: str
binding: int
descriptor_set: int
descriptor_type: str
struct_type: StructTypeInfo
variable_name: str
element_count: Optional[int] = None # 数组元素数量对于运行时数组为None
is_dynamic: bool = False # 是否支持动态大小
access_mode: str = "read_write" # 访问模式: read_only, write_only, read_write
@dataclass
class InstanceBufferInfo:
"""实例缓冲信息
用于描述通过 @instance_buffer 注释标记的 SSBO
包含用于实例化渲染的每实例数据。
必需字段验证:
- 结构体第一个成员必须是 'rect' (vec4)
- rect 偏移必须为 0
- 结构体大小必须是 16 字节的倍数
"""
name: str # 缓冲变量名称
struct_type_name: str # 结构体类型名称(如 GradientInstanceData
binding: int # binding 编号
set_number: int # descriptor set 编号
members: List[MemberInfo] # 结构体成员列表
struct_type: StructTypeInfo # 完整的结构体类型信息
has_rect: bool = True # 是否包含必需的 rect 字段(验证后始终为 True
rect_offset: int = 0 # rect 字段的偏移量(验证后始终为 0
total_size: int = 0 # 结构体总大小(字节)
@dataclass
class BindingInfo:
"""描述符绑定信息"""
binding: int
descriptor_type: str
stages: List[str]
descriptor_set: int = 0
name: Optional[str] = None
count: int = 1
# ============ Vertex Input Data Classes ============
@dataclass
class VertexAttribute:
"""顶点输入属性信息
用于描述顶点着色器中声明的输入变量in 修饰符)。
对应 GLSL 中的 layout(location = N) in type name;
Attributes:
name: 属性名称(去除 in_ 前缀后的名称)
location: 位置编号(对应 layout(location = N)
type_id: SPIR-V 类型 ID
cpp_type: C++ 类型名称
vk_format: Vulkan 格式枚举值
offset: 在顶点结构体中的偏移量(字节)
size: 属性大小(字节)
"""
name: str
location: int
type_id: int
cpp_type: str = "float"
vk_format: str = "VK_FORMAT_R32_SFLOAT"
offset: int = 0
size: int = 4
resolved_type: Optional[TypeInfo] = None
@dataclass
class VertexLayout:
"""顶点输入布局信息
包含顶点着色器中所有输入属性的集合,用于生成 C++ 顶点结构体
和 Vulkan 顶点输入状态描述。
Attributes:
attributes: 按 location 排序的顶点属性列表
stride: 顶点结构体的步长(字节)
struct_name: 生成的 C++ 结构体名称
"""
attributes: List[VertexAttribute] = field(default_factory=list)
stride: int = 0
struct_name: str = "Vertex"
@property
def attribute_count(self) -> int:
"""返回顶点属性数量"""
return len(self.attributes)
@property
def has_attributes(self) -> bool:
"""检查是否有顶点属性"""
return len(self.attributes) > 0
# Push Constants 布局常量(字节)
# 128 字节简化布局:
# - Header [0-15]: scale (vec2) + translate (vec2) - 16 字节(由系统填充)
# - Custom [16-127]: 用户自定义数据 - 112 字节
#
# PUSH_CONSTANT_HEADER_SIZE (16) 标记 Header 部分的大小
# 用户自定义参数 (custom_members/effect_members) 应该从 offset 16 开始
PUSH_CONSTANT_HEADER_SIZE = 16
@dataclass
class PushConstantInfo:
"""Push Constants 信息"""
name: str # 结构体名称
struct_type: StructTypeInfo # 完整结构体类型
base_offset: int = PUSH_CONSTANT_HEADER_SIZE # Header 部分结束的偏移量(字节)
base_info: Optional[PushConstantBaseInfo] = None # Header 部分信息
effect_members: List[MemberInfo] = field(default_factory=list) # 用户自定义部分的成员列表(为兼容性保留字段名)
@property
def custom_members(self) -> List[MemberInfo]:
"""用户自定义部分的成员列表effect_members 的别名)"""
return self.effect_members
@dataclass
class StagePushConstantInfo:
"""单个着色器阶段的 Push Constant 信息
用于描述顶点或片元着色器中 Push Constant 的 Custom 区域布局。
Custom 区域从偏移 16 字节开始Header 之后)。
Attributes:
stage: 着色器阶段,"vertex""fragment"
members: Custom 区域的成员列表(偏移量已转换为相对于 Custom 区域)
relative_offset: 相对于 Custom 区域的起始偏移从16开始
total_size: Custom 区域的总大小(字节)
alignment: 结构体对齐要求
"""
stage: str # "vertex" 或 "fragment"
members: List[MemberInfo] = field(default_factory=list)
relative_offset: int = PUSH_CONSTANT_HEADER_SIZE # 相对于 Custom 区域的起始偏移从16开始
total_size: int = 0
alignment: int = 4 # 结构体对齐要求
@property
def has_members(self) -> bool:
"""检查是否有 Custom 成员"""
return len(self.members) > 0
@property
def end_offset(self) -> int:
"""计算 Custom 区域的结束偏移量"""
return self.relative_offset + self.total_size
@dataclass
class CombinedPushConstantInfo:
"""合并后的双阶段 Push Constant 信息
用于分析顶点和片元着色器的 Push Constant 布局关系,
判断是否可以共享同一个结构体或需要分离处理。
布局模式:
1. 共享模式 (overlapping=True): 两个阶段使用相同的 Custom 区域布局
2. 分离模式 (overlapping=False): 两个阶段使用不同的 Custom 区域布局
Attributes:
vertex_info: 顶点着色器的 Push Constant 信息(可为 None
fragment_info: 片元着色器的 Push Constant 信息(可为 None
overlapping: 是否有重叠区域(共享模式)
shared_members: 共享的成员名称列表
"""
vertex_info: Optional[StagePushConstantInfo] = None
fragment_info: Optional[StagePushConstantInfo] = None
# 布局分析结果
overlapping: bool = False # 是否有重叠区域(共享模式)
shared_members: List[str] = field(default_factory=list) # 共享的成员名称
@property
def has_vertex(self) -> bool:
"""检查是否有顶点着色器的 Push Constant"""
return self.vertex_info is not None and self.vertex_info.has_members
@property
def has_fragment(self) -> bool:
"""检查是否有片元着色器的 Push Constant"""
return self.fragment_info is not None and self.fragment_info.has_members
@property
def is_shared_layout(self) -> bool:
"""检查是否使用共享布局模式"""
return self.overlapping and len(self.shared_members) > 0
@property
def vertex_only_members(self) -> List[str]:
"""获取仅在顶点着色器中存在的成员名称"""
if not self.has_vertex:
return []
vert_names = {m.name for m in self.vertex_info.members}
return [name for name in vert_names if name not in self.shared_members]
@property
def fragment_only_members(self) -> List[str]:
"""获取仅在片元着色器中存在的成员名称"""
if not self.has_fragment:
return []
frag_names = {m.name for m in self.fragment_info.members}
return [name for name in frag_names if name not in self.shared_members]
@dataclass
class SPIRVReflection:
"""完整的SPIR-V反射信息"""
types: Dict[int, TypeInfo] = field(default_factory=dict)
variables: Dict[int, VariableInfo] = field(default_factory=dict)
names: Dict[int, str] = field(default_factory=dict)
member_names: Dict[int, Dict[int, str]] = field(default_factory=dict)
decorations: Dict[int, Dict[int, Any]] = field(default_factory=dict)
member_decorations: Dict[int, Dict[int, Dict[int, Any]]] = field(default_factory=dict)
constants: Dict[int, Any] = field(default_factory=dict)
buffers: List[BufferInfo] = field(default_factory=list)
push_constant: Optional[PushConstantInfo] = None # Push Constants 信息
instance_buffer: Optional[InstanceBufferInfo] = None # 实例缓冲信息(@instance_buffer 标记)
vertex_layout: Optional[VertexLayout] = None # 顶点输入布局信息(仅顶点着色器)
entry_point: str = "main"
shader_stage: str = "compute"
# ============ Shader Metadata Classes ============
@dataclass
class ShaderMetadata:
"""着色器配置元数据"""
name: str
shader_type: str
entry_points: Dict[str, str]
output_header: str
namespace: Optional[str] = None
bindings: List[BindingInfo] = field(default_factory=list)
include_paths: List[Path] = field(default_factory=list)
defines: Dict[str, str] = field(default_factory=dict)
reflection: Optional[SPIRVReflection] = None
generate_buffer_helpers: bool = True # 生成基础辅助函数
generate_typed_buffers: bool = True # 生成类型安全的 buffer 包装器
generate_buffer_manager: bool = True # 生成 buffer 管理器类
@dataclass
class CompilationResult:
"""编译结果"""
source_file: Path
shader_type: str
spirv_data: bytes
entry_point: str
bindings: List[BindingInfo] = field(default_factory=list)
reflection: Optional[SPIRVReflection] = None