封装全局参数
This commit is contained in:
11
compiler.py
11
compiler.py
@@ -17,15 +17,14 @@ from shader_types import ShaderInfo, TargetFormat
|
||||
|
||||
|
||||
class SDL3GPUSlangCompiler:
|
||||
def __init__(self, include_paths: List[str] = None):
|
||||
self.include_paths = include_paths or []
|
||||
self.parser = ShaderParser(self.include_paths)
|
||||
def __init__(self):
|
||||
self.parser = ShaderParser()
|
||||
self.binding_manager = BindingManager()
|
||||
self.code_generator = CodeGenerator()
|
||||
|
||||
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
||||
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||||
"""解析Slang着色器源码,提取资源信息"""
|
||||
return self.parser.parse_slang_shader(source_path, output_path, target, include_paths)
|
||||
return self.parser.parse_slang_shader()
|
||||
|
||||
def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]:
|
||||
"""编译着色器并返回二进制路径和绑定信息"""
|
||||
@@ -48,7 +47,7 @@ class SDL3GPUSlangCompiler:
|
||||
|
||||
try:
|
||||
# 编译着色器
|
||||
cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, self.include_paths, output_path)
|
||||
cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, output_path)
|
||||
print(f"Compiling shader with command: {' '.join(cmd)}")
|
||||
|
||||
subprocess.run(cmd, check=True)
|
||||
|
||||
@@ -2,12 +2,12 @@
|
||||
"""
|
||||
SDL3_GPU Slang Compiler - Command Generation
|
||||
"""
|
||||
from global_vars import global_vars
|
||||
from shader_types import TargetFormat, ShaderStage
|
||||
from slangc_finder import slangc_path
|
||||
|
||||
|
||||
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, include_paths,
|
||||
output_path: str):
|
||||
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, output_path: str):
|
||||
"""生成编译命令"""
|
||||
target_flag = {
|
||||
TargetFormat.SPIRV: 'spirv',
|
||||
@@ -31,7 +31,7 @@ def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_p
|
||||
'-stage', stage_flag,
|
||||
]
|
||||
# 添加包含路径
|
||||
for include_path in include_paths:
|
||||
for include_path in global_vars.include_dirs:
|
||||
cmd.extend(['-I', include_path])
|
||||
|
||||
if target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||||
|
||||
11
global_vars.py
Normal file
11
global_vars.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from shader_types import *
|
||||
|
||||
class GlobalVars:
|
||||
source_file = ''
|
||||
source_file_name = ''
|
||||
source_path = ''
|
||||
output_dir = ''
|
||||
target: TargetFormat
|
||||
layout: ShaderLayout
|
||||
|
||||
global_vars = GlobalVars()
|
||||
33
main.py
33
main.py
@@ -4,13 +4,13 @@ SDL3_GPU Slang Compiler - Command Line Interface
|
||||
命令行接口
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import json
|
||||
import argparse
|
||||
import os
|
||||
|
||||
from compiler import SDL3GPUSlangCompiler
|
||||
from shader_types import TargetFormat
|
||||
from global_vars import *
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='SDL3 GPU Slang Compiler')
|
||||
@@ -19,23 +19,34 @@ def main():
|
||||
parser.add_argument('-t', '--target', choices=['spirv', 'dxil', 'dxbc', 'msl'],
|
||||
required=True, help='Target shader format')
|
||||
parser.add_argument('-o', '--output-dir', required=True, help='Output path for binding code')
|
||||
parser.add_argument('-i', '--include-dir', help='Include path for slang shader')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# 获取编译文件的绝对路径
|
||||
input_path = os.path.abspath(args.input)
|
||||
# 获取不包含拓展名的源文件名
|
||||
source_file_name = os.path.splitext(os.path.basename(input_path))[0]
|
||||
|
||||
global_vars.source_file = input_path
|
||||
global_vars.source_file_name = os.path.splitext(os.path.basename(input_path))[0]
|
||||
global_vars.source_path = os.path.dirname(input_path)
|
||||
global_vars.output_dir = os.path.abspath(args.output_dir)
|
||||
global_vars.target = target = TargetFormat(args.target)
|
||||
|
||||
# 仅保留路径部分
|
||||
input_path = os.path.dirname(input_path)
|
||||
include_dirs = [
|
||||
global_vars.source_path,
|
||||
]
|
||||
if args.include_dir:
|
||||
include_dirs.append(args.include_dir)
|
||||
global_vars.include_dirs = include_dirs
|
||||
|
||||
# 创建编译器实例
|
||||
compiler = SDL3GPUSlangCompiler([input_path])
|
||||
target = TargetFormat(args.target)
|
||||
compiler = SDL3GPUSlangCompiler()
|
||||
|
||||
|
||||
# 解析着色器
|
||||
print(f"**Parsing** {args.input}...")
|
||||
shaders = compiler.parse_slang_shader(args.input, args.output_dir, target)
|
||||
shaders = compiler.parse_slang_shader()
|
||||
|
||||
# 编译每个入口点
|
||||
binding_infos = []
|
||||
@@ -47,7 +58,7 @@ def main():
|
||||
binding_infos.append(binding_info)
|
||||
|
||||
binding_output_file_pathname = os.path.abspath(args.output_dir)
|
||||
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{source_file_name}.shader.h")
|
||||
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{global_vars.source_file_name}.shader.h")
|
||||
|
||||
# 生成绑定代码
|
||||
print(f"\n**Generating** binding code to {binding_output_file_pathname}...")
|
||||
|
||||
@@ -1,85 +1,12 @@
|
||||
#!/usr/bin/env python3
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from typing import Dict, List, Any, Tuple, Optional, TextIO
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from contextlib import contextmanager
|
||||
import io
|
||||
import json
|
||||
import sys
|
||||
from typing import Tuple
|
||||
|
||||
from code_generator import IndentManager
|
||||
|
||||
|
||||
# 数据模型类
|
||||
@dataclass
|
||||
class FieldType:
|
||||
"""字段类型信息"""
|
||||
kind: str # 'scalar', 'vector', 'matrix'
|
||||
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
|
||||
element_count: Optional[int] = None # for vector
|
||||
row_count: Optional[int] = None # for matrix
|
||||
column_count: Optional[int] = None # for matrix
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'FieldType':
|
||||
"""从字典创建FieldType对象"""
|
||||
kind = data.get('kind')
|
||||
|
||||
if kind == 'vector':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
element_count=data['elementCount']
|
||||
)
|
||||
elif kind == 'scalar':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['scalarType']
|
||||
)
|
||||
elif kind == 'matrix':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
row_count=data['rowCount'],
|
||||
column_count=data['columnCount']
|
||||
)
|
||||
|
||||
return cls(kind='scalar', scalar_type='float32')
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexField:
|
||||
"""顶点输入字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
location: int
|
||||
semantic: str
|
||||
semantic_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniformField:
|
||||
"""Uniform缓冲区字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniformBuffer:
|
||||
"""Uniform缓冲区"""
|
||||
name: str
|
||||
binding: int
|
||||
fields: List[UniformField] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShaderLayout:
|
||||
"""着色器布局数据"""
|
||||
vertex_fields: List[VertexField] = field(default_factory=list)
|
||||
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
||||
from global_vars import global_vars
|
||||
from shader_types import *
|
||||
|
||||
|
||||
# 数据解析器(保持不变)
|
||||
@@ -194,7 +121,7 @@ class ShaderLayoutCodeGenerator:
|
||||
'float16': 2,
|
||||
}
|
||||
|
||||
def generate(self, layout: ShaderLayout, source_file: str, namespace: str) -> str:
|
||||
def generate(self) -> str:
|
||||
"""生成完整的头文件内容"""
|
||||
output = io.StringIO()
|
||||
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
|
||||
@@ -205,15 +132,13 @@ class ShaderLayoutCodeGenerator:
|
||||
writer.write("#include <SDL3/SDL_gpu.h>")
|
||||
writer.write("#include <cstdint> // For fixed-width integer types")
|
||||
writer.write()
|
||||
writer.write(f"// Auto-generated from: {source_file}")
|
||||
writer.write()
|
||||
|
||||
writer.write("// Auto-generated vertex structure")
|
||||
with writer.block(f'namespace {namespace}'):
|
||||
self._generate_vertex_struct(writer, layout.vertex_fields)
|
||||
self._generate_uniform_structs(writer, layout.uniform_buffers)
|
||||
self._generate_vertex_attributes(writer, layout.vertex_fields)
|
||||
self._generate_helper_functions(writer, layout.vertex_fields)
|
||||
with writer.block(f'namespace {global_vars.source_file_name}_shader'):
|
||||
self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
|
||||
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
|
||||
self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
|
||||
self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
|
||||
|
||||
return output.getvalue()
|
||||
|
||||
@@ -417,7 +342,7 @@ class ShaderLayoutCodeGenerator:
|
||||
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||
"""生成顶点属性数组"""
|
||||
writer.write("// Vertex attribute descriptions")
|
||||
writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}")
|
||||
writer.write(f"static constexpr uint32_t VERTEX_ATTRIBUTE_COUNT = {len(vertex_fields)};")
|
||||
writer.write()
|
||||
|
||||
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
|
||||
@@ -473,7 +398,7 @@ class ShaderLayoutCodeGenerator:
|
||||
with writer.block(" Uint32 vertex_count)", "}"):
|
||||
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
|
||||
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
|
||||
writer.write(".size = static_cast<Uint32>(sizeof(vertex)) * vertex_count")
|
||||
writer.write(".size = static_cast<Uint32>(sizeof(vertex_t)) * vertex_count")
|
||||
|
||||
writer.write()
|
||||
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
|
||||
@@ -516,13 +441,13 @@ class ShaderLayoutGenerator:
|
||||
self.parser = ShaderLayoutParser()
|
||||
self.code_generator = ShaderLayoutCodeGenerator()
|
||||
|
||||
def generate_header(self, json_file: str, namespace: str) -> str:
|
||||
def generate_header(self, json_file: str) -> str:
|
||||
"""生成头文件内容"""
|
||||
# 解析JSON到类对象
|
||||
layout = self.parser.parse(json_file)
|
||||
global_vars.layout = self.parser.parse(json_file)
|
||||
|
||||
# 基于类对象生成代码
|
||||
return self.code_generator.generate(layout, json_file, namespace)
|
||||
return self.code_generator.generate()
|
||||
|
||||
|
||||
# 使用示例
|
||||
@@ -532,5 +457,5 @@ if __name__ == "__main__":
|
||||
sys.exit(1)
|
||||
|
||||
generator = ShaderLayoutGenerator()
|
||||
header_content = generator.generate_header(sys.argv[1], 'test_shader')
|
||||
header_content = generator.generate_header(sys.argv[1])
|
||||
print(header_content)
|
||||
|
||||
@@ -4,38 +4,24 @@ SDL3_GPU Slang Compiler - Shader Parser
|
||||
着色器源码解析功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
import tempfile
|
||||
import shutil
|
||||
import re
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from compiler_cmd import make_cmd
|
||||
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo, TargetFormat
|
||||
from global_vars import global_vars
|
||||
from shader_layout_generator import ShaderLayoutGenerator
|
||||
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
|
||||
|
||||
|
||||
class ShaderParser:
|
||||
def __init__(self, include_paths: List[str] = None):
|
||||
self.include_paths = include_paths or []
|
||||
|
||||
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
||||
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||||
"""解析Slang着色器源码,提取资源信息"""
|
||||
|
||||
# 合并包含路径
|
||||
all_include_paths = self.include_paths.copy()
|
||||
if include_paths:
|
||||
all_include_paths.extend(include_paths)
|
||||
|
||||
# 添加源文件所在目录作为包含路径
|
||||
source_dir = os.path.dirname(os.path.abspath(source_path))
|
||||
if source_dir not in all_include_paths:
|
||||
all_include_paths.insert(0, source_dir)
|
||||
|
||||
print(f"Include paths: {all_include_paths}")
|
||||
|
||||
with open(source_path, 'r', encoding='utf-8') as f:
|
||||
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
|
||||
source = f.read()
|
||||
|
||||
# 首先分析源码找到所有入口点
|
||||
@@ -49,7 +35,7 @@ class ShaderParser:
|
||||
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
|
||||
|
||||
shader_info = self._compile_and_reflect_entry_point(
|
||||
source_path, output_path, entry_name, target, stage, all_include_paths, source
|
||||
entry_name, stage, source
|
||||
)
|
||||
|
||||
if shader_info:
|
||||
@@ -161,9 +147,7 @@ class ShaderParser:
|
||||
print(f"Total entry points found: {len(entry_points)}")
|
||||
return entry_points
|
||||
|
||||
def _compile_and_reflect_entry_point(self, source_path: str, output_path: str, entry_name: str, target: TargetFormat,
|
||||
stage: ShaderStage, include_paths: List[str],
|
||||
source: str) -> Optional[ShaderInfo]:
|
||||
def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
|
||||
"""为单个入口点进行完整编译和反射"""
|
||||
|
||||
# 创建临时输出文件
|
||||
@@ -173,7 +157,7 @@ class ShaderParser:
|
||||
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
|
||||
reflection_path = tmp_reflection.name
|
||||
|
||||
cmd = make_cmd(source_path, target, stage, entry_name, include_paths, temp_output_path)
|
||||
cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
|
||||
cmd.extend([
|
||||
'-reflection-json', reflection_path, # 反射输出
|
||||
])
|
||||
@@ -197,9 +181,8 @@ class ShaderParser:
|
||||
if stage == ShaderStage.VERTEX:
|
||||
# 生成顶点布局代码
|
||||
generator = ShaderLayoutGenerator()
|
||||
source_file_name = os.path.splitext(os.path.basename(source_path))[0]
|
||||
header_content = generator.generate_header(reflection_path, f'{source_file_name}_shader')
|
||||
layout_path_filename = os.path.join(output_path, f"{source_file_name}_layout.h")
|
||||
header_content = generator.generate_header(reflection_path)
|
||||
layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
|
||||
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
||||
out_file.write(header_content)
|
||||
|
||||
|
||||
@@ -5,8 +5,9 @@ SDL3_GPU Slang Compiler - Type Definitions
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from typing import List
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Dict
|
||||
|
||||
|
||||
class ShaderStage(Enum):
|
||||
VERTEX = "vertex"
|
||||
@@ -42,3 +43,73 @@ class ShaderInfo:
|
||||
entry_point: str
|
||||
resources: List[Resource]
|
||||
source_code: str
|
||||
|
||||
# 数据模型类
|
||||
@dataclass
|
||||
class FieldType:
|
||||
"""字段类型信息"""
|
||||
kind: str # 'scalar', 'vector', 'matrix'
|
||||
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
|
||||
element_count: Optional[int] = None # for vector
|
||||
row_count: Optional[int] = None # for matrix
|
||||
column_count: Optional[int] = None # for matrix
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Dict) -> 'FieldType':
|
||||
"""从字典创建FieldType对象"""
|
||||
kind = data.get('kind')
|
||||
|
||||
if kind == 'vector':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
element_count=data['elementCount']
|
||||
)
|
||||
elif kind == 'scalar':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['scalarType']
|
||||
)
|
||||
elif kind == 'matrix':
|
||||
return cls(
|
||||
kind=kind,
|
||||
scalar_type=data['elementType']['scalarType'],
|
||||
row_count=data['rowCount'],
|
||||
column_count=data['columnCount']
|
||||
)
|
||||
|
||||
return cls(kind='scalar', scalar_type='float32')
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexField:
|
||||
"""顶点输入字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
location: int
|
||||
semantic: str
|
||||
semantic_index: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniformField:
|
||||
"""Uniform缓冲区字段"""
|
||||
name: str
|
||||
type: FieldType
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
@dataclass
|
||||
class UniformBuffer:
|
||||
"""Uniform缓冲区"""
|
||||
name: str
|
||||
binding: int
|
||||
fields: List[UniformField] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ShaderLayout:
|
||||
"""着色器布局数据"""
|
||||
vertex_fields: List[VertexField] = field(default_factory=list)
|
||||
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
||||
|
||||
Reference in New Issue
Block a user