重构80%

This commit is contained in:
2025-06-18 16:33:11 +08:00
parent 562f598a41
commit 6cc65e12b2
11 changed files with 619 additions and 563 deletions

View File

@@ -7,6 +7,7 @@
#include "shaders/mirage_text.hlsl.h"
#include "shaders/mirage_wireframe.hlsl.h"
#include "shaders/mirage_line.hlsl.h"
#include "shaders/mirage_line.shader.h"
template<typename Derived>
void compute_rect_vertices(const Eigen::MatrixBase<Derived>& in_pos,

View File

@@ -5,53 +5,19 @@ Slang Compiler - Code Generator
"""
from pathlib import Path
from code_generator_helper import *
from global_vars import global_vars
from indent_manager import IndentManager
from shader_types import *
from shader_reflection_type import *
class CodeGenerator:
"""C++绑定代码生成器"""
def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str) -> None:
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None:
"""生成C++绑定函数的入口方法"""
self._generate_cpp_bindings(binding_infos, output_path)
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
"""生成Uniform缓冲区结构体"""
writer.write("// Uniform buffer structures")
for buffer in uniform_buffers:
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
# 计算总大小和对齐要求
total_size = 0
max_alignment = 16 # GPU通常要求16字节对齐
with writer.block(f'struct {struct_name}', '};'):
for i, field in enumerate(buffer.fields):
# 检查是否需要填充
if field.offset > total_size:
padding_size = field.offset - total_size
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
total_size = field.offset
# 生成字段声明
declaration = get_c_type_declaration(field.type, field.name)
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
total_size = field.offset + field.size
# 确保结构体大小正确对齐
aligned_size = get_aligned_size(total_size, max_alignment)
if aligned_size > total_size:
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
writer.write()
def _generate_cpp_bindings(self, binding_infos: List[ShaderInfo], output_path: str) -> None:
def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None:
"""生成C++绑定函数"""
output_file = Path(output_path)
# 尝试创建输出目录
@@ -61,7 +27,7 @@ class CodeGenerator:
writer = IndentManager(file)
self._write_complete_file(writer, binding_infos)
def _write_complete_file(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入完整的文件内容"""
self._write_header(writer)
@@ -81,19 +47,19 @@ class CodeGenerator:
writer.write(header)
writer.write()
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入ShaderBindings类"""
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
# self._generate_uniform_structs(writer, global_vars.vertex_layout.uniform_buffers)
self._write_blob(writer, binding_infos)
self._write_public_methods(writer, binding_infos)
self._write_get_pipeline_desc_method(writer, binding_infos)
def _write_blob(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入public部分的结构体定义"""
# 写入二进制内容
for info in binding_infos:
entry_point = info.entry_point
for stage, info in binding_infos.stages.items():
entry_point = info.get_entry_name()
blob_data = info.blob
with writer.block(f'static constexpr std::array<uint8_t, {len(blob_data)}> {global_vars.source_file_name}_{entry_point}_blob =', '};'):
@@ -108,69 +74,100 @@ class CodeGenerator:
writer.write()
writer.write()
def _write_public_methods(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入公共方法"""
self._write_get_shader_info_method(writer, binding_infos)
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入getShaderInfo方法"""
with writer.indent(f'static sg_shader_desc get_{global_vars.source_file_name}_shader_desc()'' {', '}'):
writer.write('sg_shader_desc desc = {};')
writer.write(f'desc.label = "{global_vars.source_file_name}";')
writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";')
writer.write('// 顶点布局')
with writer.indent('{', '}'):
for i, vertex_field in enumerate(global_vars.layout.vertex_fields):
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semantic}";')
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semantic_index};')
for i, vertex_field in enumerate(binding_infos.vertex_layout):
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";')
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};')
for info in binding_infos:
entry_point = info.entry_point
for stage, info in binding_infos.stages.items():
entry_point = info.get_entry_name()
entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob'
stage_name = info.stage.name
stage_name = info.get_stage().name
writer.write(f'// {stage_name}')
with writer.indent('{','}'):
writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name})')
writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name});')
writer.write(f'desc.{stage_name.lower()}_func.entry = "{entry_point}";')
writer.write('// 资源绑定')
resource_index_map = {}
# 处理资源绑定
for resource in info.resources:
binding_index = resource.binding_index
param_index = resource_index_map.get(resource.type, 0)
resource_index_map[resource.type] = param_index + 1
self._write_resource_binding(writer, info.parameters)
if resource.type == ResourceType.UNIFORM_BUFFER:
uniform_info = resource.uniform_data
writer.write(f'desc.uniform_blocks[{param_index}].name = "{resource.name}";')
writer.write(f'desc.uniform_blocks[{param_index}].size = {uniform_info.size};')
writer.write(f'desc.uniform_blocks[{param_index}].hlsl_register_b_n = {uniform_info.binding};')
writer.write(f'desc.uniform_blocks[{param_index}].msl_buffer_n = {uniform_info.binding};')
writer.write(f'desc.uniform_blocks[{param_index}].wgsl_group0_binding_n = {uniform_info.binding};')
def _write_resource_binding(self, writer: IndentManager, parameters: List[Parameter]) -> None:
resource_index = {}
target = global_vars.target
glsl_block_index = {}
elif resource.type == ResourceType.SAMPLED_TEXTURE:
writer.write(f'desc.images[{param_index}].stage = SG_SHADERSTAGE_{stage_name.upper()}')
writer.write(f'desc.images[{param_index}].name = "{resource.name}";')
writer.write(f'desc.sampled_textures[{param_index}].name = "{resource.name}";')
for p in parameters:
binding_kind = p.get_binding_kind()
index = resource_index.get(binding_kind, 0)
resource_index[binding_kind] = index + 1
stage_name = p.stage.name.upper()
if binding_kind == BindingKind.UNIFORM:
t = f'desc.uniform_blocks[{index}]'
writer.write(f'{t}.stage = SG_SHADERSTAGE_{stage_name};')
writer.write(f'{t}.size = {p.get_byte_size()};')
if target == TargetFormat.GLSL:
glsl_index = glsl_block_index.get(binding_kind, 0)
glsl_block_index[binding_kind] = glsl_index + 1
writer.write(f'{t}.glsl_uniforms[{glsl_index}].type = ;')
writer.write(f'{t}.glsl_uniforms[{glsl_index}].array_count = 1;')
writer.write(f'{t}.glsl_uniforms[{glsl_index}].glsl_name = "{p.name}";')
elif target == TargetFormat.DXBC:
writer.write(f'{t}.hlsl_register_b_n = {p.get_register_index()};')
elif resource.type == ResourceType.STORAGE_BUFFER:
writer.write(f'desc.storage_buffers[{param_index}].name = "{resource.name}";')
elif resource.type == ResourceType.SAMPLER:
writer.write(f'desc.samplers[{param_index}].name = "{resource.name}";')
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
"""写入getPipelineDesc方法"""
with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc(sg_shader shader, sg_pixel_format pixel_format, int32_t sample_count)'' {', '}'):
function_params = ('\n\t\t\tsg_shader shader, \n'
'\t\t\tsg_pixel_format pixel_format, \n'
'\t\t\tint32_t sample_count = 1, \n'
'\t\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES, \n'
'\t\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE\n')
with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc({function_params}\t)'' {', '}'):
writer.write('sg_pipeline_desc desc = {};')
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline";')
writer.write('desc.shader = shader;')
writer.write('desc.index_type = SG_INDEXTYPE_UINT32;')
writer.write()
# 处理顶点缓冲区大小
writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};')
writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;')
writer.write('desc.layout.buffers[0].step_rate = 1;')
writer.write()
# 处理顶点输入布局
for i, vertex_field in enumerate(global_vars.layout.vertex_fields):
writer.write(f'// {vertex_field.semantic}')
writer.write(f'desc.layout.attrs[{i}].buffer_index = {vertex_field.location}')
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.offset}')
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.type.scalar_type}')
for i, vertex_field in enumerate(binding_infos.vertex_layout):
writer.write(f'// {vertex_field.semanticName} {vertex_field.semanticIndex or ''}')
writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;')
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};')
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};')
writer.write()
writer.write('desc.primitive_type = primitive_type;')
writer.write('desc.cull_mode = cull_mode;')
writer.write('desc.face_winding = SG_FACEWINDING_CW;')
writer.write('desc.depth.write_enabled = false;')
writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;')
writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;')
writer.write('desc.colors[0].blend.enabled = true;')
writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;')
writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;')
writer.write('desc.colors[0].pixel_format = pixel_format;')
writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;')
writer.write('desc.sample_count = sample_count;')
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline_desc";')
writer.write('return desc;')

