优化代码可读性

This commit is contained in:
daiqingshuang
2025-06-06 17:48:25 +08:00
parent de9cf7415e
commit ece556b138
2 changed files with 467 additions and 189 deletions

View File

@@ -1,62 +1,105 @@
#!/usr/bin/env python3
import json
import sys
from typing import Dict, List, Any, Tuple
from typing import Dict, List, Any, Tuple, Optional, TextIO
from dataclasses import dataclass, field
from enum import Enum
from contextlib import contextmanager
import io
class ShaderLayoutGenerator:
def __init__(self):
# SDL GPU 格式映射
self.format_mapping = {
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
}
from code_generator import IndentManager
# 类型大小映射
self.type_sizes = {
'float32': 4,
'int32': 4,
'uint32': 4,
}
def get_type_info(self, type_obj: Dict) -> Tuple[str, int, int]:
"""获取类型信息C类型名称、元素数量、字节大小"""
kind = type_obj.get('kind')
# 数据模型类
@dataclass
class FieldType:
"""字段类型信息"""
kind: str # 'scalar', 'vector', 'matrix'
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
element_count: Optional[int] = None # for vector
row_count: Optional[int] = None # for matrix
column_count: Optional[int] = None # for matrix
@classmethod
def from_dict(cls, data: Dict) -> 'FieldType':
"""从字典创建FieldType对象"""
kind = data.get('kind')
if kind == 'vector':
element_type = type_obj['elementType']['scalarType']
element_count = type_obj['elementCount']
c_type = 'float' if element_type == 'float32' else element_type
size = self.type_sizes[element_type] * element_count
return c_type, element_count, size
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
element_count=data['elementCount']
)
elif kind == 'scalar':
scalar_type = type_obj['scalarType']
c_type = 'float' if scalar_type == 'float32' else scalar_type
size = self.type_sizes[scalar_type]
return c_type, 1, size
return cls(
kind=kind,
scalar_type=data['scalarType']
)
elif kind == 'matrix':
rows = type_obj['rowCount']
cols = type_obj['columnCount']
element_type = type_obj['elementType']['scalarType']
size = self.type_sizes[element_type] * rows * cols
return 'float', rows * cols, size
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
row_count=data['rowCount'],
column_count=data['columnCount']
)
return 'float', 1, 4
return cls(kind='scalar', scalar_type='float32')
def parse_vertex_input(self, json_data: Dict) -> List[Dict]:
@dataclass
class VertexField:
"""顶点输入字段"""
name: str
type: FieldType
location: int
semantic: str
semantic_index: int
@dataclass
class UniformField:
"""Uniform缓冲区字段"""
name: str
type: FieldType
offset: int
size: int
@dataclass
class UniformBuffer:
"""Uniform缓冲区"""
name: str
binding: int
fields: List[UniformField] = field(default_factory=list)
@dataclass
class ShaderLayout:
"""着色器布局数据"""
vertex_fields: List[VertexField] = field(default_factory=list)
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
# 数据解析器(保持不变)
class ShaderLayoutParser:
"""解析JSON数据并提取到类对象"""
def parse(self, json_file: str) -> ShaderLayout:
"""解析JSON文件并返回ShaderLayout对象"""
with open(json_file, 'r') as f:
json_data = json.load(f)
layout = ShaderLayout()
layout.vertex_fields = self._parse_vertex_input(json_data)
layout.uniform_buffers = self._parse_uniform_buffers(json_data)
return layout
def _parse_vertex_input(self, json_data: Dict) -> List[VertexField]:
"""解析顶点输入字段"""
vertex_fields = []
# 查找顶点输入参数
entry_points = json_data.get('entryPoints', [])
for entry in entry_points:
if entry.get('stage') == 'vertex':
@@ -65,18 +108,18 @@ class ShaderLayoutGenerator:
if param.get('name') == 'input' and param.get('stage') == 'vertex':
fields = param.get('type', {}).get('fields', [])
for field in fields:
field_info = {
'name': field['name'],
'type': field['type'],
'location': field['binding']['index'],
'semantic': field.get('semanticName', ''),
'semantic_index': field.get('semanticIndex', 0)
}
vertex_fields.append(field_info)
vertex_field = VertexField(
name=field['name'],
type=FieldType.from_dict(field['type']),
location=field['binding']['index'],
semantic=field.get('semanticName', ''),
semantic_index=field.get('semanticIndex', 0)
)
vertex_fields.append(vertex_field)
return vertex_fields
def parse_uniform_buffers(self, json_data: Dict) -> List[Dict]:
def _parse_uniform_buffers(self, json_data: Dict) -> List[UniformBuffer]:
"""解析Uniform缓冲区"""
uniform_buffers = []
@@ -84,182 +127,412 @@ class ShaderLayoutGenerator:
for param in parameters:
binding = param.get('binding', {})
if binding.get('kind') == 'constantBuffer':
buffer_info = {
'name': param['name'],
'binding': binding['index'],
'fields': []
}
buffer = UniformBuffer(
name=param['name'],
binding=binding['index']
)
# 解析缓冲区字段
element_type = param.get('type', {}).get('elementType', {})
if element_type.get('kind') == 'struct':
fields = element_type.get('fields', [])
for field in fields:
field_info = {
'name': field['name'],
'type': field['type'],
'offset': field['binding']['offset'],
'size': field['binding']['size']
}
buffer_info['fields'].append(field_info)
uniform_field = UniformField(
name=field['name'],
type=FieldType.from_dict(field['type']),
offset=field['binding']['offset'],
size=field['binding']['size']
)
buffer.fields.append(uniform_field)
uniform_buffers.append(buffer_info)
uniform_buffers.append(buffer)
return uniform_buffers
def generate_vertex_struct(self, vertex_fields: List[Dict]) -> str:
"""生成顶点结构体"""
code = "// Auto-generated vertex structure\n"
code += "typedef struct Vertex {\n"
for field in vertex_fields:
c_type, count, _ = self.get_type_info(field['type'])
if count == 1:
code += f" {c_type} {field['name']};\n"
# **重构后的代码生成器**
class ShaderLayoutCodeGenerator:
"""从ShaderLayout对象生成C代码"""
def __init__(self):
# SDL GPU 格式映射
self.format_mapping = {
('int32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_INT',
('int32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_INT2',
('int32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_INT3',
('int32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_INT4',
('uint32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT',
('uint32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT2',
('uint32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT3',
('uint32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UINT4',
('float32', 1): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT',
('float32', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT2',
('float32', 3): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT3',
('float32', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4',
('int8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE2',
('int8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_BYTE4',
('uint8', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE2',
('uint8', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_UBYTE4',
('int16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT2',
('int16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_SHORT4',
('uint16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT2',
('uint16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_USHORT4',
('float16', 2): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF2',
('float16', 4): 'SDL_GPU_VERTEXELEMENTFORMAT_HALF4',
}
# 类型大小映射
self.type_sizes = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
def generate(self, layout: ShaderLayout, source_file: str) -> str:
"""生成完整的头文件内容"""
output = io.StringIO()
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
writer.write("#pragma once")
writer.write()
writer.write("#include <SDL3/SDL.h>")
writer.write("#include <SDL3/SDL_gpu.h>")
writer.write()
writer.write(f"// Auto-generated from: {source_file}")
writer.write()
self._generate_vertex_struct(writer, layout.vertex_fields)
self._generate_uniform_structs(writer, layout.uniform_buffers)
self._generate_vertex_attributes(writer, layout.vertex_fields)
self._generate_helper_functions(writer, layout.vertex_fields)
return output.getvalue()
def _get_type_info(self, field_type: FieldType) -> Tuple[str, int, int]:
"""获取类型信息C类型名称、元素数量、字节大小
Returns:
Tuple[str, int, int]: (C类型名称, 元素数量, 总字节大小)
"""
# C类型映射表
c_type_mapping = {
'int32': 'int32_t',
'uint32': 'uint32_t',
'float32': 'float',
'int8': 'int8_t',
'uint8': 'uint8_t',
'int16': 'int16_t',
'uint16': 'uint16_t',
'float16': 'mirage_float16', # 注意C语言中没有标准的float16, 所以这里使用mirage_float16宏便于用户实现
}
# 对齐要求映射表
alignment_mapping = {
'int32': 4,
'uint32': 4,
'float32': 4,
'int8': 1,
'uint8': 1,
'int16': 2,
'uint16': 2,
'float16': 2,
}
if field_type.kind == 'scalar':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
size = self.type_sizes.get(scalar_type, 4)
return c_type, 1, size
elif field_type.kind == 'vector':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
element_count = field_type.element_count or 4
# 验证向量元素数量的合法性
if element_count not in [1, 2, 3, 4]:
raise ValueError(f"Invalid vector element count: {element_count}. Must be 1, 2, 3, or 4.")
size = self.type_sizes.get(scalar_type, 4) * element_count
# 对于vec3通常需要按vec4对齐
if element_count == 3:
alignment = alignment_mapping.get(scalar_type, 4) * 4
size = ((size + alignment - 1) // alignment) * alignment
return c_type, element_count, size
elif field_type.kind == 'matrix':
scalar_type = field_type.scalar_type or 'float32'
c_type = c_type_mapping.get(scalar_type, 'float')
row_count = field_type.row_count or 4
column_count = field_type.column_count or 4
# 验证矩阵维度的合法性
if row_count not in [2, 3, 4] or column_count not in [2, 3, 4]:
raise ValueError(f"Invalid matrix dimensions: {row_count}x{column_count}. "
f"Rows and columns must be 2, 3, or 4.")
# 矩阵通常按列主序存储,每列都需要对齐
element_size = self.type_sizes.get(scalar_type, 4)
# 每列的大小(考虑对齐)
if row_count == 3:
# mat3xN 的每列通常按 vec4 对齐
column_size = element_size * 4
else:
code += f" {c_type} {field['name']}[{count}];\n"
column_size = element_size * row_count
# 添加注释
semantic = field['semantic']
if field['semantic_index'] > 0:
semantic += str(field['semantic_index'])
code += f" // location: {field['location']}, semantic: {semantic}\n"
total_size = column_size * column_count
total_elements = row_count * column_count
code += "} Vertex;\n\n"
return code
return c_type, total_elements, total_size
def generate_vertex_attributes(self, vertex_fields: List[Dict]) -> str:
"""生成顶点属性数组"""
code = "// Vertex attribute descriptions\n"
code += f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}\n\n"
code += "static const SDL_GPUVertexAttribute g_vertexAttributes[] = {\n"
else:
# 未知类型,返回默认值并记录警告
print(f"Warning: Unknown field type kind '{field_type.kind}'. Using default float.")
return 'float', 1, 4
offset = 0
for i, field in enumerate(vertex_fields):
c_type, count, size = self.get_type_info(field['type'])
scalar_type = field['type'].get('elementType', field['type']).get('scalarType', 'float32')
def _get_c_type_declaration(self, field_type: FieldType, field_name: str) -> str:
"""生成C类型声明
format_key = (scalar_type, count)
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
Args:
field_type: 字段类型信息
field_name: 字段名称
code += " {\n"
code += f" .location = {field['location']},\n"
code += f" .buffer_slot = 0,\n"
code += f" .format = {format_name},\n"
code += f" .offset = {offset}\n"
code += " }"
Returns:
str: 完整的C类型声明
"""
c_type, count, _ = self._get_type_info(field_type)
if i < len(vertex_fields) - 1:
code += ","
if field_type.kind == 'scalar':
return f"{c_type} {field_name}"
code += f" // {field['name']}\n"
elif field_type.kind == 'vector':
# 对于向量,可以选择使用数组或专门的向量类型
if count == 1:
return f"{c_type} {field_name}"
else:
# 选项1: 使用数组
return f"{c_type} {field_name}[{count}]"
offset += size
# 选项2: 使用对齐的结构体(如果需要的话)
# return f"struct {{ {c_type} data[{count}]; }} {field_name}"
code += "};\n\n"
elif field_type.kind == 'matrix':
rows = field_type.row_count or 4
cols = field_type.column_count or 4
# 生成顶点缓冲区描述
code += "// Vertex buffer description\n"
code += "static const SDL_GPUVertexBufferDescription g_vertexBufferDesc = {\n"
code += " .slot = 0,\n"
code += f" .pitch = {offset}, // sizeof(Vertex)\n"
code += " .input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,\n"
code += " .instance_step_rate = 0\n"
code += "};\n\n"
# 矩阵使用二维数组(列主序)
return f"{c_type} {field_name}[{cols}][{rows}]"
# 生成顶点输入状态
code += "// Vertex input state\n"
code += "static const SDL_GPUVertexInputState g_vertexInputState = {\n"
code += " .vertex_buffer_descriptions = &g_vertexBufferDesc,\n"
code += " .num_vertex_buffers = 1,\n"
code += " .vertex_attributes = g_vertexAttributes,\n"
code += " .num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT\n"
code += "};\n\n"
else:
return f"{c_type} {field_name}"
return code
def _get_aligned_size(self, size: int, alignment: int) -> int:
"""计算对齐后的大小
def generate_uniform_structs(self, uniform_buffers: List[Dict]) -> str:
Args:
size: 原始大小
alignment: 对齐要求
Returns:
int: 对齐后的大小
"""
return ((size + alignment - 1) // alignment) * alignment
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成顶点结构体"""
writer.write("// Auto-generated vertex structure")
# 添加必要的头文件包含
writer.write("#include <stdint.h> // For fixed-width integer types")
writer.write()
with writer.block("typedef struct Vertex", "} Vertex;"):
total_offset = 0
for field in vertex_fields:
# 生成字段声明
declaration = self._get_c_type_declaration(field.type, field.name)
# 获取字段信息
c_type, count, size = self._get_type_info(field.type)
# 添加详细注释
semantic = field.semantic
if field.semantic_index > 0:
semantic += str(field.semantic_index)
comment = f"// location: {field.location}, semantic: {semantic}"
comment += f", offset: {total_offset}, size: {size} bytes"
writer.write(comment)
writer.write(f"{declaration};")
total_offset += size
# 添加结构体大小的静态断言
writer.write()
writer.write(f"// Total size: {total_offset} bytes")
writer.write(f"static_assert(sizeof(Vertex) == {total_offset}, \"Vertex struct size mismatch\");")
writer.write()
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
"""生成Uniform缓冲区结构体"""
code = "// Uniform buffer structures\n"
writer.write("// Uniform buffer structures")
for buffer in uniform_buffers:
struct_name = buffer['name'].replace('_buffer', '').title().replace('_', '') + 'Buffer'
code += f"typedef struct {struct_name} {{\n"
struct_name = buffer.name.replace('_buffer', '').title().replace('_', '') + 'Buffer'
for field in buffer['fields']:
c_type, count, _ = self.get_type_info(field['type'])
if field['type']['kind'] == 'matrix':
rows = field['type']['rowCount']
cols = field['type']['columnCount']
code += f" {c_type} {field['name']}[{rows}][{cols}];\n"
elif count == 1:
code += f" {c_type} {field['name']};\n"
else:
code += f" {c_type} {field['name']}[{count}];\n"
# 计算总大小和对齐要求
total_size = 0
max_alignment = 16 # GPU通常要求16字节对齐
code += f"}} {struct_name};\n"
code += f"// Binding: {buffer['binding']}, Size: {sum(f['size'] for f in buffer['fields'])} bytes\n\n"
with writer.block(f"typedef struct {struct_name}", f"}} {struct_name};"):
for i, field in enumerate(buffer.fields):
# 检查是否需要填充
if field.offset > total_size:
padding_size = field.offset - total_size
writer.write(f"uint8_t _padding{i}[{padding_size}]; // Padding")
total_size = field.offset
return code
# 生成字段声明
declaration = self._get_c_type_declaration(field.type, field.name)
writer.write(f"{declaration}; // offset: {field.offset}, size: {field.size}")
def generate_helper_functions(self, vertex_fields: List[Dict]) -> str:
total_size = field.offset + field.size
# 确保结构体大小正确对齐
aligned_size = self._get_aligned_size(total_size, max_alignment)
if aligned_size > total_size:
writer.write(f"// Note: Structure may need padding to {aligned_size} bytes for alignment")
writer.write(f"// Binding: {buffer.binding}, Size: {total_size} bytes (aligned: {aligned_size})")
writer.write()
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成顶点属性数组"""
writer.write("// Vertex attribute descriptions")
writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}")
writer.write()
writer.write("static const SDL_GPUVertexAttribute g_vertexAttributes[] = {")
offset = 0
with writer.indent():
for i, field in enumerate(vertex_fields):
c_type, count, size = self._get_type_info(field.type)
scalar_type = field.type.scalar_type or 'float32'
format_key = (scalar_type, count)
format_name = self.format_mapping.get(format_key, 'SDL_GPU_VERTEXELEMENTFORMAT_FLOAT4')
with writer.indent("{", "}," if i < len(vertex_fields) - 1 else "}"):
writer.write(f"// {field.name}")
writer.write(f".location = {field.location},")
writer.write(f".buffer_slot = 0,")
writer.write(f".format = {format_name},")
writer.write(f".offset = {offset}")
offset += size
writer.write("};")
writer.write()
# 生成顶点缓冲区描述
writer.write("// Vertex buffer description")
with writer.block("static const SDL_GPUVertexBufferDescription g_vertexBufferDesc =", "};"):
writer.write(".slot = 0,")
writer.write(f".pitch = {offset}, // sizeof(Vertex)")
writer.write(".input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,")
writer.write(".instance_step_rate = 0")
writer.write()
# 生成顶点输入状态
writer.write("// Vertex input state")
with writer.block("static const SDL_GPUVertexInputState g_vertexInputState =", "};"):
writer.write(".vertex_buffer_descriptions = &g_vertexBufferDesc,")
writer.write(".num_vertex_buffers = 1,")
writer.write(".vertex_attributes = g_vertexAttributes,")
writer.write(".num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT")
writer.write()
def _generate_helper_functions(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成辅助函数"""
code = "// Helper functions\n\n"
writer.write("// Helper functions")
writer.write()
# 创建顶点缓冲区函数
code += "static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device, \n"
code += " const Vertex* vertices, \n"
code += " Uint32 vertexCount) {\n"
code += " SDL_GPUBufferCreateInfo bufferInfo = {\n"
code += " .usage = SDL_GPU_BUFFERUSAGE_VERTEX,\n"
code += " .size = static_cast<Uint32>(sizeof(Vertex)) * vertexCount\n"
code += " };\n"
code += " \n"
code += " SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);\n"
code += " \n"
code += " // Upload vertex data\n"
code += " SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device,\n"
code += " SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD, bufferInfo.size);\n"
code += " \n"
code += " void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);\n"
code += " SDL_memcpy(mapped, vertices, bufferInfo.size);\n"
code += " SDL_UnmapGPUTransferBuffer(device, transfer);\n"
code += " \n"
code += " // Copy to GPU\n"
code += " SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);\n"
code += " SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);\n"
code += " \n"
code += " SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};\n"
code += " SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};\n"
code += " \n"
code += " SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);\n"
code += " SDL_EndGPUCopyPass(copy);\n"
code += " SDL_SubmitGPUCommandBuffer(cmd);\n"
code += " \n"
code += " SDL_ReleaseGPUTransferBuffer(device, transfer);\n"
code += " return buffer;\n"
code += "}\n\n"
writer.write("static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device,")
writer.write(" const Vertex* vertices,")
return code
with writer.block(" Uint32 vertexCount)", "}"):
with writer.block("SDL_GPUBufferCreateInfo bufferInfo =", "};"):
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
writer.write(".size = static_cast<Uint32>(sizeof(Vertex)) * vertexCount")
writer.write()
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);")
writer.write()
with writer.block("SDL_GPUTransferBufferCreateInfo transferInfo =", "};"):
writer.write(".usage = SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD,")
writer.write(".size = bufferInfo.size")
writer.write()
writer.write("// Upload vertex data")
writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transferInfo);")
writer.write()
writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);")
writer.write("SDL_memcpy(mapped, vertices, bufferInfo.size);")
writer.write("SDL_UnmapGPUTransferBuffer(device, transfer);")
writer.write()
writer.write("// Copy to GPU")
writer.write("SDL_GPUCommandBuffer* cmd = SDL_AcquireGPUCommandBuffer(device);")
writer.write("SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);")
writer.write()
writer.write("SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};")
writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};")
writer.write()
writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);")
writer.write("SDL_EndGPUCopyPass(copy);")
writer.write("SDL_SubmitGPUCommandBuffer(cmd);")
writer.write()
writer.write("SDL_ReleaseGPUTransferBuffer(device, transfer);")
writer.write("return buffer;")
writer.write()
# 主生成器类(保持不变)
class ShaderLayoutGenerator:
"""主生成器类,整合解析和代码生成功能"""
def __init__(self):
self.parser = ShaderLayoutParser()
self.code_generator = ShaderLayoutCodeGenerator()
def generate_header(self, json_file: str) -> str:
"""生成头文件内容"""
with open(json_file, 'r') as f:
json_data = json.load(f)
# 解析JSON到类对象
layout = self.parser.parse(json_file)
# 解析数据
vertex_fields = self.parse_vertex_input(json_data)
uniform_buffers = self.parse_uniform_buffers(json_data)
# 基于类对象生成代码
return self.code_generator.generate(layout, json_file)
# 生成代码
code = "#pragma once\n\n"
code += "#include <SDL3/SDL.h>\n"
code += "#include <SDL3/SDL_gpu.h>\n\n"
code += "// Auto-generated from: " + json_file + "\n\n"
code += self.generate_vertex_struct(vertex_fields)
code += self.generate_uniform_structs(uniform_buffers)
code += self.generate_vertex_attributes(vertex_fields)
code += self.generate_helper_functions(vertex_fields)
# 使用示例
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: shader_layout_gen.py <shader.json>")
sys.exit(1)
return code
generator = ShaderLayoutGenerator()
header_content = generator.generate_header(sys.argv[1])
print(header_content)

View File

@@ -12,7 +12,7 @@ import shutil
import re
from typing import List, Dict, Optional
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
from shader_layout_generator import ShaderLayoutGenerator
class ShaderParser:
def __init__(self, slangc_path: str, include_paths: List[str] = None):
@@ -220,6 +220,11 @@ class ShaderParser:
if reflection_json.strip():
try:
reflection = json.loads(reflection_json)
generator = ShaderLayoutGenerator()
header_content = generator.generate_header(reflection_path)
layout_path_filename = os.path.join(os.path.dirname(output_path), f"{entry_name}_layout.h")
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
out_file.write(header_content)
return self._create_shader_info_from_reflection(
reflection, entry_name, stage, source
)