重构着色器编译器,支持自动检测多入口点,优化编译流程,更新命令行参数,改进输出文件管理

This commit is contained in:
2025-12-31 23:01:01 +08:00
parent 63bc415857
commit 596f503dfa
6 changed files with 571 additions and 281 deletions

View File

@@ -1,7 +1,8 @@
# ============================================================================
# MIRAI 着色器编译 CMake 模块
# ============================================================================
# 提供自动化着色器编译和 C++ 代码生成功能
# 提供自动化着色器编译和代码生成功能
# 自动检测 [shader("xxx")] 属性并编译所有入口点
#
# 主要函数:
# add_shader_library() - 创建着色器库目标
@@ -74,6 +75,7 @@ endfunction()
# ============================================================================
#
# 将着色器编译和代码生成附加到现有的编译目标上
# 自动检测所有 [shader("xxx")] 入口点并编译
#
# 用法:
# # 方案1: 自动搜索着色器,指定引用路径
@@ -112,7 +114,7 @@ function(add_shader_library)
# 验证必需参数
if(NOT SHADER_TARGET)
message(FATAL_ERROR "add_shader_library: TARGET is required (existing compilation target)")
message(FATAL_ERROR "add_shader_library: TARGET is required")
endif()
# 验证目标存在
@@ -168,8 +170,6 @@ function(add_shader_library)
# 收集生成的文件
set(ALL_GENERATED_HEADERS "")
set(ALL_SPIRV_FILES "")
set(ALL_CUSTOM_TARGETS "")
# 处理每个着色器文件
foreach(SHADER_SOURCE ${ALL_SHADERS})
@@ -184,112 +184,54 @@ function(add_shader_library)
endif()
endif()
# 输出文件路径
set(VERT_SPV "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.vert.spv")
set(FRAG_SPV "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.frag.spv")
set(VERT_REFLECT_JSON "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.vert.reflect.json")
set(FRAG_REFLECT_JSON "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}.frag.reflect.json")
# 每个着色器文件的输出目录
set(SPIRV_OUTPUT_DIR "${SHADER_INTERMEDIATE_DIR}/${SHADER_NAME_WE}")
set(GENERATED_HPP "${SHADER_OUTPUT_DIR}/${SHADER_NAME_WE}_bindings.hpp")
# 编译顶点着色器
# 步骤1: 编译着色器(自动检测所有入口点)
add_custom_command(
OUTPUT "${VERT_SPV}"
OUTPUT "${SPIRV_OUTPUT_DIR}/.compiled"
COMMAND ${CMAKE_COMMAND} -E make_directory "${SPIRV_OUTPUT_DIR}"
COMMAND $<TARGET_FILE:${MIRAI_SHADER_COMPILER_TARGET}>
"${SHADER_ABS}"
-o "${VERT_SPV}"
-e vertexMain
-s vertex
-o "${SPIRV_OUTPUT_DIR}"
${INCLUDE_ARGS}
${DEFINE_ARGS}
COMMAND ${CMAKE_COMMAND} -E touch "${SPIRV_OUTPUT_DIR}/.compiled"
DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET}
COMMENT "Compiling vertex shader: ${SHADER_NAME_WE}"
COMMENT "Compiling shader: ${SHADER_NAME_WE}"
VERBATIM COMMAND_EXPAND_LISTS
)
# 编译片段着色器
add_custom_command(
OUTPUT "${FRAG_SPV}"
COMMAND $<TARGET_FILE:${MIRAI_SHADER_COMPILER_TARGET}>
"${SHADER_ABS}"
-o "${FRAG_SPV}"
-e fragmentMain
-s fragment
${INCLUDE_ARGS}
${DEFINE_ARGS}
DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET}
COMMENT "Compiling fragment shader: ${SHADER_NAME_WE}"
VERBATIM COMMAND_EXPAND_LISTS
)
# 生成顶点反射数据
add_custom_command(
OUTPUT "${VERT_REFLECT_JSON}"
COMMAND $<TARGET_FILE:${MIRAI_SHADER_COMPILER_TARGET}>
"${SHADER_ABS}"
-e vertexMain
-s vertex
-r "${VERT_REFLECT_JSON}"
${INCLUDE_ARGS}
${DEFINE_ARGS}
DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET}
COMMENT "Generating vertex reflection: ${SHADER_NAME_WE}"
VERBATIM COMMAND_EXPAND_LISTS
)
# 生成片段反射数据
add_custom_command(
OUTPUT "${FRAG_REFLECT_JSON}"
COMMAND $<TARGET_FILE:${MIRAI_SHADER_COMPILER_TARGET}>
"${SHADER_ABS}"
-e fragmentMain
-s fragment
-r "${FRAG_REFLECT_JSON}"
${INCLUDE_ARGS}
${DEFINE_ARGS}
DEPENDS "${SHADER_ABS}" ${MIRAI_SHADER_COMPILER_TARGET}
COMMENT "Generating fragment reflection: ${SHADER_NAME_WE}"
VERBATIM COMMAND_EXPAND_LISTS
)
# 生成 C++ 绑定头文件
# 步骤2: 生成绑定头文件Python 脚本自己查找 spv 文件)
add_custom_command(
OUTPUT "${GENERATED_HPP}"
COMMAND "${PYTHON_EXECUTABLE}" "${MIRAI_SHADER_GENERATOR_SCRIPT}"
--spirv "${VERT_SPV}" "${FRAG_SPV}"
--reflection "${VERT_REFLECT_JSON}" "${FRAG_REFLECT_JSON}"
COMMAND ${PYTHON_EXECUTABLE} "${MIRAI_SHADER_GENERATOR_SCRIPT}"
--dir "${SPIRV_OUTPUT_DIR}"
--output "${GENERATED_HPP}"
--name "${SHADER_NAME_WE}"
--template-dir "${MIRAI_SHADER_TEMPLATE_DIR}"
DEPENDS "${VERT_SPV}" "${FRAG_SPV}" "${VERT_REFLECT_JSON}" "${FRAG_REFLECT_JSON}"
"${MIRAI_SHADER_GENERATOR_SCRIPT}"
DEPENDS "${SPIRV_OUTPUT_DIR}/.compiled" "${MIRAI_SHADER_GENERATOR_SCRIPT}"
COMMENT "Generating bindings: ${SHADER_NAME_WE}"
VERBATIM
)
list(APPEND ALL_GENERATED_HEADERS "${GENERATED_HPP}")
list(APPEND ALL_SPIRV_FILES "${VERT_SPV}" "${FRAG_SPV}")
endforeach()
# ========================================================================
# 附加到目标:添加生成目录到 include 路径
# ========================================================================
target_include_directories(${SHADER_TARGET} PUBLIC
$<BUILD_INTERFACE:${SHADER_OUTPUT_DIR}>
)
# ========================================================================
# 添加编译依赖
# ========================================================================
if(ALL_GENERATED_HEADERS)
add_custom_target(${SHADER_TARGET}_shaders DEPENDS ${ALL_GENERATED_HEADERS})
add_dependencies(${SHADER_TARGET} ${SHADER_TARGET}_shaders)
endif()
# ========================================================================
# 导出变量供其他 CMake 代码使用
# ========================================================================
# 导出变量
set(${SHADER_TARGET}_SHADER_HEADERS ${ALL_GENERATED_HEADERS} PARENT_SCOPE)
set(${SHADER_TARGET}_SHADER_SPIRV ${ALL_SPIRV_FILES} PARENT_SCOPE)
set(${SHADER_TARGET}_SHADER_OUTPUT_DIR ${SHADER_OUTPUT_DIR} PARENT_SCOPE)
message(STATUS "Added shaders to target: ${SHADER_TARGET}")
@@ -297,7 +239,4 @@ function(add_shader_library)
if(ALL_SHADERS)
message(STATUS " Shaders: ${ALL_SHADERS}")
endif()
if(SHADER_REF_PATHS)
message(STATUS " Ref paths: ${SHADER_REF_PATHS}")
endif()
endfunction()
endfunction()