View File

@@ -1,6 +1,6 @@
from typing import Tuple
from shader_types import FieldType
from shader_reflection_type import Field, TypeKind, ScalarType
format_mapping = {
('int32', 1): 'SG_VERTEXFORMAT_INVALID',
@@ -29,57 +29,56 @@ format_mapping = {
# 类型大小映射
type_sizes = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
ScalarType.INT32: 4,
ScalarType.UINT32: 4,
ScalarType.FLOAT32: 4,
ScalarType.INT8: 1,
ScalarType.UINT8: 1,
ScalarType.INT16: 2,
ScalarType.UINT16: 2,
ScalarType.FLOAT16: 2, # 注意C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
}
def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
# C类型映射表
c_type_mapping = {
ScalarType.INT32: 'int32_t',
ScalarType.UINT32: 'uint32_t',
ScalarType.FLOAT32: 'float',
ScalarType.INT8: 'int8_t',
ScalarType.UINT8: 'uint8_t',
ScalarType.INT16: 'int16_t',
ScalarType.UINT16: 'uint16_t',
ScalarType.FLOAT16: 'mirage_float16', # 注意C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
}
# 对齐要求映射表
alignment_mapping = {
ScalarType.INT32: 4,
ScalarType.UINT32: 4,
ScalarType.FLOAT32: 4,
ScalarType.INT8: 1,
ScalarType.UINT8: 1,
ScalarType.INT16: 2,
ScalarType.UINT16: 2,
ScalarType.FLOAT16: 2,
}
def get_type_info(field_type: Field) -> Tuple[str, int, int]:
"""获取类型信息C类型名称、元素数量、字节大小
Returns:
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
"""
# C类型映射表
c_type_mapping = {
'int32': 'int32_t',
'uint32': 'uint32_t',
'float32': 'float',
'int8': 'int8_t',
'uint8': 'uint8_t',
'int16': 'int16_t',
'uint16': 'uint16_t',
'float16': 'mirage_float16', # 注意C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
}
kind = field_type.type.kind
scalar_type = field_type.type.scalarType
element_count = field_type.type.elementCount
c_type = c_type_mapping.get(scalar_type, 'float')
# 对齐要求映射表
alignment_mapping = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
if field_type.kind == 'scalar':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
if kind == TypeKind.SCALAR:
size = type_sizes.get(scalar_type, 4)
return c_type, 1, size
elif field_type.kind == 'vector':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
element_count = field_type.element_count or 4
elif kind == TypeKind.VECTOR:
# 验证向量元素数量的合法性
if element_count not in [1, 2, 3, 4]:
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
@@ -93,12 +92,9 @@ def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
return c_type, element_count, size
elif field_type.kind == 'matrix':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
row_count = field_type.row_count or 4
column_count = field_type.column_count or 4
elif kind == TypeKind.MATRIX:
row_count = field_type.type.rowCount or 4
column_count = field_type.type.columnCount or 4
# 验证矩阵维度的合法性
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
@@ -125,8 +121,7 @@ def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
return 'float', 1, 4
def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
def get_c_type_declaration(field_type: Field, field_name: str) -> str:
"""生成C类型声明
Args:
@@ -137,11 +132,12 @@ def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
str: 完整的C类型声明
"""
c_type, count, _ = get_type_info(field_type)
kind = field_type.type.kind
if field_type.kind == 'scalar':
if kind == TypeKind.SCALAR:
return f"{c_type} {field_name}"
elif field_type.kind == 'vector':
elif kind == TypeKind.VECTOR:
# 对于向量,可以选择使用数组或专门的向量类型
if count == 1:
return f"{c_type} {field_name}"
@@ -152,9 +148,9 @@ def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
# 选项2: 使用对齐的结构体(如果需要的话)
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
elif field_type.kind == 'matrix':
rows = field_type.row_count or 4
cols = field_type.column_count or 4
elif kind == TypeKind.MATRIX:
rows = field_type.type.rowCount or 4
cols = field_type.type.columnCount or 4
# 矩阵使用二维数组(列主序)
return f"{c_type} {field_name}[{cols}][{rows}]"

View File

@@ -8,7 +8,7 @@ from typing import List, Dict
from code_generator import CodeGenerator
from shader_parser import ShaderParser
from shader_types import ShaderInfo
from shader_reflection_type import ShaderReflection, ShaderInfos
class SlangCompiler:
@@ -16,10 +16,10 @@ class SlangCompiler:
self.parser = ShaderParser()
self.code_generator = CodeGenerator()
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
def parse_slang_shader(self) -> List[ShaderReflection]:
"""解析Slang着色器源码提取资源信息"""
return self.parser.parse_slang_shader()
def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str):
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str):
"""生成C/C++绑定函数"""
self.code_generator.generate_binding_functions(binding_infos, output_path)

