封装全局参数

This commit is contained in:
2025-06-07 03:39:44 +08:00
parent a496b9c0a9
commit 274af12589
7 changed files with 143 additions and 143 deletions

View File

@@ -17,15 +17,14 @@ from shader_types import ShaderInfo, TargetFormat
class SDL3GPUSlangCompiler: class SDL3GPUSlangCompiler:
def __init__(self, include_paths: List[str] = None): def __init__(self):
self.include_paths = include_paths or [] self.parser = ShaderParser()
self.parser = ShaderParser(self.include_paths)
self.binding_manager = BindingManager() self.binding_manager = BindingManager()
self.code_generator = CodeGenerator() self.code_generator = CodeGenerator()
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]: def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
"""解析Slang着色器源码提取资源信息""" """解析Slang着色器源码提取资源信息"""
return self.parser.parse_slang_shader(source_path, output_path, target, include_paths) return self.parser.parse_slang_shader()
def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]: def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]:
"""编译着色器并返回二进制路径和绑定信息""" """编译着色器并返回二进制路径和绑定信息"""
@@ -48,7 +47,7 @@ class SDL3GPUSlangCompiler:
try: try:
# 编译着色器 # 编译着色器
cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, self.include_paths, output_path) cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, output_path)
print(f"Compiling shader with command: {' '.join(cmd)}") print(f"Compiling shader with command: {' '.join(cmd)}")
subprocess.run(cmd, check=True) subprocess.run(cmd, check=True)

View File

@@ -2,12 +2,12 @@
""" """
SDL3_GPU Slang Compiler - Command Generation SDL3_GPU Slang Compiler - Command Generation
""" """
from global_vars import global_vars
from shader_types import TargetFormat, ShaderStage from shader_types import TargetFormat, ShaderStage
from slangc_finder import slangc_path from slangc_finder import slangc_path
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, include_paths, def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, output_path: str):
output_path: str):
"""生成编译命令""" """生成编译命令"""
target_flag = { target_flag = {
TargetFormat.SPIRV: 'spirv', TargetFormat.SPIRV: 'spirv',
@@ -31,7 +31,7 @@ def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_p
'-stage', stage_flag, '-stage', stage_flag,
] ]
# 添加包含路径 # 添加包含路径
for include_path in include_paths: for include_path in global_vars.include_dirs:
cmd.extend(['-I', include_path]) cmd.extend(['-I', include_path])
if target in [TargetFormat.DXIL, TargetFormat.DXBC]: if target in [TargetFormat.DXIL, TargetFormat.DXBC]:

11
global_vars.py Normal file
View File

@@ -0,0 +1,11 @@
from shader_types import *
class GlobalVars:
source_file = ''
source_file_name = ''
source_path = ''
output_dir = ''
target: TargetFormat
layout: ShaderLayout
global_vars = GlobalVars()

33
main.py
View File