View File

@@ -1,7 +1,8 @@
/**
* @file basic.slang
* @brief 基础 UI 着色器示例
* @description 展示 GUI 渲染中常用的 Uniform Buffer、纹理采样、顶点属性Push Constants
* @description 展示 GUI 渲染中常用的 Uniform Buffer、纹理采样、顶点属性Push Constants
* 以及多入口点支持(通过 [shader("xxx")] 属性自动检测)
*/
// ============================================================================
@@ -38,6 +39,18 @@ Texture2D<float4> uiTexture;
[[vk::binding(1, 1)]]
SamplerState linearSampler;
/**
* @brief 计算着色器用的输入纹理
*/
[[vk::binding(0, 2)]]
RWTexture2D<float4> computeInput;
/**
* @brief 计算着色器用的输出纹理
*/
[[vk::binding(1, 2)]]
RWTexture2D<float4> computeOutput;
// ============================================================================
// Push Constants
// ============================================================================
@@ -55,6 +68,18 @@ struct PushConstants {
[[vk::push_constant]]
PushConstants pc;
/**
* @brief 计算着色器的 Push Constants
*/
struct ComputePushConstants {
uint2 workgroupSize; // 工作组大小
uint2 workgroupCount; // 工作组数量
float intensity; // 强度参数
};
[[vk::push_constant]]
ComputePushConstants computePc;
// ============================================================================
// 顶点输入输出结构
// ============================================================================
@@ -114,4 +139,21 @@ float4 fragmentMain(VertexOutput input) : SV_Target {
color.a *= pc.opacity;
return color;
}
}
// ============================================================================
// 计算着色器
// ============================================================================
/**
* @brief 简单的图像处理计算着色器
* @description 演示多入口点支持 - 自动检测 [shader("compute")]
*/
[shader("compute")]
void computeMain(uint3 globalID : SV_DispatchThreadID) {
// 简单的图像处理:应用强度
float4 color = computeInput[globalID.xy];
color.rgb *= computePc.intensity;
color.a = 1.0f;
computeOutput[globalID.xy] = color;
}

