整理输出格式
This commit is contained in:
@@ -23,9 +23,9 @@ class SDL3GPUSlangCompiler:
|
||||
self.binding_manager = BindingManager()
|
||||
self.code_generator = CodeGenerator()
|
||||
|
||||
def parse_slang_shader(self, source_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
||||
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
||||
"""解析Slang着色器源码,提取资源信息"""
|
||||
return self.parser.parse_slang_shader(source_path, target, include_paths)
|
||||
return self.parser.parse_slang_shader(source_path, output_path, target, include_paths)
|
||||
|
||||
def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]:
|
||||
"""编译着色器并返回二进制路径和绑定信息"""
|
||||
|
||||
2
main.py
2
main.py
@@ -35,7 +35,7 @@ def main():
|
||||
|
||||
# 解析着色器
|
||||
print(f"**Parsing** {args.input}...")
|
||||
shaders = compiler.parse_slang_shader(args.input, target)
|
||||
shaders = compiler.parse_slang_shader(args.input, args.output_dir, target)
|
||||
|
||||
# 编译每个入口点
|
||||
binding_infos = []
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Any, Tuple, Optional, TextIO
|
||||
from dataclasses import dataclass, field
|
||||
@@ -193,7 +194,7 @@ class ShaderLayoutCodeGenerator:
|
||||
'float16': 2,
|
||||
}
|
||||
|
||||
def generate(self, layout: ShaderLayout, source_file: str) -> str:
|
||||
def generate(self, layout: ShaderLayout, source_file: str, namespace: str) -> str:
|
||||
"""生成完整的头文件内容"""
|
||||
output = io.StringIO()
|
||||
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
|
||||
@@ -202,14 +203,17 @@ class ShaderLayoutCodeGenerator:
|
||||
writer.write()
|
||||
writer.write("#include <SDL3/SDL.h>")
|
||||
writer.write("#include <SDL3/SDL_gpu.h>")
|
||||
writer.write("#include <cstdint> // For fixed-width integer types")
|
||||
writer.write()
|
||||
writer.write(f"// Auto-generated from: {source_file}")
|
||||
writer.write()
|
||||
|
||||
self._generate_vertex_struct(writer, layout.vertex_fields)
|
||||
self._generate_uniform_structs(writer, layout.uniform_buffers)
|
||||
self._generate_vertex_attributes(writer, layout.vertex_fields)
|
||||
self._generate_helper_functions(writer, layout.vertex_fields)
|
||||
writer.write("// Auto-generated vertex structure")
|
||||
with writer.block(f'namespace {namespace}'):
|
||||
self._generate_vertex_struct(writer, layout.vertex_fields)
|
||||
self._generate_uniform_structs(writer, layout.uniform_buffers)
|
||||
self._generate_vertex_attributes(writer, layout.vertex_fields)
|
||||
self._generate_helper_functions(writer, layout.vertex_fields)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
@@ -349,13 +353,7 @@ class ShaderLayoutCodeGenerator:
|
||||
|
||||
def _generate_vertex_struct(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||
"""生成顶点结构体"""
|
||||
writer.write("// Auto-generated vertex structure")
|
||||
|
||||
# 添加必要的头文件包含
|
||||
writer.write("#include <stdint.h> // For fixed-width integer types")
|
||||
writer.write()
|
||||
|
||||
with writer.block("typedef struct Vertex", "} Vertex;"):
|
||||
with writer.block("struct vertex_t", "};"):
|
||||
total_offset = 0
|
||||
for field in vertex_fields:
|
||||
# 生成字段声明
|
||||
@@ -380,7 +378,7 @@ class ShaderLayoutCodeGenerator:
|
||||
# 添加结构体大小的静态断言
|
||||
writer.write()
|
||||
writer.write(f"// Total size: {total_offset} bytes")
|
||||
writer.write(f"static_assert(sizeof(Vertex) == {total_offset}, \"Vertex struct size mismatch\");")
|
||||
writer.write(f"static_assert(sizeof(vertex_t) == {total_offset}, \"Vertex struct size mismatch\");")
|
||||
writer.write()
|
||||
|
||||
def _generate_uniform_structs(self, writer: IndentManager, uniform_buffers: List[UniformBuffer]) -> None:
|
||||
@@ -422,7 +420,7 @@ class ShaderLayoutCodeGenerator:
|
||||
writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}")
|
||||
writer.write()
|
||||
|
||||
writer.write("static const SDL_GPUVertexAttribute g_vertexAttributes[] = {")
|
||||
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
|
||||
|
||||
offset = 0
|
||||
with writer.indent():
|
||||
@@ -447,19 +445,19 @@ class ShaderLayoutCodeGenerator:
|
||||
|
||||
# 生成顶点缓冲区描述
|
||||
writer.write("// Vertex buffer description")
|
||||
with writer.block("static const SDL_GPUVertexBufferDescription g_vertexBufferDesc =", "};"):
|
||||
with writer.block("static constexpr SDL_GPUVertexBufferDescription vertex_buffer_desc =", "};"):
|
||||
writer.write(".slot = 0,")
|
||||
writer.write(f".pitch = {offset}, // sizeof(Vertex)")
|
||||
writer.write(f".pitch = {offset}, // sizeof(vertex)")
|
||||
writer.write(".input_rate = SDL_GPU_VERTEXINPUTRATE_VERTEX,")
|
||||
writer.write(".instance_step_rate = 0")
|
||||
writer.write()
|
||||
|
||||
# 生成顶点输入状态
|
||||
writer.write("// Vertex input state")
|
||||
with writer.block("static const SDL_GPUVertexInputState g_vertexInputState =", "};"):
|
||||
writer.write(".vertex_buffer_descriptions = &g_vertexBufferDesc,")
|
||||
with writer.block("static constexpr SDL_GPUVertexInputState vertex_input_state =", "};"):
|
||||
writer.write(".vertex_buffer_descriptions = &vertex_buffer_desc,")
|
||||
writer.write(".num_vertex_buffers = 1,")
|
||||
writer.write(".vertex_attributes = g_vertexAttributes,")
|
||||
writer.write(".vertex_attributes = vertex_attributes,")
|
||||
writer.write(".num_vertex_attributes = VERTEX_ATTRIBUTE_COUNT")
|
||||
writer.write()
|
||||
|
||||
@@ -469,28 +467,28 @@ class ShaderLayoutCodeGenerator:
|
||||
writer.write()
|
||||
|
||||
# 创建顶点缓冲区函数
|
||||
writer.write("static SDL_GPUBuffer* createVertexBuffer(SDL_GPUDevice* device,")
|
||||
writer.write(" const Vertex* vertices,")
|
||||
writer.write("static SDL_GPUBuffer* create_vertex_buffer(SDL_GPUDevice* device,")
|
||||
writer.write(" const vertex_t* vertices,")
|
||||
|
||||
with writer.block(" Uint32 vertexCount)", "}"):
|
||||
with writer.block("SDL_GPUBufferCreateInfo bufferInfo =", "};"):
|
||||
with writer.block(" Uint32 vertex_count)", "}"):
|
||||
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
|
||||
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
|
||||
writer.write(".size = static_cast<Uint32>(sizeof(Vertex)) * vertexCount")
|
||||
writer.write(".size = static_cast<Uint32>(sizeof(vertex)) * vertex_count")
|
||||
|
||||
writer.write()
|
||||
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &bufferInfo);")
|
||||
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
|
||||
writer.write()
|
||||
|
||||
with writer.block("SDL_GPUTransferBufferCreateInfo transferInfo =", "};"):
|
||||
with writer.block("SDL_GPUTransferBufferCreateInfo transfer_info =", "};"):
|
||||
writer.write(".usage = SDL_GPU_TRANSFERBUFFERUSAGE_UPLOAD,")
|
||||
writer.write(".size = bufferInfo.size")
|
||||
writer.write(".size = buffer_info.size")
|
||||
|
||||
writer.write()
|
||||
writer.write("// Upload vertex data")
|
||||
writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transferInfo);")
|
||||
writer.write("SDL_GPUTransferBuffer* transfer = SDL_CreateGPUTransferBuffer(device, &transfer_info);")
|
||||
writer.write()
|
||||
writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, SDL_FALSE);")
|
||||
writer.write("SDL_memcpy(mapped, vertices, bufferInfo.size);")
|
||||
writer.write("void* mapped = SDL_MapGPUTransferBuffer(device, transfer, false);")
|
||||
writer.write("SDL_memcpy(mapped, vertices, buffer_info.size);")
|
||||
writer.write("SDL_UnmapGPUTransferBuffer(device, transfer);")
|
||||
writer.write()
|
||||
writer.write("// Copy to GPU")
|
||||
@@ -498,9 +496,9 @@ class ShaderLayoutCodeGenerator:
|
||||
writer.write("SDL_GPUCopyPass* copy = SDL_BeginGPUCopyPass(cmd);")
|
||||
writer.write()
|
||||
writer.write("SDL_GPUTransferBufferLocation src = {.transfer_buffer = transfer, .offset = 0};")
|
||||
writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = bufferInfo.size};")
|
||||
writer.write("SDL_GPUBufferRegion dst = {.buffer = buffer, .offset = 0, .size = buffer_info.size};")
|
||||
writer.write()
|
||||
writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, SDL_FALSE);")
|
||||
writer.write("SDL_UploadToGPUBuffer(copy, &src, &dst, false);")
|
||||
writer.write("SDL_EndGPUCopyPass(copy);")
|
||||
writer.write("SDL_SubmitGPUCommandBuffer(cmd);")
|
||||
writer.write()
|
||||
@@ -518,13 +516,13 @@ class ShaderLayoutGenerator:
|
||||
self.parser = ShaderLayoutParser()
|
||||
self.code_generator = ShaderLayoutCodeGenerator()
|
||||
|
||||
def generate_header(self, json_file: str) -> str:
|
||||
def generate_header(self, json_file: str, namespace: str) -> str:
|
||||
"""生成头文件内容"""
|
||||
# 解析JSON到类对象
|
||||
layout = self.parser.parse(json_file)
|
||||
|
||||
# 基于类对象生成代码
|
||||
return self.code_generator.generate(layout, json_file)
|
||||
return self.code_generator.generate(layout, json_file, namespace)
|
||||
|
||||
|
||||
# 使用示例
|
||||
@@ -534,5 +532,5 @@ if __name__ == "__main__":
|
||||
sys.exit(1)
|
||||
|
||||
generator = ShaderLayoutGenerator()
|
||||
header_content = generator.generate_header(sys.argv[1])
|
||||
header_content = generator.generate_header(sys.argv[1], 'test_shader')
|
||||
print(header_content)
|
||||
|
||||
@@ -20,7 +20,7 @@ class ShaderParser:
|
||||
def __init__(self, include_paths: List[str] = None):
|
||||
self.include_paths = include_paths or []
|
||||
|
||||
def parse_slang_shader(self, source_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
||||
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
||||
"""解析Slang着色器源码,提取资源信息"""
|
||||
|
||||
# 合并包含路径
|
||||
@@ -49,7 +49,7 @@ class ShaderParser:
|
||||
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
|
||||
|
||||
shader_info = self._compile_and_reflect_entry_point(
|
||||
source_path, entry_name, target, stage, all_include_paths, source
|
||||
source_path, output_path, entry_name, target, stage, all_include_paths, source
|
||||
)
|
||||
|
||||
if shader_info:
|
||||
@@ -161,19 +161,19 @@ class ShaderParser:
|
||||
print(f"Total entry points found: {len(entry_points)}")
|
||||
return entry_points
|
||||
|
||||
def _compile_and_reflect_entry_point(self, source_path: str, entry_name: str, target: TargetFormat,
|
||||
def _compile_and_reflect_entry_point(self, source_path: str, output_path: str, entry_name: str, target: TargetFormat,
|
||||
stage: ShaderStage, include_paths: List[str],
|
||||
source: str) -> Optional[ShaderInfo]:
|
||||
"""为单个入口点进行完整编译和反射"""
|
||||
|
||||
# 创建临时输出文件
|
||||
with tempfile.NamedTemporaryFile(suffix='.shader.tmp', delete=False) as tmp_output:
|
||||
output_path = tmp_output.name
|
||||
temp_output_path = tmp_output.name
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
|
||||
reflection_path = tmp_reflection.name
|
||||
|
||||
cmd = make_cmd(source_path, target, stage, entry_name, include_paths, output_path)
|
||||
cmd = make_cmd(source_path, target, stage, entry_name, include_paths, temp_output_path)
|
||||
cmd.extend([
|
||||
'-reflection-json', reflection_path, # 反射输出
|
||||
])
|
||||
@@ -197,9 +197,9 @@ class ShaderParser:
|
||||
if stage == ShaderStage.VERTEX:
|
||||
# 生成顶点布局代码
|
||||
generator = ShaderLayoutGenerator()
|
||||
header_content = generator.generate_header(reflection_path)
|
||||
source_file_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||
layout_path_filename = os.path.join(os.path.dirname(output_path), f"{source_file_name}_layout.h")
|
||||
header_content = generator.generate_header(reflection_path, f'{source_file_name}_shader')
|
||||
layout_path_filename = os.path.join(output_path, f"{source_file_name}_layout.h")
|
||||
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
||||
out_file.write(header_content)
|
||||
|
||||
@@ -226,7 +226,7 @@ class ShaderParser:
|
||||
# 如果反射失败,尝试手动解析
|
||||
print(f"Falling back to manual parsing for {entry_name}")
|
||||
# 清理临时文件
|
||||
for temp_file in [output_path, reflection_path]:
|
||||
for temp_file in [temp_output_path, reflection_path]:
|
||||
if os.path.exists(temp_file):
|
||||
os.unlink(temp_file)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user