diff --git a/cmake/shader_compile.cmake b/cmake/shader_compile.cmake index 95aa390..44ada55 100644 --- a/cmake/shader_compile.cmake +++ b/cmake/shader_compile.cmake @@ -1,7 +1,8 @@ # ============================================================================ # MIRAI 着色器编译 CMake 模块 # ============================================================================ -# 提供自动化着色器编译和 C++ 代码生成功能 +# 提供自动化着色器编译和代码生成功能 +# 自动检测 [shader("xxx")] 属性并编译所有入口点 # # 主要函数: # add_shader_library() - 创建着色器库目标 @@ -74,6 +75,7 @@ endfunction() # ============================================================================ # # 将着色器编译和代码生成附加到现有的编译目标上 +# 自动检测所有 [shader("xxx")] 入口点并编译 # # 用法: # # 方案1: 自动搜索着色器,指定引用路径 @@ -112,7 +114,7 @@ function(add_shader_library) # 验证必需参数 if(NOT SHADER_TARGET) - message(FATAL_ERROR "add_shader_library: TARGET is required (existing compilation target)") + message(FATAL_ERROR "add_shader_library: TARGET is required") endif() # 验证目标存在 @@ -168,8 +170,6 @@ function(add_shader_library) # 收集生成的文件 set(ALL_GENERATED_HEADERS "") - set(ALL_SPIRV_FILES "") - set(ALL_CUSTOM_TARGETS "") # 处理每个着色器文件 foreach(SHADER_SOURCE ${ALL_SHADERS}) @@ -184,112 +184,54 @@ function(add_shader_library) endif() endif() - # 输出文件路径 - set(VERT_SPV "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.vert.spv") - set(FRAG_SPV "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.frag.spv") - set(VERT_REFLECT_JSON "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.vert.reflect.json") - set(FRAG_REFLECT_JSON "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.frag.reflect.json") + # 每个着色器文件的输出目录 + set(SPIRV_OUTPUT_DIR "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}") set(GENERATED_HPP "${SHADER_OUTPUT_DIR}/${SHADER_NAME_WE}_bindings.hpp") - # 编译顶点着色器 + # 步骤1: 编译着色器(自动检测所有入口点) add_custom_command( - OUTPUT "${VERT_SPV}" + OUTPUT "${SPIRV_OUTPUT_DIR}/.compiled" + COMMAND ${CMAKE_COMMAND} -E make_directory "${SPIRV_OUTPUT_DIR}" COMMAND $ "${SHADER_ABS}" - -o "${VERT_SPV}" - -e vertexMain - -s vertex + -o "${SPIRV_OUTPUT_DIR}" ${INCLUDE_ARGS} ${DEFINE_ARGS} + COMMAND ${CMAKE_COMMAND} -E touch "${SPIRV_OUTPUT_DIR}/.compiled" DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET} - COMMENT "Compiling vertex shader: ${SHADER_NAME_WE}" + COMMENT "Compiling shader: ${SHADER_NAME_WE}" VERBATIM COMMAND_EXPAND_LISTS ) - # 编译片段着色器 - add_custom_command( - OUTPUT "${FRAG_SPV}" - COMMAND $ - "${SHADER_ABS}" - -o "${FRAG_SPV}" - -e fragmentMain - -s fragment - ${INCLUDE_ARGS} - ${DEFINE_ARGS} - DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET} - COMMENT "Compiling fragment shader: ${SHADER_NAME_WE}" - VERBATIM COMMAND_EXPAND_LISTS - ) - - # 生成顶点反射数据 - add_custom_command( - OUTPUT "${VERT_REFLECT_JSON}" - COMMAND $ - "${SHADER_ABS}" - -e vertexMain - -s vertex - -r "${VERT_REFLECT_JSON}" - ${INCLUDE_ARGS} - ${DEFINE_ARGS} - DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET} - COMMENT "Generating vertex reflection: ${SHADER_NAME_WE}" - VERBATIM COMMAND_EXPAND_LISTS - ) - - # 生成片段反射数据 - add_custom_command( - OUTPUT "${FRAG_REFLECT_JSON}" - COMMAND $ - "${SHADER_ABS}" - -e fragmentMain - -s fragment - -r "${FRAG_REFLECT_JSON}" - ${INCLUDE_ARGS} - ${DEFINE_ARGS} - DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET} - COMMENT "Generating fragment reflection: ${SHADER_NAME_WE}" - VERBATIM COMMAND_EXPAND_LISTS - ) - - # 生成 C++ 绑定头文件 + # 步骤2: 生成绑定头文件(Python 脚本自己查找 spv 文件) add_custom_command( OUTPUT "${GENERATED_HPP}" - COMMAND "${PYTHON_EXECUTABLE}" "${MIRAI_SHADER_GENERATOR_SCRIPT}" - --spirv "${VERT_SPV}" "${FRAG_SPV}" - --reflection "${VERT_REFLECT_JSON}" "${FRAG_REFLECT_JSON}" + COMMAND ${PYTHON_EXECUTABLE} "${MIRAI_SHADER_GENERATOR_SCRIPT}" + --dir "${SPIRV_OUTPUT_DIR}" --output "${GENERATED_HPP}" --name "${SHADER_NAME_WE}" --template-dir "${MIRAI_SHADER_TEMPLATE_DIR}" - DEPENDS "${VERT_SPV}" "${FRAG_SPV}" "${VERT_REFLECT_JSON}" "${FRAG_REFLECT_JSON}" - "${MIRAI_SHADER_GENERATOR_SCRIPT}" + DEPENDS "${SPIRV_OUTPUT_DIR}/.compiled" "${MIRAI_SHADER_GENERATOR_SCRIPT}" COMMENT "Generating bindings: ${SHADER_NAME_WE}" VERBATIM ) list(APPEND ALL_GENERATED_HEADERS "${GENERATED_HPP}") - list(APPEND ALL_SPIRV_FILES "${VERT_SPV}" "${FRAG_SPV}") endforeach() - # ======================================================================== # 附加到目标:添加生成目录到 include 路径 - # ======================================================================== target_include_directories(${SHADER_TARGET} PUBLIC $ ) - # ======================================================================== # 添加编译依赖 - # ======================================================================== if(ALL_GENERATED_HEADERS) add_custom_target(${SHADER_TARGET}_shaders DEPENDS ${ALL_GENERATED_HEADERS}) add_dependencies(${SHADER_TARGET} ${SHADER_TARGET}_shaders) endif() - # ======================================================================== - # 导出变量供其他 CMake 代码使用 - # ======================================================================== + # 导出变量 set(${SHADER_TARGET}_SHADER_HEADERS ${ALL_GENERATED_HEADERS} PARENT_SCOPE) - set(${SHADER_TARGET}_SHADER_SPIRV ${ALL_SPIRV_FILES} PARENT_SCOPE) set(${SHADER_TARGET}_SHADER_OUTPUT_DIR ${SHADER_OUTPUT_DIR} PARENT_SCOPE) message(STATUS "Added shaders to target: ${SHADER_TARGET}") @@ -297,7 +239,4 @@ function(add_shader_library) if(ALL_SHADERS) message(STATUS " Shaders: ${ALL_SHADERS}") endif() - if(SHADER_REF_PATHS) - message(STATUS " Ref paths: ${SHADER_REF_PATHS}") - endif() -endfunction() \ No newline at end of file +endfunction() diff --git a/src/shader/shaders/basic.slang b/src/shader/shaders/basic.slang index daaa7b1..d2374cf 100644 --- a/src/shader/shaders/basic.slang +++ b/src/shader/shaders/basic.slang @@ -1,7 +1,8 @@ /** * @file basic.slang * @brief 基础 UI 着色器示例 - * @description 展示 GUI 渲染中常用的 Uniform Buffer、纹理采样、顶点属性和 Push Constants + * @description 展示 GUI 渲染中常用的 Uniform Buffer、纹理采样、顶点属性、Push Constants + * 以及多入口点支持(通过 [shader("xxx")] 属性自动检测) */ // ============================================================================ @@ -38,6 +39,18 @@ Texture2D uiTexture; [[vk::binding(1, 1)]] SamplerState linearSampler; +/** + * @brief 计算着色器用的输入纹理 + */ +[[vk::binding(0, 2)]] +RWTexture2D computeInput; + +/** + * @brief 计算着色器用的输出纹理 + */ +[[vk::binding(1, 2)]] +RWTexture2D computeOutput; + // ============================================================================ // Push Constants // ============================================================================ @@ -55,6 +68,18 @@ struct PushConstants { [[vk::push_constant]] PushConstants pc; +/** + * @brief 计算着色器的 Push Constants + */ +struct ComputePushConstants { + uint2 workgroupSize; // 工作组大小 + uint2 workgroupCount; // 工作组数量 + float intensity; // 强度参数 +}; + +[[vk::push_constant]] +ComputePushConstants computePc; + // ============================================================================ // 顶点输入输出结构 // ============================================================================ @@ -114,4 +139,21 @@ float4 fragmentMain(VertexOutput input) : SV_Target { color.a *= pc.opacity; return color; -} \ No newline at end of file +} + +// ============================================================================ +// 计算着色器 +// ============================================================================ + +/** + * @brief 简单的图像处理计算着色器 + * @description 演示多入口点支持 - 自动检测 [shader("compute")] + */ +[shader("compute")] +void computeMain(uint3 globalID : SV_DispatchThreadID) { + // 简单的图像处理:应用强度 + float4 color = computeInput[globalID.xy]; + color.rgb *= computePc.intensity; + color.a = 1.0f; + computeOutput[globalID.xy] = color; +} diff --git a/tools/generate_shader_bindings.py b/tools/generate_shader_bindings.py index 8da97e4..9937052 100644 --- a/tools/generate_shader_bindings.py +++ b/tools/generate_shader_bindings.py @@ -6,8 +6,7 @@ MIRAI 着色器绑定代码生成器 Usage: python generate_shader_bindings.py \ - --spirv shader.vert.spv shader.frag.spv \ - --reflection shader.reflect.json \ + --dir ./shader_intermediate \ --output generated/shader_name_bindings.hpp \ --name ShaderName """ @@ -493,25 +492,16 @@ def main(): epilog=""" 示例: python generate_shader_bindings.py \\ - --spirv shader.vert.spv shader.frag.spv \\ - --reflection shader.vert.reflect.json shader.frag.reflect.json \\ - --output generated/shader_bindings.hpp \\ - --name MyShader + --dir ./shader_intermediate/basic \\ + --output generated/basic_bindings.hpp \\ + --name basic """ ) parser.add_argument( - '--spirv', '-s', - nargs='+', + '--dir', '-d', required=True, - help='SPIR-V 二进制文件 (.spv),可以指定多个文件' - ) - - parser.add_argument( - '--reflection', '-r', - nargs='+', - required=True, - help='反射 JSON 文件 (.json),可以指定多个文件(每个阶段一个)' + help='包含 SPIR-V 和反射文件的目录' ) parser.add_argument( @@ -550,45 +540,43 @@ def main(): print(f"Error: Template directory not found: {template_dir}") return 1 - # 解析反射数据(支持多个文件) - all_reflections = [] - reflection_paths = [] - found_any_reflection = False - for reflection_path_str in args.reflection: - reflection_path = Path(reflection_path_str) - if reflection_path.exists(): - found_any_reflection = True - if args.verbose: - print(f"Parsing reflection: {reflection_path}") - reflection = parse_reflection_json(reflection_path) - all_reflections.append(reflection) - reflection_paths.append(reflection_path) - else: - if args.verbose: - print(f"Warning: Reflection file not found: {reflection_path} (skipping)") + # 从目录中查找文件 + shader_dir = Path(args.dir) + if not shader_dir.exists(): + print(f"Info: Shader directory not found: {shader_dir}. No entry points, skipping.") + return 0 - # 如果没有找到任何反射文件,说明着色器没有入口函数,跳过生成 - if not found_any_reflection: + # 查找所有 spv 和 reflect.json 文件 + spirv_files = sorted(shader_dir.glob("*.spv")) + reflect_files = sorted(shader_dir.glob("*.reflect.json")) + + if args.verbose: + print(f"Found {len(spirv_files)} SPIR-V files") + print(f"Found {len(reflect_files)} reflection files") + + # 如果没有反射文件,跳过生成 + if not reflect_files: + print(f"Info: No reflection files found in {shader_dir}. Skipping binding generation.") + return 0 + + # 解析反射数据 + all_reflections = [] + for reflection_path in reflect_files: if args.verbose: - print(f"Info: No reflection files found for {args.name}. Shader has no entry points, skipping binding generation.") - return 0 # 正常退出,不生成文件 - else: - # 合并多个反射数据(取并集) - reflection = merge_reflections(all_reflections) + print(f"Parsing reflection: {reflection_path}") + reflection = parse_reflection_json(reflection_path) + all_reflections.append(reflection) + + # 合并多个反射数据(取并集) + reflection = merge_reflections(all_reflections) # 加载 SPIR-V 文件 stages: Dict[str, bytes] = {} - for spirv_file in args.spirv: - spirv_path = Path(spirv_file) - if not spirv_path.exists(): - if args.verbose: - print(f"Warning: SPIR-V file not found: {spirv_path} (skipping)") - continue - + for spirv_path in spirv_files: stage = detect_stage_from_filename(spirv_path.name) if stage is None: - print(f"Warning: Cannot detect shader stage from filename: {spirv_path.name}") - # 尝试使用文件名作为阶段名 + if args.verbose: + print(f"Warning: Cannot detect shader stage from filename: {spirv_path.name}") stage = spirv_path.stem.split('.')[-1] if '.' in spirv_path.stem else 'unknown' if args.verbose: @@ -623,4 +611,4 @@ def main(): if __name__ == '__main__': - exit(main()) \ No newline at end of file + exit(main()) diff --git a/tools/shader_compile/compiler.cpp b/tools/shader_compile/compiler.cpp index 1f00740..b441f2f 100644 --- a/tools/shader_compile/compiler.cpp +++ b/tools/shader_compile/compiler.cpp @@ -9,6 +9,7 @@ #include #include +#include #include namespace mirai::tools { @@ -63,6 +64,18 @@ static SlangStage to_slang_stage(shader_stage stage) { return SLANG_STAGE_NONE; } +std::optional slang_stage_to_shader_stage(SlangStage stage) { + switch (stage) { + case SLANG_STAGE_VERTEX: return shader_stage::vertex; + case SLANG_STAGE_FRAGMENT: return shader_stage::fragment; + case SLANG_STAGE_COMPUTE: return shader_stage::compute; + case SLANG_STAGE_GEOMETRY: return shader_stage::geometry; + case SLANG_STAGE_HULL: return shader_stage::tessellation_control; + case SLANG_STAGE_DOMAIN: return shader_stage::tessellation_evaluation; + default: return std::nullopt; + } +} + // ============================================================================ // Implementation Class // ============================================================================ @@ -578,6 +591,9 @@ struct shader_compiler::impl { target_desc.format = SLANG_SPIRV; target_desc.profile = global_session->findProfile("glsl_450"); + session_desc.targets = &target_desc; + session_desc.targetCount = 1; + std::vector include_path_strs; std::vector search_paths; for (const auto& path : options.include_paths) { @@ -614,26 +630,246 @@ struct shader_compiler::impl { return result; } - // 获取入口点列表 - 使用正确的 API + // 获取入口点列表 SlangInt32 entry_point_count = module->getDefinedEntryPointCount(); - for (SlangInt32 i = 0; i < entry_point_count; ++i) { - Slang::ComPtr entry_point; - if (SLANG_SUCCEEDED(module->getDefinedEntryPoint(i, entry_point.writeRef()))) { - // 获取入口点的反射信息来获取名称 - slang::IComponentType* entry_comp = entry_point.get(); - if (entry_comp) { - // 尝试通过 getLayout 获取名称 - slang::ProgramLayout* layout = entry_comp->getLayout(); - if (layout) { - SlangUInt count = layout->getEntryPointCount(); - if (count > 0) { - slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0); - if (ep_ref) { - const char* entry_name = ep_ref->getName(); - if (entry_name && entry_name[0] != '\0') { - result.push_back(entry_name); - } - } + + // 常见入口点名称模式 + const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain", + "vertexShader", "fragmentShader", "computeShader"}; + + const SlangStage stages[] = { + SLANG_STAGE_VERTEX, + SLANG_STAGE_FRAGMENT, + SLANG_STAGE_COMPUTE, + SLANG_STAGE_GEOMETRY, + SLANG_STAGE_HULL, + SLANG_STAGE_DOMAIN + }; + + // 尝试每个可能的入口点名称和阶段的组合 + for (const char* name : common_names) { + for (SlangStage stage : stages) { + Slang::ComPtr checked_entry; + SlangResult find_result = module->findAndCheckEntryPoint( + name, + stage, + checked_entry.writeRef(), + diagnostics_blob.writeRef() + ); + + if (SLANG_FAILED(find_result) || !checked_entry) { + continue; + } + + // 验证并编译这个入口点来获取名称 + std::vector comps = {module, checked_entry.get()}; + Slang::ComPtr composite; + + if (SLANG_FAILED(session->createCompositeComponentType( + comps.data(), + static_cast(comps.size()), + composite.writeRef(), + diagnostics_blob.writeRef() + ))) { + continue; + } + + Slang::ComPtr linked; + if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) { + continue; + } + + // 生成代码来触发 layout 生成 + Slang::ComPtr spirv_blob; + if (SLANG_FAILED(linked->getEntryPointCode( + 0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) { + continue; + } + + // 获取 layout + slang::ProgramLayout* layout = linked->getLayout(); + if (!layout || layout->getEntryPointCount() == 0) { + continue; + } + + slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0); + if (!ep_ref) { + continue; + } + + const char* entry_name = ep_ref->getName(); + if (entry_name && entry_name[0] != '\0') { + // 检查是否已经添加 + bool exists = false; + for (const auto& existing : result) { + if (existing == entry_name) { + exists = true; + break; + } + } + if (!exists) { + result.push_back(entry_name); + } + } + } + } + + return result; + } + + // 获取所有入口函数及其阶段 + std::vector get_entry_points_with_stages( + const std::string& source, + const std::string& filename, + const compile_options& options + ) { + std::vector result; + + if (!global_session) { + return result; + } + + // 创建 session + slang::SessionDesc session_desc = {}; + slang::TargetDesc target_desc = {}; + target_desc.format = SLANG_SPIRV; + target_desc.profile = global_session->findProfile("glsl_450"); + + session_desc.targets = &target_desc; + session_desc.targetCount = 1; + + std::vector include_path_strs; + std::vector search_paths; + for (const auto& path : options.include_paths) { + include_path_strs.push_back(path.string()); + search_paths.push_back(include_path_strs.back().c_str()); + } + session_desc.searchPaths = search_paths.data(); + session_desc.searchPathCount = static_cast(search_paths.size()); + + std::vector> macro_strs; + std::vector macros; + for (const auto& [name, value] : options.defines) { + macro_strs.emplace_back(name, value); + macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()}); + } + session_desc.preprocessorMacros = macros.data(); + session_desc.preprocessorMacroCount = static_cast(macros.size()); + + Slang::ComPtr session; + if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) { + return result; + } + + // 加载模块 + Slang::ComPtr diagnostics_blob; + slang::IModule* module = session->loadModuleFromSourceString( + filename.c_str(), + filename.c_str(), + source.c_str(), + diagnostics_blob.writeRef() + ); + + if (!module) { + return result; + } + + // 获取所有定义的入口点数量 + SlangInt32 entry_point_count = module->getDefinedEntryPointCount(); + + if (entry_point_count == 0) { + return result; + } + + // 常见入口点名称模式 + const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain", + "vertexShader", "fragmentShader", "computeShader"}; + + const SlangStage stages[] = { + SLANG_STAGE_VERTEX, + SLANG_STAGE_FRAGMENT, + SLANG_STAGE_COMPUTE, + SLANG_STAGE_GEOMETRY, + SLANG_STAGE_HULL, + SLANG_STAGE_DOMAIN + }; + + // 尝试每个可能的入口点名称和阶段的组合 + for (const char* name : common_names) { + for (SlangStage stage : stages) { + // 如果已经找到所有入口点,提前退出 + if (static_cast(result.size()) >= entry_point_count) { + break; + } + + Slang::ComPtr checked_entry; + SlangResult find_result = module->findAndCheckEntryPoint( + name, + stage, + checked_entry.writeRef(), + diagnostics_blob.writeRef() + ); + + if (SLANG_FAILED(find_result) || !checked_entry) { + continue; + } + + // 验证并编译这个入口点来获取名称 + std::vector comps = {module, checked_entry.get()}; + Slang::ComPtr composite; + + if (SLANG_FAILED(session->createCompositeComponentType( + comps.data(), + static_cast(comps.size()), + composite.writeRef(), + diagnostics_blob.writeRef() + ))) { + continue; + } + + Slang::ComPtr linked; + if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) { + continue; + } + + // 生成代码来触发 layout 生成 + Slang::ComPtr spirv_blob; + if (SLANG_FAILED(linked->getEntryPointCode( + 0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) { + continue; + } + + // 获取 layout + slang::ProgramLayout* layout = linked->getLayout(); + if (!layout || layout->getEntryPointCount() == 0) { + continue; + } + + slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0); + if (!ep_ref) { + continue; + } + + const char* entry_name = ep_ref->getName(); + SlangStage slang_stage = ep_ref->getStage(); + + if (entry_name && entry_name[0] != '\0') { + // 检查是否已经添加 + bool exists = false; + for (const auto& existing : result) { + if (existing.name == entry_name) { + exists = true; + break; + } + } + + if (!exists) { + auto shader_stage_opt = slang_stage_to_shader_stage(slang_stage); + if (shader_stage_opt) { + entry_point_info info; + info.name = entry_name; + info.stage = *shader_stage_opt; + result.push_back(info); } } } @@ -642,6 +878,31 @@ struct shader_compiler::impl { return result; } + + // 编译所有入口点 + std::vector compile_all_entries( + const std::string& source, + const std::string& filename, + const compile_options& options + ) { + std::vector results; + + // 获取所有入口点及其阶段 + auto entry_points = get_entry_points_with_stages(source, filename, options); + + for (const auto& ep_info : entry_points) { + compile_options ep_options = options; + ep_options.entry_point = ep_info.name; + ep_options.stage = ep_info.stage; + + compile_result result = compile(source, filename, ep_options); + result.entry_point = ep_info.name; + result.stage = ep_info.stage; + results.push_back(result); + } + + return results; + } }; // ============================================================================ @@ -751,6 +1012,57 @@ std::vector shader_compiler::get_entry_points( return impl_->get_entry_points(source, path.filename().string(), options); } +std::vector shader_compiler::get_entry_points_with_stages( + const std::filesystem::path& path, + const compile_options& options +) { + // 首先检查文件是否存在 + if (!std::filesystem::exists(path)) { + return {}; + } + + // 读取文件内容 + std::ifstream file(path); + if (!file) { + return {}; + } + + std::ostringstream ss; + ss << file.rdbuf(); + std::string source = ss.str(); + + return impl_->get_entry_points_with_stages(source, path.filename().string(), options); +} + +std::vector shader_compiler::compile_file_all_entries( + const std::filesystem::path& path, + const compile_options& options +) { + // 首先检查文件是否存在 + if (!std::filesystem::exists(path)) { + return {}; + } + + // 读取文件内容 + std::ifstream file(path); + if (!file) { + return {}; + } + + std::ostringstream ss; + ss << file.rdbuf(); + std::string source = ss.str(); + + // 添加文件所在目录到包含路径 + compile_options modified_options = options; + modified_options.include_paths.insert( + modified_options.include_paths.begin(), + path.parent_path() + ); + + return impl_->compile_all_entries(source, path.filename().string(), modified_options); +} + bool shader_compiler::is_available() const noexcept { return impl_ && impl_->global_session; } diff --git a/tools/shader_compile/compiler.hpp b/tools/shader_compile/compiler.hpp index 5ed03b8..1ee9529 100644 --- a/tools/shader_compile/compiler.hpp +++ b/tools/shader_compile/compiler.hpp @@ -7,6 +7,7 @@ #include #include #include +#include #include #include @@ -34,11 +35,18 @@ enum class shader_stage { */ [[nodiscard]] const char* shader_stage_to_string(shader_stage stage); +/** + * @brief 将 Slang 阶段转换为 shader_stage + */ +[[nodiscard]] std::optional slang_stage_to_shader_stage(SlangStage stage); + /** * @brief 编译结果 */ struct compile_result { bool success{false}; + std::string entry_point; // 入口点名称 + shader_stage stage; // 着色器阶段 std::vector spirv; std::string reflection_json; std::string error_message; @@ -59,6 +67,14 @@ struct compile_options { bool emit_reflection{false}; }; +/** + * @brief 入口点信息(包含名称和阶段) + */ +struct entry_point_info { + std::string name; + shader_stage stage; +}; + /** * @brief 着色器编译器 * @@ -136,6 +152,28 @@ public: const compile_options& options ); + /** + * @brief 获取模块中的所有入口函数及其阶段 + * @param path 输入文件路径 + * @param options 编译选项 + * @return 入口点信息列表 + */ + [[nodiscard]] std::vector get_entry_points_with_stages( + const std::filesystem::path& path, + const compile_options& options + ); + + /** + * @brief 编译着色器文件(自动检测所有入口点) + * @param path 输入文件路径 + * @param options 编译选项 + * @return 编译结果列表(每个入口点一个结果) + */ + [[nodiscard]] std::vector compile_file_all_entries( + const std::filesystem::path& path, + const compile_options& options + ); + /** * @brief 检查编译器是否可用 */ @@ -151,4 +189,4 @@ private: std::unique_ptr impl_; }; -} // namespace mirai::tools \ No newline at end of file +} // namespace mirai::tools diff --git a/tools/shader_compile/main.cpp b/tools/shader_compile/main.cpp index f55c1ce..a249586 100644 --- a/tools/shader_compile/main.cpp +++ b/tools/shader_compile/main.cpp @@ -1,5 +1,6 @@ // tools/shader_compile/main.cpp // MIRAI 着色器编译器命令行入口 +// 自动检测所有 [shader("xxx")] 入口点并编译 #include "compiler.hpp" @@ -12,47 +13,47 @@ namespace { void print_usage(const char* program_name) { std::cout << R"( -MIRAI Shader Compiler +MIRAI Shader Compiler - Auto Mode Usage: )" << program_name << R"( [options] Options: - -o, --output Output SPIR-V file (.spv) - -I, --include Add include search path - -D, --define Define preprocessor macro (e.g., -DDEBUG or -DVALUE=1) - --entry Entry point name (default: main) - --stage Shader stage: vertex, fragment, compute, geometry, - tesscontrol, tesseval (or short: vert, frag, comp, etc.) - --emit-spirv Output SPIR-V (default if -o is specified) - --emit-reflection Output reflection JSON - -r, --reflection Output reflection JSON to file - -g Generate debug info - -O0 Disable optimization - --deps Write dependency file (Make format) - -h, --help Show this help - -v, --version Show version info + -o, --output Output directory for compiled shaders + -I, --include Add include search path + -D, --define Define preprocessor macro (e.g., -DDEBUG or -DVALUE=1) + --prefix Prefix for output file names (default: shader name) + -g Generate debug info + -O0 Disable optimization + --list Only list entry points, don't compile + -h, --help Show this help + -v, --version Show version info Examples: - )" << program_name << R"( shader.slang -o shader.vert.spv --entry vertexMain --stage vertex - )" << program_name << R"( shader.slang -o shader.frag.spv --entry fragmentMain --stage fragment - )" << program_name << R"( shader.slang -r shader.reflect.json --emit-reflection + )" << program_name << R"( shader.slang -o ./compiled_shaders + )" << program_name << R"( shader.slang -o ./shaders --prefix myshader + )" << program_name << R"( shader.slang --list + +Output Format: + Automatically compiles all entry points found via [shader("xxx")] attributes. + Output files: ...spv + ...reflect.json )"; } void print_version() { - std::cout << "MIRAI Shader Compiler v1.0.0\n"; + std::cout << "MIRAI Shader Compiler v2.0.0 (Auto Mode)\n"; std::cout << "Based on Slang Shader Language\n"; } struct command_line_args { std::filesystem::path input_path; - std::filesystem::path output_path; - std::filesystem::path reflection_path; - std::filesystem::path deps_path; + std::filesystem::path output_dir; + std::string prefix; mirai::tools::compile_options options; bool show_help = false; bool show_version = false; + bool list_only = false; }; bool parse_args(int argc, char* argv[], command_line_args& args) { @@ -72,16 +73,7 @@ bool parse_args(int argc, char* argv[], command_line_args& args) { std::cerr << "Error: -o requires an argument\n"; return false; } - args.output_path = argv[i]; - args.options.emit_spirv = true; - } - else if (arg == "-r" || arg == "--reflection") { - if (++i >= argc) { - std::cerr << "Error: -r requires an argument\n"; - return false; - } - args.reflection_path = argv[i]; - args.options.emit_reflection = true; + args.output_dir = argv[i]; } else if (arg == "-I" || arg == "--include") { if (++i >= argc) { @@ -115,30 +107,15 @@ bool parse_args(int argc, char* argv[], command_line_args& args) { args.options.defines.emplace_back(def, "1"); } } - else if (arg == "--entry" || arg == "-e") { + else if (arg == "--prefix") { if (++i >= argc) { - std::cerr << "Error: --entry requires an argument\n"; + std::cerr << "Error: --prefix requires an argument\n"; return false; } - args.options.entry_point = argv[i]; + args.prefix = argv[i]; } - else if (arg == "--stage" || arg == "-s") { - if (++i >= argc) { - std::cerr << "Error: --stage requires an argument\n"; - return false; - } - auto stage = mirai::tools::parse_shader_stage(argv[i]); - if (!stage) { - std::cerr << "Error: Invalid shader stage: " << argv[i] << "\n"; - return false; - } - args.options.stage = stage; - } - else if (arg == "--emit-spirv") { - args.options.emit_spirv = true; - } - else if (arg == "--emit-reflection") { - args.options.emit_reflection = true; + else if (arg == "--list") { + args.list_only = true; } else if (arg == "-g") { args.options.generate_debug_info = true; @@ -146,13 +123,6 @@ bool parse_args(int argc, char* argv[], command_line_args& args) { else if (arg == "-O0") { args.options.optimize = false; } - else if (arg == "--deps") { - if (++i >= argc) { - std::cerr << "Error: --deps requires an argument\n"; - return false; - } - args.deps_path = argv[i]; - } else if (!arg.starts_with("-")) { if (!args.input_path.empty()) { std::cerr << "Error: Multiple input files specified\n"; @@ -169,6 +139,18 @@ bool parse_args(int argc, char* argv[], command_line_args& args) { return true; } +std::string stage_to_suffix(mirai::tools::shader_stage stage) { + switch (stage) { + case mirai::tools::shader_stage::vertex: return "vert"; + case mirai::tools::shader_stage::fragment: return "frag"; + case mirai::tools::shader_stage::compute: return "comp"; + case mirai::tools::shader_stage::geometry: return "geom"; + case mirai::tools::shader_stage::tessellation_control: return "tesc"; + case mirai::tools::shader_stage::tessellation_evaluation: return "tese"; + } + return "unknown"; +} + bool write_spirv(const std::filesystem::path& path, const std::vector& spirv) { std::ofstream file(path, std::ios::binary); if (!file) { @@ -191,25 +173,6 @@ bool write_text(const std::filesystem::path& path, const std::string& text) { return file.good(); } -bool write_deps( - const std::filesystem::path& path, - const std::filesystem::path& target, - const std::vector& deps -) { - std::ofstream file(path); - if (!file) { - std::cerr << "Error: Failed to open deps file: " << path << "\n"; - return false; - } - - file << target.string() << ":"; - for (const auto& dep : deps) { - file << " \\\n " << dep.string(); - } - file << "\n"; - return file.good(); -} - } // anonymous namespace int main(int argc, char* argv[]) { @@ -235,12 +198,6 @@ int main(int argc, char* argv[]) { return 1; } - // 如果没有指定输出,至少需要指定一种输出 - if (args.output_path.empty() && args.reflection_path.empty()) { - std::cerr << "Error: No output specified. Use -o for SPIR-V or -r for reflection JSON\n"; - return 1; - } - // 创建编译器 mirai::tools::shader_compiler compiler; if (!compiler.is_available()) { @@ -248,60 +205,74 @@ int main(int argc, char* argv[]) { return 1; } - // 编译 - auto result = compiler.compile_file(args.input_path, args.options); + // 获取所有入口点及其阶段 + auto entry_points = compiler.get_entry_points_with_stages(args.input_path, args.options); - if (!result.success) { - // 检查是否是入口点未找到的错误(这对于纯模块文件是正常的) - if (result.error_message.find("Entry point") != std::string::npos && - result.error_message.find("not found") != std::string::npos) { - // 模块文件没有入口点是正常行为,跳过编译 - std::cout << "Warning: No entry point found (module file?), skipping: " << args.input_path << "\n"; - return 0; - } - std::cerr << "Compilation failed:\n" << result.error_message << "\n"; - return 1; + if (entry_points.empty()) { + std::cout << "Warning: No entry points found in " << args.input_path << "\n"; + std::cout << " (This may be a module file with no [shader(...)] attributes)\n"; + return 0; } - // 写入 SPIR-V - if (!args.output_path.empty() && !result.spirv.empty()) { - // 创建输出目录 - auto parent = args.output_path.parent_path(); - if (!parent.empty()) { - std::filesystem::create_directories(parent); - } - - if (!write_spirv(args.output_path, result.spirv)) { - return 1; - } - std::cout << "Generated: " << args.output_path << " (" << result.spirv.size() * 4 << " bytes)\n"; + // 列出入口点 + std::cout << "Found " << entry_points.size() << " entry point(s) in " << args.input_path.filename() << ":\n"; + for (const auto& ep : entry_points) { + std::cout << " - " << ep.name << " [" << mirai::tools::shader_stage_to_string(ep.stage) << "]\n"; } - // 写入反射 JSON - if (!args.reflection_path.empty() && !result.reflection_json.empty()) { - auto parent = args.reflection_path.parent_path(); - if (!parent.empty()) { - std::filesystem::create_directories(parent); - } - - if (!write_text(args.reflection_path, result.reflection_json)) { - return 1; - } - std::cout << "Generated: " << args.reflection_path << "\n"; + // 如果只是列出,不编译 + if (args.list_only) { + return 0; } - // 写入依赖文件 - if (!args.deps_path.empty()) { - auto target = args.output_path.empty() ? args.reflection_path : args.output_path; - - // 添加输入文件作为依赖 - auto deps = result.dependencies; - deps.insert(deps.begin(), args.input_path); - - if (!write_deps(args.deps_path, target, deps)) { - return 1; - } + // 确保输出目录存在 + if (args.output_dir.empty()) { + args.output_dir = "."; + } + if (!std::filesystem::exists(args.output_dir)) { + std::filesystem::create_directories(args.output_dir); } - return 0; + // 确定前缀 + if (args.prefix.empty()) { + args.prefix = args.input_path.stem().string(); + } + + // 编译所有入口点 + args.options.emit_spirv = true; + args.options.emit_reflection = true; + + auto results = compiler.compile_file_all_entries(args.input_path, args.options); + + int success_count = 0; + int fail_count = 0; + + for (const auto& result : results) { + if (!result.success) { + std::cerr << "Error compiling " << result.entry_point << ": " << result.error_message << "\n"; + fail_count++; + continue; + } + + std::string suffix = stage_to_suffix(result.stage); + std::string base_name = args.prefix + "." + result.entry_point + "." + suffix; + + // 写入 SPIR-V + std::filesystem::path spv_path = args.output_dir / (base_name + ".spv"); + if (write_spirv(spv_path, result.spirv)) { + std::cout << "Generated: " << spv_path << " (" << result.spirv.size() * 4 << " bytes)\n"; + } + + // 写入反射 JSON + std::filesystem::path reflect_path = args.output_dir / (base_name + ".reflect.json"); + if (write_text(reflect_path, result.reflection_json)) { + std::cout << "Generated: " << reflect_path << "\n"; + } + + success_count++; + } + + std::cout << "\nSummary: " << success_count << " succeeded, " << fail_count << " failed\n"; + + return fail_count > 0 ? 1 : 0; }