- Removed legacy push constant structures and functions for better clarity and maintainability. - Introduced new `text_push_constants_t` structure for text rendering with optimized layout. - Implemented dual stage push constant analysis to support separate layouts for vertex and fragment shaders. - Added functions to generate push constant structures and fill functions based on shader reflection. - Enhanced static checks for push constant layouts to ensure compatibility and correctness. - Updated templates to accommodate new dual stage push constant generation. - Added support detection for procedural vertex shaders based on push constant layout.
481 lines
15 KiB
Python
481 lines
15 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
数据类型定义
|
||
|
||
包含SPIR-V反射所需的所有数据类和枚举类型。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from dataclasses import dataclass, field
|
||
from enum import Enum, Flag, auto
|
||
from pathlib import Path
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
|
||
class ToolError(RuntimeError):
|
||
"""工具遇到致命错误时抛出"""
|
||
pass
|
||
|
||
|
||
# ============ Push Constant Base Members ============
|
||
|
||
class PushConstantBaseMember(Flag):
|
||
"""Push Constants Header 部分成员标志(简化布局)"""
|
||
NONE = 0
|
||
SCALE = auto() # vec2 scale (offset 0)
|
||
TRANSLATE = auto() # vec2 translate (offset 8)
|
||
|
||
# 预定义的组合
|
||
STANDARD_HEADER = SCALE | TRANSLATE # 标准 Header 布局
|
||
|
||
|
||
@dataclass
|
||
class PushConstantBaseMemberInfo:
|
||
"""Push Constants Header 部分单个成员信息"""
|
||
name: str # 成员名称(如 "scale")
|
||
member_flag: PushConstantBaseMember # 成员标志
|
||
offset: int # 实际偏移量
|
||
size: int # 实际大小
|
||
expected_offset: int # 期望偏移量
|
||
expected_size: int # 期望大小
|
||
type_name: str # C++ 类型名称
|
||
is_valid: bool = True # 是否与期望匹配
|
||
|
||
|
||
@dataclass
|
||
class PushConstantBaseInfo:
|
||
"""Push Constants Header 部分完整信息"""
|
||
members: List[PushConstantBaseMemberInfo] = field(default_factory=list)
|
||
present_flags: PushConstantBaseMember = PushConstantBaseMember.NONE
|
||
total_size: int = 0
|
||
is_standard_layout: bool = False # 是否为标准布局
|
||
|
||
# 期望布局定义(简化布局)
|
||
EXPECTED_LAYOUT = {
|
||
'uScale': (0, 8), # vec2 (offset, size)
|
||
'scale': (0, 8), # 兼容小写命名
|
||
'uTranslate': (8, 8), # vec2
|
||
'translate': (8, 8), # 兼容小写命名
|
||
}
|
||
EXPECTED_TOTAL_SIZE = 16
|
||
|
||
@property
|
||
def has_scale(self) -> bool:
|
||
"""检查是否包含 scale 成员"""
|
||
return any(m.name in ('uScale', 'scale') for m in self.members)
|
||
|
||
@property
|
||
def has_translate(self) -> bool:
|
||
"""检查是否包含 translate 成员"""
|
||
return any(m.name in ('uTranslate', 'translate') for m in self.members)
|
||
|
||
|
||
# ============ Enums ============
|
||
|
||
class StorageClass(Enum):
|
||
"""SPIR-V存储类"""
|
||
UNIFORM_CONSTANT = 0
|
||
INPUT = 1
|
||
UNIFORM = 2
|
||
OUTPUT = 3
|
||
WORKGROUP = 4
|
||
CROSS_WORKGROUP = 5
|
||
PRIVATE = 6
|
||
FUNCTION = 7
|
||
GENERIC = 8
|
||
PUSH_CONSTANT = 9
|
||
ATOMIC_COUNTER = 10
|
||
IMAGE = 11
|
||
STORAGE_BUFFER = 12
|
||
|
||
|
||
class BaseType(Enum):
|
||
"""基础数据类型"""
|
||
VOID = "void"
|
||
BOOL = "bool"
|
||
INT = "int"
|
||
UINT = "uint"
|
||
FLOAT = "float"
|
||
DOUBLE = "double"
|
||
|
||
|
||
# ============ Type Info Classes ============
|
||
|
||
@dataclass
|
||
class TypeInfo:
|
||
"""类型信息基类"""
|
||
id: int
|
||
name: Optional[str] = None
|
||
|
||
|
||
@dataclass
|
||
class ScalarTypeInfo(TypeInfo):
|
||
"""标量类型"""
|
||
base_type: BaseType = BaseType.FLOAT
|
||
bit_width: int = 32
|
||
|
||
|
||
@dataclass
|
||
class VectorTypeInfo(TypeInfo):
|
||
"""向量类型"""
|
||
component_type_id: int = 0
|
||
component_count: int = 4
|
||
resolved_component_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class MatrixTypeInfo(TypeInfo):
|
||
"""矩阵类型"""
|
||
column_type_id: int = 0
|
||
column_count: int = 4
|
||
resolved_column_type: Optional[VectorTypeInfo] = None
|
||
|
||
@property
|
||
def row_count(self) -> int:
|
||
"""行数 = 列向量的分量数"""
|
||
if self.resolved_column_type:
|
||
return self.resolved_column_type.component_count
|
||
return 4
|
||
|
||
|
||
@dataclass
|
||
class ArrayTypeInfo(TypeInfo):
|
||
"""数组类型"""
|
||
element_type_id: int = 0
|
||
length: Optional[int] = None
|
||
is_runtime: bool = False
|
||
stride: Optional[int] = None
|
||
resolved_element_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class MemberInfo:
|
||
"""结构体成员信息"""
|
||
index: int
|
||
name: str
|
||
type_id: int
|
||
offset: int = 0
|
||
resolved_type: Optional[TypeInfo] = None
|
||
matrix_stride: Optional[int] = None
|
||
array_stride: Optional[int] = None
|
||
is_row_major: bool = False
|
||
|
||
|
||
@dataclass
|
||
class StructTypeInfo(TypeInfo):
|
||
"""结构体类型"""
|
||
members: List[MemberInfo] = field(default_factory=list)
|
||
is_block: bool = False
|
||
is_buffer_block: bool = False
|
||
|
||
|
||
@dataclass
|
||
class ImageTypeInfo(TypeInfo):
|
||
"""图像类型(OpTypeImage)"""
|
||
sampled_type_id: int = 0
|
||
dimension: int = 0
|
||
depth: int = 0
|
||
arrayed: int = 0
|
||
ms: int = 0
|
||
sampled: int = 0
|
||
format: int = 0
|
||
|
||
|
||
@dataclass
|
||
class SampledImageTypeInfo(TypeInfo):
|
||
"""采样图像类型(OpTypeSampledImage)"""
|
||
image_type_id: int = 0
|
||
|
||
|
||
@dataclass
|
||
class SamplerTypeInfo(TypeInfo):
|
||
"""采样器类型(OpTypeSampler)"""
|
||
pass
|
||
|
||
|
||
@dataclass
|
||
class PointerTypeInfo(TypeInfo):
|
||
"""指针类型"""
|
||
storage_class: int = 0
|
||
pointee_type_id: int = 0
|
||
resolved_pointee_type: Optional[TypeInfo] = None
|
||
|
||
|
||
# ============ Reflection Data Classes ============
|
||
|
||
@dataclass
|
||
class VariableInfo:
|
||
"""变量信息"""
|
||
id: int
|
||
name: Optional[str]
|
||
type_id: int
|
||
storage_class: int
|
||
binding: Optional[int] = None
|
||
descriptor_set: Optional[int] = None
|
||
resolved_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class BufferInfo:
|
||
"""Buffer/Uniform完整信息"""
|
||
name: str
|
||
binding: int
|
||
descriptor_set: int
|
||
descriptor_type: str
|
||
struct_type: StructTypeInfo
|
||
variable_name: str
|
||
|
||
element_count: Optional[int] = None # 数组元素数量(对于运行时数组为None)
|
||
is_dynamic: bool = False # 是否支持动态大小
|
||
access_mode: str = "read_write" # 访问模式: read_only, write_only, read_write
|
||
|
||
|
||
@dataclass
|
||
class InstanceBufferInfo:
|
||
"""实例缓冲信息
|
||
|
||
用于描述通过 @instance_buffer 注释标记的 SSBO,
|
||
包含用于实例化渲染的每实例数据。
|
||
|
||
必需字段验证:
|
||
- 结构体第一个成员必须是 'rect' (vec4)
|
||
- rect 偏移必须为 0
|
||
- 结构体大小必须是 16 字节的倍数
|
||
"""
|
||
name: str # 缓冲变量名称
|
||
struct_type_name: str # 结构体类型名称(如 GradientInstanceData)
|
||
binding: int # binding 编号
|
||
set_number: int # descriptor set 编号
|
||
members: List[MemberInfo] # 结构体成员列表
|
||
struct_type: StructTypeInfo # 完整的结构体类型信息
|
||
has_rect: bool = True # 是否包含必需的 rect 字段(验证后始终为 True)
|
||
rect_offset: int = 0 # rect 字段的偏移量(验证后始终为 0)
|
||
total_size: int = 0 # 结构体总大小(字节)
|
||
|
||
|
||
@dataclass
|
||
class BindingInfo:
|
||
"""描述符绑定信息"""
|
||
binding: int
|
||
descriptor_type: str
|
||
stages: List[str]
|
||
descriptor_set: int = 0
|
||
name: Optional[str] = None
|
||
count: int = 1
|
||
|
||
|
||
# ============ Vertex Input Data Classes ============
|
||
|
||
@dataclass
|
||
class VertexAttribute:
|
||
"""顶点输入属性信息
|
||
|
||
用于描述顶点着色器中声明的输入变量(in 修饰符)。
|
||
对应 GLSL 中的 layout(location = N) in type name;
|
||
|
||
Attributes:
|
||
name: 属性名称(去除 in_ 前缀后的名称)
|
||
location: 位置编号(对应 layout(location = N))
|
||
type_id: SPIR-V 类型 ID
|
||
cpp_type: C++ 类型名称
|
||
vk_format: Vulkan 格式枚举值
|
||
offset: 在顶点结构体中的偏移量(字节)
|
||
size: 属性大小(字节)
|
||
"""
|
||
name: str
|
||
location: int
|
||
type_id: int
|
||
cpp_type: str = "float"
|
||
vk_format: str = "VK_FORMAT_R32_SFLOAT"
|
||
offset: int = 0
|
||
size: int = 4
|
||
resolved_type: Optional[TypeInfo] = None
|
||
|
||
|
||
@dataclass
|
||
class VertexLayout:
|
||
"""顶点输入布局信息
|
||
|
||
包含顶点着色器中所有输入属性的集合,用于生成 C++ 顶点结构体
|
||
和 Vulkan 顶点输入状态描述。
|
||
|
||
Attributes:
|
||
attributes: 按 location 排序的顶点属性列表
|
||
stride: 顶点结构体的步长(字节)
|
||
struct_name: 生成的 C++ 结构体名称
|
||
"""
|
||
attributes: List[VertexAttribute] = field(default_factory=list)
|
||
stride: int = 0
|
||
struct_name: str = "Vertex"
|
||
|
||
@property
|
||
def attribute_count(self) -> int:
|
||
"""返回顶点属性数量"""
|
||
return len(self.attributes)
|
||
|
||
@property
|
||
def has_attributes(self) -> bool:
|
||
"""检查是否有顶点属性"""
|
||
return len(self.attributes) > 0
|
||
|
||
|
||
# Push Constants 布局常量(字节)
|
||
# 128 字节简化布局:
|
||
# - Header [0-15]: scale (vec2) + translate (vec2) - 16 字节(由系统填充)
|
||
# - Custom [16-127]: 用户自定义数据 - 112 字节
|
||
#
|
||
# PUSH_CONSTANT_HEADER_SIZE (16) 标记 Header 部分的大小
|
||
# 用户自定义参数 (custom_members/effect_members) 应该从 offset 16 开始
|
||
PUSH_CONSTANT_HEADER_SIZE = 16
|
||
|
||
|
||
@dataclass
|
||
class PushConstantInfo:
|
||
"""Push Constants 信息"""
|
||
name: str # 结构体名称
|
||
struct_type: StructTypeInfo # 完整结构体类型
|
||
base_offset: int = PUSH_CONSTANT_HEADER_SIZE # Header 部分结束的偏移量(字节)
|
||
base_info: Optional[PushConstantBaseInfo] = None # Header 部分信息
|
||
effect_members: List[MemberInfo] = field(default_factory=list) # 用户自定义部分的成员列表(为兼容性保留字段名)
|
||
|
||
@property
|
||
def custom_members(self) -> List[MemberInfo]:
|
||
"""用户自定义部分的成员列表(effect_members 的别名)"""
|
||
return self.effect_members
|
||
|
||
|
||
@dataclass
|
||
class StagePushConstantInfo:
|
||
"""单个着色器阶段的 Push Constant 信息
|
||
|
||
用于描述顶点或片元着色器中 Push Constant 的 Custom 区域布局。
|
||
Custom 区域从偏移 16 字节开始(Header 之后)。
|
||
|
||
Attributes:
|
||
stage: 着色器阶段,"vertex" 或 "fragment"
|
||
members: Custom 区域的成员列表(偏移量已转换为相对于 Custom 区域)
|
||
relative_offset: 相对于 Custom 区域的起始偏移(从16开始)
|
||
total_size: Custom 区域的总大小(字节)
|
||
alignment: 结构体对齐要求
|
||
"""
|
||
stage: str # "vertex" 或 "fragment"
|
||
members: List[MemberInfo] = field(default_factory=list)
|
||
relative_offset: int = PUSH_CONSTANT_HEADER_SIZE # 相对于 Custom 区域的起始偏移(从16开始)
|
||
total_size: int = 0
|
||
alignment: int = 4 # 结构体对齐要求
|
||
|
||
@property
|
||
def has_members(self) -> bool:
|
||
"""检查是否有 Custom 成员"""
|
||
return len(self.members) > 0
|
||
|
||
@property
|
||
def end_offset(self) -> int:
|
||
"""计算 Custom 区域的结束偏移量"""
|
||
return self.relative_offset + self.total_size
|
||
|
||
|
||
@dataclass
|
||
class CombinedPushConstantInfo:
|
||
"""合并后的双阶段 Push Constant 信息
|
||
|
||
用于分析顶点和片元着色器的 Push Constant 布局关系,
|
||
判断是否可以共享同一个结构体或需要分离处理。
|
||
|
||
布局模式:
|
||
1. 共享模式 (overlapping=True): 两个阶段使用相同的 Custom 区域布局
|
||
2. 分离模式 (overlapping=False): 两个阶段使用不同的 Custom 区域布局
|
||
|
||
Attributes:
|
||
vertex_info: 顶点着色器的 Push Constant 信息(可为 None)
|
||
fragment_info: 片元着色器的 Push Constant 信息(可为 None)
|
||
overlapping: 是否有重叠区域(共享模式)
|
||
shared_members: 共享的成员名称列表
|
||
"""
|
||
vertex_info: Optional[StagePushConstantInfo] = None
|
||
fragment_info: Optional[StagePushConstantInfo] = None
|
||
|
||
# 布局分析结果
|
||
overlapping: bool = False # 是否有重叠区域(共享模式)
|
||
shared_members: List[str] = field(default_factory=list) # 共享的成员名称
|
||
|
||
@property
|
||
def has_vertex(self) -> bool:
|
||
"""检查是否有顶点着色器的 Push Constant"""
|
||
return self.vertex_info is not None and self.vertex_info.has_members
|
||
|
||
@property
|
||
def has_fragment(self) -> bool:
|
||
"""检查是否有片元着色器的 Push Constant"""
|
||
return self.fragment_info is not None and self.fragment_info.has_members
|
||
|
||
@property
|
||
def is_shared_layout(self) -> bool:
|
||
"""检查是否使用共享布局模式"""
|
||
return self.overlapping and len(self.shared_members) > 0
|
||
|
||
@property
|
||
def vertex_only_members(self) -> List[str]:
|
||
"""获取仅在顶点着色器中存在的成员名称"""
|
||
if not self.has_vertex:
|
||
return []
|
||
vert_names = {m.name for m in self.vertex_info.members}
|
||
return [name for name in vert_names if name not in self.shared_members]
|
||
|
||
@property
|
||
def fragment_only_members(self) -> List[str]:
|
||
"""获取仅在片元着色器中存在的成员名称"""
|
||
if not self.has_fragment:
|
||
return []
|
||
frag_names = {m.name for m in self.fragment_info.members}
|
||
return [name for name in frag_names if name not in self.shared_members]
|
||
|
||
|
||
@dataclass
|
||
class SPIRVReflection:
|
||
"""完整的SPIR-V反射信息"""
|
||
types: Dict[int, TypeInfo] = field(default_factory=dict)
|
||
variables: Dict[int, VariableInfo] = field(default_factory=dict)
|
||
names: Dict[int, str] = field(default_factory=dict)
|
||
member_names: Dict[int, Dict[int, str]] = field(default_factory=dict)
|
||
decorations: Dict[int, Dict[int, Any]] = field(default_factory=dict)
|
||
member_decorations: Dict[int, Dict[int, Dict[int, Any]]] = field(default_factory=dict)
|
||
constants: Dict[int, Any] = field(default_factory=dict)
|
||
buffers: List[BufferInfo] = field(default_factory=list)
|
||
push_constant: Optional[PushConstantInfo] = None # Push Constants 信息
|
||
instance_buffer: Optional[InstanceBufferInfo] = None # 实例缓冲信息(@instance_buffer 标记)
|
||
vertex_layout: Optional[VertexLayout] = None # 顶点输入布局信息(仅顶点着色器)
|
||
entry_point: str = "main"
|
||
shader_stage: str = "compute"
|
||
|
||
|
||
# ============ Shader Metadata Classes ============
|
||
|
||
@dataclass
|
||
class ShaderMetadata:
|
||
"""着色器配置元数据"""
|
||
name: str
|
||
shader_type: str
|
||
entry_points: Dict[str, str]
|
||
output_header: str
|
||
namespace: Optional[str] = None
|
||
bindings: List[BindingInfo] = field(default_factory=list)
|
||
include_paths: List[Path] = field(default_factory=list)
|
||
defines: Dict[str, str] = field(default_factory=dict)
|
||
reflection: Optional[SPIRVReflection] = None
|
||
|
||
generate_buffer_helpers: bool = True # 生成基础辅助函数
|
||
generate_typed_buffers: bool = True # 生成类型安全的 buffer 包装器
|
||
generate_buffer_manager: bool = True # 生成 buffer 管理器类
|
||
|
||
|
||
@dataclass
|
||
class CompilationResult:
|
||
"""编译结果"""
|
||
source_file: Path
|
||
shader_type: str
|
||
spirv_data: bytes
|
||
entry_point: str
|
||
bindings: List[BindingInfo] = field(default_factory=list)
|
||
reflection: Optional[SPIRVReflection] = None |