diff --git a/src/mirage_render/src/render/render_elements.cpp b/src/mirage_render/src/render/render_elements.cpp index 1684684..00df4e6 100644 --- a/src/mirage_render/src/render/render_elements.cpp +++ b/src/mirage_render/src/render/render_elements.cpp @@ -7,6 +7,7 @@ #include "shaders/mirage_text.hlsl.h" #include "shaders/mirage_wireframe.hlsl.h" #include "shaders/mirage_line.hlsl.h" +#include "shaders/mirage_line.shader.h" template void compute_rect_vertices(const Eigen::MatrixBase& in_pos, diff --git a/tools/code_generator.py b/tools/code_generator.py index 4cde869..17cefad 100644 --- a/tools/code_generator.py +++ b/tools/code_generator.py @@ -5,53 +5,19 @@ Slang Compiler - Code Generator """ from pathlib import Path -from code_generator_helper import * from global_vars import global_vars from indent_manager import IndentManager -from shader_types import * +from shader_reflection_type import * class CodeGenerator: """C++绑定代码生成器""" - def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str) -> None: + def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None: """生成C++绑定函数的入口方法""" self._generate_cpp_bindings(binding_infos, output_path) - def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None: - """生成Uniform缓冲区结构体""" - writer.write("// Uniform buffer structures") - - for buffer in uniform_buffers: - struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer' - - # 计算总大小和对齐要求 - total_size = 0 - max_alignment = 16 # GPU通常要求16字节对齐 - - with writer.block(f'struct {struct_name}', '};'): - for i, field in enumerate(buffer.fields): - # 检查是否需要填充 - if field.offset > total_size: - padding_size = field.offset - total_size - writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding") - total_size = field.offset - - # 生成字段声明 - declaration = get_c_type_declaration(field.type, field.name) - writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}") - - total_size = field.offset + field.size - - # 确保结构体大小正确对齐 - aligned_size = get_aligned_size(total_size, max_alignment) - if aligned_size > total_size: - writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment") - - writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})") - writer.write() - - def _generate_cpp_bindings(self, binding_infos: List[ShaderInfo], output_path: str) -> None: + def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None: """生成C++绑定函数""" output_file = Path(output_path) # 尝试创建输出目录 @@ -61,7 +27,7 @@ class CodeGenerator: writer = IndentManager(file) self._write_complete_file(writer, binding_infos) - def _write_complete_file(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None: + def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入完整的文件内容""" self._write_header(writer) @@ -81,19 +47,19 @@ class CodeGenerator: writer.write(header) writer.write() - def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None: + def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入ShaderBindings类""" - self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers) + # self._generate_uniform_structs(writer, global_vars.vertex_layout.uniform_buffers) self._write_blob(writer, binding_infos) self._write_public_methods(writer, binding_infos) self._write_get_pipeline_desc_method(writer, binding_infos) - def _write_blob(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None: + def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入public部分的结构体定义""" # 写入二进制内容 - for info in binding_infos: - entry_point = info.entry_point + for stage, info in binding_infos.stages.items(): + entry_point = info.get_entry_name() blob_data = info.blob with writer.block(f'static constexpr std::array {global_vars.source_file_name}_{entry_point}_blob =', '};'): @@ -108,69 +74,100 @@ class CodeGenerator: writer.write() writer.write() - def _write_public_methods(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None: + def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入公共方法""" self._write_get_shader_info_method(writer, binding_infos) - def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None: + def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入getShaderInfo方法""" with writer.indent(f'static sg_shader_desc get_{global_vars.source_file_name}_shader_desc()'' {', '}'): writer.write('sg_shader_desc desc = {};') - writer.write(f'desc.label = "{global_vars.source_file_name}";') + writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";') writer.write('// 顶点布局') with writer.indent('{', '}'): - for i, vertex_field in enumerate(global_vars.layout.vertex_fields): - writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semantic}";') - writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semantic_index};') + for i, vertex_field in enumerate(binding_infos.vertex_layout): + writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";') + writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};') - for info in binding_infos: - entry_point = info.entry_point + for stage, info in binding_infos.stages.items(): + entry_point = info.get_entry_name() entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob' - stage_name = info.stage.name + stage_name = info.get_stage().name writer.write(f'// {stage_name}') with writer.indent('{','}'): - writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name})') + writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name});') writer.write(f'desc.{stage_name.lower()}_func.entry = "{entry_point}";') writer.write('// 资源绑定') - resource_index_map = {} - # 处理资源绑定 - for resource in info.resources: - binding_index = resource.binding_index - param_index = resource_index_map.get(resource.type, 0) - resource_index_map[resource.type] = param_index + 1 + self._write_resource_binding(writer, info.parameters) - if resource.type == ResourceType.UNIFORM_BUFFER: - uniform_info = resource.uniform_data - writer.write(f'desc.uniform_blocks[{param_index}].name = "{resource.name}";') - writer.write(f'desc.uniform_blocks[{param_index}].size = {uniform_info.size};') - writer.write(f'desc.uniform_blocks[{param_index}].hlsl_register_b_n = {uniform_info.binding};') - writer.write(f'desc.uniform_blocks[{param_index}].msl_buffer_n = {uniform_info.binding};') - writer.write(f'desc.uniform_blocks[{param_index}].wgsl_group0_binding_n = {uniform_info.binding};') + def _write_resource_binding(self, writer: IndentManager, parameters: List[Parameter]) -> None: + resource_index = {} + target = global_vars.target + glsl_block_index = {} - elif resource.type == ResourceType.SAMPLED_TEXTURE: - writer.write(f'desc.images[{param_index}].stage = SG_SHADERSTAGE_{stage_name.upper()}') - writer.write(f'desc.images[{param_index}].name = "{resource.name}";') - writer.write(f'desc.sampled_textures[{param_index}].name = "{resource.name}";') + for p in parameters: + binding_kind = p.get_binding_kind() + index = resource_index.get(binding_kind, 0) + resource_index[binding_kind] = index + 1 + stage_name = p.stage.name.upper() + if binding_kind == BindingKind.UNIFORM: + t = f'desc.uniform_blocks[{index}]' + writer.write(f'{t}.stage = SG_SHADERSTAGE_{stage_name};') + writer.write(f'{t}.size = {p.get_byte_size()};') + if target == TargetFormat.GLSL: + glsl_index = glsl_block_index.get(binding_kind, 0) + glsl_block_index[binding_kind] = glsl_index + 1 + writer.write(f'{t}.glsl_uniforms[{glsl_index}].type = ;') + writer.write(f'{t}.glsl_uniforms[{glsl_index}].array_count = 1;') + writer.write(f'{t}.glsl_uniforms[{glsl_index}].glsl_name = "{p.name}";') + elif target == TargetFormat.DXBC: + writer.write(f'{t}.hlsl_register_b_n = {p.get_register_index()};') - elif resource.type == ResourceType.STORAGE_BUFFER: - writer.write(f'desc.storage_buffers[{param_index}].name = "{resource.name}";') - elif resource.type == ResourceType.SAMPLER: - writer.write(f'desc.samplers[{param_index}].name = "{resource.name}";') - - def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None: + def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None: """写入getPipelineDesc方法""" - with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc(sg_shader shader, sg_pixel_format pixel_format, int32_t sample_count)'' {', '}'): + function_params = ('\n\t\t\tsg_shader shader, \n' + '\t\t\tsg_pixel_format pixel_format, \n' + '\t\t\tint32_t sample_count = 1, \n' + '\t\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES, \n' + '\t\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE\n') + + with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc({function_params}\t)'' {', '}'): writer.write('sg_pipeline_desc desc = {};') - writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline";') writer.write('desc.shader = shader;') writer.write('desc.index_type = SG_INDEXTYPE_UINT32;') + writer.write() + + # 处理顶点缓冲区大小 + writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};') + writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;') + writer.write('desc.layout.buffers[0].step_rate = 1;') + writer.write() # 处理顶点输入布局 - for i, vertex_field in enumerate(global_vars.layout.vertex_fields): - writer.write(f'// {vertex_field.semantic}') - writer.write(f'desc.layout.attrs[{i}].buffer_index = {vertex_field.location}') - writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.offset}') - writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.type.scalar_type}') + for i, vertex_field in enumerate(binding_infos.vertex_layout): + writer.write(f'// {vertex_field.semanticName} {vertex_field.semanticIndex or ''}') + writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;') + writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};') + writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};') + writer.write() + + writer.write('desc.primitive_type = primitive_type;') + writer.write('desc.cull_mode = cull_mode;') + writer.write('desc.face_winding = SG_FACEWINDING_CW;') + + writer.write('desc.depth.write_enabled = false;') + writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;') + writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;') + + writer.write('desc.colors[0].blend.enabled = true;') + writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;') + writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;') + writer.write('desc.colors[0].pixel_format = pixel_format;') + writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;') + + writer.write('desc.sample_count = sample_count;') + writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline_desc";') + writer.write('return desc;') diff --git a/tools/code_generator_helper.py b/tools/code_generator_helper.py index 4644515..6a335f3 100644 --- a/tools/code_generator_helper.py +++ b/tools/code_generator_helper.py @@ -1,6 +1,6 @@ from typing import Tuple -from shader_types import FieldType +from shader_reflection_type import Field, TypeKind, ScalarType format_mapping = { ('int32', 1): 'SG_VERTEXFORMAT_INVALID', @@ -29,57 +29,56 @@ format_mapping = { # 类型大小映射 type_sizes = { - 'int32': 4, - 'uint32': 4, - 'float32': 4, - 'int8': 1, - 'uint8': 1, - 'int16': 2, - 'uint16': 2, - 'float16': 2, + ScalarType.INT32: 4, + ScalarType.UINT32: 4, + ScalarType.FLOAT32: 4, + ScalarType.INT8: 1, + ScalarType.UINT8: 1, + ScalarType.INT16: 2, + ScalarType.UINT16: 2, + ScalarType.FLOAT16: 2, # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现 } -def get_type_info(field_type: FieldType) -> Tuple[str, int, int]: +# C类型映射表 +c_type_mapping = { + ScalarType.INT32: 'int32_t', + ScalarType.UINT32: 'uint32_t', + ScalarType.FLOAT32: 'float', + ScalarType.INT8: 'int8_t', + ScalarType.UINT8: 'uint8_t', + ScalarType.INT16: 'int16_t', + ScalarType.UINT16: 'uint16_t', + ScalarType.FLOAT16: 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现 +} + +# 对齐要求映射表 +alignment_mapping = { + ScalarType.INT32: 4, + ScalarType.UINT32: 4, + ScalarType.FLOAT32: 4, + ScalarType.INT8: 1, + ScalarType.UINT8: 1, + ScalarType.INT16: 2, + ScalarType.UINT16: 2, + ScalarType.FLOAT16: 2, +} + +def get_type_info(field_type: Field) -> Tuple[str, int, int]: """获取类型信息:C类型名称、元素数量、字节大小 Returns: Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小) """ - # C类型映射表 - c_type_mapping = { - 'int32': 'int32_t', - 'uint32': 'uint32_t', - 'float32': 'float', - 'int8': 'int8_t', - 'uint8': 'uint8_t', - 'int16': 'int16_t', - 'uint16': 'uint16_t', - 'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现 - } + kind = field_type.type.kind + scalar_type = field_type.type.scalarType + element_count = field_type.type.elementCount + c_type = c_type_mapping.get(scalar_type, 'float') - # 对齐要求映射表 - alignment_mapping = { - 'int32': 4, - 'uint32': 4, - 'float32': 4, - 'int8': 1, - 'uint8': 1, - 'int16': 2, - 'uint16': 2, - 'float16': 2, - } - - if field_type.kind == 'scalar': - scalar_type = field_type.scalar_type or 'float32' - c_type = c_type_mapping.get(scalar_type, 'float') + if kind == TypeKind.SCALAR: size = type_sizes.get(scalar_type, 4) return c_type, 1, size - elif field_type.kind == 'vector': - scalar_type = field_type.scalar_type or 'float32' - c_type = c_type_mapping.get(scalar_type, 'float') - element_count = field_type.element_count or 4 - + elif kind == TypeKind.VECTOR: # 验证向量元素数量的合法性 if element_count not in [1, 2, 3, 4]: raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.") @@ -93,12 +92,9 @@ def get_type_info(field_type: FieldType) -> Tuple[str, int, int]: return c_type, element_count, size - elif field_type.kind == 'matrix': - scalar_type = field_type.scalar_type or 'float32' - c_type = c_type_mapping.get(scalar_type, 'float') - - row_count = field_type.row_count or 4 - column_count = field_type.column_count or 4 + elif kind == TypeKind.MATRIX: + row_count = field_type.type.rowCount or 4 + column_count = field_type.type.columnCount or 4 # 验证矩阵维度的合法性 if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]: @@ -125,8 +121,7 @@ def get_type_info(field_type: FieldType) -> Tuple[str, int, int]: print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.") return 'float', 1, 4 - -def get_c_type_declaration(field_type: FieldType, field_name: str) -> str: +def get_c_type_declaration(field_type: Field, field_name: str) -> str: """生成C类型声明 Args: @@ -137,11 +132,12 @@ def get_c_type_declaration(field_type: FieldType, field_name: str) -> str: str: 完整的C类型声明 """ c_type, count, _ = get_type_info(field_type) + kind = field_type.type.kind - if field_type.kind == 'scalar': + if kind == TypeKind.SCALAR: return f"{c_type} {field_name}" - elif field_type.kind == 'vector': + elif kind == TypeKind.VECTOR: # 对于向量,可以选择使用数组或专门的向量类型 if count == 1: return f"{c_type} {field_name}" @@ -152,9 +148,9 @@ def get_c_type_declaration(field_type: FieldType, field_name: str) -> str: # 选项2: 使用对齐的结构体(如果需要的话) # return f"struct {{ {c_type} data[{count}]; }} {field_name}" - elif field_type.kind == 'matrix': - rows = field_type.row_count or 4 - cols = field_type.column_count or 4 + elif kind == TypeKind.MATRIX: + rows = field_type.type.rowCount or 4 + cols = field_type.type.columnCount or 4 # 矩阵使用二维数组(列主序) return f"{c_type} {field_name}[{cols}][{rows}]" diff --git a/tools/compiler.py b/tools/compiler.py index cb630fd..bbb16b3 100644 --- a/tools/compiler.py +++ b/tools/compiler.py @@ -8,7 +8,7 @@ from typing import List, Dict from code_generator import CodeGenerator from shader_parser import ShaderParser -from shader_types import ShaderInfo +from shader_reflection_type import ShaderReflection, ShaderInfos class SlangCompiler: @@ -16,10 +16,10 @@ class SlangCompiler: self.parser = ShaderParser() self.code_generator = CodeGenerator() - def parse_slang_shader(self) -> Dict[str, ShaderInfo]: + def parse_slang_shader(self) -> List[ShaderReflection]: """解析Slang着色器源码,提取资源信息""" return self.parser.parse_slang_shader() - def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str): + def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str): """生成C/C++绑定函数""" self.code_generator.generate_binding_functions(binding_infos, output_path) diff --git a/tools/compiler_cmd.py b/tools/compiler_cmd.py index 22585b9..bdd82ad 100644 --- a/tools/compiler_cmd.py +++ b/tools/compiler_cmd.py @@ -4,21 +4,22 @@ Slang Compiler - Command Generation """ from exe_finder import slangc_path from global_vars import global_vars -from shader_types import TargetFormat, ShaderStage +from shader_reflection_type import * -def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, output_path: str): + +def make_cmd(source_file: str, target: TargetFormat, stage: Stage, entry_point: str, output_path: str): """生成编译命令""" target_flag = { - TargetFormat.SPIRV: 'spirv', + TargetFormat.GLSL: 'glsl', TargetFormat.DXBC: 'dxbc', TargetFormat.MSL: 'metal', TargetFormat.HLSL_DX11: 'hlsl', }[target] stage_flag = { - ShaderStage.VERTEX: 'vertex', - ShaderStage.FRAGMENT: 'fragment', - ShaderStage.COMPUTE: 'compute' + Stage.VERTEX: 'vertex', + Stage.FRAGMENT: 'fragment', + Stage.COMPUTE: 'compute' }[stage] cmd = [ diff --git a/tools/global_vars.py b/tools/global_vars.py index d46d21a..19b1d92 100644 --- a/tools/global_vars.py +++ b/tools/global_vars.py @@ -1,11 +1,11 @@ -from shader_types import * + +from shader_reflection_type import * class GlobalVars: source_file = '' source_file_name = '' source_path = '' output_dir = '' - target: TargetFormat - layout: ShaderLayout + target = TargetFormat.DXBC global_vars = GlobalVars() diff --git a/tools/main.py b/tools/main.py index 3d6b571..1c53eed 100644 --- a/tools/main.py +++ b/tools/main.py @@ -31,6 +31,10 @@ def main(): global_vars.output_dir = os.path.abspath(args.output_dir) global_vars.target = TargetFormat(args.target) + shader_infos = ShaderInfos( + stages={} + ) + # 仅保留路径部分 include_dirs = [ global_vars.source_path, @@ -45,12 +49,8 @@ def main(): # 解析着色器 print(f"**Parsing** {args.input}...") shaders = compiler.parse_slang_shader() - - # 编译每个入口点 - shader_infos = [] - - for name, shader_info in shaders.items(): - shader_infos.append(shader_info) + for shader_info in shaders: + shader_infos.add_shader_info(shader_info) binding_output_file_pathname = os.path.abspath(args.output_dir) binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{global_vars.source_file_name}.shader.h") diff --git a/tools/shader_layout.py b/tools/shader_layout.py deleted file mode 100644 index 876a7e3..0000000 --- a/tools/shader_layout.py +++ /dev/null @@ -1,85 +0,0 @@ -import json -from typing import List, Dict - -from shader_types import * -from shader_types import UniformBuffer - - -class ShaderLayoutParser: - """解析JSON数据并提取到类对象""" - - def parse(self, json_file: str) -> ShaderLayout: - """解析JSON文件并返回ShaderLayout对象""" - with open(json_file, 'r') as f: - json_data = json.load(f) - - layout = ShaderLayout() - layout.vertex_fields = self.parse_vertex_input(json_data) - layout.uniform_buffers = self.parse_uniform_buffers(json_data) - - return layout - - @staticmethod - def parse_vertex_input(json_data: Dict) -> List[VertexField]: - """解析顶点输入字段""" - vertex_fields = [] - - entry_points = json_data.get('entryPoints', []) - for entry in entry_points: - if entry.get('stage') == 'vertex': - parameters = entry.get('parameters', []) - offset = entry.get('offset', 0) - for param in parameters: - if param.get('name') == 'input' and param.get('stage') == 'vertex': - fields = param.get('type', {}).get('fields', []) - for field in fields: - vertex_field = VertexField( - name=field['name'], - type=FieldType.from_dict(field['type']), - location=field['binding']['index'], - semantic=field.get('semanticName', ''), - semantic_index=field.get('semanticIndex', 0) - ) - vertex_fields.append(vertex_field) - - return vertex_fields - - @staticmethod - def parse_uniform_buffer(param: Dict) -> UniformBuffer | None: - binding = param.get('binding', {}) - type_info = param.get('type', {}) - kind = type_info.get('kind', '') - if kind == 'constantBuffer' or kind == 'parameterBlock': - buffer = UniformBuffer( - name=param['name'], - binding=binding['index'] - ) - - # 解析缓冲区字段 - element_type = param.get('type', {}).get('elementType', {}) - if element_type.get('kind') == 'struct': - fields = element_type.get('fields', []) - for field in fields: - uniform_field = UniformField( - name=field['name'], - type=FieldType.from_dict(field['type']), - offset=field['binding']['offset'], - size=field['binding']['size'] - ) - buffer.size += uniform_field.size - buffer.fields.append(uniform_field) - - return buffer - return None - - @staticmethod - def parse_uniform_buffers(json_data: Dict) -> List[UniformBuffer]: - """解析Uniform缓冲区""" - uniform_buffers = [] - - parameters = json_data.get('parameters', []) - for param in parameters: - buffer = ShaderLayoutParser.parse_uniform_buffer(param) - uniform_buffers.append(buffer) - - return uniform_buffers \ No newline at end of file diff --git a/tools/shader_parser.py b/tools/shader_parser.py index 4715bd2..bf1c875 100644 --- a/tools/shader_parser.py +++ b/tools/shader_parser.py @@ -13,12 +13,11 @@ from typing import List, Dict, Optional from compiler_cmd import make_cmd from global_vars import global_vars -from shader_layout import ShaderLayoutParser -from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo, UniformBuffer, FieldType, ResourceSubType +from shader_reflection_type import * class ShaderParser: - def parse_slang_shader(self) -> Dict[str, ShaderInfo]: + def parse_slang_shader(self) -> List[ShaderReflection]: """解析Slang着色器源码,提取资源信息""" with open(global_vars.source_file, 'r', encoding='utf-8') as f: @@ -28,7 +27,7 @@ class ShaderParser: entry_points = self._find_entry_points(source) print(f"Found potential entry points: {entry_points}") - shaders = {} + shaders = [] # 为每个入口点单独进行完整编译和反射 for entry_name, stage in entry_points.items(): @@ -39,22 +38,22 @@ class ShaderParser: ) if shader_info: - shaders[entry_name] = shader_info + shaders.append(shader_info) else: print(f"Failed to process entry point: {entry_name}") return shaders - def _find_entry_points(self, source: str) -> Dict[str, ShaderStage]: + def _find_entry_points(self, source: str) -> Dict[str, Stage]: """在源码中查找入口点函数""" entry_points = {} # 1. 查找带有Slang属性的函数 attribute_patterns = [ - (r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.VERTEX), - (r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT), - (r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT), - (r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.COMPUTE), + (r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.VERTEX), + (r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT), + (r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT), + (r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.COMPUTE), ] for pattern, stage in attribute_patterns: @@ -66,29 +65,29 @@ class ShaderParser: # 2. 查找常见的命名约定 common_entry_points = { # Vertex shaders - 'vertex_main': ShaderStage.VERTEX, - 'vertexMain': ShaderStage.VERTEX, - 'vert_main': ShaderStage.VERTEX, - 'vs_main': ShaderStage.VERTEX, - 'vertex_shader': ShaderStage.VERTEX, - 'VS': ShaderStage.VERTEX, + 'vertex_main': Stage.VERTEX, + 'vertexMain': Stage.VERTEX, + 'vert_main': Stage.VERTEX, + 'vs_main': Stage.VERTEX, + 'vertex_shader': Stage.VERTEX, + 'VS': Stage.VERTEX, # Fragment shaders - 'fragment_main': ShaderStage.FRAGMENT, - 'fragmentMain': ShaderStage.FRAGMENT, - 'frag_main': ShaderStage.FRAGMENT, - 'pixel_main': ShaderStage.FRAGMENT, - 'ps_main': ShaderStage.FRAGMENT, - 'fragment_shader': ShaderStage.FRAGMENT, - 'PS': ShaderStage.FRAGMENT, + 'fragment_main': Stage.FRAGMENT, + 'fragmentMain': Stage.FRAGMENT, + 'frag_main': Stage.FRAGMENT, + 'pixel_main': Stage.FRAGMENT, + 'ps_main': Stage.FRAGMENT, + 'fragment_shader': Stage.FRAGMENT, + 'PS': Stage.FRAGMENT, # Compute shaders - 'compute_main': ShaderStage.COMPUTE, - 'computeMain': ShaderStage.COMPUTE, - 'comp_main': ShaderStage.COMPUTE, - 'cs_main': ShaderStage.COMPUTE, - 'compute_shader': ShaderStage.COMPUTE, - 'CS': ShaderStage.COMPUTE, + 'compute_main': Stage.COMPUTE, + 'computeMain': Stage.COMPUTE, + 'comp_main': Stage.COMPUTE, + 'cs_main': Stage.COMPUTE, + 'compute_shader': Stage.COMPUTE, + 'CS': Stage.COMPUTE, } # 查找这些函数名是否在源码中存在 @@ -112,22 +111,22 @@ class ShaderParser: # 顶点着色器关键词 if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']): - entry_points[func_name] = ShaderStage.VERTEX + entry_points[func_name] = Stage.VERTEX print(f"Inferred vertex shader: {func_name}") # 片元着色器关键词 elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']): - entry_points[func_name] = ShaderStage.FRAGMENT + entry_points[func_name] = Stage.FRAGMENT print(f"Inferred fragment shader: {func_name}") # 计算着色器关键词 elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']): - entry_points[func_name] = ShaderStage.COMPUTE + entry_points[func_name] = Stage.COMPUTE print(f"Inferred compute shader: {func_name}") # 4. 查找带有HLSL语义的函数 semantic_patterns = [ - (r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', ShaderStage.VERTEX), - (r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', ShaderStage.FRAGMENT), - (r'(\w+)\s*$[^)]*$\s*:\s*POSITION', ShaderStage.VERTEX), + (r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', Stage.VERTEX), + (r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', Stage.FRAGMENT), + (r'(\w+)\s*$[^)]*$\s*:\s*POSITION', Stage.VERTEX), ] for pattern, stage in semantic_patterns: @@ -141,13 +140,13 @@ class ShaderParser: # 5. 如果没找到任何入口点,查找main函数 if not entry_points: if re.search(r'\bmain\s*\(', source, re.IGNORECASE): - entry_points['main'] = ShaderStage.VERTEX + entry_points['main'] = Stage.VERTEX print("Found main function, assuming vertex shader") print(f"Total entry points found: {len(entry_points)}") return entry_points - def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]: + def _compile_and_reflect_entry_point(self, entry_name: str, stage: Stage, source: str) -> Optional[ShaderReflection]: """为单个入口点进行完整编译和反射""" # 创建临时输出文件 @@ -180,10 +179,6 @@ class ShaderParser: if result.returncode == 0: # 读取反射数据 if os.path.exists(reflection_path): - if stage == ShaderStage.VERTEX: - shader_layout_ = ShaderLayoutParser() - global_vars.layout = shader_layout_.parse(reflection_path) - with open(reflection_path, 'r', encoding='utf-8') as f: reflection_json = f.read() print(f"Reflection JSON length: {len(reflection_json)}") @@ -191,17 +186,15 @@ class ShaderParser: if reflection_json.strip(): try: reflection = json.loads(reflection_json) - shader_info = self._create_shader_info_from_reflection( - reflection, entry_name, stage, source - ) + shader_reflection = ShaderReflection.from_dict(reflection) # 读取编译后的二进制数据 with open(temp_output_path, 'rb') as f: shader_binary = f.read() # 写入shader_info.blob - shader_info.blob = shader_binary + shader_reflection.blob = shader_binary - return shader_info + return shader_reflection except json.JSONDecodeError as e: print(f"JSON parsing error: {e}") print(f"Raw JSON: {reflection_json[:500]}...") @@ -217,161 +210,4 @@ class ShaderParser: if os.path.exists(temp_file): os.unlink(temp_file) - return self._create_shader_info_manual(entry_name, stage, source) - - def _create_shader_info_from_reflection(self, reflection: dict, entry_name: str, - stage: ShaderStage, source: str) -> ShaderInfo: - """从反射数据创建ShaderInfo""" - print(f"Processing reflection data for {entry_name}") - print(f"Reflection keys: {list(reflection.keys())}") - - shader_info = ShaderInfo( - stage=stage, - entry_point=entry_name, - resources=[], - blob=b'', # 二进制数据将在编译后填充 - ) - - # Slang反射数据的可能结构 - # 尝试不同的数据结构 - entry_point_data = None - - # 方法1: 直接在根级别查找 - if 'parameters' in reflection: - entry_point_data = reflection - - # 方法2: 在entryPoints数组中查找 - elif 'entryPoints' in reflection: - for ep in reflection['entryPoints']: - if ep.get('name') == entry_name: - entry_point_data = ep - break - - # 方法3: 在modules中查找 - elif 'modules' in reflection: - for module in reflection['modules']: - if 'entryPoints' in module: - for ep in module['entryPoints']: - if ep.get('name') == entry_name: - entry_point_data = ep - break - - if entry_point_data: - print(f"Found entry point data: {list(entry_point_data.keys())}") - - # 解析资源参数 - parameters = entry_point_data.get('parameters', []) - print(f"Found {len(parameters)} parameters") - - for param in parameters: - print(f"Processing parameter: {param}") - resource = self._parse_resource(param) - if resource: - shader_info.resources.append(resource) - print(f"Added resource: {resource.name} ({resource.type.value})") - else: - print("No entry point data found in reflection, using manual parsing") - # 使用手动解析作为fallback - manual_resources = self._extract_resources_from_source(source) - shader_info.resources.extend(manual_resources) - - print(f"Shader {entry_name} has {len(shader_info.resources)} resources") - return shader_info - - def _create_shader_info_manual(self, entry_name: str, stage: ShaderStage, source: str) -> ShaderInfo: - """手动创建ShaderInfo(fallback方法)""" - print(f"Creating shader info manually for {entry_name}") - shader_info = ShaderInfo( - stage=stage, - entry_point=entry_name, - resources=[], - blob=b'', - ) - - # 手动解析资源 - resources = self._extract_resources_from_source(source) - shader_info.resources.extend(resources) - - print(f"Manual parsing found {len(resources)} resources") - return shader_info - - def _extract_resources_from_source(self, source: str) -> List[Resource]: - """从源码中提取资源声明""" - resources = [] - - # 资源声明的正则表达式模式 - patterns = { - # Texture resources - ResourceType.SAMPLED_TEXTURE: [ - r'Texture2D\s*(?:<[^>]*>)?\s+(\w+)', - r'Texture3D\s*(?:<[^>]*>)?\s+(\w+)', - r'TextureCube\s*(?:<[^>]*>)?\s+(\w+)', - ], - ResourceType.STORAGE_TEXTURE: [ - r'RWTexture2D\s*(?:<[^>]*>)?\s+(\w+)', - r'RWTexture3D\s*(?:<[^>]*>)?\s+(\w+)', - ], - # Buffer resources - ResourceType.STORAGE_BUFFER: [ - r'RWStructuredBuffer\s*<[^>]*>\s+(\w+)', - r'RWByteAddressBuffer\s+(\w+)', - r'StructuredBuffer\s*<[^>]*>\s+(\w+)', - r'ByteAddressBuffer\s+(\w+)', - ], - ResourceType.UNIFORM_BUFFER: [ - r'ConstantBuffer\s*<[^>]*>\s+(\w+)', - r'cbuffer\s+(\w+)', - ], - ResourceType.SAMPLER: [ - r'SamplerState\s+(\w+)', - r'SamplerComparisonState\s+(\w+)', - ] - } - - for resource_type, type_patterns in patterns.items(): - for pattern in type_patterns: - matches = re.findall(pattern, source, re.IGNORECASE) - for match in matches: - resources.append(Resource(match, resource_type)) - print(f"Found resource: {match} ({resource_type.value})") - - return resources - - def _parse_resource(self, param: dict) -> Optional[Resource]: - """解析资源参数""" - type_info = param.get('type', {}) - - kind = type_info.get('kind', '') - if kind == 'resource': - kind = type_info.get('baseShape', '') - param_name = param.get('name', '') - resource = Resource(param_name, ResourceType.STORAGE_BUFFER) - resource.binding_index = param.get('binding', {}).get('index', -1) - - # 判断资源类型 - if 'texture' in kind: - if 'RW' in kind: - resource.type = ResourceType.STORAGE_TEXTURE - else: - resource.type = ResourceType.SAMPLED_TEXTURE - - if '2D' in kind: - resource.sub_type = ResourceSubType.TEXTURE_2D - elif 'cube' in kind: - resource.sub_type = ResourceSubType.TEXTURE_CUBE - elif '3D' in kind: - resource.sub_type = ResourceSubType.TEXTURE_3D - elif 'array' in kind: - resource.sub_type = ResourceSubType.TEXTURE_ARRAY - elif kind == 'structuredBuffer' or kind == 'byteAddressBuffer': - if 'RW' in kind: - resource.type = ResourceType.STORAGE_BUFFER - else: - resource.type = ResourceType.STORAGE_BUFFER - elif kind == 'constantBuffer' or kind == 'parameterBlock': - resource.type = ResourceType.UNIFORM_BUFFER - resource.uniform_data = ShaderLayoutParser.parse_uniform_buffer(param) - elif kind == 'samplerState' or 'sampler' in kind: - resource.type = ResourceType.SAMPLER - - return resource \ No newline at end of file + return None \ No newline at end of file diff --git a/tools/shader_reflection_type.py b/tools/shader_reflection_type.py new file mode 100644 index 0000000..3843a5d --- /dev/null +++ b/tools/shader_reflection_type.py @@ -0,0 +1,429 @@ +from dataclasses import dataclass, field +from enum import Enum +from typing import List, Optional, Union, Dict, Any + + +# 枚举类定义 +class TargetFormat(Enum): + GLSL = "glsl" + DXBC = "dxbc" + MSL = "msl" + HLSL_DX11 = "hlsl" + +class BindingKind(Enum): + SUB_ELEMENT_REGISTER_SPACE = "subElementRegisterSpace" + UNIFORM = "uniform" + TEXTURE = "texture" + STORAGE_BUFFER = "storageBuffer" + CONSTANT_BUFFER = "constantBuffer" + SAMPLER = "sampler" + VARYING_INPUT = "varyingInput" + VARYING_OUTPUT = "varyingOutput" + + +class TypeKind(Enum): + PARAMETER_BLOCK = "parameterBlock" + STRUCT = "struct" + MATRIX = "matrix" + SCALAR = "scalar" + VECTOR = "vector" + + +class ScalarType(Enum): + INT32 = "int32" + UINT32 = "uint32" + INT8 = "int8" + UINT8 = "uint8" + INT16 = "int16" + UINT16 = "uint16" + FLOAT16 = "float16" + FLOAT32 = "float32" + + +class Stage(Enum): + VERTEX = "vertex" + FRAGMENT = "fragment" + PIXEL = "pixel" + COMPUTE = "compute" + + +# 基础数据类 +@dataclass +class Binding: + kind: BindingKind + index: Optional[int] = None + offset: Optional[int] = None + size: Optional[int] = None + count: Optional[int] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Binding': + return cls( + kind=BindingKind(data['kind']), + index=data.get('index'), + offset=data.get('offset'), + size=data.get('size'), + count=data.get('count') + ) + + +@dataclass +class ScalarTypeInfo: + kind: TypeKind + scalarType: ScalarType + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ScalarTypeInfo': + return cls( + kind=TypeKind(data['kind']), + scalarType=ScalarType(data['scalarType']) + ) + + +@dataclass +class MatrixType: + kind: TypeKind + rowCount: int + columnCount: int + elementType: ScalarTypeInfo + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'MatrixType': + return cls( + kind=TypeKind(data['kind']), + rowCount=data['rowCount'], + columnCount=data['columnCount'], + elementType=ScalarTypeInfo.from_dict(data['elementType']) + ) + + +@dataclass +class VectorType: + kind: TypeKind + elementCount: int + elementType: ScalarTypeInfo + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'VectorType': + return cls( + kind=TypeKind(data['kind']), + elementCount=data['elementCount'], + elementType=ScalarTypeInfo.from_dict(data['elementType']) + ) + + +@dataclass +class Field: + name: str + type: Union[MatrixType, VectorType, ScalarTypeInfo, 'StructType'] + binding: Optional[Binding] = None + stage: Optional[Stage] = None + semanticName: Optional[str] = None + semanticIndex: Optional[int] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Field': + # 根据type的kind来决定如何解析 + type_data = data['type'] + type_kind = type_data['kind'] + + if type_kind == 'matrix': + type_obj = MatrixType.from_dict(type_data) + elif type_kind == 'vector': + type_obj = VectorType.from_dict(type_data) + elif type_kind == 'scalar': + type_obj = ScalarTypeInfo.from_dict(type_data) + elif type_kind == 'struct': + type_obj = StructType.from_dict(type_data) + else: + raise ValueError(f"Unknown type kind: {type_kind}") + + return cls( + name=data['name'], + type=type_obj, + binding=Binding.from_dict(data['binding']) if 'binding' in data else None, + stage=Stage(data['stage']) if 'stage' in data else None, + semanticName=data.get('semanticName'), + semanticIndex=data.get('semanticIndex') + ) + + def get_sg_format(self): + """获取对应的sg_format""" + out = 'SG_VERTEXFORMAT_' + scalar_type = ScalarType.FLOAT32 + element_count = 1 + + if isinstance(self.type, ScalarTypeInfo): + scalar_type = self.type.scalarType + elif isinstance(self.type, VectorType): + scalar_type = self.type.elementType.scalarType + + if isinstance(self.type, VectorType): + element_count = self.type.elementCount + + if scalar_type == ScalarType.FLOAT32: + out += 'FLOAT' + elif scalar_type == ScalarType.INT32: + out += 'INT' + elif scalar_type == ScalarType.UINT32: + out += 'UINT' + elif scalar_type == ScalarType.INT8: + out += 'BYTE' + elif scalar_type == ScalarType.UINT8: + out += 'UBYTE' + elif scalar_type == ScalarType.INT16: + out += 'SHORT' + elif scalar_type == ScalarType.UINT16: + out += 'USHORT' + elif scalar_type == ScalarType.FLOAT16: + out += 'HALF' + else: + raise ValueError(f"Unsupported scalar type: {scalar_type}") + + if element_count > 1: + out += str(element_count) + return out + + def get_byte_size(self): + """获取字段的字节大小""" + if isinstance(self.type, ScalarTypeInfo): + return 4 if self.type.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2 + elif isinstance(self.type, VectorType): + return self.type.elementCount * (4 if self.type.elementType.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2) + elif isinstance(self.type, MatrixType): + return self.type.rowCount * self.type.columnCount * (4 if self.type.elementType.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2) + elif isinstance(self.type, StructType): + return sum(f.get_byte_size() for f in self.type.fields) + return 0 + +@dataclass +class StructType: + kind: TypeKind + name: str + fields: List[Field] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'StructType': + fields = [Field.from_dict(f) for f in data['fields']] + # 计算offset + offset = 0 + for field in fields: + if field.binding is not None: + field.binding.offset = offset + offset += field.get_byte_size() + + return cls( + kind=TypeKind(data['kind']), + name=data['name'], + fields=fields + ) + + def get_byte_size(self) -> int: + """获取结构体的字节大小""" + return sum(field.get_byte_size() for field in self.fields) + + +@dataclass +class VarLayout: + type: StructType + binding: Binding + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'VarLayout': + return cls( + type=StructType.from_dict(data['type']), + binding=Binding.from_dict(data['binding']) + ) + + +@dataclass +class ContainerVarLayout: + bindings: List[Binding] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ContainerVarLayout': + return cls( + bindings=[Binding.from_dict(b) for b in data['bindings']] + ) + + +@dataclass +class ParameterBlockType: + kind: TypeKind + elementType: StructType + containerVarLayout: ContainerVarLayout + elementVarLayout: VarLayout + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ParameterBlockType': + return cls( + kind=TypeKind(data['kind']), + elementType=StructType.from_dict(data['elementType']), + containerVarLayout=ContainerVarLayout.from_dict(data['containerVarLayout']), + elementVarLayout=VarLayout.from_dict(data['elementVarLayout']) + ) + + +@dataclass +class Parameter: + name: str + binding: Binding + type: Union[ParameterBlockType, StructType] + stage: Optional[Stage] = None + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'Parameter': + type_data = data['type'] + if type_data['kind'] == 'struct': + type_obj = StructType.from_dict(type_data) + else: + type_obj = ParameterBlockType.from_dict(type_data) + + return cls( + name=data['name'], + binding=Binding.from_dict(data['binding']), + type=type_obj, + stage=Stage(data['stage']) if 'stage' in data else None + ) + + def get_byte_size(self) -> int: + """获取参数的字节大小""" + if isinstance(self.type, ParameterBlockType): + return self.type.elementVarLayout.type.get_byte_size() + elif isinstance(self.type, StructType): + return sum(field.get_byte_size() for field in self.type.fields) + return 0 + + def get_binding_kind(self) -> BindingKind: + kind = self.binding.kind + if kind == BindingKind.SUB_ELEMENT_REGISTER_SPACE: + kind = BindingKind.UNIFORM + return kind + + def get_register_index(self): + """获取HLSL注册索引""" + if self.binding.index is not None: + return self.binding.index + return 0 + +@dataclass +class EntryPointBinding: + name: str + binding: Binding + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'EntryPointBinding': + return cls( + name=data['name'], + binding=Binding.from_dict(data['binding']) + ) + + +@dataclass +class EntryPointResult: + stage: Stage + binding: Binding + type: Union[StructType, VectorType] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'EntryPointResult': + kind = data['type']['kind'] + if kind == 'struct': + type_obj = StructType.from_dict(data['type']) + elif kind == 'vector': + type_obj = VectorType.from_dict(data['type']) + else: + raise ValueError(f"Unknown result type kind: {kind}") + + return cls( + stage=Stage(data['stage']), + binding=Binding.from_dict(data['binding']), + type=type_obj + ) + + +@dataclass +class EntryPoint: + name: str + stage: Stage + parameters: List[Parameter] + result: EntryPointResult + bindings: List[EntryPointBinding] + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'EntryPoint': + return cls( + name=data['name'], + stage=Stage(data['stage']), + parameters=[Parameter.from_dict(p) for p in data['parameters']], + result=EntryPointResult.from_dict(data['result']), + bindings=[EntryPointBinding.from_dict(b) for b in data['bindings']] + ) + + +@dataclass +class ShaderReflection: + parameters: List[Parameter] + entryPoints: List[EntryPoint] + blob: Optional[bytes] = None # 用于存储编译后的着色器二进制数据 + + @classmethod + def from_dict(cls, data: Dict[str, Any]) -> 'ShaderReflection': + return cls( + parameters=[Parameter.from_dict(p) for p in data['parameters']], + entryPoints=[EntryPoint.from_dict(e) for e in data['entryPoints']] + ) + + @classmethod + def get_vertex_layout(cls, reflection: 'ShaderReflection') -> Optional[List[Field]]: + """获取顶点布局字段""" + for entry in reflection.entryPoints: + if entry.stage == Stage.VERTEX: + vertex_fields = [] + for param in entry.parameters: + if param.stage == Stage.VERTEX and param.type.kind == TypeKind.STRUCT: + vertex_fields.extend(param.type.fields) + return vertex_fields + return None + +@dataclass +class ShaderStageInfo: + parameters: List[Parameter] + entryPoint: EntryPoint + blob: Optional[bytes] = None # 用于存储编译后的着色器二进制数据 + + def __init__(self, parameters: List[Parameter], entryPoint: EntryPoint, blob: Optional[bytes] = None): + self.parameters = parameters + self.entryPoint = entryPoint + self.blob = blob + for p in self.parameters: + if p.stage is None: + p.stage = entryPoint.stage + else: + assert p.stage == entryPoint.stage, f"Parameter stage {p.stage} does not match entry point stage {entryPoint.stage}" + + def get_entry_name(self): + return self.entryPoint.name + def get_stage(self): + return self.entryPoint.stage + +@dataclass +class ShaderInfos: + """着色器信息集合""" + stages: dict[Stage, ShaderStageInfo] + vertex_layout: List[Field] = field(default_factory=list) + vertex_size: int = 0 + + def add_shader_info(self, shader_info: ShaderReflection): + """添加单个着色器信息""" + vertex_layout = ShaderReflection.get_vertex_layout(shader_info) + if vertex_layout is not None: + self.vertex_layout = vertex_layout + self.vertex_size = sum(field.get_byte_size() for field in vertex_layout) + for entry in shader_info.entryPoints: + stage = entry.stage + self.stages[stage] = ShaderStageInfo( + parameters=shader_info.parameters, + entryPoint=entry, + blob=shader_info.blob + ) diff --git a/tools/shader_types.py b/tools/shader_types.py deleted file mode 100644 index b45889e..0000000 --- a/tools/shader_types.py +++ /dev/null @@ -1,119 +0,0 @@ -#!/usr/bin/env python3 -""" -Slang Compiler - Type Definitions -数据类型和枚举定义 -""" - -from enum import Enum -from dataclasses import dataclass, field -from typing import List, Optional, Dict, Mapping - - -class ShaderStage(Enum): - VERTEX = "vertex" - FRAGMENT = "fragment" - COMPUTE = "compute" - -class ResourceType(Enum): - SAMPLED_TEXTURE = "sampled_texture" - STORAGE_TEXTURE = "storage_texture" - STORAGE_BUFFER = "storage_buffer" - UNIFORM_BUFFER = "uniform_buffer" - SAMPLER = "sampler" - -class ResourceSubType(Enum): - TEXTURE_2D = "SG_IMAGETYPE_2D" - TEXTURE_CUBE = "SG_IMAGETYPE_CUBE" - TEXTURE_3D = "SG_IMAGETYPE_3D" - TEXTURE_ARRAY = "SG_IMAGETYPE_ARRAY" - -class TargetFormat(Enum): - SPIRV = "spirv" - DXBC = "dxbc" - MSL = "msl" - HLSL_DX11 = "hlsl" - -# 数据模型类 -@dataclass -class FieldType: - """字段类型信息""" - kind: str # 'scalar', 'vector', 'matrix' - scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16' - element_count: Optional[int] = None # for vector - row_count: Optional[int] = None # for matrix - column_count: Optional[int] = None # for matrix - size: int = 0 - - @classmethod - def from_dict(cls, data: Dict) -> 'FieldType': - """从字典创建FieldType对象""" - kind = data.get('kind') - - if kind == 'vector': - return cls( - kind=kind, - scalar_type=data['elementType']['scalarType'], - element_count=data['elementCount'] - ) - elif kind == 'scalar': - return cls( - kind=kind, - scalar_type=data['scalarType'] - ) - elif kind == 'matrix': - return cls( - kind=kind, - scalar_type=data['elementType']['scalarType'], - row_count=data['rowCount'], - column_count=data['columnCount'] - ) - - return cls(kind='scalar', scalar_type='float32') - -@dataclass -class UniformField: - """Uniform缓冲区字段""" - name: str - type: FieldType - offset: int - size: int - -@dataclass -class UniformBuffer: - """Uniform缓冲区""" - name: str - binding: int - size: int = 0 - fields: List[UniformField] = field(default_factory=list) - -@dataclass -class Resource: - name: str - type: ResourceType - sub_type: Optional[ResourceSubType] = None - binding_index: int = -1 - field_type: Optional[FieldType] = None - uniform_data: Optional[UniformBuffer] = None - -@dataclass -class ShaderInfo: - stage: ShaderStage - entry_point: str - blob: bytes - resources: List[Resource] - -@dataclass -class VertexField: - """顶点输入字段""" - name: str - type: FieldType - location: int - semantic: str - semantic_index: int - offset: int = 0 - -@dataclass -class ShaderLayout: - """着色器布局数据""" - vertex_fields: List[VertexField] = field(default_factory=list) - uniform_buffers: List[UniformBuffer] = field(default_factory=list)