Files
mirage_slang/shader_types.py
2025-06-07 03:39:44 +08:00

116 lines
2.7 KiB
Python

#!/usr/bin/env python3
"""
SDL3_GPU Slang Compiler - Type Definitions
数据类型和枚举定义
"""
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Optional, Dict
class ShaderStage(Enum):
VERTEX = "vertex"
FRAGMENT = "fragment"
COMPUTE = "compute"
class ResourceType(Enum):
SAMPLED_TEXTURE = "sampled_texture"
STORAGE_TEXTURE = "storage_texture"
STORAGE_BUFFER = "storage_buffer"
UNIFORM_BUFFER = "uniform_buffer"
SAMPLER = "sampler"
class TargetFormat(Enum):
SPIRV = "spirv"
DXIL = "dxil"
DXBC = "dxbc"
MSL = "msl"
@dataclass
class Resource:
name: str
type: ResourceType
binding: int = -1
set: int = -1
space: int = -1
register: str = ""
metal_index: int = -1
@dataclass
class ShaderInfo:
stage: ShaderStage
entry_point: str
resources: List[Resource]
source_code: str
# 数据模型类
@dataclass
class FieldType:
"""字段类型信息"""
kind: str # 'scalar', 'vector', 'matrix'
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
element_count: Optional[int] = None # for vector
row_count: Optional[int] = None # for matrix
column_count: Optional[int] = None # for matrix
@classmethod
def from_dict(cls, data: Dict) -> 'FieldType':
"""从字典创建FieldType对象"""
kind = data.get('kind')
if kind == 'vector':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
element_count=data['elementCount']
)
elif kind == 'scalar':
return cls(
kind=kind,
scalar_type=data['scalarType']
)
elif kind == 'matrix':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
row_count=data['rowCount'],
column_count=data['columnCount']
)
return cls(kind='scalar', scalar_type='float32')
@dataclass
class VertexField:
"""顶点输入字段"""
name: str
type: FieldType
location: int
semantic: str
semantic_index: int
@dataclass
class UniformField:
"""Uniform缓冲区字段"""
name: str
type: FieldType
offset: int
size: int
@dataclass
class UniformBuffer:
"""Uniform缓冲区"""
name: str
binding: int
fields: List[UniformField] = field(default_factory=list)
@dataclass
class ShaderLayout:
"""着色器布局数据"""
vertex_fields: List[VertexField] = field(default_factory=list)
uniform_buffers: List[UniformBuffer] = field(default_factory=list)