Files
mirai/tools/shader_compile/main.cpp
2026-01-04 09:58:08 +08:00

369 lines
12 KiB
C++

// tools/shader_compile/main.cpp
// MIRAI 着色器反射工具命令行入口
// 从 SPIR-V 文件提取反射信息并生成绑定代码
#include "compiler.hpp"
#include <cstring>
#include <fstream>
#include <iomanip>
#include <iostream>
#include <string>
namespace {
void print_usage(const char* program_name) {
std::cout << R"(
MIRAI Shader Reflection Tool
Usage: )" << program_name << R"( [options]
Options:
--dir <path> Directory containing SPIR-V files
--file <path> Single SPIR-V file to process
--output <path> Output file path (header or JSON)
--name <name> Shader group name (for header generation)
--json Output JSON reflection data
--header Output C++ header with bindings (default)
-h, --help Show this help
-v, --version Show version info
Examples:
)" << program_name << R"( --dir ./shaders --output bindings.hpp --name my_shader
)" << program_name << R"( --file shader.vert.spv --output shader.json --json
)" << program_name << R"( --dir ./compiled --output reflection.json --json
Output Formats:
Header (.hpp): C++ header with embedded SPIR-V and reflection data
JSON (.json): Structured reflection data for external tools
)";
}
void print_version() {
std::cout << "MIRAI Shader Reflection Tool v1.0.0\n";
std::cout << "Based on SPIRV-Cross\n";
}
struct command_line_args {
std::filesystem::path input_dir;
std::filesystem::path input_file;
std::filesystem::path output_path;
std::string name;
bool output_json = false;
bool output_header = true;
bool show_help = false;
bool show_version = false;
};
bool parse_args(int argc, char* argv[], command_line_args& args) {
for (int i = 1; i < argc; ++i) {
std::string arg = argv[i];
if (arg == "-h" || arg == "--help") {
args.show_help = true;
return true;
}
if (arg == "-v" || arg == "--version") {
args.show_version = true;
return true;
}
if (arg == "--dir") {
if (++i >= argc) {
std::cerr << "Error: --dir requires an argument\n";
return false;
}
args.input_dir = argv[i];
}
else if (arg == "--file") {
if (++i >= argc) {
std::cerr << "Error: --file requires an argument\n";
return false;
}
args.input_file = argv[i];
}
else if (arg == "--output" || arg == "-o") {
if (++i >= argc) {
std::cerr << "Error: --output requires an argument\n";
return false;
}
args.output_path = argv[i];
}
else if (arg == "--name") {
if (++i >= argc) {
std::cerr << "Error: --name requires an argument\n";
return false;
}
args.name = argv[i];
}
else if (arg == "--json") {
args.output_json = true;
args.output_header = false;
}
else if (arg == "--header") {
args.output_header = true;
args.output_json = false;
}
else {
std::cerr << "Error: Unknown option: " << arg << "\n";
return false;
}
}
return true;
}
bool write_text(const std::filesystem::path& path, const std::string& text) {
std::ofstream file(path);
if (!file) {
std::cerr << "Error: Failed to open output file: " << path << "\n";
return false;
}
file << text;
return file.good();
}
std::vector<uint8_t> read_binary_file(const std::filesystem::path& path) {
std::ifstream file(path, std::ios::binary | std::ios::ate);
if (!file) {
return {};
}
auto size = file.tellg();
file.seekg(0, std::ios::beg);
std::vector<uint8_t> data(size);
file.read(reinterpret_cast<char*>(data.data()), size);
return data;
}
std::string generate_header(
const std::string& name,
const std::vector<mirai::tools::reflection_result>& results
) {
std::ostringstream ss;
// Header guard
std::string guard_name = name;
for (auto& c : guard_name) {
c = static_cast<char>(std::toupper(static_cast<unsigned char>(c)));
}
ss << "// Auto-generated shader bindings for: " << name << "\n";
ss << "// DO NOT EDIT - Generated by mirai_shader_compile\n\n";
ss << "#pragma once\n\n";
ss << "#include <array>\n";
ss << "#include <cstdint>\n";
ss << "#include <span>\n";
ss << "#include <string_view>\n\n";
ss << "namespace mirai::shaders {\n\n";
ss << "namespace " << name << " {\n\n";
// Generate SPIR-V data for each shader
for (const auto& result : results) {
if (!result.success) {
continue;
}
const auto& reflection = result.reflection;
std::string stage_name = mirai::tools::shader_stage_to_string(reflection.stage);
ss << "// " << stage_name << " shader\n";
ss << "namespace " << stage_name << " {\n\n";
ss << "constexpr std::string_view entry_point = \"" << reflection.entry_point << "\";\n\n";
// Embed SPIR-V bytecode
if (!result.spirv_data.empty()) {
ss << "// SPIR-V bytecode (" << result.spirv_data.size() << " words, "
<< (result.spirv_data.size() * 4) << " bytes)\n";
ss << "inline constexpr std::array<uint32_t, " << result.spirv_data.size() << "> spirv_code = {\n";
// Write SPIR-V data in rows of 8 values
for (size_t i = 0; i < result.spirv_data.size(); ++i) {
if (i % 8 == 0) {
ss << " ";
}
ss << "0x" << std::hex << std::setfill('0') << std::setw(8) << result.spirv_data[i];
if (i + 1 < result.spirv_data.size()) {
ss << ",";
}
if ((i + 1) % 8 == 0 || i + 1 == result.spirv_data.size()) {
ss << "\n";
} else {
ss << " ";
}
}
ss << std::dec; // Reset to decimal
ss << "};\n\n";
// Provide a span view for easy access
ss << "inline constexpr std::span<const uint32_t> spirv() {\n";
ss << " return spirv_code;\n";
ss << "}\n\n";
// Provide byte view
ss << "inline const uint8_t* spirv_bytes() {\n";
ss << " return reinterpret_cast<const uint8_t*>(spirv_code.data());\n";
ss << "}\n\n";
ss << "inline constexpr size_t spirv_size() {\n";
ss << " return spirv_code.size() * sizeof(uint32_t);\n";
ss << "}\n\n";
}
// Uniform buffer info
if (!reflection.uniform_buffers.empty()) {
ss << "// Uniform Buffers\n";
for (const auto& ub : reflection.uniform_buffers) {
ss << "struct " << ub.name << " {\n";
ss << " static constexpr uint32_t set = " << ub.set << ";\n";
ss << " static constexpr uint32_t binding = " << ub.binding << ";\n";
ss << " static constexpr uint32_t size = " << ub.size << ";\n";
ss << "};\n\n";
}
}
// Push constant info
if (!reflection.push_constants.empty()) {
ss << "// Push Constants\n";
for (const auto& pc : reflection.push_constants) {
ss << "struct " << pc.name << " {\n";
ss << " static constexpr uint32_t size = " << pc.size << ";\n";
ss << "};\n\n";
}
}
// Sampler info
if (!reflection.samplers.empty()) {
ss << "// Samplers\n";
for (const auto& sampler : reflection.samplers) {
ss << "struct " << sampler.name << "_info {\n";
ss << " static constexpr uint32_t set = " << sampler.set << ";\n";
ss << " static constexpr uint32_t binding = " << sampler.binding << ";\n";
ss << " static constexpr std::string_view dimension = \"" << sampler.dimension << "\";\n";
ss << "};\n\n";
}
}
ss << "} // namespace " << stage_name << "\n\n";
}
// Reflection JSON
ss << "// Combined reflection data\n";
ss << "constexpr std::string_view reflection_json = R\"JSON(\n";
std::vector<mirai::tools::shader_reflection> reflections;
for (const auto& result : results) {
if (result.success) {
reflections.push_back(result.reflection);
}
}
ss << mirai::tools::reflections_to_json(reflections);
ss << "\n)JSON\";\n\n";
ss << "} // namespace " << name << "\n\n";
ss << "} // namespace mirai::shaders\n";
return ss.str();
}
} // anonymous namespace
int main(int argc, char* argv[]) {
command_line_args args;
if (!parse_args(argc, argv, args)) {
return 1;
}
if (args.show_help) {
print_usage(argv[0]);
return 0;
}
if (args.show_version) {
print_version();
return 0;
}
if (args.input_dir.empty() && args.input_file.empty()) {
std::cerr << "Error: No input specified. Use --dir or --file.\n";
print_usage(argv[0]);
return 1;
}
if (args.output_path.empty()) {
std::cerr << "Error: No output path specified. Use --output.\n";
return 1;
}
// Create reflector
mirai::tools::spirv_reflector reflector;
std::vector<mirai::tools::reflection_result> results;
// Process input
if (!args.input_dir.empty()) {
results = reflector.reflect_directory(args.input_dir);
} else if (!args.input_file.empty()) {
auto result = reflector.reflect_file(args.input_file);
results.push_back(result);
}
if (results.empty()) {
std::cout << "Warning: No SPIR-V files found to process.\n";
return 0;
}
// Check for errors
int error_count = 0;
for (const auto& result : results) {
if (!result.success) {
std::cerr << "Error: " << result.error_message << "\n";
error_count++;
}
}
// Generate output
std::string output;
if (args.output_json) {
std::vector<mirai::tools::shader_reflection> reflections;
for (const auto& result : results) {
if (result.success) {
reflections.push_back(result.reflection);
}
}
output = mirai::tools::reflections_to_json(reflections);
} else {
// Header output
if (args.name.empty()) {
// Try to derive name from output path
args.name = args.output_path.stem().string();
// Remove _bindings suffix if present
if (args.name.ends_with("_bindings")) {
args.name = args.name.substr(0, args.name.length() - 9);
}
}
output = generate_header(args.name, results);
}
// Write output
if (!write_text(args.output_path, output)) {
return 1;
}
std::cout << "Generated: " << args.output_path << "\n";
int success_count = static_cast<int>(results.size()) - error_count;
std::cout << "Processed " << success_count << " shader(s)";
if (error_count > 0) {
std::cout << " (" << error_count << " error(s))";
}
std::cout << "\n";
return error_count > 0 ? 1 : 0;
}