@@ -4,13 +4,13 @@ SDL3_GPU Slang Compiler - Command Line Interface
命令行接口 命令行接口
""" """
import os
import sys
import tempfile
import json
import argparse import argparse
import os
from compiler import SDL3GPUSlangCompiler from compiler import SDL3GPUSlangCompiler
from shader_types import TargetFormat from shader_types import TargetFormat
from global_vars import *
def main(): def main():
parser = argparse.ArgumentParser(description='SDL3 GPU Slang Compiler') parser = argparse.ArgumentParser(description='SDL3 GPU Slang Compiler')
@@ -19,23 +19,34 @@ def main():
parser.add_argument('-t', '--target', choices=['spirv', 'dxil', 'dxbc', 'msl'], parser.add_argument('-t', '--target', choices=['spirv', 'dxil', 'dxbc', 'msl'],
required=True, help='Target shader format') required=True, help='Target shader format')
parser.add_argument('-o', '--output-dir', required=True, help='Output path for binding code') parser.add_argument('-o', '--output-dir', required=True, help='Output path for binding code')
parser.add_argument('-i', '--include-dir', help='Include path for slang shader')
args = parser.parse_args() args = parser.parse_args()
# 获取编译文件的绝对路径 # 获取编译文件的绝对路径
input_path = os.path.abspath(args.input) input_path = os.path.abspath(args.input)
# 获取不包含拓展名的源文件名
source_file_name = os.path.splitext(os.path.basename(input_path))[0] global_vars.source_file = input_path
global_vars.source_file_name = os.path.splitext(os.path.basename(input_path))[0]
global_vars.source_path = os.path.dirname(input_path)
global_vars.output_dir = os.path.abspath(args.output_dir)
global_vars.target = target = TargetFormat(args.target)
# 仅保留路径部分 # 仅保留路径部分
input_path = os.path.dirname(input_path) include_dirs = [
global_vars.source_path,
]
if args.include_dir:
include_dirs.append(args.include_dir)
global_vars.include_dirs = include_dirs
# 创建编译器实例 # 创建编译器实例
compiler = SDL3GPUSlangCompiler([input_path]) compiler = SDL3GPUSlangCompiler()
target = TargetFormat(args.target)
# 解析着色器 # 解析着色器
print(f"**Parsing** {args.input}...") print(f"**Parsing** {args.input}...")
shaders = compiler.parse_slang_shader(args.input, args.output_dir, target) shaders = compiler.parse_slang_shader()
# 编译每个入口点 # 编译每个入口点
binding_infos = [] binding_infos = []
@@ -47,7 +58,7 @@ def main():
binding_infos.append(binding_info) binding_infos.append(binding_info)
binding_output_file_pathname = os.path.abspath(args.output_dir) binding_output_file_pathname = os.path.abspath(args.output_dir)
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{source_file_name}.shader.h") binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{global_vars.source_file_name}.shader.h")
# 生成绑定代码 # 生成绑定代码
print(f"\n**Generating** binding code to {binding_output_file_pathname}...") print(f"\n**Generating** binding code to {binding_output_file_pathname}...")

View File

