优化代码可读性
This commit is contained in:
@@ -1,62 +1,105 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import sys
|
||||
from typing import Dict, List, Any, Tuple
|
||||
from typing import Dict, List, Any, Tuple, Optional, TextIO
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
|
||||
class ShaderLayoutGenerator:
|
||||
def __init__(self):
|
||||
# SDL GPU 格式映射
|
||||
self.format_mapping = {
|
||||
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
|
||||
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
|
||||
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
|
||||
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
|
||||
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
|
||||
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
|
||||
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
|
||||
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
|
||||
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
|
||||
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
|
||||
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
|
||||
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
|
||||
}
|
||||
from code_generator import IndentManager
|
||||
|
||||
# 类型大小映射
|
||||
self.type_sizes = {
|
||||
'float32': 4,
|
||||
'int32': 4,
|
||||
'uint32': 4,
|
||||
}
|
||||
|
||||
def get_type_info(self, type_obj: Dict) -> Tuple[str, int, int]:
|
||||
"""获取类型信息:C类型名称、元素数量、字节大小"""
|
||||
kind = type_obj.get('kind')
|
||||
# 数据模型类
|
||||
@dataclass
|
||||
class FieldType:
|
||||
"""字段类型信息"""
|
||||
kind: str # 'scalar', 'vector', 'matrix'
|
||||
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
|
||||
element_count: Optional[int] = None # for vector
|
||||
row_count: Optional[int] = None # for matrix
|
||||
column_count: Optional[int] = None # for matrix
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'FieldType':
|
||||
"""从字典创建FieldType对象"""
|
||||
kind = data.get('kind')
|
||||
|
||||
if kind == 'vector':
|
||||
element_type = type_obj['elementType']['scalarType']
|
||||
element_count = type_obj['elementCount']
|
||||
c_type = 'float' if element_type == 'float32' else element_type
|
||||
size = self.type_sizes[element_type] * element_count
|
||||
return c_type, element_count, size
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
element_count=data['elementCount']
|
||||
)
|
||||
elif kind == 'scalar':
|
||||
scalar_type = type_obj['scalarType']
|
||||
c_type = 'float' if scalar_type == 'float32' else scalar_type
|
||||
size = self.type_sizes[scalar_type]
|
||||
return c_type, 1, size
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['scalarType']
|
||||
)
|
||||
elif kind == 'matrix':
|
||||
rows = type_obj['rowCount']
|
||||
cols = type_obj['columnCount']
|
||||
element_type = type_obj['elementType']['scalarType']
|
||||
size = self.type_sizes[element_type] * rows * cols
|
||||
return 'float', rows * cols, size
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
row_count=data['rowCount'],
|
||||
column_count=data['columnCount']
|
||||
)
|
||||
|
||||
return 'float', 1, 4
|
||||
return cls(kind='scalar', scalar_type='float32')
|
||||
|
||||
def parse_vertex_input(self, json_data: Dict) -> List[Dict]:
|
||||
|
||||
@dataclass
|
||||
class VertexField:
|
||||
"""顶点输入字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
location: int
|
||||
semantic: str
|
||||
semantic_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniformField:
|
||||
"""Uniform缓冲区字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniformBuffer:
|
||||
"""Uniform缓冲区"""
|
||||
name: str
|
||||
binding: int
|
||||
fields: List[UniformField] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShaderLayout:
|
||||
"""着色器布局数据"""
|
||||
vertex_fields: List[VertexField] = field(default_factory=list)
|
||||
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
||||
|
||||
|
||||
# 数据解析器(保持不变)
|
||||
class ShaderLayoutParser:
|
||||
"""解析JSON数据并提取到类对象"""
|
||||
|
||||
def parse(self, json_file: str) -> ShaderLayout:
|
||||
"""解析JSON文件并返回ShaderLayout对象"""
|
||||
with open(json_file, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
|
||||
layout = ShaderLayout()
|
||||
layout.vertex_fields = self._parse_vertex_input(json_data)
|
||||
layout.uniform_buffers = self._parse_uniform_buffers(json_data)
|
||||
|
||||
return layout
|
||||
|
||||
def _parse_vertex_input(self, json_data: Dict) -> List[VertexField]:
|
||||
"""解析顶点输入字段"""
|
||||
vertex_fields = []
|
||||
|
||||
# 查找顶点输入参数
|
||||
entry_points = json_data.get('entryPoints', [])
|
||||
for entry in entry_points:
|
||||
if entry.get('stage') == 'vertex':
|
||||
@@ -65,18 +108,18 @@ class ShaderLayoutGenerator:
|
||||
if param.get('name') == 'input' and param.get('stage') == 'vertex':
|
||||
fields = param.get('type', {}).get('fields', [])
|
||||
for field in fields:
|
||||
field_info = {
|
||||
'name': field['name'],
|
||||
'type': field['type'],
|
||||
'location': field['binding']['index'],
|
||||
'semantic': field.get('semanticName', ''),
|
||||
'semantic_index': field.get('semanticIndex', 0)
|
||||
}
|
||||
vertex_fields.append(field_info)
|
||||
vertex_field = VertexField(
|
||||
name=field['name'],
|
||||
type=FieldType.from_dict(field['type']),
|
||||
location=field['binding']['index'],
|
||||
semantic=field.get('semanticName', ''),
|
||||
semantic_index=field.get('semanticIndex', 0)
|
||||
)
|
||||
vertex_fields.append(vertex_field)
|
||||
|
||||
return vertex_fields
|
||||
|
||||
def parse_uniform_buffers(self, json_data: Dict) -> List[Dict]:
|
||||
def _parse_uniform_buffers(self, json_data: Dict) -> List[UniformBuffer]:
|
||||
"""解析Uniform缓冲区"""
|
||||
uniform_buffers = []
|
||||
|
||||
@@ -84,182 +127,412 @@ class ShaderLayoutGenerator:
|
||||
for param in parameters:
|
||||
binding = param.get('binding', {})
|
||||
if binding.get('kind') == 'constantBuffer':
|
||||
buffer_info = {
|
||||
'name': param['name'],
|
||||
'binding': binding['index'],
|
||||
'fields': []
|
||||
}
|
||||
buffer = UniformBuffer(
|
||||
name=param['name'],
|
||||
binding=binding['index']
|
||||
)
|
||||
|
||||
# 解析缓冲区字段
|
||||
element_type = param.get('type', {}).get('elementType', {})
|
||||
if element_type.get('kind') == 'struct':
|
||||
fields = element_type.get('fields', [])
|
||||
for field in fields:
|
||||
field_info = {
|
||||
'name': field['name'],
|
||||
'type': field['type'],
|
||||
'offset': field['binding']['offset'],
|
||||
'size': field['binding']['size']
|
||||
}
|
||||
buffer_info['fields'].append(field_info)
|
||||
uniform_field = UniformField(
|
||||
name=field['name'],
|
||||
type=FieldType.from_dict(field['type']),
|
||||
offset=field['binding']['offset'],
|
||||
size=field['binding']['size']
|
||||
)
|
||||
buffer.fields.append(uniform_field)
|
||||
|
||||
uniform_buffers.append(buffer_info)
|
||||
uniform_buffers.append(buffer)
|
||||
|
||||
return uniform_buffers
|
||||
|
||||
def generate_vertex_struct(self, vertex_fields: List[Dict]) -> str:
|
||||
"""生成顶点结构体"""
|
||||
code = "// Auto-generated vertex structure\n"
|
||||
code += "typedef struct Vertex {\n"
|
||||
|
||||
for field in vertex_fields:
|
||||
c_type, count, _ = self.get_type_info(field['type'])
|
||||
if count == 1:
|
||||
code += f" {c_type} {field['name']};\n"
|
||||
# **重构后的代码生成器**
|
||||
class ShaderLayoutCodeGenerator:
|
||||
"""从ShaderLayout对象生成C代码"""
|
||||
|
||||
def __init__(self):
|
||||
# SDL GPU 格式映射
|
||||
self.format_mapping = {
|
||||
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
|
||||
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
|
||||
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
|
||||
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
|
||||
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
|
||||
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
|
||||
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
|
||||
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
|
||||
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
|
||||
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
|
||||
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
|
||||
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
|
||||
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
|
||||
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
|
||||
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
|
||||
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
|
||||
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
|
||||
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
|
||||
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
|
||||
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
|
||||
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
|
||||
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
|
||||
}
|
||||
|
||||
# 类型大小映射
|
||||
self.type_sizes = {
|
||||
'int32': 4,
|
||||
'uint32': 4,
|
||||
'float32': 4,
|
||||
'int8': 1,
|
||||
'uint8': 1,
|
||||
'int16': 2,
|
||||
'uint16': 2,
|
||||
'float16': 2,
|
||||
}
|
||||
|
||||
def generate(self, layout: ShaderLayout, source_file: str) -> str:
|
||||
"""生成完整的头文件内容"""
|
||||
output = io.StringIO()
|
||||
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
|
||||
|
||||
writer.write("#pragma once")
|
||||
writer.write()
|
||||
writer.write("#include <SDL3/SDL.h>")
|
||||
writer.write("#include <SDL3/SDL_gpu.h>")
|
||||
writer.write()
|
||||
writer.write(f"// Auto-generated from: {source_file}")
|
||||
writer.write()
|
||||
|
||||
self._generate_vertex_struct(writer, layout.vertex_fields)
|
||||
self._generate_uniform_structs(writer, layout.uniform_buffers)
|
||||
self._generate_vertex_attributes(writer, layout.vertex_fields)
|
||||
self._generate_helper_functions(writer, layout.vertex_fields)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
def _get_type_info(self, field_type: FieldType) -> Tuple[str, int, int]:
|
||||
"""获取类型信息:C类型名称、元素数量、字节大小
|
||||
|
||||
Returns:
|
||||
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
|
||||
"""
|
||||
# C类型映射表
|
||||
c_type_mapping = {
|
||||
'int32': 'int32_t',
|
||||
'uint32': 'uint32_t',
|
||||
'float32': 'float',
|
||||
'int8': 'int8_t',
|
||||
'uint8': 'uint8_t',
|
||||
'int16': 'int16_t',
|
||||
'uint16': 'uint16_t',
|
||||
'float16': 'mirage_float16', # 注意:C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
|
||||
}
|
||||
|
||||
# 对齐要求映射表
|
||||
alignment_mapping = {
|
||||
'int32': 4,
|
||||
'uint32': 4,
|
||||
'float32': 4,
|
||||
'int8': 1,
|
||||
'uint8': 1,
|
||||
'int16': 2,
|
||||
'uint16': 2,
|
||||
'float16': 2,
|
||||
}
|
||||
|
||||
if field_type.kind == 'scalar':
|
||||
scalar_type = field_type.scalar_type or 'float32'
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
size = self.type_sizes.get(scalar_type, 4)
|
||||
return c_type, 1, size
|
||||
|
||||
elif field_type.kind == 'vector':
|
||||
scalar_type = field_type.scalar_type or 'float32'
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
element_count = field_type.element_count or 4
|
||||
|
||||
# 验证向量元素数量的合法性
|
||||
if element_count not in [1, 2, 3, 4]:
|
||||
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
|
||||
|
||||
size = self.type_sizes.get(scalar_type, 4) * element_count
|
||||
|
||||
# 对于vec3,通常需要按vec4对齐
|
||||
if element_count == 3:
|
||||
alignment = alignment_mapping.get(scalar_type, 4) * 4
|
||||
size = ((size + alignment - 1) // alignment) * alignment
|
||||
|
||||
return c_type, element_count, size
|
||||
|
||||
elif field_type.kind == 'matrix':
|
||||
scalar_type = field_type.scalar_type or 'float32'
|
||||
c_type = c_type_mapping.get(scalar_type, 'float')
|
||||
|
||||
row_count = field_type.row_count or 4
|
||||
column_count = field_type.column_count or 4
|
||||
|
||||
# 验证矩阵维度的合法性
|
||||
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
|
||||
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
|
||||
f"Rows and columns must be 2, 3, or 4.")
|
||||
|
||||
# 矩阵通常按列主序存储,每列都需要对齐
|
||||
element_size = self.type_sizes.get(scalar_type, 4)
|
||||
|
||||
# 每列的大小(考虑对齐)
|
||||
if row_count == 3:
|
||||
# mat3xN 的每列通常按 vec4 对齐
|
||||
column_size = element_size * 4
|
||||
else:
|
||||
code += f" {c_type} {field['name']}[{count}];\n"
|
||||
column_size = element_size * row_count
|
||||
|
||||
# 添加注释
|
||||
semantic = field['semantic']
|
||||
if field['semantic_index'] > 0:
|
||||
semantic += str(field['semantic_index'])
|
||||
code += f" // location: {field['location']}, semantic: {semantic}\n"
|
||||
total_size = column_size * column_count
|
||||
total_elements = row_count * column_count
|
||||
|
||||
code += "} Vertex;\n\n"
|
||||
return code
|
||||
return c_type, total_elements, total_size
|
||||
|
||||
def generate_vertex_attributes(self, vertex_fields: List[Dict]) -> str:
|
||||
"""生成顶点属性数组"""
|
||||
code = "// Vertex attribute descriptions\n"
|
||||
code += f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}\n\n"
|
||||
code += "static const SDL_GPUVertexAttribute g_vertexAttributes[] = {\n"
|
||||
else:
|
||||
# 未知类型,返回默认值并记录警告
|
||||
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
|
||||
return 'float', 1, 4
|
||||
|
||||
offset = 0
|
||||
for i, field in enumerate(vertex_fields):
|
||||
c_type, count, size = self.get_type_info(field['type'])
|
||||
scalar_type = field['type'].get('elementType', field['type']).get('scalarType', 'float32')
|
||||
def _get_c_type_declaration(self, field_type: FieldType, field_name: str) -> str:
|
||||
"""生成C类型声明
|
||||
|
||||
format_key = (scalar_type, count)
|
||||
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
|
||||
Args:
|
||||
field_type: 字段类型信息
|
||||
field_name: 字段名称
|
||||
|
||||
code += " {\n"
|
||||
code += f" .location = {field['location']},\n"
|
||||
code += f" .buffer_slot = 0,\n"
|
||||
code += f" .format = {format_name},\n"
|
||||
code += f" .offset = {offset}\n"
|
||||
code += " }"
|
||||
Returns:
|
||||
str: 完整的C类型声明
|
||||
"""
|
||||
c_type, count, _ = self._get_type_info(field_type)
|
||||
|
||||
if i < len(vertex_fields) - 1:
|
||||
code += ","
|
||||
if field_type.kind == 'scalar':
|
||||
return f"{c_type} {field_name}"
|
||||
|
||||
code += f" // {field['name']}\n"
|
||||
elif field_type.kind == 'vector':
|
||||
# 对于向量,可以选择使用数组或专门的向量类型
|
||||
if count == 1:
|
||||
return f"{c_type} {field_name}"
|
||||
else:
|
||||
# 选项1: 使用数组
|
||||
return f"{c_type} {field_name}[{count}]"
|
||||
|
||||
offset += size
|
||||
# 选项2: 使用对齐的结构体(如果需要的话)
|
||||
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
|
||||
|
||||
code += "};\n\n"
|
||||
elif field_type.kind == 'matrix':
|
||||
rows = field_type.row_count or 4
|
||||
cols = field_type.column_count or 4
|
||||
|
||||
# 生成顶点缓冲区描述
|
||||
code += "// Vertex buffer description\n"
|
||||
code += "static const SDL_GPUVertexBufferDescription g_vertexBufferDesc = {\n"
|
||||
code += " .slot = 0,\n"
|
||||
code += f" .pitch = {offset}, // sizeof(Vertex)\n"
|
||||
code += " .input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,\n"
|
||||
code += " .instance_step_rate = 0\n"
|
||||
code += "};\n\n"
|
||||
# 矩阵使用二维数组(列主序)
|
||||
return f"{c_type} {field_name}[{cols}][{rows}]"
|
||||
|
||||
# 生成顶点输入状态
|
||||
code += "// Vertex input state\n"
|
||||
code += "static const SDL_GPUVertexInputState g_vertexInputState = {\n"
|
||||
code += " .vertex_buffer_descriptions = &g_vertexBufferDesc,\n"
|
||||
code += " .num_vertex_buffers = 1,\n"
|
||||
code += " .vertex_attributes = g_vertexAttributes,\n"
|
||||
code += " .num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT\n"
|
||||
code += "};\n\n"
|
||||
else:
|
||||
return f"{c_type} {field_name}"
|
||||
|
||||
return code
|
||||
def _get_aligned_size(self, size: int, alignment: int) -> int:
|
||||
"""计算对齐后的大小
|
||||
|
||||
def generate_uniform_structs(self, uniform_buffers: List[Dict]) -> str:
|
||||
Args:
|
||||
size: 原始大小
|
||||
alignment: 对齐要求
|
||||
|
||||
Returns:
|
||||
int: 对齐后的大小
|
||||
"""
|
||||
return ((size + alignment - 1) // alignment) * alignment
|
||||
|
||||
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||
"""生成顶点结构体"""
|
||||
writer.write("// Auto-generated vertex structure")
|
||||
|
||||
# 添加必要的头文件包含
|
||||
writer.write("#include <stdint.h> // For fixed-width integer types")
|
||||
writer.write()
|
||||
|
||||
with writer.block("typedef struct Vertex", "} Vertex;"):
|
||||
total_offset = 0
|
||||
for field in vertex_fields:
|
||||
# 生成字段声明
|
||||
declaration = self._get_c_type_declaration(field.type, field.name)
|
||||
|
||||
# 获取字段信息
|
||||
c_type, count, size = self._get_type_info(field.type)
|
||||
|
||||
# 添加详细注释
|
||||
semantic = field.semantic
|
||||
if field.semantic_index > 0:
|
||||
semantic += str(field.semantic_index)
|
||||
|
||||
comment = f"// location: {field.location}, semantic: {semantic}"
|
||||
comment += f", offset: {total_offset}, size: {size} bytes"
|
||||
|
||||
writer.write(comment)
|
||||
writer.write(f"{declaration};")
|
||||
|
||||
total_offset += size
|
||||
|
||||
# 添加结构体大小的静态断言
|
||||
writer.write()
|
||||
writer.write(f"// Total size: {total_offset} bytes")
|
||||
writer.write(f"static_assert(sizeof(Vertex) == {total_offset}, \"Vertex struct size mismatch\");")
|
||||
writer.write()
|
||||
|
||||
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
|
||||
"""生成Uniform缓冲区结构体"""
|
||||
code = "// Uniform buffer structures\n"
|
||||
writer.write("// Uniform buffer structures")
|
||||
|
||||
for buffer in uniform_buffers:
|
||||
struct_name = buffer['name'].replace('_buffer', '').title().replace('_', '') + 'Buffer'
|
||||
code += f"typedef struct {struct_name} {{\n"
|
||||
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
|
||||
|
||||
for field in buffer['fields']:
|
||||
c_type, count, _ = self.get_type_info(field['type'])
|
||||
if field['type']['kind'] == 'matrix':
|
||||
rows = field['type']['rowCount']
|
||||
cols = field['type']['columnCount']
|
||||
code += f" {c_type} {field['name']}[{rows}][{cols}];\n"
|
||||
elif count == 1:
|
||||
code += f" {c_type} {field['name']};\n"
|
||||
else:
|
||||
code += f" {c_type} {field['name']}[{count}];\n"
|
||||
# 计算总大小和对齐要求
|
||||
total_size = 0
|
||||
max_alignment = 16 # GPU通常要求16字节对齐
|
||||
|
||||
code += f"}} {struct_name};\n"
|
||||
code += f"// Binding: {buffer['binding']}, Size: {sum(f['size'] for f in buffer['fields'])} bytes\n\n"
|
||||
with writer.block(f"typedef struct {struct_name}", f"}} {struct_name};"):
|
||||
for i, field in enumerate(buffer.fields):
|
||||
# 检查是否需要填充
|
||||
if field.offset > total_size:
|
||||
padding_size = field.offset - total_size
|
||||
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
|
||||
total_size = field.offset
|
||||
|
||||
return code
|
||||
# 生成字段声明
|
||||
declaration = self._get_c_type_declaration(field.type, field.name)
|
||||
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
|
||||
|
||||
def generate_helper_functions(self, vertex_fields: List[Dict]) -> str:
|
||||
total_size = field.offset + field.size
|
||||
|
||||
# 确保结构体大小正确对齐
|
||||
aligned_size = self._get_aligned_size(total_size, max_alignment)
|
||||
if aligned_size > total_size:
|
||||
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
|
||||
|
||||
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
|
||||
writer.write()
|
||||
|
||||
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||
"""生成顶点属性数组"""
|
||||
writer.write("// Vertex attribute descriptions")
|
||||
writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}")
|
||||
writer.write()
|
||||
|
||||
writer.write("static const SDL_GPUVertexAttribute g_vertexAttributes[] = {")
|
||||
|
||||
offset = 0
|
||||
with writer.indent():
|
||||
for i, field in enumerate(vertex_fields):
|
||||
c_type, count, size = self._get_type_info(field.type)
|
||||
scalar_type = field.type.scalar_type or 'float32'
|
||||
|
||||
format_key = (scalar_type, count)
|
||||
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
|
||||
|
||||
with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"):
|
||||
writer.write(f"// {field.name}")
|
||||
writer.write(f".location = {field.location},")
|
||||
writer.write(f".buffer_slot = 0,")
|
||||
writer.write(f".format = {format_name},")
|
||||
writer.write(f".offset = {offset}")
|
||||
|
||||
offset += size
|
||||
|
||||
writer.write("};")
|
||||
writer.write()
|
||||
|
||||
# 生成顶点缓冲区描述
|
||||
writer.write("// Vertex buffer description")
|
||||
with writer.block("static const SDL_GPUVertexBufferDescription g_vertexBufferDesc =", "};"):
|
||||
writer.write(".slot = 0,")
|
||||
writer.write(f".pitch = {offset}, // sizeof(Vertex)")
|
||||
writer.write(".input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,")
|
||||
writer.write(".instance_step_rate = 0")
|
||||
writer.write()
|
||||
|
||||
# 生成顶点输入状态
|
||||
writer.write("// Vertex input state")
|
||||
with writer.block("static const SDL_GPUVertexInputState g_vertexInputState =", "};"):
|
||||
writer.write(".vertex_buffer_descriptions = &g_vertexBufferDesc,")
|
||||
writer.write(".num_vertex_buffers = 1,")
|
||||
writer.write(".vertex_attributes = g_vertexAttributes,")
|
||||
writer.write(".num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT")
|
||||
writer.write()
|
||||
|
||||
def _generate_helper_functions(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||
"""生成辅助函数"""
|
||||
code = "// Helper functions\n\n"
|
||||
writer.write("// Helper functions")
|
||||
writer.write()
|
||||
|
||||
# 创建顶点缓冲区函数
|
||||
code += "static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device, \n"
|
||||
code += " const Vertex* vertices, \n"
|
||||
code += " Uint32 vertexCount) {\n"
|
||||
code += " SDL_GPUBufferCreateInfo bufferInfo = {\n"
|
||||
code += " .usage = SDL_GPU_BUFFERUSAGE_VERTEX,\n"
|
||||
code += " .size = static_cast<Uint32>(sizeof(Vertex)) * vertexCount\n"
|
||||
code += " };\n"
|
||||
code += " \n"
|
||||
code += " SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);\n"
|
||||
code += " \n"
|
||||
code += " // Upload vertex data\n"
|
||||
code += " SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device,\n"
|
||||
code += " SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD, bufferInfo.size);\n"
|
||||
code += " \n"
|
||||
code += " void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);\n"
|
||||
code += " SDL_memcpy(mapped, vertices, bufferInfo.size);\n"
|
||||
code += " SDL_UnmapGPUTransferBuffer(device, transfer);\n"
|
||||
code += " \n"
|
||||
code += " // Copy to GPU\n"
|
||||
code += " SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);\n"
|
||||
code += " SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);\n"
|
||||
code += " \n"
|
||||
code += " SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};\n"
|
||||
code += " SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};\n"
|
||||
code += " \n"
|
||||
code += " SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);\n"
|
||||
code += " SDL_EndGPUCopyPass(copy);\n"
|
||||
code += " SDL_SubmitGPUCommandBuffer(cmd);\n"
|
||||
code += " \n"
|
||||
code += " SDL_ReleaseGPUTransferBuffer(device, transfer);\n"
|
||||
code += " return buffer;\n"
|
||||
code += "}\n\n"
|
||||
writer.write("static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device,")
|
||||
writer.write(" const Vertex* vertices,")
|
||||
|
||||
return code
|
||||
with writer.block(" Uint32 vertexCount)", "}"):
|
||||
with writer.block("SDL_GPUBufferCreateInfo bufferInfo =", "};"):
|
||||
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
|
||||
writer.write(".size = static_cast<Uint32>(sizeof(Vertex)) * vertexCount")
|
||||
|
||||
writer.write()
|
||||
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);")
|
||||
writer.write()
|
||||
|
||||
with writer.block("SDL_GPUTransferBufferCreateInfo transferInfo =", "};"):
|
||||
writer.write(".usage = SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD,")
|
||||
writer.write(".size = bufferInfo.size")
|
||||
|
||||
writer.write()
|
||||
writer.write("// Upload vertex data")
|
||||
writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transferInfo);")
|
||||
writer.write()
|
||||
writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);")
|
||||
writer.write("SDL_memcpy(mapped, vertices, bufferInfo.size);")
|
||||
writer.write("SDL_UnmapGPUTransferBuffer(device, transfer);")
|
||||
writer.write()
|
||||
writer.write("// Copy to GPU")
|
||||
writer.write("SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);")
|
||||
writer.write("SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);")
|
||||
writer.write()
|
||||
writer.write("SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};")
|
||||
writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};")
|
||||
writer.write()
|
||||
writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);")
|
||||
writer.write("SDL_EndGPUCopyPass(copy);")
|
||||
writer.write("SDL_SubmitGPUCommandBuffer(cmd);")
|
||||
writer.write()
|
||||
writer.write("SDL_ReleaseGPUTransferBuffer(device, transfer);")
|
||||
writer.write("return buffer;")
|
||||
|
||||
writer.write()
|
||||
|
||||
|
||||
# 主生成器类(保持不变)
|
||||
class ShaderLayoutGenerator:
|
||||
"""主生成器类,整合解析和代码生成功能"""
|
||||
|
||||
def __init__(self):
|
||||
self.parser = ShaderLayoutParser()
|
||||
self.code_generator = ShaderLayoutCodeGenerator()
|
||||
|
||||
def generate_header(self, json_file: str) -> str:
|
||||
"""生成头文件内容"""
|
||||
with open(json_file, 'r') as f:
|
||||
json_data = json.load(f)
|
||||
# 解析JSON到类对象
|
||||
layout = self.parser.parse(json_file)
|
||||
|
||||
# 解析数据
|
||||
vertex_fields = self.parse_vertex_input(json_data)
|
||||
uniform_buffers = self.parse_uniform_buffers(json_data)
|
||||
# 基于类对象生成代码
|
||||
return self.code_generator.generate(layout, json_file)
|
||||
|
||||
# 生成代码
|
||||
code = "#pragma once\n\n"
|
||||
code += "#include <SDL3/SDL.h>\n"
|
||||
code += "#include <SDL3/SDL_gpu.h>\n\n"
|
||||
code += "// Auto-generated from: " + json_file + "\n\n"
|
||||
|
||||
code += self.generate_vertex_struct(vertex_fields)
|
||||
code += self.generate_uniform_structs(uniform_buffers)
|
||||
code += self.generate_vertex_attributes(vertex_fields)
|
||||
code += self.generate_helper_functions(vertex_fields)
|
||||
# 使用示例
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: shader_layout_gen.py <shader.json>")
|
||||
sys.exit(1)
|
||||
|
||||
return code
|
||||
generator = ShaderLayoutGenerator()
|
||||
header_content = generator.generate_header(sys.argv[1])
|
||||
print(header_content)
|
||||
|
||||
@@ -12,7 +12,7 @@ import shutil
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
|
||||
|
||||
from shader_layout_generator import ShaderLayoutGenerator
|
||||
|
||||
class ShaderParser:
|
||||
def __init__(self, slangc_path: str, include_paths: List[str] = None):
|
||||
@@ -220,6 +220,11 @@ class ShaderParser:
|
||||
if reflection_json.strip():
|
||||
try:
|
||||
reflection = json.loads(reflection_json)
|
||||
generator = ShaderLayoutGenerator()
|
||||
header_content = generator.generate_header(reflection_path)
|
||||
layout_path_filename = os.path.join(os.path.dirname(output_path), f"{entry_name}_layout.h")
|
||||
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
||||
out_file.write(header_content)
|
||||
return self._create_shader_info_from_reflection(
|
||||
reflection, entry_name, stage, source
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user