View File

@@ -6,8 +6,7 @@ MIRAI 着色器绑定代码生成器
Usage:
python generate_shader_bindings.py \
--spirv shader.vert.spv shader.frag.spv \
--reflection shader.reflect.json \
--dir ./shader_intermediate \
--output generated/shader_name_bindings.hpp \
--name ShaderName
"""
@@ -493,25 +492,16 @@ def main():
epilog="""
示例:
python generate_shader_bindings.py \\
--spirv shader.vert.spv shader.frag.spv \\
--reflection shader.vert.reflect.json shader.frag.reflect.json \\
--output generated/shader_bindings.hpp \\
--name MyShader
--dir ./shader_intermediate/basic \\
--output generated/basic_bindings.hpp \\
--name basic
"""
)
parser.add_argument(
'--spirv', '-s',
nargs='+',
'--dir', '-d',
required=True,
help='SPIR-V 二进制文件 (.spv),可以指定多个文件'
)
parser.add_argument(
'--reflection', '-r',
nargs='+',
required=True,
help='反射 JSON 文件 (.json),可以指定多个文件(每个阶段一个)'
help='包含 SPIR-V 和反射文件的目录'
)
parser.add_argument(
@@ -550,45 +540,43 @@ def main():
print(f"Error: Template directory not found: {template_dir}")
return 1
# 解析反射数据(支持多个文件
all_reflections = []
reflection_paths = []
found_any_reflection = False
for reflection_path_str in args.reflection:
reflection_path = Path(reflection_path_str)
if reflection_path.exists():
found_any_reflection = True
if args.verbose:
print(f"Parsing reflection: {reflection_path}")
reflection = parse_reflection_json(reflection_path)
all_reflections.append(reflection)
reflection_paths.append(reflection_path)
else:
if args.verbose:
print(f"Warning: Reflection file not found: {reflection_path} (skipping)")
# 从目录中查找文件
shader_dir = Path(args.dir)
if not shader_dir.exists():
print(f"Info: Shader directory not found: {shader_dir}. No entry points, skipping.")
return 0
# 如果没有找到任何反射文件,说明着色器没有入口函数,跳过生成
if not found_any_reflection:
# 查找所有 spv 和 reflect.json 文件
spirv_files = sorted(shader_dir.glob("*.spv"))
reflect_files = sorted(shader_dir.glob("*.reflect.json"))
if args.verbose:
print(f"Found {len(spirv_files)} SPIR-V files")
print(f"Found {len(reflect_files)} reflection files")
# 如果没有反射文件,跳过生成
if not reflect_files:
print(f"Info: No reflection files found in {shader_dir}. Skipping binding generation.")
return 0
# 解析反射数据
all_reflections = []
for reflection_path in reflect_files:
if args.verbose:
print(f"Info: No reflection files found for {args.name}. Shader has no entry points, skipping binding generation.")
return 0 # 正常退出,不生成文件
else:
# 合并多个反射数据(取并集)
reflection = merge_reflections(all_reflections)
print(f"Parsing reflection: {reflection_path}")
reflection = parse_reflection_json(reflection_path)
all_reflections.append(reflection)
# 合并多个反射数据(取并集)
reflection = merge_reflections(all_reflections)
# 加载 SPIR-V 文件
stages: Dict[str, bytes] = {}
for spirv_file in args.spirv:
spirv_path = Path(spirv_file)
if not spirv_path.exists():
if args.verbose:
print(f"Warning: SPIR-V file not found: {spirv_path} (skipping)")
continue
for spirv_path in spirv_files:
stage = detect_stage_from_filename(spirv_path.name)
if stage is None:
print(f"Warning: Cannot detect shader stage from filename: {spirv_path.name}")
# 尝试使用文件名作为阶段名
if args.verbose:
print(f"Warning: Cannot detect shader stage from filename: {spirv_path.name}")
stage = spirv_path.stem.split('.')[-1] if '.' in spirv_path.stem else 'unknown'
if args.verbose:
@@ -623,4 +611,4 @@ def main():
if __name__ == '__main__':
exit(main())
exit(main())

View File

@@ -9,6 +9,7 @@
#include <nlohmann/json.hpp>
#include <fstream>
#include <iostream>
#include <sstream>
namespace mirai::tools {
@@ -63,6 +64,18 @@ static SlangStage to_slang_stage(shader_stage stage) {
return SLANG_STAGE_NONE;
}
std::optional<shader_stage> slang_stage_to_shader_stage(SlangStage stage) {
switch (stage) {
case SLANG_STAGE_VERTEX: return shader_stage::vertex;
case SLANG_STAGE_FRAGMENT: return shader_stage::fragment;
case SLANG_STAGE_COMPUTE: return shader_stage::compute;
case SLANG_STAGE_GEOMETRY: return shader_stage::geometry;
case SLANG_STAGE_HULL: return shader_stage::tessellation_control;
case SLANG_STAGE_DOMAIN: return shader_stage::tessellation_evaluation;
default: return std::nullopt;
}
}
// ============================================================================
// Implementation Class
// ============================================================================
@@ -578,6 +591,9 @@ struct shader_compiler::impl {
target_desc.format = SLANG_SPIRV;
target_desc.profile = global_session->findProfile("glsl_450");
session_desc.targets = &target_desc;
session_desc.targetCount = 1;
std::vector<std::string> include_path_strs;
std::vector<const char*> search_paths;
for (const auto& path : options.include_paths) {
@@ -614,26 +630,246 @@ struct shader_compiler::impl {
return result;
}
// 获取入口点列表 - 使用正确的 API
// 获取入口点列表
SlangInt32 entry_point_count = module->getDefinedEntryPointCount();
for (SlangInt32 i = 0; i < entry_point_count; ++i) {
Slang::ComPtr<slang::IEntryPoint> entry_point;
if (SLANG_SUCCEEDED(module->getDefinedEntryPoint(i, entry_point.writeRef()))) {
// 获取入口点的反射信息来获取名称
slang::IComponentType* entry_comp = entry_point.get();
if (entry_comp) {
// 尝试通过 getLayout 获取名称
slang::ProgramLayout* layout = entry_comp->getLayout();
if (layout) {
SlangUInt count = layout->getEntryPointCount();
if (count > 0) {
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
if (ep_ref) {
const char* entry_name = ep_ref->getName();
if (entry_name && entry_name[0] != '\0') {
result.push_back(entry_name);
}
}
// 常见入口点名称模式
const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain",
"vertexShader", "fragmentShader", "computeShader"};
const SlangStage stages[] = {
SLANG_STAGE_VERTEX,
SLANG_STAGE_FRAGMENT,
SLANG_STAGE_COMPUTE,
SLANG_STAGE_GEOMETRY,
SLANG_STAGE_HULL,
SLANG_STAGE_DOMAIN
};
// 尝试每个可能的入口点名称和阶段的组合
for (const char* name : common_names) {
for (SlangStage stage : stages) {
Slang::ComPtr<slang::IEntryPoint> checked_entry;
SlangResult find_result = module->findAndCheckEntryPoint(
name,
stage,
checked_entry.writeRef(),
diagnostics_blob.writeRef()
);
if (SLANG_FAILED(find_result) || !checked_entry) {
continue;
}
// 验证并编译这个入口点来获取名称
std::vector<slang::IComponentType*> comps = {module, checked_entry.get()};
Slang::ComPtr<slang::IComponentType> composite;
if (SLANG_FAILED(session->createCompositeComponentType(
comps.data(),
static_cast<SlangInt>(comps.size()),
composite.writeRef(),
diagnostics_blob.writeRef()
))) {
continue;
}
Slang::ComPtr<slang::IComponentType> linked;
if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 生成代码来触发 layout 生成
Slang::ComPtr<slang::IBlob> spirv_blob;
if (SLANG_FAILED(linked->getEntryPointCode(
0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 获取 layout
slang::ProgramLayout* layout = linked->getLayout();
if (!layout || layout->getEntryPointCount() == 0) {
continue;
}
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
if (!ep_ref) {
continue;
}
const char* entry_name = ep_ref->getName();
if (entry_name && entry_name[0] != '\0') {
// 检查是否已经添加
bool exists = false;
for (const auto& existing : result) {
if (existing == entry_name) {
exists = true;
break;
}
}
if (!exists) {
result.push_back(entry_name);
}
}
}
}
return result;
}
// 获取所有入口函数及其阶段
std::vector<entry_point_info> get_entry_points_with_stages(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
std::vector<entry_point_info> result;
if (!global_session) {
return result;
}
// 创建 session
slang::SessionDesc session_desc = {};
slang::TargetDesc target_desc = {};
target_desc.format = SLANG_SPIRV;
target_desc.profile = global_session->findProfile("glsl_450");
session_desc.targets = &target_desc;
session_desc.targetCount = 1;
std::vector<std::string> include_path_strs;
std::vector<const char*> search_paths;
for (const auto& path : options.include_paths) {
include_path_strs.push_back(path.string());
search_paths.push_back(include_path_strs.back().c_str());
}
session_desc.searchPaths = search_paths.data();
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
std::vector<std::pair<std::string, std::string>> macro_strs;
std::vector<slang::PreprocessorMacroDesc> macros;
for (const auto& [name, value] : options.defines) {
macro_strs.emplace_back(name, value);
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
}
session_desc.preprocessorMacros = macros.data();
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
Slang::ComPtr<slang::ISession> session;
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
return result;
}
// 加载模块
Slang::ComPtr<slang::IBlob> diagnostics_blob;
slang::IModule* module = session->loadModuleFromSourceString(
filename.c_str(),
filename.c_str(),
source.c_str(),
diagnostics_blob.writeRef()
);
if (!module) {
return result;
}
// 获取所有定义的入口点数量
SlangInt32 entry_point_count = module->getDefinedEntryPointCount();
if (entry_point_count == 0) {
return result;
}
// 常见入口点名称模式
const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain",
"vertexShader", "fragmentShader", "computeShader"};
const SlangStage stages[] = {
SLANG_STAGE_VERTEX,
SLANG_STAGE_FRAGMENT,
SLANG_STAGE_COMPUTE,
SLANG_STAGE_GEOMETRY,
SLANG_STAGE_HULL,
SLANG_STAGE_DOMAIN
};
// 尝试每个可能的入口点名称和阶段的组合
for (const char* name : common_names) {
for (SlangStage stage : stages) {
// 如果已经找到所有入口点,提前退出
if (static_cast<SlangInt32>(result.size()) >= entry_point_count) {
break;
}
Slang::ComPtr<slang::IEntryPoint> checked_entry;
SlangResult find_result = module->findAndCheckEntryPoint(
name,
stage,
checked_entry.writeRef(),
diagnostics_blob.writeRef()
);
if (SLANG_FAILED(find_result) || !checked_entry) {
continue;
}
// 验证并编译这个入口点来获取名称
std::vector<slang::IComponentType*> comps = {module, checked_entry.get()};
Slang::ComPtr<slang::IComponentType> composite;
if (SLANG_FAILED(session->createCompositeComponentType(
comps.data(),
static_cast<SlangInt>(comps.size()),
composite.writeRef(),
diagnostics_blob.writeRef()
))) {
continue;
}
Slang::ComPtr<slang::IComponentType> linked;
if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 生成代码来触发 layout 生成
Slang::ComPtr<slang::IBlob> spirv_blob;
if (SLANG_FAILED(linked->getEntryPointCode(
0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 获取 layout
slang::ProgramLayout* layout = linked->getLayout();
if (!layout || layout->getEntryPointCount() == 0) {
continue;
}
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
if (!ep_ref) {
continue;
}
const char* entry_name = ep_ref->getName();
SlangStage slang_stage = ep_ref->getStage();
if (entry_name && entry_name[0] != '\0') {
// 检查是否已经添加
bool exists = false;
for (const auto& existing : result) {
if (existing.name == entry_name) {
exists = true;
break;
}
}
if (!exists) {
auto shader_stage_opt = slang_stage_to_shader_stage(slang_stage);
if (shader_stage_opt) {
entry_point_info info;
info.name = entry_name;
info.stage = *shader_stage_opt;
result.push_back(info);
}
}
}
@@ -642,6 +878,31 @@ struct shader_compiler::impl {
return result;
}
// 编译所有入口点
std::vector<compile_result> compile_all_entries(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
std::vector<compile_result> results;
// 获取所有入口点及其阶段
auto entry_points = get_entry_points_with_stages(source, filename, options);
for (const auto& ep_info : entry_points) {
compile_options ep_options = options;
ep_options.entry_point = ep_info.name;
ep_options.stage = ep_info.stage;
compile_result result = compile(source, filename, ep_options);
result.entry_point = ep_info.name;
result.stage = ep_info.stage;
results.push_back(result);
}
return results;
}
};
// ============================================================================
@@ -751,6 +1012,57 @@ std::vector<std::string> shader_compiler::get_entry_points(
return impl_->get_entry_points(source, path.filename().string(), options);
}
std::vector<entry_point_info> shader_compiler::get_entry_points_with_stages(
const std::filesystem::path& path,
const compile_options& options
) {
// 首先检查文件是否存在
if (!std::filesystem::exists(path)) {
return {};
}
// 读取文件内容
std::ifstream file(path);
if (!file) {
return {};
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
return impl_->get_entry_points_with_stages(source, path.filename().string(), options);
}
std::vector<compile_result> shader_compiler::compile_file_all_entries(
const std::filesystem::path& path,
const compile_options& options
) {
// 首先检查文件是否存在
if (!std::filesystem::exists(path)) {
return {};
}
// 读取文件内容
std::ifstream file(path);
if (!file) {
return {};
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
// 添加文件所在目录到包含路径
compile_options modified_options = options;
modified_options.include_paths.insert(
modified_options.include_paths.begin(),
path.parent_path()
);
return impl_->compile_all_entries(source, path.filename().string(), modified_options);
}
bool shader_compiler::is_available() const noexcept {
return impl_ && impl_->global_session;
}

View File

@@ -7,6 +7,7 @@
#include <filesystem>
#include <memory>
#include <optional>
#include <slang.h>
#include <string>
#include <vector>
@@ -34,11 +35,18 @@ enum class shader_stage {
*/
[[nodiscard]] const char* shader_stage_to_string(shader_stage stage);
/**
* @brief 将 Slang 阶段转换为 shader_stage
*/
[[nodiscard]] std::optional<shader_stage> slang_stage_to_shader_stage(SlangStage stage);
/**
* @brief 编译结果
*/
struct compile_result {
bool success{false};
std::string entry_point; // 入口点名称
shader_stage stage; // 着色器阶段
std::vector<uint32_t> spirv;
std::string reflection_json;
std::string error_message;
@@ -59,6 +67,14 @@ struct compile_options {
bool emit_reflection{false};
};
/**
* @brief 入口点信息(包含名称和阶段)
*/
struct entry_point_info {
std::string name;
shader_stage stage;
};
/**
* @brief 着色器编译器
*
@@ -136,6 +152,28 @@ public:
const compile_options& options
);
/**
* @brief 获取模块中的所有入口函数及其阶段
* @param path 输入文件路径
* @param options 编译选项
* @return 入口点信息列表
*/
[[nodiscard]] std::vector<entry_point_info> get_entry_points_with_stages(
const std::filesystem::path& path,
const compile_options& options
);
/**
* @brief 编译着色器文件(自动检测所有入口点)
* @param path 输入文件路径
* @param options 编译选项
* @return 编译结果列表(每个入口点一个结果)
*/
[[nodiscard]] std::vector<compile_result> compile_file_all_entries(
const std::filesystem::path& path,
const compile_options& options
);
/**
* @brief 检查编译器是否可用
*/
@@ -151,4 +189,4 @@ private:
std::unique_ptr<impl> impl_;
};
} // namespace mirai::tools
} // namespace mirai::tools

View File

@@ -1,5 +1,6 @@
// tools/shader_compile/main.cpp
// MIRAI 着色器编译器命令行入口
// 自动检测所有 [shader("xxx")] 入口点并编译
#include "compiler.hpp"
@@ -12,47 +13,47 @@ namespace {
void print_usage(const char* program_name) {
std::cout << R"(
MIRAI Shader Compiler
MIRAI Shader Compiler - Auto Mode
Usage: )" << program_name << R"( [options] <input.slang>
Options:
-o, --output <file> Output SPIR-V file (.spv)
-I, --include <path> Add include search path
-D, --define <MACRO> Define preprocessor macro (e.g., -DDEBUG or -DVALUE=1)
--entry <name> Entry point name (default: main)
--stage <stage> Shader stage: vertex, fragment, compute, geometry,
tesscontrol, tesseval (or short: vert, frag, comp, etc.)
--emit-spirv Output SPIR-V (default if -o is specified)
--emit-reflection Output reflection JSON
-r, --reflection <file> Output reflection JSON to file
-g Generate debug info
-O0 Disable optimization
--deps <file> Write dependency file (Make format)
-h, --help Show this help
-v, --version Show version info
-o, --output <dir> Output directory for compiled shaders
-I, --include <path> Add include search path
-D, --define <MACRO> Define preprocessor macro (e.g., -DDEBUG or -DVALUE=1)
--prefix <name> Prefix for output file names (default: shader name)
-g Generate debug info
-O0 Disable optimization
--list Only list entry points, don't compile
-h, --help Show this help
-v, --version Show version info
Examples:
)" << program_name << R"( shader.slang -o shader.vert.spv --entry vertexMain --stage vertex
)" << program_name << R"( shader.slang -o shader.frag.spv --entry fragmentMain --stage fragment
)" << program_name << R"( shader.slang -r shader.reflect.json --emit-reflection
)" << program_name << R"( shader.slang -o ./compiled_shaders
)" << program_name << R"( shader.slang -o ./shaders --prefix myshader
)" << program_name << R"( shader.slang --list
Output Format:
Automatically compiles all entry points found via [shader("xxx")] attributes.
Output files: <prefix>.<entry_point>.<stage>.spv
<prefix>.<entry_point>.<stage>.reflect.json
)";
}
void print_version() {
std::cout << "MIRAI Shader Compiler v1.0.0\n";
std::cout << "MIRAI Shader Compiler v2.0.0 (Auto Mode)\n";
std::cout << "Based on Slang Shader Language\n";
}
struct command_line_args {
std::filesystem::path input_path;
std::filesystem::path output_path;
std::filesystem::path reflection_path;
std::filesystem::path deps_path;
std::filesystem::path output_dir;
std::string prefix;
mirai::tools::compile_options options;
bool show_help = false;
bool show_version = false;
bool list_only = false;
};
bool parse_args(int argc, char* argv[], command_line_args& args) {
@@ -72,16 +73,7 @@ bool parse_args(int argc, char* argv[], command_line_args& args) {
std::cerr << "Error: -o requires an argument\n";
return false;
}
args.output_path = argv[i];
args.options.emit_spirv = true;
}
else if (arg == "-r" || arg == "--reflection") {
if (++i >= argc) {
std::cerr << "Error: -r requires an argument\n";
return false;
}
args.reflection_path = argv[i];
args.options.emit_reflection = true;
args.output_dir = argv[i];
}
else if (arg == "-I" || arg == "--include") {
if (++i >= argc) {
@@ -115,30 +107,15 @@ bool parse_args(int argc, char* argv[], command_line_args& args) {
args.options.defines.emplace_back(def, "1");
}
}
else if (arg == "--entry" || arg == "-e") {
else if (arg == "--prefix") {
if (++i >= argc) {
std::cerr << "Error: --entry requires an argument\n";
std::cerr << "Error: --prefix requires an argument\n";
return false;
}
args.options.entry_point = argv[i];
args.prefix = argv[i];
}
else if (arg == "--stage" || arg == "-s") {
if (++i >= argc) {
std::cerr << "Error: --stage requires an argument\n";
return false;
}
auto stage = mirai::tools::parse_shader_stage(argv[i]);
if (!stage) {
std::cerr << "Error: Invalid shader stage: " << argv[i] << "\n";
return false;
}
args.options.stage = stage;
}
else if (arg == "--emit-spirv") {
args.options.emit_spirv = true;
}
else if (arg == "--emit-reflection") {
args.options.emit_reflection = true;
else if (arg == "--list") {
args.list_only = true;
}
else if (arg == "-g") {
args.options.generate_debug_info = true;
@@ -146,13 +123,6 @@ bool parse_args(int argc, char* argv[], command_line_args& args) {
else if (arg == "-O0") {
args.options.optimize = false;
}
else if (arg == "--deps") {
if (++i >= argc) {
std::cerr << "Error: --deps requires an argument\n";
return false;
}
args.deps_path = argv[i];
}
else if (!arg.starts_with("-")) {
if (!args.input_path.empty()) {
std::cerr << "Error: Multiple input files specified\n";
@@ -169,6 +139,18 @@ bool parse_args(int argc, char* argv[], command_line_args& args) {
return true;
}
std::string stage_to_suffix(mirai::tools::shader_stage stage) {
switch (stage) {
case mirai::tools::shader_stage::vertex: return "vert";
case mirai::tools::shader_stage::fragment: return "frag";
case mirai::tools::shader_stage::compute: return "comp";
case mirai::tools::shader_stage::geometry: return "geom";
case mirai::tools::shader_stage::tessellation_control: return "tesc";
case mirai::tools::shader_stage::tessellation_evaluation: return "tese";
}
return "unknown";
}
bool write_spirv(const std::filesystem::path& path, const std::vector<uint32_t>& spirv) {
std::ofstream file(path, std::ios::binary);
if (!file) {
@@ -191,25 +173,6 @@ bool write_text(const std::filesystem::path& path, const std::string& text) {
return file.good();
}
bool write_deps(
const std::filesystem::path& path,
const std::filesystem::path& target,
const std::vector<std::filesystem::path>& deps
) {
std::ofstream file(path);
if (!file) {
std::cerr << "Error: Failed to open deps file: " << path << "\n";
return false;
}
file << target.string() << ":";
for (const auto& dep : deps) {
file << " \\\n " << dep.string();
}
file << "\n";
return file.good();
}
} // anonymous namespace
int main(int argc, char* argv[]) {
@@ -235,12 +198,6 @@ int main(int argc, char* argv[]) {
return 1;
}
// 如果没有指定输出,至少需要指定一种输出
if (args.output_path.empty() && args.reflection_path.empty()) {
std::cerr << "Error: No output specified. Use -o for SPIR-V or -r for reflection JSON\n";
return 1;
}
// 创建编译器
mirai::tools::shader_compiler compiler;
if (!compiler.is_available()) {
@@ -248,60 +205,74 @@ int main(int argc, char* argv[]) {
return 1;
}
// 编译
auto result = compiler.compile_file(args.input_path, args.options);
// 获取所有入口点及其阶段
auto entry_points = compiler.get_entry_points_with_stages(args.input_path, args.options);
if (!result.success) {
// 检查是否是入口点未找到的错误(这对于纯模块文件是正常的)
if (result.error_message.find("Entry point") != std::string::npos &&
result.error_message.find("not found") != std::string::npos) {
// 模块文件没有入口点是正常行为,跳过编译
std::cout << "Warning: No entry point found (module file?), skipping: " << args.input_path << "\n";
return 0;
}
std::cerr << "Compilation failed:\n" << result.error_message << "\n";
return 1;
if (entry_points.empty()) {
std::cout << "Warning: No entry points found in " << args.input_path << "\n";
std::cout << " (This may be a module file with no [shader(...)] attributes)\n";
return 0;
}
// 写入 SPIR-V
if (!args.output_path.empty() && !result.spirv.empty()) {
// 创建输出目录
auto parent = args.output_path.parent_path();
if (!parent.empty()) {
std::filesystem::create_directories(parent);
}
if (!write_spirv(args.output_path, result.spirv)) {
return 1;
}
std::cout << "Generated: " << args.output_path << " (" << result.spirv.size() * 4 << " bytes)\n";
// 列出入口点
std::cout << "Found " << entry_points.size() << " entry point(s) in " << args.input_path.filename() << ":\n";
for (const auto& ep : entry_points) {
std::cout << " - " << ep.name << " [" << mirai::tools::shader_stage_to_string(ep.stage) << "]\n";
}
// 写入反射 JSON
if (!args.reflection_path.empty() && !result.reflection_json.empty()) {
auto parent = args.reflection_path.parent_path();
if (!parent.empty()) {
std::filesystem::create_directories(parent);
}
if (!write_text(args.reflection_path, result.reflection_json)) {
return 1;
}
std::cout << "Generated: " << args.reflection_path << "\n";
// 如果只是列出,不编译
if (args.list_only) {
return 0;
}
// 写入依赖文件
if (!args.deps_path.empty()) {
auto target = args.output_path.empty() ? args.reflection_path : args.output_path;
// 添加输入文件作为依赖
auto deps = result.dependencies;
deps.insert(deps.begin(), args.input_path);
if (!write_deps(args.deps_path, target, deps)) {
return 1;
}
// 确保输出目录存在
if (args.output_dir.empty()) {
args.output_dir = ".";
}
if (!std::filesystem::exists(args.output_dir)) {
std::filesystem::create_directories(args.output_dir);
}
return 0;
// 确定前缀
if (args.prefix.empty()) {
args.prefix = args.input_path.stem().string();
}
// 编译所有入口点
args.options.emit_spirv = true;
args.options.emit_reflection = true;
auto results = compiler.compile_file_all_entries(args.input_path, args.options);
int success_count = 0;
int fail_count = 0;
for (const auto& result : results) {
if (!result.success) {
std::cerr << "Error compiling " << result.entry_point << ": " << result.error_message << "\n";
fail_count++;
continue;
}
std::string suffix = stage_to_suffix(result.stage);
std::string base_name = args.prefix + "." + result.entry_point + "." + suffix;
// 写入 SPIR-V
std::filesystem::path spv_path = args.output_dir / (base_name + ".spv");
if (write_spirv(spv_path, result.spirv)) {
std::cout << "Generated: " << spv_path << " (" << result.spirv.size() * 4 << " bytes)\n";
}
// 写入反射 JSON
std::filesystem::path reflect_path = args.output_dir / (base_name + ".reflect.json");
if (write_text(reflect_path, result.reflection_json)) {
std::cout << "Generated: " << reflect_path << "\n";
}
success_count++;
}
std::cout << "\nSummary: " << success_count << " succeeded, " << fail_count << " failed\n";
return fail_count > 0 ? 1 : 0;
}