@@ -1,85 +1,12 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import json
import os
import sys
from typing import Dict, List, Any, Tuple, Optional, TextIO
from dataclasses import dataclass, field
from enum import Enum
from contextlib import contextmanager
import io import io
import json
import sys
from typing import Tuple
from code_generator import IndentManager from code_generator import IndentManager
from global_vars import global_vars
from shader_types import *
# 数据模型类
@dataclass
class FieldType:
"""字段类型信息"""
kind: str # 'scalar', 'vector', 'matrix'
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
element_count: Optional[int] = None # for vector
row_count: Optional[int] = None # for matrix
column_count: Optional[int] = None # for matrix
@classmethod
def from_dict(cls, data: Dict) -> 'FieldType':
"""从字典创建FieldType对象"""
kind = data.get('kind')
if kind == 'vector':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
element_count=data['elementCount']
)
elif kind == 'scalar':
return cls(
kind=kind,
scalar_type=data['scalarType']
)
elif kind == 'matrix':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
row_count=data['rowCount'],
column_count=data['columnCount']
)
return cls(kind='scalar', scalar_type='float32')
@dataclass
class VertexField:
"""顶点输入字段"""
name: str
type: FieldType
location: int
semantic: str
semantic_index: int
@dataclass
class UniformField:
"""Uniform缓冲区字段"""
name: str
type: FieldType
offset: int
size: int
@dataclass
class UniformBuffer:
"""Uniform缓冲区"""
name: str
binding: int
fields: List[UniformField] = field(default_factory=list)
@dataclass
class ShaderLayout:
"""着色器布局数据"""
vertex_fields: List[VertexField] = field(default_factory=list)
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
# 数据解析器(保持不变) # 数据解析器(保持不变)
@@ -194,7 +121,7 @@ class ShaderLayoutCodeGenerator:
'float16': 2, 'float16': 2,
} }
def generate(self, layout: ShaderLayout, source_file: str, namespace: str) -> str: def generate(self) -> str:
"""生成完整的头文件内容""" """生成完整的头文件内容"""
output = io.StringIO() output = io.StringIO()
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进 writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
@@ -205,15 +132,13 @@ class ShaderLayoutCodeGenerator:
writer.write("#include <SDL3/SDL_gpu.h>") writer.write("#include <SDL3/SDL_gpu.h>")
writer.write("#include <cstdint> // For fixed-width integer types") writer.write("#include <cstdint> // For fixed-width integer types")
writer.write() writer.write()
writer.write(f"// Auto-generated from: {source_file}")
writer.write()
writer.write("// Auto-generated vertex structure") writer.write("// Auto-generated vertex structure")
with writer.block(f'namespace {namespace}'): with writer.block(f'namespace {global_vars.source_file_name}_shader'):
self._generate_vertex_struct(writer, layout.vertex_fields) self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
self._generate_uniform_structs(writer, layout.uniform_buffers) self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
self._generate_vertex_attributes(writer, layout.vertex_fields) self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
self._generate_helper_functions(writer, layout.vertex_fields) self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
return output.getvalue() return output.getvalue()
@@ -417,7 +342,7 @@ class ShaderLayoutCodeGenerator:
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None: def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
"""生成顶点属性数组""" """生成顶点属性数组"""
writer.write("// Vertex attribute descriptions") writer.write("// Vertex attribute descriptions")
writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}") writer.write(f"static constexpr uint32_t VERTEX_ATTRIBUTE_COUNT = {len(vertex_fields)};")
writer.write() writer.write()
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {") writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
@@ -473,7 +398,7 @@ class ShaderLayoutCodeGenerator:
with writer.block(" Uint32 vertex_count)", "}"): with writer.block(" Uint32 vertex_count)", "}"):
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"): with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,") writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
writer.write(".size = static_cast<Uint32>(sizeof(vertex)) * vertex_count") writer.write(".size = static_cast<Uint32>(sizeof(vertex_t)) * vertex_count")
writer.write() writer.write()
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);") writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
@@ -516,13 +441,13 @@ class ShaderLayoutGenerator:
self.parser = ShaderLayoutParser() self.parser = ShaderLayoutParser()
self.code_generator = ShaderLayoutCodeGenerator() self.code_generator = ShaderLayoutCodeGenerator()
def generate_header(self, json_file: str, namespace: str) -> str: def generate_header(self, json_file: str) -> str:
"""生成头文件内容""" """生成头文件内容"""
# 解析JSON到类对象 # 解析JSON到类对象
layout = self.parser.parse(json_file) global_vars.layout = self.parser.parse(json_file)
# 基于类对象生成代码 # 基于类对象生成代码
return self.code_generator.generate(layout, json_file, namespace) return self.code_generator.generate()
# 使用示例 # 使用示例
@@ -532,5 +457,5 @@ if __name__ == "__main__":
sys.exit(1) sys.exit(1)
generator = ShaderLayoutGenerator() generator = ShaderLayoutGenerator()
header_content = generator.generate_header(sys.argv[1], 'test_shader') header_content = generator.generate_header(sys.argv[1])
print(header_content) print(header_content)

View File