View File

@@ -4,21 +4,22 @@ Slang Compiler - Command Generation
"""
from exe_finder import slangc_path
from global_vars import global_vars
from shader_types import TargetFormat, ShaderStage
from shader_reflection_type import *
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, output_path: str):
def make_cmd(source_file: str, target: TargetFormat, stage: Stage, entry_point: str, output_path: str):
"""生成编译命令"""
target_flag = {
TargetFormat.SPIRV: 'spirv',
TargetFormat.GLSL: 'glsl',
TargetFormat.DXBC: 'dxbc',
TargetFormat.MSL: 'metal',
TargetFormat.HLSL_DX11: 'hlsl',
}[target]
stage_flag = {
ShaderStage.VERTEX: 'vertex',
ShaderStage.FRAGMENT: 'fragment',
ShaderStage.COMPUTE: 'compute'
Stage.VERTEX: 'vertex',
Stage.FRAGMENT: 'fragment',
Stage.COMPUTE: 'compute'
}[stage]
cmd = [

View File

@@ -1,11 +1,11 @@
from shader_types import *
from shader_reflection_type import *
class GlobalVars:
source_file = ''
source_file_name = ''
source_path = ''
output_dir = ''
target: TargetFormat
layout: ShaderLayout
target = TargetFormat.DXBC
global_vars = GlobalVars()

View File

@@ -31,6 +31,10 @@ def main():
global_vars.output_dir = os.path.abspath(args.output_dir)
global_vars.target = TargetFormat(args.target)
shader_infos = ShaderInfos(
stages={}
)
# 仅保留路径部分
include_dirs = [
global_vars.source_path,
@@ -45,12 +49,8 @@ def main():
# 解析着色器
print(f"**Parsing** {args.input}...")
shaders = compiler.parse_slang_shader()
# 编译每个入口点
shader_infos = []
for name, shader_info in shaders.items():
shader_infos.append(shader_info)
for shader_info in shaders:
shader_infos.add_shader_info(shader_info)
binding_output_file_pathname = os.path.abspath(args.output_dir)
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{global_vars.source_file_name}.shader.h")

View File

@@ -1,85 +0,0 @@
import json
from typing import List, Dict
from shader_types import *
from shader_types import UniformBuffer
class ShaderLayoutParser:
"""解析JSON数据并提取到类对象"""
def parse(self, json_file: str) -> ShaderLayout:
"""解析JSON文件并返回ShaderLayout对象"""
with open(json_file, 'r') as f:
json_data = json.load(f)
layout = ShaderLayout()
layout.vertex_fields = self.parse_vertex_input(json_data)
layout.uniform_buffers = self.parse_uniform_buffers(json_data)
return layout
@staticmethod
def parse_vertex_input(json_data: Dict) -> List[VertexField]:
"""解析顶点输入字段"""
vertex_fields = []
entry_points = json_data.get('entryPoints', [])
for entry in entry_points:
if entry.get('stage') == 'vertex':
parameters = entry.get('parameters', [])
offset = entry.get('offset', 0)
for param in parameters:
if param.get('name') == 'input' and param.get('stage') == 'vertex':
fields = param.get('type', {}).get('fields', [])
for field in fields:
vertex_field = VertexField(
name=field['name'],
type=FieldType.from_dict(field['type']),
location=field['binding']['index'],
semantic=field.get('semanticName', ''),
semantic_index=field.get('semanticIndex', 0)
)
vertex_fields.append(vertex_field)
return vertex_fields
@staticmethod
def parse_uniform_buffer(param: Dict) -> UniformBuffer | None:
binding = param.get('binding', {})
type_info = param.get('type', {})
kind = type_info.get('kind', '')
if kind == 'constantBuffer' or kind == 'parameterBlock':
buffer = UniformBuffer(
name=param['name'],
binding=binding['index']
)
# 解析缓冲区字段
element_type = param.get('type', {}).get('elementType', {})
if element_type.get('kind') == 'struct':
fields = element_type.get('fields', [])
for field in fields:
uniform_field = UniformField(
name=field['name'],
type=FieldType.from_dict(field['type']),
offset=field['binding']['offset'],
size=field['binding']['size']
)
buffer.size += uniform_field.size
buffer.fields.append(uniform_field)
return buffer
return None
@staticmethod
def parse_uniform_buffers(json_data: Dict) -> List[UniformBuffer]:
"""解析Uniform缓冲区"""
uniform_buffers = []
parameters = json_data.get('parameters', [])
for param in parameters:
buffer = ShaderLayoutParser.parse_uniform_buffer(param)
uniform_buffers.append(buffer)
return uniform_buffers

View File

@@ -13,12 +13,11 @@ from typing import List, Dict, Optional
from compiler_cmd import make_cmd
from global_vars import global_vars
from shader_layout import ShaderLayoutParser
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo, UniformBuffer, FieldType, ResourceSubType
from shader_reflection_type import *
class ShaderParser:
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
def parse_slang_shader(self) -> List[ShaderReflection]:
"""解析Slang着色器源码提取资源信息"""
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
@@ -28,7 +27,7 @@ class ShaderParser:
entry_points = self._find_entry_points(source)
print(f"Found potential entry points: {entry_points}")
shaders = {}
shaders = []
# 为每个入口点单独进行完整编译和反射
for entry_name, stage in entry_points.items():
@@ -39,22 +38,22 @@ class ShaderParser:
)
if shader_info:
shaders[entry_name] = shader_info
shaders.append(shader_info)
else:
print(f"Failed to process entry point: {entry_name}")
return shaders
def _find_entry_points(self, source: str) -> Dict[str, ShaderStage]:
def _find_entry_points(self, source: str) -> Dict[str, Stage]:
"""在源码中查找入口点函数"""
entry_points = {}
# 1. 查找带有Slang属性的函数
attribute_patterns = [
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.VERTEX),
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.COMPUTE),
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.VERTEX),
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.COMPUTE),
]
for pattern, stage in attribute_patterns:
@@ -66,29 +65,29 @@ class ShaderParser:
# 2. 查找常见的命名约定
common_entry_points = {
# Vertex shaders
'vertex_main': ShaderStage.VERTEX,
'vertexMain': ShaderStage.VERTEX,
'vert_main': ShaderStage.VERTEX,
'vs_main': ShaderStage.VERTEX,
'vertex_shader': ShaderStage.VERTEX,
'VS': ShaderStage.VERTEX,
'vertex_main': Stage.VERTEX,
'vertexMain': Stage.VERTEX,
'vert_main': Stage.VERTEX,
'vs_main': Stage.VERTEX,
'vertex_shader': Stage.VERTEX,
'VS': Stage.VERTEX,
# Fragment shaders
'fragment_main': ShaderStage.FRAGMENT,
'fragmentMain': ShaderStage.FRAGMENT,
'frag_main': ShaderStage.FRAGMENT,
'pixel_main': ShaderStage.FRAGMENT,
'ps_main': ShaderStage.FRAGMENT,
'fragment_shader': ShaderStage.FRAGMENT,
'PS': ShaderStage.FRAGMENT,
'fragment_main': Stage.FRAGMENT,
'fragmentMain': Stage.FRAGMENT,
'frag_main': Stage.FRAGMENT,
'pixel_main': Stage.FRAGMENT,
'ps_main': Stage.FRAGMENT,
'fragment_shader': Stage.FRAGMENT,
'PS': Stage.FRAGMENT,
# Compute shaders
'compute_main': ShaderStage.COMPUTE,
'computeMain': ShaderStage.COMPUTE,
'comp_main': ShaderStage.COMPUTE,
'cs_main': ShaderStage.COMPUTE,
'compute_shader': ShaderStage.COMPUTE,
'CS': ShaderStage.COMPUTE,
'compute_main': Stage.COMPUTE,
'computeMain': Stage.COMPUTE,
'comp_main': Stage.COMPUTE,
'cs_main': Stage.COMPUTE,
'compute_shader': Stage.COMPUTE,
'CS': Stage.COMPUTE,
}
# 查找这些函数名是否在源码中存在
@@ -112,22 +111,22 @@ class ShaderParser:
# 顶点着色器关键词
if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']):
entry_points[func_name] = ShaderStage.VERTEX
entry_points[func_name] = Stage.VERTEX
print(f"Inferred vertex shader: {func_name}")
# 片元着色器关键词
elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']):
entry_points[func_name] = ShaderStage.FRAGMENT
entry_points[func_name] = Stage.FRAGMENT
print(f"Inferred fragment shader: {func_name}")
# 计算着色器关键词
elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']):
entry_points[func_name] = ShaderStage.COMPUTE
entry_points[func_name] = Stage.COMPUTE
print(f"Inferred compute shader: {func_name}")
# 4. 查找带有HLSL语义的函数
semantic_patterns = [
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', ShaderStage.VERTEX),
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', ShaderStage.FRAGMENT),
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', ShaderStage.VERTEX),
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', Stage.VERTEX),
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', Stage.FRAGMENT),
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', Stage.VERTEX),
]
for pattern, stage in semantic_patterns:
@@ -141,13 +140,13 @@ class ShaderParser:
# 5. 如果没找到任何入口点查找main函数
if not entry_points:
if re.search(r'\bmain\s*\(', source, re.IGNORECASE):
entry_points['main'] = ShaderStage.VERTEX
entry_points['main'] = Stage.VERTEX
print("Found main function, assuming vertex shader")
print(f"Total entry points found: {len(entry_points)}")
return entry_points
def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
def _compile_and_reflect_entry_point(self, entry_name: str, stage: Stage, source: str) -> Optional[ShaderReflection]:
"""为单个入口点进行完整编译和反射"""
# 创建临时输出文件
@@ -180,10 +179,6 @@ class ShaderParser:
if result.returncode == 0:
# 读取反射数据
if os.path.exists(reflection_path):
if stage == ShaderStage.VERTEX:
shader_layout_ = ShaderLayoutParser()
global_vars.layout = shader_layout_.parse(reflection_path)
with open(reflection_path, 'r', encoding='utf-8') as f:
reflection_json = f.read()
print(f"Reflection JSON length: {len(reflection_json)}")
@@ -191,17 +186,15 @@ class ShaderParser:
if reflection_json.strip():
try:
reflection = json.loads(reflection_json)
shader_info = self._create_shader_info_from_reflection(
reflection, entry_name, stage, source
)
shader_reflection = ShaderReflection.from_dict(reflection)
# 读取编译后的二进制数据
with open(temp_output_path, 'rb') as f:
shader_binary = f.read()
# 写入shader_info.blob
shader_info.blob = shader_binary
shader_reflection.blob = shader_binary
return shader_info
return shader_reflection
except json.JSONDecodeError as e:
print(f"JSON parsing error: {e}")
print(f"Raw JSON: {reflection_json[:500]}...")
@@ -217,161 +210,4 @@ class ShaderParser:
if os.path.exists(temp_file):
os.unlink(temp_file)
return self._create_shader_info_manual(entry_name, stage, source)
def _create_shader_info_from_reflection(self, reflection: dict, entry_name: str,
stage: ShaderStage, source: str) -> ShaderInfo:
"""从反射数据创建ShaderInfo"""
print(f"Processing reflection data for {entry_name}")
print(f"Reflection keys: {list(reflection.keys())}")
shader_info = ShaderInfo(
stage=stage,
entry_point=entry_name,
resources=[],
blob=b'', # 二进制数据将在编译后填充
)
# Slang反射数据的可能结构
# 尝试不同的数据结构
entry_point_data = None
# 方法1: 直接在根级别查找
if 'parameters' in reflection:
entry_point_data = reflection
# 方法2: 在entryPoints数组中查找
elif 'entryPoints' in reflection:
for ep in reflection['entryPoints']:
if ep.get('name') == entry_name:
entry_point_data = ep
break
# 方法3: 在modules中查找
elif 'modules' in reflection:
for module in reflection['modules']:
if 'entryPoints' in module:
for ep in module['entryPoints']:
if ep.get('name') == entry_name:
entry_point_data = ep
break
if entry_point_data:
print(f"Found entry point data: {list(entry_point_data.keys())}")
# 解析资源参数
parameters = entry_point_data.get('parameters', [])
print(f"Found {len(parameters)} parameters")
for param in parameters:
print(f"Processing parameter: {param}")
resource = self._parse_resource(param)
if resource:
shader_info.resources.append(resource)
print(f"Added resource: {resource.name} ({resource.type.value})")
else:
print("No entry point data found in reflection, using manual parsing")
# 使用手动解析作为fallback
manual_resources = self._extract_resources_from_source(source)
shader_info.resources.extend(manual_resources)
print(f"Shader {entry_name} has {len(shader_info.resources)} resources")
return shader_info
def _create_shader_info_manual(self, entry_name: str, stage: ShaderStage, source: str) -> ShaderInfo:
"""手动创建ShaderInfofallback方法"""
print(f"Creating shader info manually for {entry_name}")
shader_info = ShaderInfo(
stage=stage,
entry_point=entry_name,
resources=[],
blob=b'',
)
# 手动解析资源
resources = self._extract_resources_from_source(source)
shader_info.resources.extend(resources)
print(f"Manual parsing found {len(resources)} resources")
return shader_info
def _extract_resources_from_source(self, source: str) -> List[Resource]:
"""从源码中提取资源声明"""
resources = []
# 资源声明的正则表达式模式
patterns = {
# Texture resources
ResourceType.SAMPLED_TEXTURE: [
r'Texture2D\s*(?:<[^>]*>)?\s+(\w+)',
r'Texture3D\s*(?:<[^>]*>)?\s+(\w+)',
r'TextureCube\s*(?:<[^>]*>)?\s+(\w+)',
],
ResourceType.STORAGE_TEXTURE: [
r'RWTexture2D\s*(?:<[^>]*>)?\s+(\w+)',
r'RWTexture3D\s*(?:<[^>]*>)?\s+(\w+)',
],
# Buffer resources
ResourceType.STORAGE_BUFFER: [
r'RWStructuredBuffer\s*<[^>]*>\s+(\w+)',
r'RWByteAddressBuffer\s+(\w+)',
r'StructuredBuffer\s*<[^>]*>\s+(\w+)',
r'ByteAddressBuffer\s+(\w+)',
],
ResourceType.UNIFORM_BUFFER: [
r'ConstantBuffer\s*<[^>]*>\s+(\w+)',
r'cbuffer\s+(\w+)',
],
ResourceType.SAMPLER: [
r'SamplerState\s+(\w+)',
r'SamplerComparisonState\s+(\w+)',
]
}
for resource_type, type_patterns in patterns.items():
for pattern in type_patterns:
matches = re.findall(pattern, source, re.IGNORECASE)
for match in matches:
resources.append(Resource(match, resource_type))
print(f"Found resource: {match} ({resource_type.value})")
return resources
def _parse_resource(self, param: dict) -> Optional[Resource]:
"""解析资源参数"""
type_info = param.get('type', {})
kind = type_info.get('kind', '')
if kind == 'resource':
kind = type_info.get('baseShape', '')
param_name = param.get('name', '')
resource = Resource(param_name, ResourceType.STORAGE_BUFFER)
resource.binding_index = param.get('binding', {}).get('index', -1)
# 判断资源类型
if 'texture' in kind:
if 'RW' in kind:
resource.type = ResourceType.STORAGE_TEXTURE
else:
resource.type = ResourceType.SAMPLED_TEXTURE
if '2D' in kind:
resource.sub_type = ResourceSubType.TEXTURE_2D
elif 'cube' in kind:
resource.sub_type = ResourceSubType.TEXTURE_CUBE
elif '3D' in kind:
resource.sub_type = ResourceSubType.TEXTURE_3D
elif 'array' in kind:
resource.sub_type = ResourceSubType.TEXTURE_ARRAY
elif kind == 'structuredBuffer' or kind == 'byteAddressBuffer':
if 'RW' in kind:
resource.type = ResourceType.STORAGE_BUFFER
else:
resource.type = ResourceType.STORAGE_BUFFER
elif kind == 'constantBuffer' or kind == 'parameterBlock':
resource.type = ResourceType.UNIFORM_BUFFER
resource.uniform_data = ShaderLayoutParser.parse_uniform_buffer(param)
elif kind == 'samplerState' or 'sampler' in kind:
resource.type = ResourceType.SAMPLER
return resource
return None

View File

@@ -0,0 +1,429 @@
from dataclasses import dataclass, field
from enum import Enum
from typing import List, Optional, Union, Dict, Any
# 枚举类定义
class TargetFormat(Enum):
GLSL = "glsl"
DXBC = "dxbc"
MSL = "msl"
HLSL_DX11 = "hlsl"
class BindingKind(Enum):
SUB_ELEMENT_REGISTER_SPACE = "subElementRegisterSpace"
UNIFORM = "uniform"
TEXTURE = "texture"
STORAGE_BUFFER = "storageBuffer"
CONSTANT_BUFFER = "constantBuffer"
SAMPLER = "sampler"
VARYING_INPUT = "varyingInput"
VARYING_OUTPUT = "varyingOutput"
class TypeKind(Enum):
PARAMETER_BLOCK = "parameterBlock"
STRUCT = "struct"
MATRIX = "matrix"
SCALAR = "scalar"
VECTOR = "vector"
class ScalarType(Enum):
INT32 = "int32"
UINT32 = "uint32"
INT8 = "int8"
UINT8 = "uint8"
INT16 = "int16"
UINT16 = "uint16"
FLOAT16 = "float16"
FLOAT32 = "float32"
class Stage(Enum):
VERTEX = "vertex"
FRAGMENT = "fragment"
PIXEL = "pixel"
COMPUTE = "compute"
# 基础数据类
@dataclass
class Binding:
kind: BindingKind
index: Optional[int] = None
offset: Optional[int] = None
size: Optional[int] = None
count: Optional[int] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Binding':
return cls(
kind=BindingKind(data['kind']),
index=data.get('index'),
offset=data.get('offset'),
size=data.get('size'),
count=data.get('count')
)
@dataclass
class ScalarTypeInfo:
kind: TypeKind
scalarType: ScalarType
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ScalarTypeInfo':
return cls(
kind=TypeKind(data['kind']),
scalarType=ScalarType(data['scalarType'])
)
@dataclass
class MatrixType:
kind: TypeKind
rowCount: int
columnCount: int
elementType: ScalarTypeInfo
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'MatrixType':
return cls(
kind=TypeKind(data['kind']),
rowCount=data['rowCount'],
columnCount=data['columnCount'],
elementType=ScalarTypeInfo.from_dict(data['elementType'])
)
@dataclass
class VectorType:
kind: TypeKind
elementCount: int
elementType: ScalarTypeInfo
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'VectorType':
return cls(
kind=TypeKind(data['kind']),
elementCount=data['elementCount'],
elementType=ScalarTypeInfo.from_dict(data['elementType'])
)
@dataclass
class Field:
name: str
type: Union[MatrixType, VectorType, ScalarTypeInfo, 'StructType']
binding: Optional[Binding] = None
stage: Optional[Stage] = None
semanticName: Optional[str] = None
semanticIndex: Optional[int] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Field':
# 根据type的kind来决定如何解析
type_data = data['type']
type_kind = type_data['kind']
if type_kind == 'matrix':
type_obj = MatrixType.from_dict(type_data)
elif type_kind == 'vector':
type_obj = VectorType.from_dict(type_data)
elif type_kind == 'scalar':
type_obj = ScalarTypeInfo.from_dict(type_data)
elif type_kind == 'struct':
type_obj = StructType.from_dict(type_data)
else:
raise ValueError(f"Unknown type kind: {type_kind}")
return cls(
name=data['name'],
type=type_obj,
binding=Binding.from_dict(data['binding']) if 'binding' in data else None,
stage=Stage(data['stage']) if 'stage' in data else None,
semanticName=data.get('semanticName'),
semanticIndex=data.get('semanticIndex')
)
def get_sg_format(self):
"""获取对应的sg_format"""
out = 'SG_VERTEXFORMAT_'
scalar_type = ScalarType.FLOAT32
element_count = 1
if isinstance(self.type, ScalarTypeInfo):
scalar_type = self.type.scalarType
elif isinstance(self.type, VectorType):
scalar_type = self.type.elementType.scalarType
if isinstance(self.type, VectorType):
element_count = self.type.elementCount
if scalar_type == ScalarType.FLOAT32:
out += 'FLOAT'
elif scalar_type == ScalarType.INT32:
out += 'INT'
elif scalar_type == ScalarType.UINT32:
out += 'UINT'
elif scalar_type == ScalarType.INT8:
out += 'BYTE'
elif scalar_type == ScalarType.UINT8:
out += 'UBYTE'
elif scalar_type == ScalarType.INT16:
out += 'SHORT'
elif scalar_type == ScalarType.UINT16:
out += 'USHORT'
elif scalar_type == ScalarType.FLOAT16:
out += 'HALF'
else:
raise ValueError(f"Unsupported scalar type: {scalar_type}")
if element_count > 1:
out += str(element_count)
return out
def get_byte_size(self):
"""获取字段的字节大小"""
if isinstance(self.type, ScalarTypeInfo):
return 4 if self.type.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2
elif isinstance(self.type, VectorType):
return self.type.elementCount * (4 if self.type.elementType.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2)
elif isinstance(self.type, MatrixType):
return self.type.rowCount * self.type.columnCount * (4 if self.type.elementType.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2)
elif isinstance(self.type, StructType):
return sum(f.get_byte_size() for f in self.type.fields)
return 0
@dataclass
class StructType:
kind: TypeKind
name: str
fields: List[Field]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'StructType':
fields = [Field.from_dict(f) for f in data['fields']]
# 计算offset
offset = 0
for field in fields:
if field.binding is not None:
field.binding.offset = offset
offset += field.get_byte_size()
return cls(
kind=TypeKind(data['kind']),
name=data['name'],
fields=fields
)
def get_byte_size(self) -> int:
"""获取结构体的字节大小"""
return sum(field.get_byte_size() for field in self.fields)
@dataclass
class VarLayout:
type: StructType
binding: Binding
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'VarLayout':
return cls(
type=StructType.from_dict(data['type']),
binding=Binding.from_dict(data['binding'])
)
@dataclass
class ContainerVarLayout:
bindings: List[Binding]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ContainerVarLayout':
return cls(
bindings=[Binding.from_dict(b) for b in data['bindings']]
)
@dataclass
class ParameterBlockType:
kind: TypeKind
elementType: StructType
containerVarLayout: ContainerVarLayout
elementVarLayout: VarLayout
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ParameterBlockType':
return cls(
kind=TypeKind(data['kind']),
elementType=StructType.from_dict(data['elementType']),
containerVarLayout=ContainerVarLayout.from_dict(data['containerVarLayout']),
elementVarLayout=VarLayout.from_dict(data['elementVarLayout'])
)
@dataclass
class Parameter:
name: str
binding: Binding
type: Union[ParameterBlockType, StructType]
stage: Optional[Stage] = None
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'Parameter':
type_data = data['type']
if type_data['kind'] == 'struct':
type_obj = StructType.from_dict(type_data)
else:
type_obj = ParameterBlockType.from_dict(type_data)
return cls(
name=data['name'],
binding=Binding.from_dict(data['binding']),
type=type_obj,
stage=Stage(data['stage']) if 'stage' in data else None
)
def get_byte_size(self) -> int:
"""获取参数的字节大小"""
if isinstance(self.type, ParameterBlockType):
return self.type.elementVarLayout.type.get_byte_size()
elif isinstance(self.type, StructType):
return sum(field.get_byte_size() for field in self.type.fields)
return 0
def get_binding_kind(self) -> BindingKind:
kind = self.binding.kind
if kind == BindingKind.SUB_ELEMENT_REGISTER_SPACE:
kind = BindingKind.UNIFORM
return kind
def get_register_index(self):
"""获取HLSL注册索引"""
if self.binding.index is not None:
return self.binding.index
return 0
@dataclass
class EntryPointBinding:
name: str
binding: Binding
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'EntryPointBinding':
return cls(
name=data['name'],
binding=Binding.from_dict(data['binding'])
)
@dataclass
class EntryPointResult:
stage: Stage
binding: Binding
type: Union[StructType, VectorType]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'EntryPointResult':
kind = data['type']['kind']
if kind == 'struct':
type_obj = StructType.from_dict(data['type'])
elif kind == 'vector':
type_obj = VectorType.from_dict(data['type'])
else:
raise ValueError(f"Unknown result type kind: {kind}")
return cls(
stage=Stage(data['stage']),
binding=Binding.from_dict(data['binding']),
type=type_obj
)
@dataclass
class EntryPoint:
name: str
stage: Stage
parameters: List[Parameter]
result: EntryPointResult
bindings: List[EntryPointBinding]
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'EntryPoint':
return cls(
name=data['name'],
stage=Stage(data['stage']),
parameters=[Parameter.from_dict(p) for p in data['parameters']],
result=EntryPointResult.from_dict(data['result']),
bindings=[EntryPointBinding.from_dict(b) for b in data['bindings']]
)
@dataclass
class ShaderReflection:
parameters: List[Parameter]
entryPoints: List[EntryPoint]
blob: Optional[bytes] = None # 用于存储编译后的着色器二进制数据
@classmethod
def from_dict(cls, data: Dict[str, Any]) -> 'ShaderReflection':
return cls(
parameters=[Parameter.from_dict(p) for p in data['parameters']],
entryPoints=[EntryPoint.from_dict(e) for e in data['entryPoints']]
)
@classmethod
def get_vertex_layout(cls, reflection: 'ShaderReflection') -> Optional[List[Field]]:
"""获取顶点布局字段"""
for entry in reflection.entryPoints:
if entry.stage == Stage.VERTEX:
vertex_fields = []
for param in entry.parameters:
if param.stage == Stage.VERTEX and param.type.kind == TypeKind.STRUCT:
vertex_fields.extend(param.type.fields)
return vertex_fields
return None
@dataclass
class ShaderStageInfo:
parameters: List[Parameter]
entryPoint: EntryPoint
blob: Optional[bytes] = None # 用于存储编译后的着色器二进制数据
def __init__(self, parameters: List[Parameter], entryPoint: EntryPoint, blob: Optional[bytes] = None):
self.parameters = parameters
self.entryPoint = entryPoint
self.blob = blob
for p in self.parameters:
if p.stage is None:
p.stage = entryPoint.stage
else:
assert p.stage == entryPoint.stage, f"Parameter stage {p.stage} does not match entry point stage {entryPoint.stage}"
def get_entry_name(self):
return self.entryPoint.name
def get_stage(self):
return self.entryPoint.stage
@dataclass
class ShaderInfos:
"""着色器信息集合"""
stages: dict[Stage, ShaderStageInfo]
vertex_layout: List[Field] = field(default_factory=list)
vertex_size: int = 0
def add_shader_info(self, shader_info: ShaderReflection):
"""添加单个着色器信息"""
vertex_layout = ShaderReflection.get_vertex_layout(shader_info)
if vertex_layout is not None:
self.vertex_layout = vertex_layout
self.vertex_size = sum(field.get_byte_size() for field in vertex_layout)
for entry in shader_info.entryPoints:
stage = entry.stage
self.stages[stage] = ShaderStageInfo(
parameters=shader_info.parameters,
entryPoint=entry,
blob=shader_info.blob
)

View File

@@ -1,119 +0,0 @@
#!/usr/bin/env python3
"""
Slang Compiler - Type Definitions
数据类型和枚举定义
"""
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Optional, Dict, Mapping
class ShaderStage(Enum):
VERTEX = "vertex"
FRAGMENT = "fragment"
COMPUTE = "compute"
class ResourceType(Enum):
SAMPLED_TEXTURE = "sampled_texture"
STORAGE_TEXTURE = "storage_texture"
STORAGE_BUFFER = "storage_buffer"
UNIFORM_BUFFER = "uniform_buffer"
SAMPLER = "sampler"
class ResourceSubType(Enum):
TEXTURE_2D = "SG_IMAGETYPE_2D"
TEXTURE_CUBE = "SG_IMAGETYPE_CUBE"
TEXTURE_3D = "SG_IMAGETYPE_3D"
TEXTURE_ARRAY = "SG_IMAGETYPE_ARRAY"
class TargetFormat(Enum):
SPIRV = "spirv"
DXBC = "dxbc"
MSL = "msl"
HLSL_DX11 = "hlsl"
# 数据模型类
@dataclass
class FieldType:
"""字段类型信息"""
kind: str # 'scalar', 'vector', 'matrix'
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
element_count: Optional[int] = None # for vector
row_count: Optional[int] = None # for matrix
column_count: Optional[int] = None # for matrix
size: int = 0
@classmethod
def from_dict(cls, data: Dict) -> 'FieldType':
"""从字典创建FieldType对象"""
kind = data.get('kind')
if kind == 'vector':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
element_count=data['elementCount']
)
elif kind == 'scalar':
return cls(
kind=kind,
scalar_type=data['scalarType']
)
elif kind == 'matrix':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
row_count=data['rowCount'],
column_count=data['columnCount']
)
return cls(kind='scalar', scalar_type='float32')
@dataclass
class UniformField:
"""Uniform缓冲区字段"""
name: str
type: FieldType
offset: int
size: int
@dataclass
class UniformBuffer:
"""Uniform缓冲区"""
name: str
binding: int
size: int = 0
fields: List[UniformField] = field(default_factory=list)
@dataclass
class Resource:
name: str
type: ResourceType
sub_type: Optional[ResourceSubType] = None
binding_index: int = -1
field_type: Optional[FieldType] = None
uniform_data: Optional[UniformBuffer] = None
@dataclass
class ShaderInfo:
stage: ShaderStage
entry_point: str
blob: bytes
resources: List[Resource]
@dataclass
class VertexField:
"""顶点输入字段"""
name: str
type: FieldType
location: int
semantic: str
semantic_index: int
offset: int = 0
@dataclass
class ShaderLayout:
"""着色器布局数据"""
vertex_fields: List[VertexField] = field(default_factory=list)
uniform_buffers: List[UniformBuffer] = field(default_factory=list)