修改代码,使uniform_buffer的结构生成在绑定阶段
This commit is contained in:
@@ -6,7 +6,6 @@ SDL3_GPU Slang Compiler Package
|
|||||||
from shader_layout_generator import ShaderLayoutGenerator
|
from shader_layout_generator import ShaderLayoutGenerator
|
||||||
from shader_types import ShaderStage, ResourceType, TargetFormat, Resource, ShaderInfo
|
from shader_types import ShaderStage, ResourceType, TargetFormat, Resource, ShaderInfo
|
||||||
from compiler import SDL3GPUSlangCompiler
|
from compiler import SDL3GPUSlangCompiler
|
||||||
from slangc_finder import SlangcFinder
|
|
||||||
from shader_parser import ShaderParser
|
from shader_parser import ShaderParser
|
||||||
from binding_manager import BindingManager
|
from binding_manager import BindingManager
|
||||||
from code_generator import CodeGenerator
|
from code_generator import CodeGenerator
|
||||||
@@ -18,7 +17,6 @@ __all__ = [
|
|||||||
'Resource',
|
'Resource',
|
||||||
'ShaderInfo',
|
'ShaderInfo',
|
||||||
'SDL3GPUSlangCompiler',
|
'SDL3GPUSlangCompiler',
|
||||||
'SlangcFinder',
|
|
||||||
'ShaderParser',
|
'ShaderParser',
|
||||||
'BindingManager',
|
'BindingManager',
|
||||||
'CodeGenerator',
|
'CodeGenerator',
|
||||||
|
|||||||
@@ -3,14 +3,15 @@
|
|||||||
SDL3_GPU Slang Compiler - Code Generator
|
SDL3_GPU Slang Compiler - Code Generator
|
||||||
生成C/C++绑定代码
|
生成C/C++绑定代码
|
||||||
"""
|
"""
|
||||||
import os.path
|
|
||||||
from typing import List, Dict, TextIO, Optional
|
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TextIO
|
||||||
|
|
||||||
|
from code_generator_helper import *
|
||||||
from global_vars import global_vars
|
from global_vars import global_vars
|
||||||
from shader_types import *
|
from shader_types import *
|
||||||
|
|
||||||
|
|
||||||
class IndentManager:
|
class IndentManager:
|
||||||
"""RAII风格的缩进管理器"""
|
"""RAII风格的缩进管理器"""
|
||||||
|
|
||||||
@@ -72,6 +73,40 @@ class CodeGenerator:
|
|||||||
"""生成C++绑定函数的入口方法"""
|
"""生成C++绑定函数的入口方法"""
|
||||||
self._generate_cpp_bindings(binding_infos, output_path)
|
self._generate_cpp_bindings(binding_infos, output_path)
|
||||||
|
|
||||||
|
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
|
||||||
|
"""生成Uniform缓冲区结构体"""
|
||||||
|
writer.write("// Uniform buffer structures")
|
||||||
|
|
||||||
|
for buffer in uniform_buffers:
|
||||||
|
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
|
||||||
|
|
||||||
|
# 计算总大小和对齐要求
|
||||||
|
total_size = 0
|
||||||
|
max_alignment = 16 # GPU通常要求16字节对齐
|
||||||
|
|
||||||
|
with writer.block(f'struct {struct_name}', '};'):
|
||||||
|
for i, field in enumerate(buffer.fields):
|
||||||
|
# 检查是否需要填充
|
||||||
|
if field.offset > total_size:
|
||||||
|
padding_size = field.offset - total_size
|
||||||
|
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
|
||||||
|
total_size = field.offset
|
||||||
|
|
||||||
|
# 生成字段声明
|
||||||
|
declaration = get_c_type_declaration(field.type, field.name)
|
||||||
|
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
|
||||||
|
|
||||||
|
total_size = field.offset + field.size
|
||||||
|
|
||||||
|
# 确保结构体大小正确对齐
|
||||||
|
aligned_size = get_aligned_size(total_size, max_alignment)
|
||||||
|
if aligned_size > total_size:
|
||||||
|
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
|
||||||
|
|
||||||
|
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
|
||||||
|
writer.write()
|
||||||
|
|
||||||
|
|
||||||
def _generate_cpp_bindings(self, binding_infos: List[Dict], output_path: str) -> None:
|
def _generate_cpp_bindings(self, binding_infos: List[Dict], output_path: str) -> None:
|
||||||
"""生成C++绑定函数"""
|
"""生成C++绑定函数"""
|
||||||
output_file = Path(output_path)
|
output_file = Path(output_path)
|
||||||
@@ -86,15 +121,13 @@ class CodeGenerator:
|
|||||||
"""写入完整的文件内容"""
|
"""写入完整的文件内容"""
|
||||||
self._write_header(writer)
|
self._write_header(writer)
|
||||||
|
|
||||||
with writer.block('namespace SDL3GPU', '} // namespace SDL3GPU'):
|
self._write_shader_bindings_class(writer, binding_infos)
|
||||||
writer.write()
|
writer.write()
|
||||||
self._write_shader_bindings_class(writer, binding_infos)
|
# 查找是否有像素着色器或计算着色器的处理
|
||||||
writer.write()
|
if any(info['stage'] == ShaderStage.FRAGMENT.value for info in binding_infos):
|
||||||
# 查找是否有像素着色器或计算着色器的处理
|
self._write_handle_class(writer, 'pixel_shader_handle_t', binding_infos)
|
||||||
if any(info['stage'] == ShaderStage.FRAGMENT.value for info in binding_infos):
|
else:
|
||||||
self._write_handle_class(writer, 'pixel_shader_handle_t', binding_infos)
|
self._write_handle_class(writer, 'compute_shader_handle_t', binding_infos)
|
||||||
else:
|
|
||||||
self._write_handle_class(writer, 'compute_shader_handle_t', binding_infos)
|
|
||||||
|
|
||||||
def _write_header(self, writer: IndentManager) -> None:
|
def _write_header(self, writer: IndentManager) -> None:
|
||||||
"""写入文件头部"""
|
"""写入文件头部"""
|
||||||
@@ -126,6 +159,7 @@ class CodeGenerator:
|
|||||||
def _write_public_section(self, writer: IndentManager, binding_infos: List[Dict]) -> None:
|
def _write_public_section(self, writer: IndentManager, binding_infos: List[Dict]) -> None:
|
||||||
"""写入public部分的结构体定义"""
|
"""写入public部分的结构体定义"""
|
||||||
writer.write('public:')
|
writer.write('public:')
|
||||||
|
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
|
||||||
|
|
||||||
# 写入二进制内容
|
# 写入二进制内容
|
||||||
for info in binding_infos:
|
for info in binding_infos:
|
||||||
|
|||||||
176
code_generator_helper.py
Normal file
176
code_generator_helper.py
Normal file
@@ -0,0 +1,176 @@
|
|||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
from shader_types import FieldType
|
||||||
|
|
||||||
|
# SDL GPU 格式映射
|
||||||
|
format_mapping = {
|
||||||
|
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
|
||||||
|
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
|
||||||
|
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
|
||||||
|
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
|
||||||
|
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
|
||||||
|
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
|
||||||
|
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
|
||||||
|
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
|
||||||
|
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
|
||||||
|
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
|
||||||
|
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
|
||||||
|
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
|
||||||
|
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
|
||||||
|
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
|
||||||
|
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
|
||||||
|
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
|
||||||
|
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
|
||||||
|
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
|
||||||
|
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
|
||||||
|
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
|
||||||
|
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
|
||||||
|
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
|
||||||
|
}
|
||||||
|
|
||||||
|
# 类型大小映射
|
||||||
|
type_sizes = {
|
||||||
|
'int32': 4,
|
||||||
|
'uint32': 4,
|
||||||
|
'float32': 4,
|
||||||
|
'int8': 1,
|
||||||
|
'uint8': 1,
|
||||||
|
'int16': 2,
|
||||||
|
'uint16': 2,
|
||||||
|
'float16': 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
def get_type_info(field_type: FieldType) -> Tuple[str, int, int]:
|
||||||
|
"""获取类型信息:C类型名称、元素数量、字节大小
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
||||||
|
"""
|
||||||
|
# C类型映射表
|
||||||
|
c_type_mapping = {
|
||||||
|
'int32': 'int32_t',
|
||||||
|
'uint32': 'uint32_t',
|
||||||
|
'float32': 'float',
|
||||||
|
'int8': 'int8_t',
|
||||||
|
'uint8': 'uint8_t',
|
||||||
|
'int16': 'int16_t',
|
||||||
|
'uint16': 'uint16_t',
|
||||||
|
'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||||||
|
}
|
||||||
|
|
||||||
|
# 对齐要求映射表
|
||||||
|
alignment_mapping = {
|
||||||
|
'int32': 4,
|
||||||
|
'uint32': 4,
|
||||||
|
'float32': 4,
|
||||||
|
'int8': 1,
|
||||||
|
'uint8': 1,
|
||||||
|
'int16': 2,
|
||||||
|
'uint16': 2,
|
||||||
|
'float16': 2,
|
||||||
|
}
|
||||||
|
|
||||||
|
if field_type.kind == 'scalar':
|
||||||
|
scalar_type = field_type.scalar_type or 'float32'
|
||||||
|
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||||
|
size = type_sizes.get(scalar_type, 4)
|
||||||
|
return c_type, 1, size
|
||||||
|
|
||||||
|
elif field_type.kind == 'vector':
|
||||||
|
scalar_type = field_type.scalar_type or 'float32'
|
||||||
|
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||||
|
element_count = field_type.element_count or 4
|
||||||
|
|
||||||
|
# 验证向量元素数量的合法性
|
||||||
|
if element_count not in [1, 2, 3, 4]:
|
||||||
|
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
||||||
|
|
||||||
|
size = type_sizes.get(scalar_type, 4) * element_count
|
||||||
|
|
||||||
|
# 对于vec3,通常需要按vec4对齐
|
||||||
|
if element_count == 3:
|
||||||
|
alignment = alignment_mapping.get(scalar_type, 4) * 4
|
||||||
|
size = ((size + alignment - 1) // alignment) * alignment
|
||||||
|
|
||||||
|
return c_type, element_count, size
|
||||||
|
|
||||||
|
elif field_type.kind == 'matrix':
|
||||||
|
scalar_type = field_type.scalar_type or 'float32'
|
||||||
|
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||||
|
|
||||||
|
row_count = field_type.row_count or 4
|
||||||
|
column_count = field_type.column_count or 4
|
||||||
|
|
||||||
|
# 验证矩阵维度的合法性
|
||||||
|
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
||||||
|
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
|
||||||
|
f"Rows and columns must be 2, 3, or 4.")
|
||||||
|
|
||||||
|
# 矩阵通常按列主序存储,每列都需要对齐
|
||||||
|
element_size = type_sizes.get(scalar_type, 4)
|
||||||
|
|
||||||
|
# 每列的大小(考虑对齐)
|
||||||
|
if row_count == 3:
|
||||||
|
# mat3xN 的每列通常按 vec4 对齐
|
||||||
|
column_size = element_size * 4
|
||||||
|
else:
|
||||||
|
column_size = element_size * row_count
|
||||||
|
|
||||||
|
total_size = column_size * column_count
|
||||||
|
total_elements = row_count * column_count
|
||||||
|
|
||||||
|
return c_type, total_elements, total_size
|
||||||
|
|
||||||
|
else:
|
||||||
|
# 未知类型,返回默认值并记录警告
|
||||||
|
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
||||||
|
return 'float', 1, 4
|
||||||
|
|
||||||
|
|
||||||
|
def get_c_type_declaration(field_type: FieldType, field_name: str) -> str:
|
||||||
|
"""生成C类型声明
|
||||||
|
|
||||||
|
Args:
|
||||||
|
field_type: 字段类型信息
|
||||||
|
field_name: 字段名称
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: 完整的C类型声明
|
||||||
|
"""
|
||||||
|
c_type, count, _ = get_type_info(field_type)
|
||||||
|
|
||||||
|
if field_type.kind == 'scalar':
|
||||||
|
return f"{c_type} {field_name}"
|
||||||
|
|
||||||
|
elif field_type.kind == 'vector':
|
||||||
|
# 对于向量,可以选择使用数组或专门的向量类型
|
||||||
|
if count == 1:
|
||||||
|
return f"{c_type} {field_name}"
|
||||||
|
else:
|
||||||
|
# 选项1: 使用数组
|
||||||
|
return f"{c_type} {field_name}[{count}]"
|
||||||
|
|
||||||
|
# 选项2: 使用对齐的结构体(如果需要的话)
|
||||||
|
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
||||||
|
|
||||||
|
elif field_type.kind == 'matrix':
|
||||||
|
rows = field_type.row_count or 4
|
||||||
|
cols = field_type.column_count or 4
|
||||||
|
|
||||||
|
# 矩阵使用二维数组(列主序)
|
||||||
|
return f"{c_type} {field_name}[{cols}][{rows}]"
|
||||||
|
|
||||||
|
else:
|
||||||
|
return f"{c_type} {field_name}"
|
||||||
|
|
||||||
|
def get_aligned_size(size: int, alignment: int) -> int:
|
||||||
|
"""计算对齐后的大小
|
||||||
|
|
||||||
|
Args:
|
||||||
|
size: 原始大小
|
||||||
|
alignment: 对齐要求
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: 对齐后的大小
|
||||||
|
"""
|
||||||
|
return ((size + alignment - 1) // alignment) * alignment
|
||||||
@@ -42,7 +42,7 @@ class SDL3GPUSlangCompiler:
|
|||||||
# 生成带绑定信息的着色器代码
|
# 生成带绑定信息的着色器代码
|
||||||
modified_source = self.binding_manager.inject_bindings(shader_info, target)
|
modified_source = self.binding_manager.inject_bindings(shader_info, target)
|
||||||
|
|
||||||
temp_source = os.path.join(global_vars.source_path, global_vars.source_file_name + '.slang.temp')
|
temp_source = os.path.join(global_vars.source_path, global_vars.source_file_name + '.slang')
|
||||||
|
|
||||||
# 写入临时文件
|
# 写入临时文件
|
||||||
with open(temp_source, 'w', encoding='utf8') as f:
|
with open(temp_source, 'w', encoding='utf8') as f:
|
||||||
|
|||||||
@@ -2,13 +2,12 @@
|
|||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
from code_generator import IndentManager
|
from code_generator import IndentManager
|
||||||
|
from code_generator_helper import *
|
||||||
from global_vars import global_vars
|
from global_vars import global_vars
|
||||||
from shader_types import *
|
from shader_types import *
|
||||||
|
|
||||||
|
|
||||||
# 数据解析器(保持不变)
|
# 数据解析器(保持不变)
|
||||||
class ShaderLayoutParser:
|
class ShaderLayoutParser:
|
||||||
"""解析JSON数据并提取到类对象"""
|
"""解析JSON数据并提取到类对象"""
|
||||||
@@ -77,50 +76,10 @@ class ShaderLayoutParser:
|
|||||||
|
|
||||||
return uniform_buffers
|
return uniform_buffers
|
||||||
|
|
||||||
|
|
||||||
# **重构后的代码生成器**
|
# **重构后的代码生成器**
|
||||||
class ShaderLayoutCodeGenerator:
|
class ShaderLayoutCodeGenerator:
|
||||||
"""从ShaderLayout对象生成C代码"""
|
"""从ShaderLayout对象生成C代码"""
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
# SDL GPU 格式映射
|
|
||||||
self.format_mapping = {
|
|
||||||
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
|
|
||||||
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
|
|
||||||
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
|
|
||||||
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
|
|
||||||
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
|
|
||||||
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
|
|
||||||
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
|
|
||||||
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
|
|
||||||
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
|
|
||||||
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
|
|
||||||
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
|
|
||||||
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
|
|
||||||
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
|
|
||||||
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
|
|
||||||
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
|
|
||||||
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
|
|
||||||
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
|
|
||||||
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
|
|
||||||
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
|
|
||||||
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
|
|
||||||
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
|
|
||||||
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
|
|
||||||
}
|
|
||||||
|
|
||||||
# 类型大小映射
|
|
||||||
self.type_sizes = {
|
|
||||||
'int32': 4,
|
|
||||||
'uint32': 4,
|
|
||||||
'float32': 4,
|
|
||||||
'int8': 1,
|
|
||||||
'uint8': 1,
|
|
||||||
'int16': 2,
|
|
||||||
'uint16': 2,
|
|
||||||
'float16': 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
def generate(self) -> str:
|
def generate(self) -> str:
|
||||||
"""生成完整的头文件内容"""
|
"""生成完整的头文件内容"""
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
@@ -136,156 +95,21 @@ class ShaderLayoutCodeGenerator:
|
|||||||
writer.write("// Auto-generated vertex structure")
|
writer.write("// Auto-generated vertex structure")
|
||||||
with writer.block(f'namespace {global_vars.source_file_name}_shader'):
|
with writer.block(f'namespace {global_vars.source_file_name}_shader'):
|
||||||
self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
|
self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
|
||||||
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
|
|
||||||
self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
|
self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
|
||||||
self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
|
self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
|
||||||
|
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
def _get_type_info(self, field_type: FieldType) -> Tuple[str, int, int]:
|
|
||||||
"""获取类型信息:C类型名称、元素数量、字节大小
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
|
||||||
"""
|
|
||||||
# C类型映射表
|
|
||||||
c_type_mapping = {
|
|
||||||
'int32': 'int32_t',
|
|
||||||
'uint32': 'uint32_t',
|
|
||||||
'float32': 'float',
|
|
||||||
'int8': 'int8_t',
|
|
||||||
'uint8': 'uint8_t',
|
|
||||||
'int16': 'int16_t',
|
|
||||||
'uint16': 'uint16_t',
|
|
||||||
'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
|
||||||
}
|
|
||||||
|
|
||||||
# 对齐要求映射表
|
|
||||||
alignment_mapping = {
|
|
||||||
'int32': 4,
|
|
||||||
'uint32': 4,
|
|
||||||
'float32': 4,
|
|
||||||
'int8': 1,
|
|
||||||
'uint8': 1,
|
|
||||||
'int16': 2,
|
|
||||||
'uint16': 2,
|
|
||||||
'float16': 2,
|
|
||||||
}
|
|
||||||
|
|
||||||
if field_type.kind == 'scalar':
|
|
||||||
scalar_type = field_type.scalar_type or 'float32'
|
|
||||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
|
||||||
size = self.type_sizes.get(scalar_type, 4)
|
|
||||||
return c_type, 1, size
|
|
||||||
|
|
||||||
elif field_type.kind == 'vector':
|
|
||||||
scalar_type = field_type.scalar_type or 'float32'
|
|
||||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
|
||||||
element_count = field_type.element_count or 4
|
|
||||||
|
|
||||||
# 验证向量元素数量的合法性
|
|
||||||
if element_count not in [1, 2, 3, 4]:
|
|
||||||
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
|
||||||
|
|
||||||
size = self.type_sizes.get(scalar_type, 4) * element_count
|
|
||||||
|
|
||||||
# 对于vec3,通常需要按vec4对齐
|
|
||||||
if element_count == 3:
|
|
||||||
alignment = alignment_mapping.get(scalar_type, 4) * 4
|
|
||||||
size = ((size + alignment - 1) // alignment) * alignment
|
|
||||||
|
|
||||||
return c_type, element_count, size
|
|
||||||
|
|
||||||
elif field_type.kind == 'matrix':
|
|
||||||
scalar_type = field_type.scalar_type or 'float32'
|
|
||||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
|
||||||
|
|
||||||
row_count = field_type.row_count or 4
|
|
||||||
column_count = field_type.column_count or 4
|
|
||||||
|
|
||||||
# 验证矩阵维度的合法性
|
|
||||||
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
|
||||||
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
|
|
||||||
f"Rows and columns must be 2, 3, or 4.")
|
|
||||||
|
|
||||||
# 矩阵通常按列主序存储,每列都需要对齐
|
|
||||||
element_size = self.type_sizes.get(scalar_type, 4)
|
|
||||||
|
|
||||||
# 每列的大小(考虑对齐)
|
|
||||||
if row_count == 3:
|
|
||||||
# mat3xN 的每列通常按 vec4 对齐
|
|
||||||
column_size = element_size * 4
|
|
||||||
else:
|
|
||||||
column_size = element_size * row_count
|
|
||||||
|
|
||||||
total_size = column_size * column_count
|
|
||||||
total_elements = row_count * column_count
|
|
||||||
|
|
||||||
return c_type, total_elements, total_size
|
|
||||||
|
|
||||||
else:
|
|
||||||
# 未知类型,返回默认值并记录警告
|
|
||||||
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
|
||||||
return 'float', 1, 4
|
|
||||||
|
|
||||||
def _get_c_type_declaration(self, field_type: FieldType, field_name: str) -> str:
|
|
||||||
"""生成C类型声明
|
|
||||||
|
|
||||||
Args:
|
|
||||||
field_type: 字段类型信息
|
|
||||||
field_name: 字段名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
str: 完整的C类型声明
|
|
||||||
"""
|
|
||||||
c_type, count, _ = self._get_type_info(field_type)
|
|
||||||
|
|
||||||
if field_type.kind == 'scalar':
|
|
||||||
return f"{c_type} {field_name}"
|
|
||||||
|
|
||||||
elif field_type.kind == 'vector':
|
|
||||||
# 对于向量,可以选择使用数组或专门的向量类型
|
|
||||||
if count == 1:
|
|
||||||
return f"{c_type} {field_name}"
|
|
||||||
else:
|
|
||||||
# 选项1: 使用数组
|
|
||||||
return f"{c_type} {field_name}[{count}]"
|
|
||||||
|
|
||||||
# 选项2: 使用对齐的结构体(如果需要的话)
|
|
||||||
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
|
||||||
|
|
||||||
elif field_type.kind == 'matrix':
|
|
||||||
rows = field_type.row_count or 4
|
|
||||||
cols = field_type.column_count or 4
|
|
||||||
|
|
||||||
# 矩阵使用二维数组(列主序)
|
|
||||||
return f"{c_type} {field_name}[{cols}][{rows}]"
|
|
||||||
|
|
||||||
else:
|
|
||||||
return f"{c_type} {field_name}"
|
|
||||||
|
|
||||||
def _get_aligned_size(self, size: int, alignment: int) -> int:
|
|
||||||
"""计算对齐后的大小
|
|
||||||
|
|
||||||
Args:
|
|
||||||
size: 原始大小
|
|
||||||
alignment: 对齐要求
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: 对齐后的大小
|
|
||||||
"""
|
|
||||||
return ((size + alignment - 1) // alignment) * alignment
|
|
||||||
|
|
||||||
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||||
"""生成顶点结构体"""
|
"""生成顶点结构体"""
|
||||||
with writer.block("struct vertex_t", "};"):
|
with writer.block("struct vertex_t", "};"):
|
||||||
total_offset = 0
|
total_offset = 0
|
||||||
for field in vertex_fields:
|
for field in vertex_fields:
|
||||||
# 生成字段声明
|
# 生成字段声明
|
||||||
declaration = self._get_c_type_declaration(field.type, field.name)
|
declaration = get_c_type_declaration(field.type, field.name)
|
||||||
|
|
||||||
# 获取字段信息
|
# 获取字段信息
|
||||||
c_type, count, size = self._get_type_info(field.type)
|
c_type, count, size = get_type_info(field.type)
|
||||||
|
|
||||||
# 添加详细注释
|
# 添加详细注释
|
||||||
semantic = field.semantic
|
semantic = field.semantic
|
||||||
@@ -306,39 +130,6 @@ class ShaderLayoutCodeGenerator:
|
|||||||
writer.write(f"static_assert(sizeof(vertex_t) == {total_offset}, \"Vertex struct size mismatch\");")
|
writer.write(f"static_assert(sizeof(vertex_t) == {total_offset}, \"Vertex struct size mismatch\");")
|
||||||
writer.write()
|
writer.write()
|
||||||
|
|
||||||
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
|
|
||||||
"""生成Uniform缓冲区结构体"""
|
|
||||||
writer.write("// Uniform buffer structures")
|
|
||||||
|
|
||||||
for buffer in uniform_buffers:
|
|
||||||
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
|
|
||||||
|
|
||||||
# 计算总大小和对齐要求
|
|
||||||
total_size = 0
|
|
||||||
max_alignment = 16 # GPU通常要求16字节对齐
|
|
||||||
|
|
||||||
with writer.block(f"typedef struct {struct_name}", f"}} {struct_name};"):
|
|
||||||
for i, field in enumerate(buffer.fields):
|
|
||||||
# 检查是否需要填充
|
|
||||||
if field.offset > total_size:
|
|
||||||
padding_size = field.offset - total_size
|
|
||||||
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
|
|
||||||
total_size = field.offset
|
|
||||||
|
|
||||||
# 生成字段声明
|
|
||||||
declaration = self._get_c_type_declaration(field.type, field.name)
|
|
||||||
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
|
|
||||||
|
|
||||||
total_size = field.offset + field.size
|
|
||||||
|
|
||||||
# 确保结构体大小正确对齐
|
|
||||||
aligned_size = self._get_aligned_size(total_size, max_alignment)
|
|
||||||
if aligned_size > total_size:
|
|
||||||
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
|
|
||||||
|
|
||||||
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
|
|
||||||
writer.write()
|
|
||||||
|
|
||||||
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||||
"""生成顶点属性数组"""
|
"""生成顶点属性数组"""
|
||||||
writer.write("// Vertex attribute descriptions")
|
writer.write("// Vertex attribute descriptions")
|
||||||
@@ -350,11 +141,11 @@ class ShaderLayoutCodeGenerator:
|
|||||||
offset = 0
|
offset = 0
|
||||||
with writer.indent():
|
with writer.indent():
|
||||||
for i, field in enumerate(vertex_fields):
|
for i, field in enumerate(vertex_fields):
|
||||||
c_type, count, size = self._get_type_info(field.type)
|
c_type, count, size = get_type_info(field.type)
|
||||||
scalar_type = field.type.scalar_type or 'float32'
|
scalar_type = field.type.scalar_type or 'float32'
|
||||||
|
|
||||||
format_key = (scalar_type, count)
|
format_key = (scalar_type, count)
|
||||||
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
|
format_name = format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
|
||||||
|
|
||||||
with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"):
|
with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"):
|
||||||
writer.write(f"// {field.name}")
|
writer.write(f"// {field.name}")
|
||||||
|
|||||||
@@ -178,10 +178,11 @@ class ShaderParser:
|
|||||||
print(f"Stderr: {result.stderr}")
|
print(f"Stderr: {result.stderr}")
|
||||||
|
|
||||||
if result.returncode == 0:
|
if result.returncode == 0:
|
||||||
|
generator = ShaderLayoutGenerator()
|
||||||
|
header_content = generator.generate_header(reflection_path)
|
||||||
|
|
||||||
if stage == ShaderStage.VERTEX:
|
if stage == ShaderStage.VERTEX:
|
||||||
# 生成顶点布局代码
|
# 生成顶点布局代码
|
||||||
generator = ShaderLayoutGenerator()
|
|
||||||
header_content = generator.generate_header(reflection_path)
|
|
||||||
layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
|
layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
|
||||||
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
||||||
out_file.write(header_content)
|
out_file.write(header_content)
|
||||||
|
|||||||
Reference in New Issue
Block a user