@@ -4,38 +4,24 @@ SDL3_GPU Slang Compiler - Shader Parser
着色器源码解析功能 着色器源码解析功能
""" """
import os
import json import json
import os
import re
import subprocess import subprocess
import tempfile import tempfile
import shutil
import re
from typing import List, Dict, Optional from typing import List, Dict, Optional
from compiler_cmd import make_cmd from compiler_cmd import make_cmd
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo, TargetFormat from global_vars import global_vars
from shader_layout_generator import ShaderLayoutGenerator from shader_layout_generator import ShaderLayoutGenerator
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
class ShaderParser: class ShaderParser:
def __init__(self, include_paths: List[str] = None): def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
self.include_paths = include_paths or []
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
"""解析Slang着色器源码提取资源信息""" """解析Slang着色器源码提取资源信息"""
# 合并包含路径 with open(global_vars.source_file, 'r', encoding='utf-8') as f:
all_include_paths = self.include_paths.copy()
if include_paths:
all_include_paths.extend(include_paths)
# 添加源文件所在目录作为包含路径
source_dir = os.path.dirname(os.path.abspath(source_path))
if source_dir not in all_include_paths:
all_include_paths.insert(0, source_dir)
print(f"Include paths: {all_include_paths}")
with open(source_path, 'r', encoding='utf-8') as f:
source = f.read() source = f.read()
# 首先分析源码找到所有入口点 # 首先分析源码找到所有入口点
@@ -49,7 +35,7 @@ class ShaderParser:
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})") print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
shader_info = self._compile_and_reflect_entry_point( shader_info = self._compile_and_reflect_entry_point(
source_path, output_path, entry_name, target, stage, all_include_paths, source entry_name, stage, source
) )
if shader_info: if shader_info:
@@ -161,9 +147,7 @@ class ShaderParser:
print(f"Total entry points found: {len(entry_points)}") print(f"Total entry points found: {len(entry_points)}")
return entry_points return entry_points
def _compile_and_reflect_entry_point(self, source_path: str, output_path: str, entry_name: str, target: TargetFormat, def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
stage: ShaderStage, include_paths: List[str],
source: str) -> Optional[ShaderInfo]:
"""为单个入口点进行完整编译和反射""" """为单个入口点进行完整编译和反射"""
# 创建临时输出文件 # 创建临时输出文件
@@ -173,7 +157,7 @@ class ShaderParser:
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection: with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
reflection_path = tmp_reflection.name reflection_path = tmp_reflection.name
cmd = make_cmd(source_path, target, stage, entry_name, include_paths, temp_output_path) cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
cmd.extend([ cmd.extend([
'-reflection-json', reflection_path, # 反射输出 '-reflection-json', reflection_path, # 反射输出
]) ])
@@ -197,9 +181,8 @@ class ShaderParser:
if stage == ShaderStage.VERTEX: if stage == ShaderStage.VERTEX:
# 生成顶点布局代码 # 生成顶点布局代码
generator = ShaderLayoutGenerator() generator = ShaderLayoutGenerator()
source_file_name = os.path.splitext(os.path.basename(source_path))[0] header_content = generator.generate_header(reflection_path)
header_content = generator.generate_header(reflection_path, f'{source_file_name}_shader') layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
layout_path_filename = os.path.join(output_path, f"{source_file_name}_layout.h")
with open(layout_path_filename, 'w', encoding='utf-8') as out_file: with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
out_file.write(header_content) out_file.write(header_content)

View File

@@ -5,8 +5,9 @@ SDL3_GPU Slang Compiler - Type Definitions
""" """
from enum import Enum from enum import Enum
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import List from typing import List, Optional, Dict
class ShaderStage(Enum): class ShaderStage(Enum):
VERTEX = "vertex" VERTEX = "vertex"
@@ -42,3 +43,73 @@ class ShaderInfo:
entry_point: str entry_point: str
resources: List[Resource] resources: List[Resource]
source_code: str source_code: str
# 数据模型类
@dataclass
class FieldType:
"""字段类型信息"""
kind: str # 'scalar', 'vector', 'matrix'
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
element_count: Optional[int] = None # for vector
row_count: Optional[int] = None # for matrix
column_count: Optional[int] = None # for matrix
@classmethod
def from_dict(cls, data: Dict) -> 'FieldType':
"""从字典创建FieldType对象"""
kind = data.get('kind')
if kind == 'vector':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
element_count=data['elementCount']
)
elif kind == 'scalar':
return cls(
kind=kind,
scalar_type=data['scalarType']
)
elif kind == 'matrix':
return cls(
kind=kind,
scalar_type=data['elementType']['scalarType'],
row_count=data['rowCount'],
column_count=data['columnCount']
)
return cls(kind='scalar', scalar_type='float32')
@dataclass
class VertexField:
"""顶点输入字段"""
name: str
type: FieldType
location: int
semantic: str
semantic_index: int
@dataclass
class UniformField:
"""Uniform缓冲区字段"""
name: str
type: FieldType
offset: int
size: int
@dataclass
class UniformBuffer:
"""Uniform缓冲区"""
name: str
binding: int
fields: List[UniformField] = field(default_factory=list)
@dataclass
class ShaderLayout:
"""着色器布局数据"""
vertex_fields: List[VertexField] = field(default_factory=list)
uniform_buffers: List[UniformBuffer] = field(default_factory=list)