重构80%
This commit is contained in:
@@ -7,6 +7,7 @@
|
||||
#include "shaders/mirage_text.hlsl.h"
|
||||
#include "shaders/mirage_wireframe.hlsl.h"
|
||||
#include "shaders/mirage_line.hlsl.h"
|
||||
#include "shaders/mirage_line.shader.h"
|
||||
|
||||
template<typename Derived>
|
||||
void compute_rect_vertices(const Eigen::MatrixBase<Derived>& in_pos,
|
||||
|
||||
@@ -5,53 +5,19 @@ Slang Compiler - Code Generator
|
||||
"""
|
||||
from pathlib import Path
|
||||
|
||||
from code_generator_helper import *
|
||||
from global_vars import global_vars
|
||||
from indent_manager import IndentManager
|
||||
from shader_types import *
|
||||
from shader_reflection_type import *
|
||||
|
||||
|
||||
class CodeGenerator:
|
||||
"""C++绑定代码生成器"""
|
||||
|
||||
def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str) -> None:
|
||||
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str) -> None:
|
||||
"""生成C++绑定函数的入口方法"""
|
||||
self._generate_cpp_bindings(binding_infos, output_path)
|
||||
|
||||
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
|
||||
"""生成Uniform缓冲区结构体"""
|
||||
writer.write("// Uniform buffer structures")
|
||||
|
||||
for buffer in uniform_buffers:
|
||||
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
|
||||
|
||||
# 计算总大小和对齐要求
|
||||
total_size = 0
|
||||
max_alignment = 16 # GPU通常要求16字节对齐
|
||||
|
||||
with writer.block(f'struct {struct_name}', '};'):
|
||||
for i, field in enumerate(buffer.fields):
|
||||
# 检查是否需要填充
|
||||
if field.offset > total_size:
|
||||
padding_size = field.offset - total_size
|
||||
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
|
||||
total_size = field.offset
|
||||
|
||||
# 生成字段声明
|
||||
declaration = get_c_type_declaration(field.type, field.name)
|
||||
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
|
||||
|
||||
total_size = field.offset + field.size
|
||||
|
||||
# 确保结构体大小正确对齐
|
||||
aligned_size = get_aligned_size(total_size, max_alignment)
|
||||
if aligned_size > total_size:
|
||||
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
|
||||
|
||||
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
|
||||
writer.write()
|
||||
|
||||
def _generate_cpp_bindings(self, binding_infos: List[ShaderInfo], output_path: str) -> None:
|
||||
def _generate_cpp_bindings(self, binding_infos: ShaderInfos, output_path: str) -> None:
|
||||
"""生成C++绑定函数"""
|
||||
output_file = Path(output_path)
|
||||
# 尝试创建输出目录
|
||||
@@ -61,7 +27,7 @@ class CodeGenerator:
|
||||
writer = IndentManager(file)
|
||||
self._write_complete_file(writer, binding_infos)
|
||||
|
||||
def _write_complete_file(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
|
||||
def _write_complete_file(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||||
"""写入完整的文件内容"""
|
||||
self._write_header(writer)
|
||||
|
||||
@@ -81,19 +47,19 @@ class CodeGenerator:
|
||||
writer.write(header)
|
||||
writer.write()
|
||||
|
||||
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
|
||||
def _write_shader_bindings_class(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||||
"""写入ShaderBindings类"""
|
||||
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
|
||||
# self._generate_uniform_structs(writer, global_vars.vertex_layout.uniform_buffers)
|
||||
self._write_blob(writer, binding_infos)
|
||||
self._write_public_methods(writer, binding_infos)
|
||||
self._write_get_pipeline_desc_method(writer, binding_infos)
|
||||
|
||||
def _write_blob(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
|
||||
def _write_blob(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||||
"""写入public部分的结构体定义"""
|
||||
|
||||
# 写入二进制内容
|
||||
for info in binding_infos:
|
||||
entry_point = info.entry_point
|
||||
for stage, info in binding_infos.stages.items():
|
||||
entry_point = info.get_entry_name()
|
||||
blob_data = info.blob
|
||||
|
||||
with writer.block(f'static constexpr std::array<uint8_t, {len(blob_data)}> {global_vars.source_file_name}_{entry_point}_blob =', '};'):
|
||||
@@ -108,69 +74,100 @@ class CodeGenerator:
|
||||
writer.write()
|
||||
writer.write()
|
||||
|
||||
def _write_public_methods(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
|
||||
def _write_public_methods(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||||
"""写入公共方法"""
|
||||
self._write_get_shader_info_method(writer, binding_infos)
|
||||
|
||||
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
|
||||
def _write_get_shader_info_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||||
"""写入getShaderInfo方法"""
|
||||
with writer.indent(f'static sg_shader_desc get_{global_vars.source_file_name}_shader_desc()'' {', '}'):
|
||||
writer.write('sg_shader_desc desc = {};')
|
||||
writer.write(f'desc.label = "{global_vars.source_file_name}";')
|
||||
writer.write(f'desc.label = "{global_vars.source_file_name}_shader_desc";')
|
||||
writer.write('// 顶点布局')
|
||||
with writer.indent('{', '}'):
|
||||
for i, vertex_field in enumerate(global_vars.layout.vertex_fields):
|
||||
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semantic}";')
|
||||
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semantic_index};')
|
||||
for i, vertex_field in enumerate(binding_infos.vertex_layout):
|
||||
writer.write(f'desc.attrs[{i}].hlsl_sem_name = "{vertex_field.semanticName}";')
|
||||
writer.write(f'desc.attrs[{i}].hlsl_sem_index = {vertex_field.semanticIndex or 0};')
|
||||
|
||||
for info in binding_infos:
|
||||
entry_point = info.entry_point
|
||||
for stage, info in binding_infos.stages.items():
|
||||
entry_point = info.get_entry_name()
|
||||
entry_point_name = f'{global_vars.source_file_name}_{entry_point}_blob'
|
||||
stage_name = info.stage.name
|
||||
stage_name = info.get_stage().name
|
||||
writer.write(f'// {stage_name}')
|
||||
with writer.indent('{','}'):
|
||||
writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name})')
|
||||
writer.write(f'desc.{stage_name.lower()}_func.bytecode = SG_RANGE({entry_point_name});')
|
||||
writer.write(f'desc.{stage_name.lower()}_func.entry = "{entry_point}";')
|
||||
|
||||
writer.write('// 资源绑定')
|
||||
resource_index_map = {}
|
||||
# 处理资源绑定
|
||||
for resource in info.resources:
|
||||
binding_index = resource.binding_index
|
||||
param_index = resource_index_map.get(resource.type, 0)
|
||||
resource_index_map[resource.type] = param_index + 1
|
||||
self._write_resource_binding(writer, info.parameters)
|
||||
|
||||
if resource.type == ResourceType.UNIFORM_BUFFER:
|
||||
uniform_info = resource.uniform_data
|
||||
writer.write(f'desc.uniform_blocks[{param_index}].name = "{resource.name}";')
|
||||
writer.write(f'desc.uniform_blocks[{param_index}].size = {uniform_info.size};')
|
||||
writer.write(f'desc.uniform_blocks[{param_index}].hlsl_register_b_n = {uniform_info.binding};')
|
||||
writer.write(f'desc.uniform_blocks[{param_index}].msl_buffer_n = {uniform_info.binding};')
|
||||
writer.write(f'desc.uniform_blocks[{param_index}].wgsl_group0_binding_n = {uniform_info.binding};')
|
||||
def _write_resource_binding(self, writer: IndentManager, parameters: List[Parameter]) -> None:
|
||||
resource_index = {}
|
||||
target = global_vars.target
|
||||
glsl_block_index = {}
|
||||
|
||||
elif resource.type == ResourceType.SAMPLED_TEXTURE:
|
||||
writer.write(f'desc.images[{param_index}].stage = SG_SHADERSTAGE_{stage_name.upper()}')
|
||||
writer.write(f'desc.images[{param_index}].name = "{resource.name}";')
|
||||
writer.write(f'desc.sampled_textures[{param_index}].name = "{resource.name}";')
|
||||
for p in parameters:
|
||||
binding_kind = p.get_binding_kind()
|
||||
index = resource_index.get(binding_kind, 0)
|
||||
resource_index[binding_kind] = index + 1
|
||||
stage_name = p.stage.name.upper()
|
||||
if binding_kind == BindingKind.UNIFORM:
|
||||
t = f'desc.uniform_blocks[{index}]'
|
||||
writer.write(f'{t}.stage = SG_SHADERSTAGE_{stage_name};')
|
||||
writer.write(f'{t}.size = {p.get_byte_size()};')
|
||||
if target == TargetFormat.GLSL:
|
||||
glsl_index = glsl_block_index.get(binding_kind, 0)
|
||||
glsl_block_index[binding_kind] = glsl_index + 1
|
||||
writer.write(f'{t}.glsl_uniforms[{glsl_index}].type = ;')
|
||||
writer.write(f'{t}.glsl_uniforms[{glsl_index}].array_count = 1;')
|
||||
writer.write(f'{t}.glsl_uniforms[{glsl_index}].glsl_name = "{p.name}";')
|
||||
elif target == TargetFormat.DXBC:
|
||||
writer.write(f'{t}.hlsl_register_b_n = {p.get_register_index()};')
|
||||
|
||||
elif resource.type == ResourceType.STORAGE_BUFFER:
|
||||
writer.write(f'desc.storage_buffers[{param_index}].name = "{resource.name}";')
|
||||
|
||||
elif resource.type == ResourceType.SAMPLER:
|
||||
writer.write(f'desc.samplers[{param_index}].name = "{resource.name}";')
|
||||
|
||||
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: List[ShaderInfo]) -> None:
|
||||
def _write_get_pipeline_desc_method(self, writer: IndentManager, binding_infos: ShaderInfos) -> None:
|
||||
"""写入getPipelineDesc方法"""
|
||||
with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc(sg_shader shader, sg_pixel_format pixel_format, int32_t sample_count)'' {', '}'):
|
||||
function_params = ('\n\t\t\tsg_shader shader, \n'
|
||||
'\t\t\tsg_pixel_format pixel_format, \n'
|
||||
'\t\t\tint32_t sample_count = 1, \n'
|
||||
'\t\t\tsg_primitive_type primitive_type = SG_PRIMITIVETYPE_TRIANGLES, \n'
|
||||
'\t\t\tsg_cull_mode cull_mode = SG_CULLMODE_NONE\n')
|
||||
|
||||
with writer.indent(f'static sg_pipeline_desc get_{global_vars.source_file_name}_pipeline_desc({function_params}\t)'' {', '}'):
|
||||
writer.write('sg_pipeline_desc desc = {};')
|
||||
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline";')
|
||||
|
||||
writer.write('desc.shader = shader;')
|
||||
writer.write('desc.index_type = SG_INDEXTYPE_UINT32;')
|
||||
writer.write()
|
||||
|
||||
# 处理顶点缓冲区大小
|
||||
writer.write(f'desc.layout.buffers[0].stride = {binding_infos.vertex_size};')
|
||||
writer.write('desc.layout.buffers[0].step_func = SG_VERTEXSTEP_PER_VERTEX;')
|
||||
writer.write('desc.layout.buffers[0].step_rate = 1;')
|
||||
writer.write()
|
||||
|
||||
# 处理顶点输入布局
|
||||
for i, vertex_field in enumerate(global_vars.layout.vertex_fields):
|
||||
writer.write(f'// {vertex_field.semantic}')
|
||||
writer.write(f'desc.layout.attrs[{i}].buffer_index = {vertex_field.location}')
|
||||
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.offset}')
|
||||
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.type.scalar_type}')
|
||||
for i, vertex_field in enumerate(binding_infos.vertex_layout):
|
||||
writer.write(f'// {vertex_field.semanticName} {vertex_field.semanticIndex or ''}')
|
||||
writer.write(f'desc.layout.attrs[{i}].buffer_index = 0;')
|
||||
writer.write(f'desc.layout.attrs[{i}].offset = {vertex_field.binding.offset};')
|
||||
writer.write(f'desc.layout.attrs[{i}].format = {vertex_field.get_sg_format()};')
|
||||
writer.write()
|
||||
|
||||
writer.write('desc.primitive_type = primitive_type;')
|
||||
writer.write('desc.cull_mode = cull_mode;')
|
||||
writer.write('desc.face_winding = SG_FACEWINDING_CW;')
|
||||
|
||||
writer.write('desc.depth.write_enabled = false;')
|
||||
writer.write('desc.depth.compare = SG_COMPAREFUNC_NEVER;')
|
||||
writer.write('desc.depth.pixel_format = SG_PIXELFORMAT_NONE;')
|
||||
|
||||
writer.write('desc.colors[0].blend.enabled = true;')
|
||||
writer.write('desc.colors[0].blend.src_factor_rgb = SG_BLENDFACTOR_SRC_ALPHA;')
|
||||
writer.write('desc.colors[0].blend.dst_factor_rgb = SG_BLENDFACTOR_ONE_MINUS_SRC_ALPHA;')
|
||||
writer.write('desc.colors[0].pixel_format = pixel_format;')
|
||||
writer.write('desc.colors[0].write_mask = SG_COLORMASK_RGBA;')
|
||||
|
||||
writer.write('desc.sample_count = sample_count;')
|
||||
writer.write(f'desc.label = "{global_vars.source_file_name}_pipeline_desc";')
|
||||
writer.write('return desc;')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Tuple
|
||||
|
||||
from shader_types import FieldType
|
||||
from shader_reflection_type import Field, TypeKind, ScalarType
|
||||
|
||||
format_mapping = {
|
||||
('int32', 1): 'SG_VERTEXFORMAT_INVALID',
|
||||
@@ -29,57 +29,56 @@ format_mapping = {
|
||||
|
||||
# 类型大小映射
|
||||
type_sizes = {
|
||||
'int32': 4,
|
||||
'uint32': 4,
|
||||
'float32': 4,
|
||||
'int8': 1,
|
||||
'uint8': 1,
|
||||
'int16': 2,
|
||||
'uint16': 2,
|
||||
'float16': 2,
|
||||
ScalarType.INT32: 4,
|
||||
ScalarType.UINT32: 4,
|
||||
ScalarType.FLOAT32: 4,
|
||||
ScalarType.INT8: 1,
|
||||
ScalarType.UINT8: 1,
|
||||
ScalarType.INT16: 2,
|
||||
ScalarType.UINT16: 2,
|
||||
ScalarType.FLOAT16: 2, # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||||
}
|
||||
|
||||
def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
|
||||
# C类型映射表
|
||||
c_type_mapping = {
|
||||
ScalarType.INT32: 'int32_t',
|
||||
ScalarType.UINT32: 'uint32_t',
|
||||
ScalarType.FLOAT32: 'float',
|
||||
ScalarType.INT8: 'int8_t',
|
||||
ScalarType.UINT8: 'uint8_t',
|
||||
ScalarType.INT16: 'int16_t',
|
||||
ScalarType.UINT16: 'uint16_t',
|
||||
ScalarType.FLOAT16: 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||||
}
|
||||
|
||||
# 对齐要求映射表
|
||||
alignment_mapping = {
|
||||
ScalarType.INT32: 4,
|
||||
ScalarType.UINT32: 4,
|
||||
ScalarType.FLOAT32: 4,
|
||||
ScalarType.INT8: 1,
|
||||
ScalarType.UINT8: 1,
|
||||
ScalarType.INT16: 2,
|
||||
ScalarType.UINT16: 2,
|
||||
ScalarType.FLOAT16: 2,
|
||||
}
|
||||
|
||||
def get_type_info(field_type: Field) -> Tuple[str, int, int]:
|
||||
"""获取类型信息:C类型名称、元素数量、字节大小
|
||||
|
||||
Returns:
|
||||
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
||||
"""
|
||||
# C类型映射表
|
||||
c_type_mapping = {
|
||||
'int32': 'int32_t',
|
||||
'uint32': 'uint32_t',
|
||||
'float32': 'float',
|
||||
'int8': 'int8_t',
|
||||
'uint8': 'uint8_t',
|
||||
'int16': 'int16_t',
|
||||
'uint16': 'uint16_t',
|
||||
'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||||
}
|
||||
kind = field_type.type.kind
|
||||
scalar_type = field_type.type.scalarType
|
||||
element_count = field_type.type.elementCount
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
|
||||
# 对齐要求映射表
|
||||
alignment_mapping = {
|
||||
'int32': 4,
|
||||
'uint32': 4,
|
||||
'float32': 4,
|
||||
'int8': 1,
|
||||
'uint8': 1,
|
||||
'int16': 2,
|
||||
'uint16': 2,
|
||||
'float16': 2,
|
||||
}
|
||||
|
||||
if field_type.kind == 'scalar':
|
||||
scalar_type = field_type.scalar_type or 'float32'
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
if kind == TypeKind.SCALAR:
|
||||
size = type_sizes.get(scalar_type, 4)
|
||||
return c_type, 1, size
|
||||
|
||||
elif field_type.kind == 'vector':
|
||||
scalar_type = field_type.scalar_type or 'float32'
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
element_count = field_type.element_count or 4
|
||||
|
||||
elif kind == TypeKind.VECTOR:
|
||||
# 验证向量元素数量的合法性
|
||||
if element_count not in [1, 2, 3, 4]:
|
||||
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
||||
@@ -93,12 +92,9 @@ def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
|
||||
|
||||
return c_type, element_count, size
|
||||
|
||||
elif field_type.kind == 'matrix':
|
||||
scalar_type = field_type.scalar_type or 'float32'
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
|
||||
row_count = field_type.row_count or 4
|
||||
column_count = field_type.column_count or 4
|
||||
elif kind == TypeKind.MATRIX:
|
||||
row_count = field_type.type.rowCount or 4
|
||||
column_count = field_type.type.columnCount or 4
|
||||
|
||||
# 验证矩阵维度的合法性
|
||||
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
||||
@@ -125,8 +121,7 @@ def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
|
||||
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
||||
return 'float', 1, 4
|
||||
|
||||
|
||||
def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
|
||||
def get_c_type_declaration(field_type: Field, field_name: str) -> str:
|
||||
"""生成C类型声明
|
||||
|
||||
Args:
|
||||
@@ -137,11 +132,12 @@ def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
|
||||
str: 完整的C类型声明
|
||||
"""
|
||||
c_type, count, _ = get_type_info(field_type)
|
||||
kind = field_type.type.kind
|
||||
|
||||
if field_type.kind == 'scalar':
|
||||
if kind == TypeKind.SCALAR:
|
||||
return f"{c_type} {field_name}"
|
||||
|
||||
elif field_type.kind == 'vector':
|
||||
elif kind == TypeKind.VECTOR:
|
||||
# 对于向量,可以选择使用数组或专门的向量类型
|
||||
if count == 1:
|
||||
return f"{c_type} {field_name}"
|
||||
@@ -152,9 +148,9 @@ def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
|
||||
# 选项2: 使用对齐的结构体(如果需要的话)
|
||||
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
||||
|
||||
elif field_type.kind == 'matrix':
|
||||
rows = field_type.row_count or 4
|
||||
cols = field_type.column_count or 4
|
||||
elif kind == TypeKind.MATRIX:
|
||||
rows = field_type.type.rowCount or 4
|
||||
cols = field_type.type.columnCount or 4
|
||||
|
||||
# 矩阵使用二维数组(列主序)
|
||||
return f"{c_type} {field_name}[{cols}][{rows}]"
|
||||
|
||||
@@ -8,7 +8,7 @@ from typing import List, Dict
|
||||
|
||||
from code_generator import CodeGenerator
|
||||
from shader_parser import ShaderParser
|
||||
from shader_types import ShaderInfo
|
||||
from shader_reflection_type import ShaderReflection, ShaderInfos
|
||||
|
||||
|
||||
class SlangCompiler:
|
||||
@@ -16,10 +16,10 @@ class SlangCompiler:
|
||||
self.parser = ShaderParser()
|
||||
self.code_generator = CodeGenerator()
|
||||
|
||||
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||||
def parse_slang_shader(self) -> List[ShaderReflection]:
|
||||
"""解析Slang着色器源码,提取资源信息"""
|
||||
return self.parser.parse_slang_shader()
|
||||
|
||||
def generate_binding_functions(self, binding_infos: List[ShaderInfo], output_path: str):
|
||||
def generate_binding_functions(self, binding_infos: ShaderInfos, output_path: str):
|
||||
"""生成C/C++绑定函数"""
|
||||
self.code_generator.generate_binding_functions(binding_infos, output_path)
|
||||
|
||||
@@ -4,21 +4,22 @@ Slang Compiler - Command Generation
|
||||
"""
|
||||
from exe_finder import slangc_path
|
||||
from global_vars import global_vars
|
||||
from shader_types import TargetFormat, ShaderStage
|
||||
from shader_reflection_type import *
|
||||
|
||||
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, output_path: str):
|
||||
|
||||
def make_cmd(source_file: str, target: TargetFormat, stage: Stage, entry_point: str, output_path: str):
|
||||
"""生成编译命令"""
|
||||
target_flag = {
|
||||
TargetFormat.SPIRV: 'spirv',
|
||||
TargetFormat.GLSL: 'glsl',
|
||||
TargetFormat.DXBC: 'dxbc',
|
||||
TargetFormat.MSL: 'metal',
|
||||
TargetFormat.HLSL_DX11: 'hlsl',
|
||||
}[target]
|
||||
|
||||
stage_flag = {
|
||||
ShaderStage.VERTEX: 'vertex',
|
||||
ShaderStage.FRAGMENT: 'fragment',
|
||||
ShaderStage.COMPUTE: 'compute'
|
||||
Stage.VERTEX: 'vertex',
|
||||
Stage.FRAGMENT: 'fragment',
|
||||
Stage.COMPUTE: 'compute'
|
||||
}[stage]
|
||||
|
||||
cmd = [
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
from shader_types import *
|
||||
|
||||
from shader_reflection_type import *
|
||||
|
||||
class GlobalVars:
|
||||
source_file = ''
|
||||
source_file_name = ''
|
||||
source_path = ''
|
||||
output_dir = ''
|
||||
target: TargetFormat
|
||||
layout: ShaderLayout
|
||||
target = TargetFormat.DXBC
|
||||
|
||||
global_vars = GlobalVars()
|
||||
|
||||
@@ -31,6 +31,10 @@ def main():
|
||||
global_vars.output_dir = os.path.abspath(args.output_dir)
|
||||
global_vars.target = TargetFormat(args.target)
|
||||
|
||||
shader_infos = ShaderInfos(
|
||||
stages={}
|
||||
)
|
||||
|
||||
# 仅保留路径部分
|
||||
include_dirs = [
|
||||
global_vars.source_path,
|
||||
@@ -45,12 +49,8 @@ def main():
|
||||
# 解析着色器
|
||||
print(f"**Parsing** {args.input}...")
|
||||
shaders = compiler.parse_slang_shader()
|
||||
|
||||
# 编译每个入口点
|
||||
shader_infos = []
|
||||
|
||||
for name, shader_info in shaders.items():
|
||||
shader_infos.append(shader_info)
|
||||
for shader_info in shaders:
|
||||
shader_infos.add_shader_info(shader_info)
|
||||
|
||||
binding_output_file_pathname = os.path.abspath(args.output_dir)
|
||||
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{global_vars.source_file_name}.shader.h")
|
||||
|
||||
@@ -1,85 +0,0 @@
|
||||
import json
|
||||
from typing import List, Dict
|
||||
|
||||
from shader_types import *
|
||||
from shader_types import UniformBuffer
|
||||
|
||||
|
||||
class ShaderLayoutParser:
|
||||
"""解析JSON数据并提取到类对象"""
|
||||
|
||||
def parse(self, json_file: str) -> ShaderLayout:
|
||||
"""解析JSON文件并返回ShaderLayout对象"""
|
||||
with open(json_file, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
layout = ShaderLayout()
|
||||
layout.vertex_fields = self.parse_vertex_input(json_data)
|
||||
layout.uniform_buffers = self.parse_uniform_buffers(json_data)
|
||||
|
||||
return layout
|
||||
|
||||
@staticmethod
|
||||
def parse_vertex_input(json_data: Dict) -> List[VertexField]:
|
||||
"""解析顶点输入字段"""
|
||||
vertex_fields = []
|
||||
|
||||
entry_points = json_data.get('entryPoints', [])
|
||||
for entry in entry_points:
|
||||
if entry.get('stage') == 'vertex':
|
||||
parameters = entry.get('parameters', [])
|
||||
offset = entry.get('offset', 0)
|
||||
for param in parameters:
|
||||
if param.get('name') == 'input' and param.get('stage') == 'vertex':
|
||||
fields = param.get('type', {}).get('fields', [])
|
||||
for field in fields:
|
||||
vertex_field = VertexField(
|
||||
name=field['name'],
|
||||
type=FieldType.from_dict(field['type']),
|
||||
location=field['binding']['index'],
|
||||
semantic=field.get('semanticName', ''),
|
||||
semantic_index=field.get('semanticIndex', 0)
|
||||
)
|
||||
vertex_fields.append(vertex_field)
|
||||
|
||||
return vertex_fields
|
||||
|
||||
@staticmethod
|
||||
def parse_uniform_buffer(param: Dict) -> UniformBuffer | None:
|
||||
binding = param.get('binding', {})
|
||||
type_info = param.get('type', {})
|
||||
kind = type_info.get('kind', '')
|
||||
if kind == 'constantBuffer' or kind == 'parameterBlock':
|
||||
buffer = UniformBuffer(
|
||||
name=param['name'],
|
||||
binding=binding['index']
|
||||
)
|
||||
|
||||
# 解析缓冲区字段
|
||||
element_type = param.get('type', {}).get('elementType', {})
|
||||
if element_type.get('kind') == 'struct':
|
||||
fields = element_type.get('fields', [])
|
||||
for field in fields:
|
||||
uniform_field = UniformField(
|
||||
name=field['name'],
|
||||
type=FieldType.from_dict(field['type']),
|
||||
offset=field['binding']['offset'],
|
||||
size=field['binding']['size']
|
||||
)
|
||||
buffer.size += uniform_field.size
|
||||
buffer.fields.append(uniform_field)
|
||||
|
||||
return buffer
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_uniform_buffers(json_data: Dict) -> List[UniformBuffer]:
|
||||
"""解析Uniform缓冲区"""
|
||||
uniform_buffers = []
|
||||
|
||||
parameters = json_data.get('parameters', [])
|
||||
for param in parameters:
|
||||
buffer = ShaderLayoutParser.parse_uniform_buffer(param)
|
||||
uniform_buffers.append(buffer)
|
||||
|
||||
return uniform_buffers
|
||||
@@ -13,12 +13,11 @@ from typing import List, Dict, Optional
|
||||
|
||||
from compiler_cmd import make_cmd
|
||||
from global_vars import global_vars
|
||||
from shader_layout import ShaderLayoutParser
|
||||
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo, UniformBuffer, FieldType, ResourceSubType
|
||||
from shader_reflection_type import *
|
||||
|
||||
|
||||
class ShaderParser:
|
||||
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||||
def parse_slang_shader(self) -> List[ShaderReflection]:
|
||||
"""解析Slang着色器源码,提取资源信息"""
|
||||
|
||||
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
|
||||
@@ -28,7 +27,7 @@ class ShaderParser:
|
||||
entry_points = self._find_entry_points(source)
|
||||
print(f"Found potential entry points: {entry_points}")
|
||||
|
||||
shaders = {}
|
||||
shaders = []
|
||||
|
||||
# 为每个入口点单独进行完整编译和反射
|
||||
for entry_name, stage in entry_points.items():
|
||||
@@ -39,22 +38,22 @@ class ShaderParser:
|
||||
)
|
||||
|
||||
if shader_info:
|
||||
shaders[entry_name] = shader_info
|
||||
shaders.append(shader_info)
|
||||
else:
|
||||
print(f"Failed to process entry point: {entry_name}")
|
||||
|
||||
return shaders
|
||||
|
||||
def _find_entry_points(self, source: str) -> Dict[str, ShaderStage]:
|
||||
def _find_entry_points(self, source: str) -> Dict[str, Stage]:
|
||||
"""在源码中查找入口点函数"""
|
||||
entry_points = {}
|
||||
|
||||
# 1. 查找带有Slang属性的函数
|
||||
attribute_patterns = [
|
||||
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.VERTEX),
|
||||
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
|
||||
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.FRAGMENT),
|
||||
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', ShaderStage.COMPUTE),
|
||||
(r'$$shader\s*$\s*["\']vertex["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.VERTEX),
|
||||
(r'$$shader\s*$\s*["\']fragment["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
|
||||
(r'$$shader\s*$\s*["\']pixel["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.FRAGMENT),
|
||||
(r'$$shader\s*$\s*["\']compute["\']\s*$\s*$$.*?(\w+)\s*\(', Stage.COMPUTE),
|
||||
]
|
||||
|
||||
for pattern, stage in attribute_patterns:
|
||||
@@ -66,29 +65,29 @@ class ShaderParser:
|
||||
# 2. 查找常见的命名约定
|
||||
common_entry_points = {
|
||||
# Vertex shaders
|
||||
'vertex_main': ShaderStage.VERTEX,
|
||||
'vertexMain': ShaderStage.VERTEX,
|
||||
'vert_main': ShaderStage.VERTEX,
|
||||
'vs_main': ShaderStage.VERTEX,
|
||||
'vertex_shader': ShaderStage.VERTEX,
|
||||
'VS': ShaderStage.VERTEX,
|
||||
'vertex_main': Stage.VERTEX,
|
||||
'vertexMain': Stage.VERTEX,
|
||||
'vert_main': Stage.VERTEX,
|
||||
'vs_main': Stage.VERTEX,
|
||||
'vertex_shader': Stage.VERTEX,
|
||||
'VS': Stage.VERTEX,
|
||||
|
||||
# Fragment shaders
|
||||
'fragment_main': ShaderStage.FRAGMENT,
|
||||
'fragmentMain': ShaderStage.FRAGMENT,
|
||||
'frag_main': ShaderStage.FRAGMENT,
|
||||
'pixel_main': ShaderStage.FRAGMENT,
|
||||
'ps_main': ShaderStage.FRAGMENT,
|
||||
'fragment_shader': ShaderStage.FRAGMENT,
|
||||
'PS': ShaderStage.FRAGMENT,
|
||||
'fragment_main': Stage.FRAGMENT,
|
||||
'fragmentMain': Stage.FRAGMENT,
|
||||
'frag_main': Stage.FRAGMENT,
|
||||
'pixel_main': Stage.FRAGMENT,
|
||||
'ps_main': Stage.FRAGMENT,
|
||||
'fragment_shader': Stage.FRAGMENT,
|
||||
'PS': Stage.FRAGMENT,
|
||||
|
||||
# Compute shaders
|
||||
'compute_main': ShaderStage.COMPUTE,
|
||||
'computeMain': ShaderStage.COMPUTE,
|
||||
'comp_main': ShaderStage.COMPUTE,
|
||||
'cs_main': ShaderStage.COMPUTE,
|
||||
'compute_shader': ShaderStage.COMPUTE,
|
||||
'CS': ShaderStage.COMPUTE,
|
||||
'compute_main': Stage.COMPUTE,
|
||||
'computeMain': Stage.COMPUTE,
|
||||
'comp_main': Stage.COMPUTE,
|
||||
'cs_main': Stage.COMPUTE,
|
||||
'compute_shader': Stage.COMPUTE,
|
||||
'CS': Stage.COMPUTE,
|
||||
}
|
||||
|
||||
# 查找这些函数名是否在源码中存在
|
||||
@@ -112,22 +111,22 @@ class ShaderParser:
|
||||
|
||||
# 顶点着色器关键词
|
||||
if any(keyword in func_lower for keyword in ['vert', 'vs', 'vertex']):
|
||||
entry_points[func_name] = ShaderStage.VERTEX
|
||||
entry_points[func_name] = Stage.VERTEX
|
||||
print(f"Inferred vertex shader: {func_name}")
|
||||
# 片元着色器关键词
|
||||
elif any(keyword in func_lower for keyword in ['frag', 'pixel', 'ps', 'fs', 'fragment']):
|
||||
entry_points[func_name] = ShaderStage.FRAGMENT
|
||||
entry_points[func_name] = Stage.FRAGMENT
|
||||
print(f"Inferred fragment shader: {func_name}")
|
||||
# 计算着色器关键词
|
||||
elif any(keyword in func_lower for keyword in ['comp', 'cs', 'compute']):
|
||||
entry_points[func_name] = ShaderStage.COMPUTE
|
||||
entry_points[func_name] = Stage.COMPUTE
|
||||
print(f"Inferred compute shader: {func_name}")
|
||||
|
||||
# 4. 查找带有HLSL语义的函数
|
||||
semantic_patterns = [
|
||||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', ShaderStage.VERTEX),
|
||||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', ShaderStage.FRAGMENT),
|
||||
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', ShaderStage.VERTEX),
|
||||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Position', Stage.VERTEX),
|
||||
(r'(\w+)\s*$[^)]*$\s*:\s*SV_Target', Stage.FRAGMENT),
|
||||
(r'(\w+)\s*$[^)]*$\s*:\s*POSITION', Stage.VERTEX),
|
||||
]
|
||||
|
||||
for pattern, stage in semantic_patterns:
|
||||
@@ -141,13 +140,13 @@ class ShaderParser:
|
||||
# 5. 如果没找到任何入口点,查找main函数
|
||||
if not entry_points:
|
||||
if re.search(r'\bmain\s*\(', source, re.IGNORECASE):
|
||||
entry_points['main'] = ShaderStage.VERTEX
|
||||
entry_points['main'] = Stage.VERTEX
|
||||
print("Found main function, assuming vertex shader")
|
||||
|
||||
print(f"Total entry points found: {len(entry_points)}")
|
||||
return entry_points
|
||||
|
||||
def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
|
||||
def _compile_and_reflect_entry_point(self, entry_name: str, stage: Stage, source: str) -> Optional[ShaderReflection]:
|
||||
"""为单个入口点进行完整编译和反射"""
|
||||
|
||||
# 创建临时输出文件
|
||||
@@ -180,10 +179,6 @@ class ShaderParser:
|
||||
if result.returncode == 0:
|
||||
# 读取反射数据
|
||||
if os.path.exists(reflection_path):
|
||||
if stage == ShaderStage.VERTEX:
|
||||
shader_layout_ = ShaderLayoutParser()
|
||||
global_vars.layout = shader_layout_.parse(reflection_path)
|
||||
|
||||
with open(reflection_path, 'r', encoding='utf-8') as f:
|
||||
reflection_json = f.read()
|
||||
print(f"Reflection JSON length: {len(reflection_json)}")
|
||||
@@ -191,17 +186,15 @@ class ShaderParser:
|
||||
if reflection_json.strip():
|
||||
try:
|
||||
reflection = json.loads(reflection_json)
|
||||
shader_info = self._create_shader_info_from_reflection(
|
||||
reflection, entry_name, stage, source
|
||||
)
|
||||
shader_reflection = ShaderReflection.from_dict(reflection)
|
||||
|
||||
# 读取编译后的二进制数据
|
||||
with open(temp_output_path, 'rb') as f:
|
||||
shader_binary = f.read()
|
||||
# 写入shader_info.blob
|
||||
shader_info.blob = shader_binary
|
||||
shader_reflection.blob = shader_binary
|
||||
|
||||
return shader_info
|
||||
return shader_reflection
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"JSON parsing error: {e}")
|
||||
print(f"Raw JSON: {reflection_json[:500]}...")
|
||||
@@ -217,161 +210,4 @@ class ShaderParser:
|
||||
if os.path.exists(temp_file):
|
||||
os.unlink(temp_file)
|
||||
|
||||
return self._create_shader_info_manual(entry_name, stage, source)
|
||||
|
||||
def _create_shader_info_from_reflection(self, reflection: dict, entry_name: str,
|
||||
stage: ShaderStage, source: str) -> ShaderInfo:
|
||||
"""从反射数据创建ShaderInfo"""
|
||||
print(f"Processing reflection data for {entry_name}")
|
||||
print(f"Reflection keys: {list(reflection.keys())}")
|
||||
|
||||
shader_info = ShaderInfo(
|
||||
stage=stage,
|
||||
entry_point=entry_name,
|
||||
resources=[],
|
||||
blob=b'', # 二进制数据将在编译后填充
|
||||
)
|
||||
|
||||
# Slang反射数据的可能结构
|
||||
# 尝试不同的数据结构
|
||||
entry_point_data = None
|
||||
|
||||
# 方法1: 直接在根级别查找
|
||||
if 'parameters' in reflection:
|
||||
entry_point_data = reflection
|
||||
|
||||
# 方法2: 在entryPoints数组中查找
|
||||
elif 'entryPoints' in reflection:
|
||||
for ep in reflection['entryPoints']:
|
||||
if ep.get('name') == entry_name:
|
||||
entry_point_data = ep
|
||||
break
|
||||
|
||||
# 方法3: 在modules中查找
|
||||
elif 'modules' in reflection:
|
||||
for module in reflection['modules']:
|
||||
if 'entryPoints' in module:
|
||||
for ep in module['entryPoints']:
|
||||
if ep.get('name') == entry_name:
|
||||
entry_point_data = ep
|
||||
break
|
||||
|
||||
if entry_point_data:
|
||||
print(f"Found entry point data: {list(entry_point_data.keys())}")
|
||||
|
||||
# 解析资源参数
|
||||
parameters = entry_point_data.get('parameters', [])
|
||||
print(f"Found {len(parameters)} parameters")
|
||||
|
||||
for param in parameters:
|
||||
print(f"Processing parameter: {param}")
|
||||
resource = self._parse_resource(param)
|
||||
if resource:
|
||||
shader_info.resources.append(resource)
|
||||
print(f"Added resource: {resource.name} ({resource.type.value})")
|
||||
else:
|
||||
print("No entry point data found in reflection, using manual parsing")
|
||||
# 使用手动解析作为fallback
|
||||
manual_resources = self._extract_resources_from_source(source)
|
||||
shader_info.resources.extend(manual_resources)
|
||||
|
||||
print(f"Shader {entry_name} has {len(shader_info.resources)} resources")
|
||||
return shader_info
|
||||
|
||||
def _create_shader_info_manual(self, entry_name: str, stage: ShaderStage, source: str) -> ShaderInfo:
|
||||
"""手动创建ShaderInfo(fallback方法)"""
|
||||
print(f"Creating shader info manually for {entry_name}")
|
||||
shader_info = ShaderInfo(
|
||||
stage=stage,
|
||||
entry_point=entry_name,
|
||||
resources=[],
|
||||
blob=b'',
|
||||
)
|
||||
|
||||
# 手动解析资源
|
||||
resources = self._extract_resources_from_source(source)
|
||||
shader_info.resources.extend(resources)
|
||||
|
||||
print(f"Manual parsing found {len(resources)} resources")
|
||||
return shader_info
|
||||
|
||||
def _extract_resources_from_source(self, source: str) -> List[Resource]:
|
||||
"""从源码中提取资源声明"""
|
||||
resources = []
|
||||
|
||||
# 资源声明的正则表达式模式
|
||||
patterns = {
|
||||
# Texture resources
|
||||
ResourceType.SAMPLED_TEXTURE: [
|
||||
r'Texture2D\s*(?:<[^>]*>)?\s+(\w+)',
|
||||
r'Texture3D\s*(?:<[^>]*>)?\s+(\w+)',
|
||||
r'TextureCube\s*(?:<[^>]*>)?\s+(\w+)',
|
||||
],
|
||||
ResourceType.STORAGE_TEXTURE: [
|
||||
r'RWTexture2D\s*(?:<[^>]*>)?\s+(\w+)',
|
||||
r'RWTexture3D\s*(?:<[^>]*>)?\s+(\w+)',
|
||||
],
|
||||
# Buffer resources
|
||||
ResourceType.STORAGE_BUFFER: [
|
||||
r'RWStructuredBuffer\s*<[^>]*>\s+(\w+)',
|
||||
r'RWByteAddressBuffer\s+(\w+)',
|
||||
r'StructuredBuffer\s*<[^>]*>\s+(\w+)',
|
||||
r'ByteAddressBuffer\s+(\w+)',
|
||||
],
|
||||
ResourceType.UNIFORM_BUFFER: [
|
||||
r'ConstantBuffer\s*<[^>]*>\s+(\w+)',
|
||||
r'cbuffer\s+(\w+)',
|
||||
],
|
||||
ResourceType.SAMPLER: [
|
||||
r'SamplerState\s+(\w+)',
|
||||
r'SamplerComparisonState\s+(\w+)',
|
||||
]
|
||||
}
|
||||
|
||||
for resource_type, type_patterns in patterns.items():
|
||||
for pattern in type_patterns:
|
||||
matches = re.findall(pattern, source, re.IGNORECASE)
|
||||
for match in matches:
|
||||
resources.append(Resource(match, resource_type))
|
||||
print(f"Found resource: {match} ({resource_type.value})")
|
||||
|
||||
return resources
|
||||
|
||||
def _parse_resource(self, param: dict) -> Optional[Resource]:
|
||||
"""解析资源参数"""
|
||||
type_info = param.get('type', {})
|
||||
|
||||
kind = type_info.get('kind', '')
|
||||
if kind == 'resource':
|
||||
kind = type_info.get('baseShape', '')
|
||||
param_name = param.get('name', '')
|
||||
resource = Resource(param_name, ResourceType.STORAGE_BUFFER)
|
||||
resource.binding_index = param.get('binding', {}).get('index', -1)
|
||||
|
||||
# 判断资源类型
|
||||
if 'texture' in kind:
|
||||
if 'RW' in kind:
|
||||
resource.type = ResourceType.STORAGE_TEXTURE
|
||||
else:
|
||||
resource.type = ResourceType.SAMPLED_TEXTURE
|
||||
|
||||
if '2D' in kind:
|
||||
resource.sub_type = ResourceSubType.TEXTURE_2D
|
||||
elif 'cube' in kind:
|
||||
resource.sub_type = ResourceSubType.TEXTURE_CUBE
|
||||
elif '3D' in kind:
|
||||
resource.sub_type = ResourceSubType.TEXTURE_3D
|
||||
elif 'array' in kind:
|
||||
resource.sub_type = ResourceSubType.TEXTURE_ARRAY
|
||||
elif kind == 'structuredBuffer' or kind == 'byteAddressBuffer':
|
||||
if 'RW' in kind:
|
||||
resource.type = ResourceType.STORAGE_BUFFER
|
||||
else:
|
||||
resource.type = ResourceType.STORAGE_BUFFER
|
||||
elif kind == 'constantBuffer' or kind == 'parameterBlock':
|
||||
resource.type = ResourceType.UNIFORM_BUFFER
|
||||
resource.uniform_data = ShaderLayoutParser.parse_uniform_buffer(param)
|
||||
elif kind == 'samplerState' or 'sampler' in kind:
|
||||
resource.type = ResourceType.SAMPLER
|
||||
|
||||
return resource
|
||||
return None
|
||||
429
tools/shader_reflection_type.py
Normal file
429
tools/shader_reflection_type.py
Normal file
@@ -0,0 +1,429 @@
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import List, Optional, Union, Dict, Any
|
||||
|
||||
|
||||
# 枚举类定义
|
||||
class TargetFormat(Enum):
|
||||
GLSL = "glsl"
|
||||
DXBC = "dxbc"
|
||||
MSL = "msl"
|
||||
HLSL_DX11 = "hlsl"
|
||||
|
||||
class BindingKind(Enum):
|
||||
SUB_ELEMENT_REGISTER_SPACE = "subElementRegisterSpace"
|
||||
UNIFORM = "uniform"
|
||||
TEXTURE = "texture"
|
||||
STORAGE_BUFFER = "storageBuffer"
|
||||
CONSTANT_BUFFER = "constantBuffer"
|
||||
SAMPLER = "sampler"
|
||||
VARYING_INPUT = "varyingInput"
|
||||
VARYING_OUTPUT = "varyingOutput"
|
||||
|
||||
|
||||
class TypeKind(Enum):
|
||||
PARAMETER_BLOCK = "parameterBlock"
|
||||
STRUCT = "struct"
|
||||
MATRIX = "matrix"
|
||||
SCALAR = "scalar"
|
||||
VECTOR = "vector"
|
||||
|
||||
|
||||
class ScalarType(Enum):
|
||||
INT32 = "int32"
|
||||
UINT32 = "uint32"
|
||||
INT8 = "int8"
|
||||
UINT8 = "uint8"
|
||||
INT16 = "int16"
|
||||
UINT16 = "uint16"
|
||||
FLOAT16 = "float16"
|
||||
FLOAT32 = "float32"
|
||||
|
||||
|
||||
class Stage(Enum):
|
||||
VERTEX = "vertex"
|
||||
FRAGMENT = "fragment"
|
||||
PIXEL = "pixel"
|
||||
COMPUTE = "compute"
|
||||
|
||||
|
||||
# 基础数据类
|
||||
@dataclass
|
||||
class Binding:
|
||||
kind: BindingKind
|
||||
index: Optional[int] = None
|
||||
offset: Optional[int] = None
|
||||
size: Optional[int] = None
|
||||
count: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Binding':
|
||||
return cls(
|
||||
kind=BindingKind(data['kind']),
|
||||
index=data.get('index'),
|
||||
offset=data.get('offset'),
|
||||
size=data.get('size'),
|
||||
count=data.get('count')
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScalarTypeInfo:
|
||||
kind: TypeKind
|
||||
scalarType: ScalarType
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ScalarTypeInfo':
|
||||
return cls(
|
||||
kind=TypeKind(data['kind']),
|
||||
scalarType=ScalarType(data['scalarType'])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatrixType:
|
||||
kind: TypeKind
|
||||
rowCount: int
|
||||
columnCount: int
|
||||
elementType: ScalarTypeInfo
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'MatrixType':
|
||||
return cls(
|
||||
kind=TypeKind(data['kind']),
|
||||
rowCount=data['rowCount'],
|
||||
columnCount=data['columnCount'],
|
||||
elementType=ScalarTypeInfo.from_dict(data['elementType'])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VectorType:
|
||||
kind: TypeKind
|
||||
elementCount: int
|
||||
elementType: ScalarTypeInfo
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'VectorType':
|
||||
return cls(
|
||||
kind=TypeKind(data['kind']),
|
||||
elementCount=data['elementCount'],
|
||||
elementType=ScalarTypeInfo.from_dict(data['elementType'])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Field:
|
||||
name: str
|
||||
type: Union[MatrixType, VectorType, ScalarTypeInfo, 'StructType']
|
||||
binding: Optional[Binding] = None
|
||||
stage: Optional[Stage] = None
|
||||
semanticName: Optional[str] = None
|
||||
semanticIndex: Optional[int] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Field':
|
||||
# 根据type的kind来决定如何解析
|
||||
type_data = data['type']
|
||||
type_kind = type_data['kind']
|
||||
|
||||
if type_kind == 'matrix':
|
||||
type_obj = MatrixType.from_dict(type_data)
|
||||
elif type_kind == 'vector':
|
||||
type_obj = VectorType.from_dict(type_data)
|
||||
elif type_kind == 'scalar':
|
||||
type_obj = ScalarTypeInfo.from_dict(type_data)
|
||||
elif type_kind == 'struct':
|
||||
type_obj = StructType.from_dict(type_data)
|
||||
else:
|
||||
raise ValueError(f"Unknown type kind: {type_kind}")
|
||||
|
||||
return cls(
|
||||
name=data['name'],
|
||||
type=type_obj,
|
||||
binding=Binding.from_dict(data['binding']) if 'binding' in data else None,
|
||||
stage=Stage(data['stage']) if 'stage' in data else None,
|
||||
semanticName=data.get('semanticName'),
|
||||
semanticIndex=data.get('semanticIndex')
|
||||
)
|
||||
|
||||
def get_sg_format(self):
|
||||
"""获取对应的sg_format"""
|
||||
out = 'SG_VERTEXFORMAT_'
|
||||
scalar_type = ScalarType.FLOAT32
|
||||
element_count = 1
|
||||
|
||||
if isinstance(self.type, ScalarTypeInfo):
|
||||
scalar_type = self.type.scalarType
|
||||
elif isinstance(self.type, VectorType):
|
||||
scalar_type = self.type.elementType.scalarType
|
||||
|
||||
if isinstance(self.type, VectorType):
|
||||
element_count = self.type.elementCount
|
||||
|
||||
if scalar_type == ScalarType.FLOAT32:
|
||||
out += 'FLOAT'
|
||||
elif scalar_type == ScalarType.INT32:
|
||||
out += 'INT'
|
||||
elif scalar_type == ScalarType.UINT32:
|
||||
out += 'UINT'
|
||||
elif scalar_type == ScalarType.INT8:
|
||||
out += 'BYTE'
|
||||
elif scalar_type == ScalarType.UINT8:
|
||||
out += 'UBYTE'
|
||||
elif scalar_type == ScalarType.INT16:
|
||||
out += 'SHORT'
|
||||
elif scalar_type == ScalarType.UINT16:
|
||||
out += 'USHORT'
|
||||
elif scalar_type == ScalarType.FLOAT16:
|
||||
out += 'HALF'
|
||||
else:
|
||||
raise ValueError(f"Unsupported scalar type: {scalar_type}")
|
||||
|
||||
if element_count > 1:
|
||||
out += str(element_count)
|
||||
return out
|
||||
|
||||
def get_byte_size(self):
|
||||
"""获取字段的字节大小"""
|
||||
if isinstance(self.type, ScalarTypeInfo):
|
||||
return 4 if self.type.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2
|
||||
elif isinstance(self.type, VectorType):
|
||||
return self.type.elementCount * (4 if self.type.elementType.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2)
|
||||
elif isinstance(self.type, MatrixType):
|
||||
return self.type.rowCount * self.type.columnCount * (4 if self.type.elementType.scalarType in (ScalarType.FLOAT32, ScalarType.UINT32, ScalarType.INT32) else 2)
|
||||
elif isinstance(self.type, StructType):
|
||||
return sum(f.get_byte_size() for f in self.type.fields)
|
||||
return 0
|
||||
|
||||
@dataclass
|
||||
class StructType:
|
||||
kind: TypeKind
|
||||
name: str
|
||||
fields: List[Field]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'StructType':
|
||||
fields = [Field.from_dict(f) for f in data['fields']]
|
||||
# 计算offset
|
||||
offset = 0
|
||||
for field in fields:
|
||||
if field.binding is not None:
|
||||
field.binding.offset = offset
|
||||
offset += field.get_byte_size()
|
||||
|
||||
return cls(
|
||||
kind=TypeKind(data['kind']),
|
||||
name=data['name'],
|
||||
fields=fields
|
||||
)
|
||||
|
||||
def get_byte_size(self) -> int:
|
||||
"""获取结构体的字节大小"""
|
||||
return sum(field.get_byte_size() for field in self.fields)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VarLayout:
|
||||
type: StructType
|
||||
binding: Binding
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'VarLayout':
|
||||
return cls(
|
||||
type=StructType.from_dict(data['type']),
|
||||
binding=Binding.from_dict(data['binding'])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ContainerVarLayout:
|
||||
bindings: List[Binding]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ContainerVarLayout':
|
||||
return cls(
|
||||
bindings=[Binding.from_dict(b) for b in data['bindings']]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParameterBlockType:
|
||||
kind: TypeKind
|
||||
elementType: StructType
|
||||
containerVarLayout: ContainerVarLayout
|
||||
elementVarLayout: VarLayout
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ParameterBlockType':
|
||||
return cls(
|
||||
kind=TypeKind(data['kind']),
|
||||
elementType=StructType.from_dict(data['elementType']),
|
||||
containerVarLayout=ContainerVarLayout.from_dict(data['containerVarLayout']),
|
||||
elementVarLayout=VarLayout.from_dict(data['elementVarLayout'])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Parameter:
|
||||
name: str
|
||||
binding: Binding
|
||||
type: Union[ParameterBlockType, StructType]
|
||||
stage: Optional[Stage] = None
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'Parameter':
|
||||
type_data = data['type']
|
||||
if type_data['kind'] == 'struct':
|
||||
type_obj = StructType.from_dict(type_data)
|
||||
else:
|
||||
type_obj = ParameterBlockType.from_dict(type_data)
|
||||
|
||||
return cls(
|
||||
name=data['name'],
|
||||
binding=Binding.from_dict(data['binding']),
|
||||
type=type_obj,
|
||||
stage=Stage(data['stage']) if 'stage' in data else None
|
||||
)
|
||||
|
||||
def get_byte_size(self) -> int:
|
||||
"""获取参数的字节大小"""
|
||||
if isinstance(self.type, ParameterBlockType):
|
||||
return self.type.elementVarLayout.type.get_byte_size()
|
||||
elif isinstance(self.type, StructType):
|
||||
return sum(field.get_byte_size() for field in self.type.fields)
|
||||
return 0
|
||||
|
||||
def get_binding_kind(self) -> BindingKind:
|
||||
kind = self.binding.kind
|
||||
if kind == BindingKind.SUB_ELEMENT_REGISTER_SPACE:
|
||||
kind = BindingKind.UNIFORM
|
||||
return kind
|
||||
|
||||
def get_register_index(self):
|
||||
"""获取HLSL注册索引"""
|
||||
if self.binding.index is not None:
|
||||
return self.binding.index
|
||||
return 0
|
||||
|
||||
@dataclass
|
||||
class EntryPointBinding:
|
||||
name: str
|
||||
binding: Binding
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'EntryPointBinding':
|
||||
return cls(
|
||||
name=data['name'],
|
||||
binding=Binding.from_dict(data['binding'])
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntryPointResult:
|
||||
stage: Stage
|
||||
binding: Binding
|
||||
type: Union[StructType, VectorType]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'EntryPointResult':
|
||||
kind = data['type']['kind']
|
||||
if kind == 'struct':
|
||||
type_obj = StructType.from_dict(data['type'])
|
||||
elif kind == 'vector':
|
||||
type_obj = VectorType.from_dict(data['type'])
|
||||
else:
|
||||
raise ValueError(f"Unknown result type kind: {kind}")
|
||||
|
||||
return cls(
|
||||
stage=Stage(data['stage']),
|
||||
binding=Binding.from_dict(data['binding']),
|
||||
type=type_obj
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EntryPoint:
|
||||
name: str
|
||||
stage: Stage
|
||||
parameters: List[Parameter]
|
||||
result: EntryPointResult
|
||||
bindings: List[EntryPointBinding]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'EntryPoint':
|
||||
return cls(
|
||||
name=data['name'],
|
||||
stage=Stage(data['stage']),
|
||||
parameters=[Parameter.from_dict(p) for p in data['parameters']],
|
||||
result=EntryPointResult.from_dict(data['result']),
|
||||
bindings=[EntryPointBinding.from_dict(b) for b in data['bindings']]
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShaderReflection:
|
||||
parameters: List[Parameter]
|
||||
entryPoints: List[EntryPoint]
|
||||
blob: Optional[bytes] = None # 用于存储编译后的着色器二进制数据
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict[str, Any]) -> 'ShaderReflection':
|
||||
return cls(
|
||||
parameters=[Parameter.from_dict(p) for p in data['parameters']],
|
||||
entryPoints=[EntryPoint.from_dict(e) for e in data['entryPoints']]
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_vertex_layout(cls, reflection: 'ShaderReflection') -> Optional[List[Field]]:
|
||||
"""获取顶点布局字段"""
|
||||
for entry in reflection.entryPoints:
|
||||
if entry.stage == Stage.VERTEX:
|
||||
vertex_fields = []
|
||||
for param in entry.parameters:
|
||||
if param.stage == Stage.VERTEX and param.type.kind == TypeKind.STRUCT:
|
||||
vertex_fields.extend(param.type.fields)
|
||||
return vertex_fields
|
||||
return None
|
||||
|
||||
@dataclass
|
||||
class ShaderStageInfo:
|
||||
parameters: List[Parameter]
|
||||
entryPoint: EntryPoint
|
||||
blob: Optional[bytes] = None # 用于存储编译后的着色器二进制数据
|
||||
|
||||
def __init__(self, parameters: List[Parameter], entryPoint: EntryPoint, blob: Optional[bytes] = None):
|
||||
self.parameters = parameters
|
||||
self.entryPoint = entryPoint
|
||||
self.blob = blob
|
||||
for p in self.parameters:
|
||||
if p.stage is None:
|
||||
p.stage = entryPoint.stage
|
||||
else:
|
||||
assert p.stage == entryPoint.stage, f"Parameter stage {p.stage} does not match entry point stage {entryPoint.stage}"
|
||||
|
||||
def get_entry_name(self):
|
||||
return self.entryPoint.name
|
||||
def get_stage(self):
|
||||
return self.entryPoint.stage
|
||||
|
||||
@dataclass
|
||||
class ShaderInfos:
|
||||
"""着色器信息集合"""
|
||||
stages: dict[Stage, ShaderStageInfo]
|
||||
vertex_layout: List[Field] = field(default_factory=list)
|
||||
vertex_size: int = 0
|
||||
|
||||
def add_shader_info(self, shader_info: ShaderReflection):
|
||||
"""添加单个着色器信息"""
|
||||
vertex_layout = ShaderReflection.get_vertex_layout(shader_info)
|
||||
if vertex_layout is not None:
|
||||
self.vertex_layout = vertex_layout
|
||||
self.vertex_size = sum(field.get_byte_size() for field in vertex_layout)
|
||||
for entry in shader_info.entryPoints:
|
||||
stage = entry.stage
|
||||
self.stages[stage] = ShaderStageInfo(
|
||||
parameters=shader_info.parameters,
|
||||
entryPoint=entry,
|
||||
blob=shader_info.blob
|
||||
)
|
||||
@@ -1,119 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Slang Compiler - Type Definitions
|
||||
数据类型和枚举定义
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict, Mapping
|
||||
|
||||
|
||||
class ShaderStage(Enum):
|
||||
VERTEX = "vertex"
|
||||
FRAGMENT = "fragment"
|
||||
COMPUTE = "compute"
|
||||
|
||||
class ResourceType(Enum):
|
||||
SAMPLED_TEXTURE = "sampled_texture"
|
||||
STORAGE_TEXTURE = "storage_texture"
|
||||
STORAGE_BUFFER = "storage_buffer"
|
||||
UNIFORM_BUFFER = "uniform_buffer"
|
||||
SAMPLER = "sampler"
|
||||
|
||||
class ResourceSubType(Enum):
|
||||
TEXTURE_2D = "SG_IMAGETYPE_2D"
|
||||
TEXTURE_CUBE = "SG_IMAGETYPE_CUBE"
|
||||
TEXTURE_3D = "SG_IMAGETYPE_3D"
|
||||
TEXTURE_ARRAY = "SG_IMAGETYPE_ARRAY"
|
||||
|
||||
class TargetFormat(Enum):
|
||||
SPIRV = "spirv"
|
||||
DXBC = "dxbc"
|
||||
MSL = "msl"
|
||||
HLSL_DX11 = "hlsl"
|
||||
|
||||
# 数据模型类
|
||||
@dataclass
|
||||
class FieldType:
|
||||
"""字段类型信息"""
|
||||
kind: str # 'scalar', 'vector', 'matrix'
|
||||
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
|
||||
element_count: Optional[int] = None # for vector
|
||||
row_count: Optional[int] = None # for matrix
|
||||
column_count: Optional[int] = None # for matrix
|
||||
size: int = 0
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'FieldType':
|
||||
"""从字典创建FieldType对象"""
|
||||
kind = data.get('kind')
|
||||
|
||||
if kind == 'vector':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
element_count=data['elementCount']
|
||||
)
|
||||
elif kind == 'scalar':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['scalarType']
|
||||
)
|
||||
elif kind == 'matrix':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
row_count=data['rowCount'],
|
||||
column_count=data['columnCount']
|
||||
)
|
||||
|
||||
return cls(kind='scalar', scalar_type='float32')
|
||||
|
||||
@dataclass
|
||||
class UniformField:
|
||||
"""Uniform缓冲区字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
@dataclass
|
||||
class UniformBuffer:
|
||||
"""Uniform缓冲区"""
|
||||
name: str
|
||||
binding: int
|
||||
size: int = 0
|
||||
fields: List[UniformField] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class Resource:
|
||||
name: str
|
||||
type: ResourceType
|
||||
sub_type: Optional[ResourceSubType] = None
|
||||
binding_index: int = -1
|
||||
field_type: Optional[FieldType] = None
|
||||
uniform_data: Optional[UniformBuffer] = None
|
||||
|
||||
@dataclass
|
||||
class ShaderInfo:
|
||||
stage: ShaderStage
|
||||
entry_point: str
|
||||
blob: bytes
|
||||
resources: List[Resource]
|
||||
|
||||
@dataclass
|
||||
class VertexField:
|
||||
"""顶点输入字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
location: int
|
||||
semantic: str
|
||||
semantic_index: int
|
||||
offset: int = 0
|
||||
|
||||
@dataclass
|
||||
class ShaderLayout:
|
||||
"""着色器布局数据"""
|
||||
vertex_fields: List[VertexField] = field(default_factory=list)
|
||||
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
||||
Reference in New Issue
Block a user