封装全局参数
This commit is contained in:
11
compiler.py
11
compiler.py
@@ -17,15 +17,14 @@ from shader_types import ShaderInfo, TargetFormat
|
|||||||
|
|
||||||
|
|
||||||
class SDL3GPUSlangCompiler:
|
class SDL3GPUSlangCompiler:
|
||||||
def __init__(self, include_paths: List[str] = None):
|
def __init__(self):
|
||||||
self.include_paths = include_paths or []
|
self.parser = ShaderParser()
|
||||||
self.parser = ShaderParser(self.include_paths)
|
|
||||||
self.binding_manager = BindingManager()
|
self.binding_manager = BindingManager()
|
||||||
self.code_generator = CodeGenerator()
|
self.code_generator = CodeGenerator()
|
||||||
|
|
||||||
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||||||
"""解析Slang着色器源码,提取资源信息"""
|
"""解析Slang着色器源码,提取资源信息"""
|
||||||
return self.parser.parse_slang_shader(source_path, output_path, target, include_paths)
|
return self.parser.parse_slang_shader()
|
||||||
|
|
||||||
def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]:
|
def compile_shader(self, shader_info: ShaderInfo, target: TargetFormat) -> tuple[str, dict]:
|
||||||
"""编译着色器并返回二进制路径和绑定信息"""
|
"""编译着色器并返回二进制路径和绑定信息"""
|
||||||
@@ -48,7 +47,7 @@ class SDL3GPUSlangCompiler:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# 编译着色器
|
# 编译着色器
|
||||||
cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, self.include_paths, output_path)
|
cmd = make_cmd(tmp_path, target, shader_info.stage, shader_info.entry_point, output_path)
|
||||||
print(f"Compiling shader with command: {' '.join(cmd)}")
|
print(f"Compiling shader with command: {' '.join(cmd)}")
|
||||||
|
|
||||||
subprocess.run(cmd, check=True)
|
subprocess.run(cmd, check=True)
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
"""
|
"""
|
||||||
SDL3_GPU Slang Compiler - Command Generation
|
SDL3_GPU Slang Compiler - Command Generation
|
||||||
"""
|
"""
|
||||||
|
from global_vars import global_vars
|
||||||
from shader_types import TargetFormat, ShaderStage
|
from shader_types import TargetFormat, ShaderStage
|
||||||
from slangc_finder import slangc_path
|
from slangc_finder import slangc_path
|
||||||
|
|
||||||
|
|
||||||
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, include_paths,
|
def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_point: str, output_path: str):
|
||||||
output_path: str):
|
|
||||||
"""生成编译命令"""
|
"""生成编译命令"""
|
||||||
target_flag = {
|
target_flag = {
|
||||||
TargetFormat.SPIRV: 'spirv',
|
TargetFormat.SPIRV: 'spirv',
|
||||||
@@ -31,7 +31,7 @@ def make_cmd(source_file: str, target: TargetFormat, stage: ShaderStage, entry_p
|
|||||||
'-stage', stage_flag,
|
'-stage', stage_flag,
|
||||||
]
|
]
|
||||||
# 添加包含路径
|
# 添加包含路径
|
||||||
for include_path in include_paths:
|
for include_path in global_vars.include_dirs:
|
||||||
cmd.extend(['-I', include_path])
|
cmd.extend(['-I', include_path])
|
||||||
|
|
||||||
if target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
if target in [TargetFormat.DXIL, TargetFormat.DXBC]:
|
||||||
|
|||||||
11
global_vars.py
Normal file
11
global_vars.py
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
from shader_types import *
|
||||||
|
|
||||||
|
class GlobalVars:
|
||||||
|
source_file = ''
|
||||||
|
source_file_name = ''
|
||||||
|
source_path = ''
|
||||||
|
output_dir = ''
|
||||||
|
target: TargetFormat
|
||||||
|
layout: ShaderLayout
|
||||||
|
|
||||||
|
global_vars = GlobalVars()
|
||||||
33
main.py
33
main.py
@@ -4,13 +4,13 @@ SDL3_GPU Slang Compiler - Command Line Interface
|
|||||||
命令行接口
|
命令行接口
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import tempfile
|
|
||||||
import json
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import os
|
||||||
|
|
||||||
from compiler import SDL3GPUSlangCompiler
|
from compiler import SDL3GPUSlangCompiler
|
||||||
from shader_types import TargetFormat
|
from shader_types import TargetFormat
|
||||||
|
from global_vars import *
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser(description='SDL3 GPU Slang Compiler')
|
parser = argparse.ArgumentParser(description='SDL3 GPU Slang Compiler')
|
||||||
@@ -19,23 +19,34 @@ def main():
|
|||||||
parser.add_argument('-t', '--target', choices=['spirv', 'dxil', 'dxbc', 'msl'],
|
parser.add_argument('-t', '--target', choices=['spirv', 'dxil', 'dxbc', 'msl'],
|
||||||
required=True, help='Target shader format')
|
required=True, help='Target shader format')
|
||||||
parser.add_argument('-o', '--output-dir', required=True, help='Output path for binding code')
|
parser.add_argument('-o', '--output-dir', required=True, help='Output path for binding code')
|
||||||
|
parser.add_argument('-i', '--include-dir', help='Include path for slang shader')
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
# 获取编译文件的绝对路径
|
# 获取编译文件的绝对路径
|
||||||
input_path = os.path.abspath(args.input)
|
input_path = os.path.abspath(args.input)
|
||||||
# 获取不包含拓展名的源文件名
|
|
||||||
source_file_name = os.path.splitext(os.path.basename(input_path))[0]
|
global_vars.source_file = input_path
|
||||||
|
global_vars.source_file_name = os.path.splitext(os.path.basename(input_path))[0]
|
||||||
|
global_vars.source_path = os.path.dirname(input_path)
|
||||||
|
global_vars.output_dir = os.path.abspath(args.output_dir)
|
||||||
|
global_vars.target = target = TargetFormat(args.target)
|
||||||
|
|
||||||
# 仅保留路径部分
|
# 仅保留路径部分
|
||||||
input_path = os.path.dirname(input_path)
|
include_dirs = [
|
||||||
|
global_vars.source_path,
|
||||||
|
]
|
||||||
|
if args.include_dir:
|
||||||
|
include_dirs.append(args.include_dir)
|
||||||
|
global_vars.include_dirs = include_dirs
|
||||||
|
|
||||||
# 创建编译器实例
|
# 创建编译器实例
|
||||||
compiler = SDL3GPUSlangCompiler([input_path])
|
compiler = SDL3GPUSlangCompiler()
|
||||||
target = TargetFormat(args.target)
|
|
||||||
|
|
||||||
# 解析着色器
|
# 解析着色器
|
||||||
print(f"**Parsing** {args.input}...")
|
print(f"**Parsing** {args.input}...")
|
||||||
shaders = compiler.parse_slang_shader(args.input, args.output_dir, target)
|
shaders = compiler.parse_slang_shader()
|
||||||
|
|
||||||
# 编译每个入口点
|
# 编译每个入口点
|
||||||
binding_infos = []
|
binding_infos = []
|
||||||
@@ -47,7 +58,7 @@ def main():
|
|||||||
binding_infos.append(binding_info)
|
binding_infos.append(binding_info)
|
||||||
|
|
||||||
binding_output_file_pathname = os.path.abspath(args.output_dir)
|
binding_output_file_pathname = os.path.abspath(args.output_dir)
|
||||||
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{source_file_name}.shader.h")
|
binding_output_file_pathname = os.path.join(binding_output_file_pathname, f"{global_vars.source_file_name}.shader.h")
|
||||||
|
|
||||||
# 生成绑定代码
|
# 生成绑定代码
|
||||||
print(f"\n**Generating** binding code to {binding_output_file_pathname}...")
|
print(f"\n**Generating** binding code to {binding_output_file_pathname}...")
|
||||||
|
|||||||
@@ -1,85 +1,12 @@
|
|||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import Dict, List, Any, Tuple, Optional, TextIO
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from enum import Enum
|
|
||||||
from contextlib import contextmanager
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
from code_generator import IndentManager
|
from code_generator import IndentManager
|
||||||
|
from global_vars import global_vars
|
||||||
|
from shader_types import *
|
||||||
# 数据模型类
|
|
||||||
@dataclass
|
|
||||||
class FieldType:
|
|
||||||
"""字段类型信息"""
|
|
||||||
kind: str # 'scalar', 'vector', 'matrix'
|
|
||||||
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
|
|
||||||
element_count: Optional[int] = None # for vector
|
|
||||||
row_count: Optional[int] = None # for matrix
|
|
||||||
column_count: Optional[int] = None # for matrix
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dict(cls, data: Dict) -> 'FieldType':
|
|
||||||
"""从字典创建FieldType对象"""
|
|
||||||
kind = data.get('kind')
|
|
||||||
|
|
||||||
if kind == 'vector':
|
|
||||||
return cls(
|
|
||||||
kind=kind,
|
|
||||||
scalar_type=data['elementType']['scalarType'],
|
|
||||||
element_count=data['elementCount']
|
|
||||||
)
|
|
||||||
elif kind == 'scalar':
|
|
||||||
return cls(
|
|
||||||
kind=kind,
|
|
||||||
scalar_type=data['scalarType']
|
|
||||||
)
|
|
||||||
elif kind == 'matrix':
|
|
||||||
return cls(
|
|
||||||
kind=kind,
|
|
||||||
scalar_type=data['elementType']['scalarType'],
|
|
||||||
row_count=data['rowCount'],
|
|
||||||
column_count=data['columnCount']
|
|
||||||
)
|
|
||||||
|
|
||||||
return cls(kind='scalar', scalar_type='float32')
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class VertexField:
|
|
||||||
"""顶点输入字段"""
|
|
||||||
name: str
|
|
||||||
type: FieldType
|
|
||||||
location: int
|
|
||||||
semantic: str
|
|
||||||
semantic_index: int
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UniformField:
|
|
||||||
"""Uniform缓冲区字段"""
|
|
||||||
name: str
|
|
||||||
type: FieldType
|
|
||||||
offset: int
|
|
||||||
size: int
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class UniformBuffer:
|
|
||||||
"""Uniform缓冲区"""
|
|
||||||
name: str
|
|
||||||
binding: int
|
|
||||||
fields: List[UniformField] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ShaderLayout:
|
|
||||||
"""着色器布局数据"""
|
|
||||||
vertex_fields: List[VertexField] = field(default_factory=list)
|
|
||||||
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# 数据解析器(保持不变)
|
# 数据解析器(保持不变)
|
||||||
@@ -194,7 +121,7 @@ class ShaderLayoutCodeGenerator:
|
|||||||
'float16': 2,
|
'float16': 2,
|
||||||
}
|
}
|
||||||
|
|
||||||
def generate(self, layout: ShaderLayout, source_file: str, namespace: str) -> str:
|
def generate(self) -> str:
|
||||||
"""生成完整的头文件内容"""
|
"""生成完整的头文件内容"""
|
||||||
output = io.StringIO()
|
output = io.StringIO()
|
||||||
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
|
writer = IndentManager(output, indent_char=' ') # 使用4个空格缩进
|
||||||
@@ -205,15 +132,13 @@ class ShaderLayoutCodeGenerator:
|
|||||||
writer.write("#include <SDL3/SDL_gpu.h>")
|
writer.write("#include <SDL3/SDL_gpu.h>")
|
||||||
writer.write("#include <cstdint> // For fixed-width integer types")
|
writer.write("#include <cstdint> // For fixed-width integer types")
|
||||||
writer.write()
|
writer.write()
|
||||||
writer.write(f"// Auto-generated from: {source_file}")
|
|
||||||
writer.write()
|
|
||||||
|
|
||||||
writer.write("// Auto-generated vertex structure")
|
writer.write("// Auto-generated vertex structure")
|
||||||
with writer.block(f'namespace {namespace}'):
|
with writer.block(f'namespace {global_vars.source_file_name}_shader'):
|
||||||
self._generate_vertex_struct(writer, layout.vertex_fields)
|
self._generate_vertex_struct(writer, global_vars.layout.vertex_fields)
|
||||||
self._generate_uniform_structs(writer, layout.uniform_buffers)
|
self._generate_uniform_structs(writer, global_vars.layout.uniform_buffers)
|
||||||
self._generate_vertex_attributes(writer, layout.vertex_fields)
|
self._generate_vertex_attributes(writer, global_vars.layout.vertex_fields)
|
||||||
self._generate_helper_functions(writer, layout.vertex_fields)
|
self._generate_helper_functions(writer, global_vars.layout.vertex_fields)
|
||||||
|
|
||||||
return output.getvalue()
|
return output.getvalue()
|
||||||
|
|
||||||
@@ -417,7 +342,7 @@ class ShaderLayoutCodeGenerator:
|
|||||||
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
def _generate_vertex_attributes(self, writer: IndentManager, vertex_fields: List[VertexField]) -> None:
|
||||||
"""生成顶点属性数组"""
|
"""生成顶点属性数组"""
|
||||||
writer.write("// Vertex attribute descriptions")
|
writer.write("// Vertex attribute descriptions")
|
||||||
writer.write(f"#define VERTEX_ATTRIBUTE_COUNT {len(vertex_fields)}")
|
writer.write(f"static constexpr uint32_t VERTEX_ATTRIBUTE_COUNT = {len(vertex_fields)};")
|
||||||
writer.write()
|
writer.write()
|
||||||
|
|
||||||
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
|
writer.write("static const SDL_GPUVertexAttribute vertex_attributes[] = {")
|
||||||
@@ -473,7 +398,7 @@ class ShaderLayoutCodeGenerator:
|
|||||||
with writer.block(" Uint32 vertex_count)", "}"):
|
with writer.block(" Uint32 vertex_count)", "}"):
|
||||||
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
|
with writer.block("SDL_GPUBufferCreateInfo buffer_info =", "};"):
|
||||||
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
|
writer.write(".usage = SDL_GPU_BUFFERUSAGE_VERTEX,")
|
||||||
writer.write(".size = static_cast<Uint32>(sizeof(vertex)) * vertex_count")
|
writer.write(".size = static_cast<Uint32>(sizeof(vertex_t)) * vertex_count")
|
||||||
|
|
||||||
writer.write()
|
writer.write()
|
||||||
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
|
writer.write("SDL_GPUBuffer* buffer = SDL_CreateGPUBuffer(device, &buffer_info);")
|
||||||
@@ -516,13 +441,13 @@ class ShaderLayoutGenerator:
|
|||||||
self.parser = ShaderLayoutParser()
|
self.parser = ShaderLayoutParser()
|
||||||
self.code_generator = ShaderLayoutCodeGenerator()
|
self.code_generator = ShaderLayoutCodeGenerator()
|
||||||
|
|
||||||
def generate_header(self, json_file: str, namespace: str) -> str:
|
def generate_header(self, json_file: str) -> str:
|
||||||
"""生成头文件内容"""
|
"""生成头文件内容"""
|
||||||
# 解析JSON到类对象
|
# 解析JSON到类对象
|
||||||
layout = self.parser.parse(json_file)
|
global_vars.layout = self.parser.parse(json_file)
|
||||||
|
|
||||||
# 基于类对象生成代码
|
# 基于类对象生成代码
|
||||||
return self.code_generator.generate(layout, json_file, namespace)
|
return self.code_generator.generate()
|
||||||
|
|
||||||
|
|
||||||
# 使用示例
|
# 使用示例
|
||||||
@@ -532,5 +457,5 @@ if __name__ == "__main__":
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
generator = ShaderLayoutGenerator()
|
generator = ShaderLayoutGenerator()
|
||||||
header_content = generator.generate_header(sys.argv[1], 'test_shader')
|
header_content = generator.generate_header(sys.argv[1])
|
||||||
print(header_content)
|
print(header_content)
|
||||||
|
|||||||
@@ -4,38 +4,24 @@ SDL3_GPU Slang Compiler - Shader Parser
|
|||||||
着色器源码解析功能
|
着色器源码解析功能
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import re
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
|
||||||
import re
|
|
||||||
from typing import List, Dict, Optional
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from compiler_cmd import make_cmd
|
from compiler_cmd import make_cmd
|
||||||
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo, TargetFormat
|
from global_vars import global_vars
|
||||||
from shader_layout_generator import ShaderLayoutGenerator
|
from shader_layout_generator import ShaderLayoutGenerator
|
||||||
|
from shader_types import ShaderStage, ResourceType, Resource, ShaderInfo
|
||||||
|
|
||||||
|
|
||||||
class ShaderParser:
|
class ShaderParser:
|
||||||
def __init__(self, include_paths: List[str] = None):
|
def parse_slang_shader(self) -> Dict[str, ShaderInfo]:
|
||||||
self.include_paths = include_paths or []
|
|
||||||
|
|
||||||
def parse_slang_shader(self, source_path: str, output_path: str, target: TargetFormat, include_paths: List[str] = None) -> Dict[str, ShaderInfo]:
|
|
||||||
"""解析Slang着色器源码,提取资源信息"""
|
"""解析Slang着色器源码,提取资源信息"""
|
||||||
|
|
||||||
# 合并包含路径
|
with open(global_vars.source_file, 'r', encoding='utf-8') as f:
|
||||||
all_include_paths = self.include_paths.copy()
|
|
||||||
if include_paths:
|
|
||||||
all_include_paths.extend(include_paths)
|
|
||||||
|
|
||||||
# 添加源文件所在目录作为包含路径
|
|
||||||
source_dir = os.path.dirname(os.path.abspath(source_path))
|
|
||||||
if source_dir not in all_include_paths:
|
|
||||||
all_include_paths.insert(0, source_dir)
|
|
||||||
|
|
||||||
print(f"Include paths: {all_include_paths}")
|
|
||||||
|
|
||||||
with open(source_path, 'r', encoding='utf-8') as f:
|
|
||||||
source = f.read()
|
source = f.read()
|
||||||
|
|
||||||
# 首先分析源码找到所有入口点
|
# 首先分析源码找到所有入口点
|
||||||
@@ -49,7 +35,7 @@ class ShaderParser:
|
|||||||
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
|
print(f"\nProcessing entry point: {entry_name} (stage: {stage.value})")
|
||||||
|
|
||||||
shader_info = self._compile_and_reflect_entry_point(
|
shader_info = self._compile_and_reflect_entry_point(
|
||||||
source_path, output_path, entry_name, target, stage, all_include_paths, source
|
entry_name, stage, source
|
||||||
)
|
)
|
||||||
|
|
||||||
if shader_info:
|
if shader_info:
|
||||||
@@ -161,9 +147,7 @@ class ShaderParser:
|
|||||||
print(f"Total entry points found: {len(entry_points)}")
|
print(f"Total entry points found: {len(entry_points)}")
|
||||||
return entry_points
|
return entry_points
|
||||||
|
|
||||||
def _compile_and_reflect_entry_point(self, source_path: str, output_path: str, entry_name: str, target: TargetFormat,
|
def _compile_and_reflect_entry_point(self, entry_name: str, stage: ShaderStage, source: str) -> Optional[ShaderInfo]:
|
||||||
stage: ShaderStage, include_paths: List[str],
|
|
||||||
source: str) -> Optional[ShaderInfo]:
|
|
||||||
"""为单个入口点进行完整编译和反射"""
|
"""为单个入口点进行完整编译和反射"""
|
||||||
|
|
||||||
# 创建临时输出文件
|
# 创建临时输出文件
|
||||||
@@ -173,7 +157,7 @@ class ShaderParser:
|
|||||||
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
|
with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as tmp_reflection:
|
||||||
reflection_path = tmp_reflection.name
|
reflection_path = tmp_reflection.name
|
||||||
|
|
||||||
cmd = make_cmd(source_path, target, stage, entry_name, include_paths, temp_output_path)
|
cmd = make_cmd(global_vars.source_file, global_vars.target, stage, entry_name, temp_output_path)
|
||||||
cmd.extend([
|
cmd.extend([
|
||||||
'-reflection-json', reflection_path, # 反射输出
|
'-reflection-json', reflection_path, # 反射输出
|
||||||
])
|
])
|
||||||
@@ -197,9 +181,8 @@ class ShaderParser:
|
|||||||
if stage == ShaderStage.VERTEX:
|
if stage == ShaderStage.VERTEX:
|
||||||
# 生成顶点布局代码
|
# 生成顶点布局代码
|
||||||
generator = ShaderLayoutGenerator()
|
generator = ShaderLayoutGenerator()
|
||||||
source_file_name = os.path.splitext(os.path.basename(source_path))[0]
|
header_content = generator.generate_header(reflection_path)
|
||||||
header_content = generator.generate_header(reflection_path, f'{source_file_name}_shader')
|
layout_path_filename = os.path.join(global_vars.output_dir, f"{global_vars.source_file_name}_layout.h")
|
||||||
layout_path_filename = os.path.join(output_path, f"{source_file_name}_layout.h")
|
|
||||||
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
with open(layout_path_filename, 'w', encoding='utf-8') as out_file:
|
||||||
out_file.write(header_content)
|
out_file.write(header_content)
|
||||||
|
|
||||||
|
|||||||
@@ -5,8 +5,9 @@ SDL3_GPU Slang Compiler - Type Definitions
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import List
|
from typing import List, Optional, Dict
|
||||||
|
|
||||||
|
|
||||||
class ShaderStage(Enum):
|
class ShaderStage(Enum):
|
||||||
VERTEX = "vertex"
|
VERTEX = "vertex"
|
||||||
@@ -42,3 +43,73 @@ class ShaderInfo:
|
|||||||
entry_point: str
|
entry_point: str
|
||||||
resources: List[Resource]
|
resources: List[Resource]
|
||||||
source_code: str
|
source_code: str
|
||||||
|
|
||||||
|
# 数据模型类
|
||||||
|
@dataclass
|
||||||
|
class FieldType:
|
||||||
|
"""字段类型信息"""
|
||||||
|
kind: str # 'scalar', 'vector', 'matrix'
|
||||||
|
scalar_type: Optional[str] = None # 'int32', 'uint32', 'float32', 'int8', 'uint8', 'int16', 'uint16', 'float16'
|
||||||
|
element_count: Optional[int] = None # for vector
|
||||||
|
row_count: Optional[int] = None # for matrix
|
||||||
|
column_count: Optional[int] = None # for matrix
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_dict(cls, data: Dict) -> 'FieldType':
|
||||||
|
"""从字典创建FieldType对象"""
|
||||||
|
kind = data.get('kind')
|
||||||
|
|
||||||
|
if kind == 'vector':
|
||||||
|
return cls(
|
||||||
|
kind=kind,
|
||||||
|
scalar_type=data['elementType']['scalarType'],
|
||||||
|
element_count=data['elementCount']
|
||||||
|
)
|
||||||
|
elif kind == 'scalar':
|
||||||
|
return cls(
|
||||||
|
kind=kind,
|
||||||
|
scalar_type=data['scalarType']
|
||||||
|
)
|
||||||
|
elif kind == 'matrix':
|
||||||
|
return cls(
|
||||||
|
kind=kind,
|
||||||
|
scalar_type=data['elementType']['scalarType'],
|
||||||
|
row_count=data['rowCount'],
|
||||||
|
column_count=data['columnCount']
|
||||||
|
)
|
||||||
|
|
||||||
|
return cls(kind='scalar', scalar_type='float32')
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VertexField:
|
||||||
|
"""顶点输入字段"""
|
||||||
|
name: str
|
||||||
|
type: FieldType
|
||||||
|
location: int
|
||||||
|
semantic: str
|
||||||
|
semantic_index: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UniformField:
|
||||||
|
"""Uniform缓冲区字段"""
|
||||||
|
name: str
|
||||||
|
type: FieldType
|
||||||
|
offset: int
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class UniformBuffer:
|
||||||
|
"""Uniform缓冲区"""
|
||||||
|
name: str
|
||||||
|
binding: int
|
||||||
|
fields: List[UniformField] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ShaderLayout:
|
||||||
|
"""着色器布局数据"""
|
||||||
|
vertex_fields: List[VertexField] = field(default_factory=list)
|
||||||
|
uniform_buffers: List[UniformBuffer] = field(default_factory=list)
|
||||||
|
|||||||
Reference in New Issue
Block a user