Files
mirai/tools/shader_compile/compiler.cpp

1079 lines
42 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
// tools/shader_compile/compiler.cpp
// MIRAI Shader Compiler Implementation
#include "compiler.hpp"
#include <slang.h>
#include <slang-com-ptr.h>
#include <nlohmann/json.hpp>
#include <fstream>
#include <iostream>
#include <sstream>
namespace mirai::tools {
// ============================================================================
// Helper Functions
// ============================================================================
std::optional<shader_stage> parse_shader_stage(const std::string& str) {
if (str == "vertex" || str == "vert" || str == "vs") {
return shader_stage::vertex;
}
if (str == "fragment" || str == "frag" || str == "pixel" || str == "fs" || str == "ps") {
return shader_stage::fragment;
}
if (str == "compute" || str == "comp" || str == "cs") {
return shader_stage::compute;
}
if (str == "geometry" || str == "geom" || str == "gs") {
return shader_stage::geometry;
}
if (str == "tesscontrol" || str == "tesc" || str == "hull" || str == "hs") {
return shader_stage::tessellation_control;
}
if (str == "tesseval" || str == "tese" || str == "domain" || str == "ds") {
return shader_stage::tessellation_evaluation;
}
return std::nullopt;
}
const char* shader_stage_to_string(shader_stage stage) {
switch (stage) {
case shader_stage::vertex: return "vertex";
case shader_stage::fragment: return "fragment";
case shader_stage::compute: return "compute";
case shader_stage::geometry: return "geometry";
case shader_stage::tessellation_control: return "tesscontrol";
case shader_stage::tessellation_evaluation: return "tesseval";
}
return "unknown";
}
static SlangStage to_slang_stage(shader_stage stage) {
switch (stage) {
case shader_stage::vertex: return SLANG_STAGE_VERTEX;
case shader_stage::fragment: return SLANG_STAGE_FRAGMENT;
case shader_stage::compute: return SLANG_STAGE_COMPUTE;
case shader_stage::geometry: return SLANG_STAGE_GEOMETRY;
case shader_stage::tessellation_control: return SLANG_STAGE_HULL;
case shader_stage::tessellation_evaluation: return SLANG_STAGE_DOMAIN;
}
return SLANG_STAGE_NONE;
}
std::optional<shader_stage> slang_stage_to_shader_stage(SlangStage stage) {
switch (stage) {
case SLANG_STAGE_VERTEX: return shader_stage::vertex;
case SLANG_STAGE_FRAGMENT: return shader_stage::fragment;
case SLANG_STAGE_COMPUTE: return shader_stage::compute;
case SLANG_STAGE_GEOMETRY: return shader_stage::geometry;
case SLANG_STAGE_HULL: return shader_stage::tessellation_control;
case SLANG_STAGE_DOMAIN: return shader_stage::tessellation_evaluation;
default: return std::nullopt;
}
}
// ============================================================================
// Implementation Class
// ============================================================================
struct shader_compiler::impl {
Slang::ComPtr<slang::IGlobalSession> global_session;
impl() {
SlangGlobalSessionDesc desc = {};
slang::createGlobalSession(&desc, global_session.writeRef());
}
~impl() = default;
compile_result compile(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
compile_result result;
if (!global_session) {
result.error_message = "Slang global session not initialized";
return result;
}
// Create session options
slang::SessionDesc session_desc = {};
// Set target to SPIR-V
slang::TargetDesc target_desc = {};
target_desc.format = SLANG_SPIRV;
target_desc.profile = global_session->findProfile("glsl_450");
if (options.generate_debug_info) {
target_desc.flags |= SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY;
}
session_desc.targets = &target_desc;
session_desc.targetCount = 1;
// Store include paths strings to keep them alive
std::vector<std::string> include_path_strs;
std::vector<const char*> search_paths;
for (const auto& path : options.include_paths) {
include_path_strs.push_back(path.string());
search_paths.push_back(include_path_strs.back().c_str());
}
session_desc.searchPaths = search_paths.data();
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
// Store macro strings to keep them alive
std::vector<std::pair<std::string, std::string>> macro_strs;
std::vector<slang::PreprocessorMacroDesc> macros;
for (const auto& [name, value] : options.defines) {
macro_strs.emplace_back(name, value);
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
}
session_desc.preprocessorMacros = macros.data();
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
// Create session
Slang::ComPtr<slang::ISession> session;
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
result.error_message = "Failed to create Slang session";
return result;
}
// Load module
Slang::ComPtr<slang::IBlob> diagnostics_blob;
slang::IModule* module = session->loadModuleFromSourceString(
filename.c_str(),
filename.c_str(),
source.c_str(),
diagnostics_blob.writeRef()
);
if (diagnostics_blob) {
result.error_message = static_cast<const char*>(diagnostics_blob->getBufferPointer());
}
if (!module) {
if (result.error_message.empty()) {
result.error_message = "Failed to load shader module";
}
return result;
}
// Find entry point
Slang::ComPtr<slang::IEntryPoint> entry_point;
SlangStage slang_stage = options.stage ? to_slang_stage(*options.stage) : SLANG_STAGE_NONE;
SlangResult find_result = module->findEntryPointByName(options.entry_point.c_str(), entry_point.writeRef());
if (SLANG_FAILED(find_result)) {
// Try to find and check entry point with stage
if (options.stage) {
find_result = module->findAndCheckEntryPoint(
options.entry_point.c_str(),
slang_stage,
entry_point.writeRef(),
diagnostics_blob.writeRef()
);
if (SLANG_FAILED(find_result)) {
result.error_message = "Entry point '" + options.entry_point + "' not found";
if (diagnostics_blob) {
result.error_message += ": ";
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
}
return result;
}
} else {
result.error_message = "Entry point '" + options.entry_point + "' not found";
return result;
}
}
// Create composite program
std::vector<slang::IComponentType*> components = {module, entry_point.get()};
Slang::ComPtr<slang::IComponentType> program;
if (SLANG_FAILED(session->createCompositeComponentType(
components.data(),
static_cast<SlangInt>(components.size()),
program.writeRef(),
diagnostics_blob.writeRef()
))) {
result.error_message = "Failed to create composite component type";
if (diagnostics_blob) {
result.error_message += ": ";
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
}
return result;
}
// Link program
Slang::ComPtr<slang::IComponentType> linked_program;
if (SLANG_FAILED(program->link(linked_program.writeRef(), diagnostics_blob.writeRef()))) {
result.error_message = "Failed to link shader program";
if (diagnostics_blob) {
result.error_message += ": ";
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
}
return result;
}
// Generate SPIR-V code
if (options.emit_spirv) {
Slang::ComPtr<slang::IBlob> spirv_blob;
if (SLANG_FAILED(linked_program->getEntryPointCode(
0, // entry point index
0, // target index
spirv_blob.writeRef(),
diagnostics_blob.writeRef()
))) {
result.error_message = "Failed to generate SPIR-V code";
if (diagnostics_blob) {
result.error_message += ": ";
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
}
return result;
}
// Copy SPIR-V data
const uint32_t* spirv_data = static_cast<const uint32_t*>(spirv_blob->getBufferPointer());
size_t spirv_size = spirv_blob->getBufferSize() / sizeof(uint32_t);
result.spirv.assign(spirv_data, spirv_data + spirv_size);
}
// Generate reflection data
if (options.emit_reflection) {
result.reflection_json = generate_reflection_json(linked_program.get(), options.entry_point);
}
result.success = true;
return result;
}
std::string generate_reflection_json(slang::IComponentType* program, const std::string& entry_point_name) {
slang::ProgramLayout* layout = program->getLayout();
if (!layout) {
return "{}";
}
nlohmann::json root;
root["entryPoint"] = entry_point_name;
// Get parameter count
unsigned param_count = layout->getParameterCount();
// Uniform Buffers
nlohmann::json uniform_buffers = nlohmann::json::array();
for (unsigned i = 0; i < param_count; ++i) {
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
if (!param) continue;
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
if (!type_layout) continue;
// Check if this is a Uniform Buffer
slang::TypeReflection::Kind kind = type_layout->getKind();
if (kind == slang::TypeReflection::Kind::ConstantBuffer ||
kind == slang::TypeReflection::Kind::ParameterBlock) {
nlohmann::json ub;
ub["name"] = param->getName() ? param->getName() : "unnamed";
ub["set"] = param->getBindingSpace();
ub["binding"] = param->getBindingIndex();
// Calculate size with fallback
size_t ub_size = type_layout->getSize();
if (ub_size == 0) {
// Try to get size from element type layout
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
if (element_type) {
ub_size = element_type->getSize();
}
}
if (ub_size == 0) {
// Fallback: calculate size from members (with STD140 padding)
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
if (element_type) {
int field_count = element_type->getFieldCount();
size_t max_offset = 0;
for (int j = 0; j < field_count; ++j) {
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
if (!field) continue;
size_t offset = field->getOffset();
size_t size = field->getTypeLayout()->getSize();
// Round up to 16-byte alignment for struct (STD140 rule)
size_t field_end = offset + size;
if (field_end > max_offset) {
max_offset = field_end;
}
}
// Round up to multiple of 16 for uniform buffer alignment
ub_size = (max_offset + 15) & ~15ULL;
}
}
ub["size"] = ub_size;
// Get members
nlohmann::json members = nlohmann::json::array();
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
if (element_type) {
int field_count = element_type->getFieldCount();
for (int j = 0; j < field_count; ++j) {
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
if (!field) continue;
nlohmann::json member;
member["name"] = field->getName() ? field->getName() : "unnamed";
member["type"] = get_type_name(field->getTypeLayout());
member["offset"] = field->getOffset();
member["size"] = field->getTypeLayout()->getSize();
members.push_back(member);
}
}
ub["members"] = members;
uniform_buffers.push_back(ub);
}
}
root["uniformBuffers"] = uniform_buffers;
// Samplers
nlohmann::json samplers = nlohmann::json::array();
for (unsigned i = 0; i < param_count; ++i) {
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
if (!param) continue;
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
if (!type_layout) continue;
slang::TypeReflection::Kind kind = type_layout->getKind();
if (kind == slang::TypeReflection::Kind::Resource) {
// Get resource shape using SlangResourceShape
SlangResourceShape shape = type_layout->getResourceShape();
SlangResourceShape base_shape = static_cast<SlangResourceShape>(shape & SLANG_RESOURCE_BASE_SHAPE_MASK);
if (base_shape == SLANG_TEXTURE_1D ||
base_shape == SLANG_TEXTURE_2D ||
base_shape == SLANG_TEXTURE_3D ||
base_shape == SLANG_TEXTURE_CUBE) {
nlohmann::json sampler;
sampler["name"] = param->getName() ? param->getName() : "unnamed";
sampler["set"] = param->getBindingSpace();
sampler["binding"] = param->getBindingIndex();
sampler["dimension"] = get_texture_dimension(base_shape);
samplers.push_back(sampler);
}
}
}
root["samplers"] = samplers;
// Debug: Print all parameters for analysis
for (unsigned i = 0; i < param_count; ++i) {
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
if (!param) continue;
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
if (!type_layout) continue;
slang::TypeReflection::Kind kind = type_layout->getKind();
slang::TypeReflection* type = type_layout->getType();
const char* param_name = param->getName();
const char* type_name = type ? type->getName() : "null";
unsigned binding_space = param->getBindingSpace();
unsigned binding_index = param->getBindingIndex();
// Debug output - can be enabled for troubleshooting
// printf("DEBUG: param[%u] name=%s type=%s kind=%d space=%u index=%u\n",
// i, param_name ? param_name : "null", type_name, (int)kind, binding_space, binding_index);
}
// Push Constants - try to find parameters that look like push constants
nlohmann::json push_constants = nlohmann::json::array();
for (unsigned i = 0; i < param_count; ++i) {
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
if (!param) continue;
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
if (!type_layout) continue;
slang::TypeReflection::Kind kind = type_layout->getKind();
slang::TypeReflection* type = type_layout->getType();
// Method 1: Check for ConstantBuffer with struct element (likely push constant)
if (kind == slang::TypeReflection::Kind::ConstantBuffer) {
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
if (element_type) {
slang::TypeReflection* element_type_ref = element_type->getType();
if (element_type_ref && element_type_ref->getKind() == slang::TypeReflection::Kind::Struct) {
// Check if this could be a push constant
// Push constants typically have:
// - No binding index (or very small)
// - Element type is a struct
// - Name might contain "pc" or "push" or match the expected struct name
const char* type_name = type ? type->getName() : "";
const char* param_name = param->getName();
// Check if name contains "pc" or "push" or matches "PushConstants"
bool is_likely_push_constant = false;
if (param_name && (strstr(param_name, "pc") || strstr(param_name, "push") || strstr(param_name, "PushConstant"))) {
is_likely_push_constant = true;
}
if (type_name && (strstr(type_name, "pc") || strstr(type_name, "push") || strstr(type_name, "PushConstant"))) {
is_likely_push_constant = true;
}
// Also check if binding space is 0 (no descriptor set assigned)
// and this is the only ConstantBuffer without a proper binding
unsigned binding_space = param->getBindingSpace();
unsigned binding_index = param->getBindingIndex();
// If binding space is 0 but binding index is 0, and there's a struct element,
// this might be a push constant (descriptors usually have higher binding indices)
if (binding_space == 0 && binding_index == 0) {
// This could be a push constant - let's also check it's NOT in uniform buffers
// by checking if we've already processed it as a regular uniform buffer
is_likely_push_constant = true;
}
if (is_likely_push_constant) {
nlohmann::json pc;
pc["name"] = param_name ? param_name : (type_name ? type_name : "PushConstants");
// Calculate size
size_t pc_size = type_layout->getSize();
if (pc_size == 0) {
pc_size = element_type->getSize();
}
if (pc_size == 0) {
// Fallback: calculate size from members
int field_count = element_type->getFieldCount();
size_t max_offset = 0;
for (int j = 0; j < field_count; ++j) {
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
if (!field) continue;
size_t offset = field->getOffset();
size_t size = field->getTypeLayout()->getSize();
size_t field_end = offset + size;
if (field_end > max_offset) {
max_offset = field_end;
}
}
pc_size = (max_offset + 15) & ~15ULL;
}
pc["size"] = pc_size;
// Get members
nlohmann::json members = nlohmann::json::array();
int field_count = element_type->getFieldCount();
for (int j = 0; j < field_count; ++j) {
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
if (!field) continue;
nlohmann::json member;
member["name"] = field->getName() ? field->getName() : "unnamed";
member["type"] = get_type_name(field->getTypeLayout());
member["offset"] = field->getOffset();
member["size"] = field->getTypeLayout()->getSize();
members.push_back(member);
}
pc["members"] = members;
push_constants.push_back(pc);
}
}
}
}
}
root["pushConstants"] = push_constants;
return root.dump(2);
}
const char* get_type_name(slang::TypeLayoutReflection* type_layout) {
if (!type_layout) return "unknown";
slang::TypeReflection* type = type_layout->getType();
if (!type) return "unknown";
switch (type->getKind()) {
case slang::TypeReflection::Kind::Scalar: {
switch (type->getScalarType()) {
case slang::TypeReflection::ScalarType::Float32: return "float";
case slang::TypeReflection::ScalarType::Int32: return "int";
case slang::TypeReflection::ScalarType::UInt32: return "uint";
case slang::TypeReflection::ScalarType::Bool: return "bool";
default: return "scalar";
}
}
case slang::TypeReflection::Kind::Vector: {
unsigned count = type->getElementCount();
switch (type->getScalarType()) {
case slang::TypeReflection::ScalarType::Float32:
switch (count) {
case 2: return "float2";
case 3: return "float3";
case 4: return "float4";
default: return "floatN";
}
case slang::TypeReflection::ScalarType::Int32:
switch (count) {
case 2: return "int2";
case 3: return "int3";
case 4: return "int4";
default: return "intN";
}
case slang::TypeReflection::ScalarType::UInt32:
switch (count) {
case 2: return "uint2";
case 3: return "uint3";
case 4: return "uint4";
default: return "uintN";
}
default: return "vector";
}
}
case slang::TypeReflection::Kind::Matrix: {
unsigned rows = type->getRowCount();
unsigned cols = type->getColumnCount();
if (rows == 4 && cols == 4) return "float4x4";
if (rows == 3 && cols == 3) return "float3x3";
if (rows == 2 && cols == 2) return "float2x2";
return "matrix";
}
case slang::TypeReflection::Kind::Struct:
return type->getName() ? type->getName() : "struct";
default:
return "unknown";
}
}
const char* get_texture_dimension(SlangResourceShape base_shape) {
switch (base_shape) {
case SLANG_TEXTURE_1D: return "1D";
case SLANG_TEXTURE_2D: return "2D";
case SLANG_TEXTURE_3D: return "3D";
case SLANG_TEXTURE_CUBE: return "Cube";
default: return "2D";
}
}
// 检查是否存在入口函数
std::optional<bool> has_entry_point(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
auto entry_points = get_entry_points(source, filename, options);
return !entry_points.empty();
}
// 获取所有入口函数
std::vector<std::string> get_entry_points(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
std::vector<std::string> result;
if (!global_session) {
return result;
}
// 创建 session
slang::SessionDesc session_desc = {};
slang::TargetDesc target_desc = {};
target_desc.format = SLANG_SPIRV;
target_desc.profile = global_session->findProfile("glsl_450");
session_desc.targets = &target_desc;
session_desc.targetCount = 1;
std::vector<std::string> include_path_strs;
std::vector<const char*> search_paths;
for (const auto& path : options.include_paths) {
include_path_strs.push_back(path.string());
search_paths.push_back(include_path_strs.back().c_str());
}
session_desc.searchPaths = search_paths.data();
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
std::vector<std::pair<std::string, std::string>> macro_strs;
std::vector<slang::PreprocessorMacroDesc> macros;
for (const auto& [name, value] : options.defines) {
macro_strs.emplace_back(name, value);
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
}
session_desc.preprocessorMacros = macros.data();
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
Slang::ComPtr<slang::ISession> session;
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
return result;
}
// 加载模块
Slang::ComPtr<slang::IBlob> diagnostics_blob;
slang::IModule* module = session->loadModuleFromSourceString(
filename.c_str(),
filename.c_str(),
source.c_str(),
diagnostics_blob.writeRef()
);
if (!module) {
return result;
}
// 获取入口点列表
SlangInt32 entry_point_count = module->getDefinedEntryPointCount();
// 常见入口点名称模式
const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain",
"vertexShader", "fragmentShader", "computeShader"};
const SlangStage stages[] = {
SLANG_STAGE_VERTEX,
SLANG_STAGE_FRAGMENT,
SLANG_STAGE_COMPUTE,
SLANG_STAGE_GEOMETRY,
SLANG_STAGE_HULL,
SLANG_STAGE_DOMAIN
};
// 尝试每个可能的入口点名称和阶段的组合
for (const char* name : common_names) {
for (SlangStage stage : stages) {
Slang::ComPtr<slang::IEntryPoint> checked_entry;
SlangResult find_result = module->findAndCheckEntryPoint(
name,
stage,
checked_entry.writeRef(),
diagnostics_blob.writeRef()
);
if (SLANG_FAILED(find_result) || !checked_entry) {
continue;
}
// 验证并编译这个入口点来获取名称
std::vector<slang::IComponentType*> comps = {module, checked_entry.get()};
Slang::ComPtr<slang::IComponentType> composite;
if (SLANG_FAILED(session->createCompositeComponentType(
comps.data(),
static_cast<SlangInt>(comps.size()),
composite.writeRef(),
diagnostics_blob.writeRef()
))) {
continue;
}
Slang::ComPtr<slang::IComponentType> linked;
if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 生成代码来触发 layout 生成
Slang::ComPtr<slang::IBlob> spirv_blob;
if (SLANG_FAILED(linked->getEntryPointCode(
0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 获取 layout
slang::ProgramLayout* layout = linked->getLayout();
if (!layout || layout->getEntryPointCount() == 0) {
continue;
}
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
if (!ep_ref) {
continue;
}
const char* entry_name = ep_ref->getName();
if (entry_name && entry_name[0] != '\0') {
// 检查是否已经添加
bool exists = false;
for (const auto& existing : result) {
if (existing == entry_name) {
exists = true;
break;
}
}
if (!exists) {
result.push_back(entry_name);
}
}
}
}
return result;
}
// 获取所有入口函数及其阶段
std::vector<entry_point_info> get_entry_points_with_stages(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
std::vector<entry_point_info> result;
if (!global_session) {
return result;
}
// 创建 session
slang::SessionDesc session_desc = {};
slang::TargetDesc target_desc = {};
target_desc.format = SLANG_SPIRV;
target_desc.profile = global_session->findProfile("glsl_450");
session_desc.targets = &target_desc;
session_desc.targetCount = 1;
std::vector<std::string> include_path_strs;
std::vector<const char*> search_paths;
for (const auto& path : options.include_paths) {
include_path_strs.push_back(path.string());
search_paths.push_back(include_path_strs.back().c_str());
}
session_desc.searchPaths = search_paths.data();
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
std::vector<std::pair<std::string, std::string>> macro_strs;
std::vector<slang::PreprocessorMacroDesc> macros;
for (const auto& [name, value] : options.defines) {
macro_strs.emplace_back(name, value);
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
}
session_desc.preprocessorMacros = macros.data();
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
Slang::ComPtr<slang::ISession> session;
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
return result;
}
// 加载模块
Slang::ComPtr<slang::IBlob> diagnostics_blob;
slang::IModule* module = session->loadModuleFromSourceString(
filename.c_str(),
filename.c_str(),
source.c_str(),
diagnostics_blob.writeRef()
);
if (!module) {
return result;
}
// 获取所有定义的入口点数量
SlangInt32 entry_point_count = module->getDefinedEntryPointCount();
if (entry_point_count == 0) {
return result;
}
// 常见入口点名称模式
const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain",
"vertexShader", "fragmentShader", "computeShader"};
const SlangStage stages[] = {
SLANG_STAGE_VERTEX,
SLANG_STAGE_FRAGMENT,
SLANG_STAGE_COMPUTE,
SLANG_STAGE_GEOMETRY,
SLANG_STAGE_HULL,
SLANG_STAGE_DOMAIN
};
// 尝试每个可能的入口点名称和阶段的组合
for (const char* name : common_names) {
for (SlangStage stage : stages) {
// 如果已经找到所有入口点,提前退出
if (static_cast<SlangInt32>(result.size()) >= entry_point_count) {
break;
}
Slang::ComPtr<slang::IEntryPoint> checked_entry;
SlangResult find_result = module->findAndCheckEntryPoint(
name,
stage,
checked_entry.writeRef(),
diagnostics_blob.writeRef()
);
if (SLANG_FAILED(find_result) || !checked_entry) {
continue;
}
// 验证并编译这个入口点来获取名称
std::vector<slang::IComponentType*> comps = {module, checked_entry.get()};
Slang::ComPtr<slang::IComponentType> composite;
if (SLANG_FAILED(session->createCompositeComponentType(
comps.data(),
static_cast<SlangInt>(comps.size()),
composite.writeRef(),
diagnostics_blob.writeRef()
))) {
continue;
}
Slang::ComPtr<slang::IComponentType> linked;
if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 生成代码来触发 layout 生成
Slang::ComPtr<slang::IBlob> spirv_blob;
if (SLANG_FAILED(linked->getEntryPointCode(
0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) {
continue;
}
// 获取 layout
slang::ProgramLayout* layout = linked->getLayout();
if (!layout || layout->getEntryPointCount() == 0) {
continue;
}
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
if (!ep_ref) {
continue;
}
const char* entry_name = ep_ref->getName();
SlangStage slang_stage = ep_ref->getStage();
if (entry_name && entry_name[0] != '\0') {
// 检查是否已经添加
bool exists = false;
for (const auto& existing : result) {
if (existing.name == entry_name) {
exists = true;
break;
}
}
if (!exists) {
auto shader_stage_opt = slang_stage_to_shader_stage(slang_stage);
if (shader_stage_opt) {
entry_point_info info;
info.name = entry_name;
info.stage = *shader_stage_opt;
result.push_back(info);
}
}
}
}
}
return result;
}
// 编译所有入口点
std::vector<compile_result> compile_all_entries(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
std::vector<compile_result> results;
// 获取所有入口点及其阶段
auto entry_points = get_entry_points_with_stages(source, filename, options);
for (const auto& ep_info : entry_points) {
compile_options ep_options = options;
ep_options.entry_point = ep_info.name;
ep_options.stage = ep_info.stage;
compile_result result = compile(source, filename, ep_options);
result.entry_point = ep_info.name;
result.stage = ep_info.stage;
results.push_back(result);
}
return results;
}
};
// ============================================================================
// shader_compiler 实现
// ============================================================================
shader_compiler::shader_compiler()
: impl_(std::make_unique<impl>()) {
}
shader_compiler::~shader_compiler() = default;
shader_compiler::shader_compiler(shader_compiler&&) noexcept = default;
shader_compiler& shader_compiler::operator=(shader_compiler&&) noexcept = default;
compile_result shader_compiler::compile_file(
const std::filesystem::path& path,
const compile_options& options
) {
compile_result result;
// 读取文件内容
std::ifstream file(path);
if (!file) {
result.error_message = "Failed to open file: " + path.string();
return result;
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
// 添加文件所在目录到包含路径
compile_options modified_options = options;
modified_options.include_paths.insert(
modified_options.include_paths.begin(),
path.parent_path()
);
return impl_->compile(source, path.filename().string(), modified_options);
}
compile_result shader_compiler::compile_source(
const std::string& source,
const std::string& filename,
const compile_options& options
) {
return impl_->compile(source, filename, options);
}
std::optional<std::string> shader_compiler::generate_reflection(
const std::filesystem::path& path,
const compile_options& options
) {
compile_options reflection_options = options;
reflection_options.emit_spirv = false;
reflection_options.emit_reflection = true;
auto result = compile_file(path, reflection_options);
if (result.success) {
return result.reflection_json;
}
return std::nullopt;
}
std::optional<bool> shader_compiler::has_entry_point(
const std::filesystem::path& path,
const compile_options& options
) {
// 首先检查文件是否存在
if (!std::filesystem::exists(path)) {
return std::nullopt; // 无法检查
}
// 读取文件内容
std::ifstream file(path);
if (!file) {
return std::nullopt; // 无法检查
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
return impl_->has_entry_point(source, path.filename().string(), options);
}
std::vector<std::string> shader_compiler::get_entry_points(
const std::filesystem::path& path,
const compile_options& options
) {
// 首先检查文件是否存在
if (!std::filesystem::exists(path)) {
return {};
}
// 读取文件内容
std::ifstream file(path);
if (!file) {
return {};
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
return impl_->get_entry_points(source, path.filename().string(), options);
}
std::vector<entry_point_info> shader_compiler::get_entry_points_with_stages(
const std::filesystem::path& path,
const compile_options& options
) {
// 首先检查文件是否存在
if (!std::filesystem::exists(path)) {
return {};
}
// 读取文件内容
std::ifstream file(path);
if (!file) {
return {};
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
return impl_->get_entry_points_with_stages(source, path.filename().string(), options);
}
std::vector<compile_result> shader_compiler::compile_file_all_entries(
const std::filesystem::path& path,
const compile_options& options
) {
// 首先检查文件是否存在
if (!std::filesystem::exists(path)) {
return {};
}
// 读取文件内容
std::ifstream file(path);
if (!file) {
return {};
}
std::ostringstream ss;
ss << file.rdbuf();
std::string source = ss.str();
// 添加文件所在目录到包含路径
compile_options modified_options = options;
modified_options.include_paths.insert(
modified_options.include_paths.begin(),
path.parent_path()
);
return impl_->compile_all_entries(source, path.filename().string(), modified_options);
}
bool shader_compiler::is_available() const noexcept {
return impl_ && impl_->global_session;
}
std::string shader_compiler::get_version() const {
if (!impl_ || !impl_->global_session) {
return "unavailable";
}
// Slang 没有直接的版本 API返回固定字符串
return "Slang Shader Compiler";
}
} // namespace mirai::tools