1079 lines
42 KiB
C++
1079 lines
42 KiB
C++
// tools/shader_compile/compiler.cpp
|
||
// MIRAI Shader Compiler Implementation
|
||
|
||
#include "compiler.hpp"
|
||
|
||
#include <slang.h>
|
||
#include <slang-com-ptr.h>
|
||
|
||
#include <nlohmann/json.hpp>
|
||
|
||
#include <fstream>
|
||
#include <iostream>
|
||
#include <sstream>
|
||
|
||
namespace mirai::tools {
|
||
|
||
// ============================================================================
|
||
// Helper Functions
|
||
// ============================================================================
|
||
|
||
std::optional<shader_stage> parse_shader_stage(const std::string& str) {
|
||
if (str == "vertex" || str == "vert" || str == "vs") {
|
||
return shader_stage::vertex;
|
||
}
|
||
if (str == "fragment" || str == "frag" || str == "pixel" || str == "fs" || str == "ps") {
|
||
return shader_stage::fragment;
|
||
}
|
||
if (str == "compute" || str == "comp" || str == "cs") {
|
||
return shader_stage::compute;
|
||
}
|
||
if (str == "geometry" || str == "geom" || str == "gs") {
|
||
return shader_stage::geometry;
|
||
}
|
||
if (str == "tesscontrol" || str == "tesc" || str == "hull" || str == "hs") {
|
||
return shader_stage::tessellation_control;
|
||
}
|
||
if (str == "tesseval" || str == "tese" || str == "domain" || str == "ds") {
|
||
return shader_stage::tessellation_evaluation;
|
||
}
|
||
return std::nullopt;
|
||
}
|
||
|
||
const char* shader_stage_to_string(shader_stage stage) {
|
||
switch (stage) {
|
||
case shader_stage::vertex: return "vertex";
|
||
case shader_stage::fragment: return "fragment";
|
||
case shader_stage::compute: return "compute";
|
||
case shader_stage::geometry: return "geometry";
|
||
case shader_stage::tessellation_control: return "tesscontrol";
|
||
case shader_stage::tessellation_evaluation: return "tesseval";
|
||
}
|
||
return "unknown";
|
||
}
|
||
|
||
static SlangStage to_slang_stage(shader_stage stage) {
|
||
switch (stage) {
|
||
case shader_stage::vertex: return SLANG_STAGE_VERTEX;
|
||
case shader_stage::fragment: return SLANG_STAGE_FRAGMENT;
|
||
case shader_stage::compute: return SLANG_STAGE_COMPUTE;
|
||
case shader_stage::geometry: return SLANG_STAGE_GEOMETRY;
|
||
case shader_stage::tessellation_control: return SLANG_STAGE_HULL;
|
||
case shader_stage::tessellation_evaluation: return SLANG_STAGE_DOMAIN;
|
||
}
|
||
return SLANG_STAGE_NONE;
|
||
}
|
||
|
||
std::optional<shader_stage> slang_stage_to_shader_stage(SlangStage stage) {
|
||
switch (stage) {
|
||
case SLANG_STAGE_VERTEX: return shader_stage::vertex;
|
||
case SLANG_STAGE_FRAGMENT: return shader_stage::fragment;
|
||
case SLANG_STAGE_COMPUTE: return shader_stage::compute;
|
||
case SLANG_STAGE_GEOMETRY: return shader_stage::geometry;
|
||
case SLANG_STAGE_HULL: return shader_stage::tessellation_control;
|
||
case SLANG_STAGE_DOMAIN: return shader_stage::tessellation_evaluation;
|
||
default: return std::nullopt;
|
||
}
|
||
}
|
||
|
||
// ============================================================================
|
||
// Implementation Class
|
||
// ============================================================================
|
||
|
||
struct shader_compiler::impl {
|
||
Slang::ComPtr<slang::IGlobalSession> global_session;
|
||
|
||
impl() {
|
||
SlangGlobalSessionDesc desc = {};
|
||
slang::createGlobalSession(&desc, global_session.writeRef());
|
||
}
|
||
|
||
~impl() = default;
|
||
|
||
compile_result compile(
|
||
const std::string& source,
|
||
const std::string& filename,
|
||
const compile_options& options
|
||
) {
|
||
compile_result result;
|
||
|
||
if (!global_session) {
|
||
result.error_message = "Slang global session not initialized";
|
||
return result;
|
||
}
|
||
|
||
// Create session options
|
||
slang::SessionDesc session_desc = {};
|
||
|
||
// Set target to SPIR-V
|
||
slang::TargetDesc target_desc = {};
|
||
target_desc.format = SLANG_SPIRV;
|
||
target_desc.profile = global_session->findProfile("glsl_450");
|
||
|
||
if (options.generate_debug_info) {
|
||
target_desc.flags |= SLANG_TARGET_FLAG_GENERATE_SPIRV_DIRECTLY;
|
||
}
|
||
|
||
session_desc.targets = &target_desc;
|
||
session_desc.targetCount = 1;
|
||
|
||
// Store include paths strings to keep them alive
|
||
std::vector<std::string> include_path_strs;
|
||
std::vector<const char*> search_paths;
|
||
for (const auto& path : options.include_paths) {
|
||
include_path_strs.push_back(path.string());
|
||
search_paths.push_back(include_path_strs.back().c_str());
|
||
}
|
||
session_desc.searchPaths = search_paths.data();
|
||
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
|
||
|
||
// Store macro strings to keep them alive
|
||
std::vector<std::pair<std::string, std::string>> macro_strs;
|
||
std::vector<slang::PreprocessorMacroDesc> macros;
|
||
for (const auto& [name, value] : options.defines) {
|
||
macro_strs.emplace_back(name, value);
|
||
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
|
||
}
|
||
session_desc.preprocessorMacros = macros.data();
|
||
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
|
||
|
||
// Create session
|
||
Slang::ComPtr<slang::ISession> session;
|
||
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
|
||
result.error_message = "Failed to create Slang session";
|
||
return result;
|
||
}
|
||
|
||
// Load module
|
||
Slang::ComPtr<slang::IBlob> diagnostics_blob;
|
||
slang::IModule* module = session->loadModuleFromSourceString(
|
||
filename.c_str(),
|
||
filename.c_str(),
|
||
source.c_str(),
|
||
diagnostics_blob.writeRef()
|
||
);
|
||
|
||
if (diagnostics_blob) {
|
||
result.error_message = static_cast<const char*>(diagnostics_blob->getBufferPointer());
|
||
}
|
||
|
||
if (!module) {
|
||
if (result.error_message.empty()) {
|
||
result.error_message = "Failed to load shader module";
|
||
}
|
||
return result;
|
||
}
|
||
|
||
// Find entry point
|
||
Slang::ComPtr<slang::IEntryPoint> entry_point;
|
||
SlangStage slang_stage = options.stage ? to_slang_stage(*options.stage) : SLANG_STAGE_NONE;
|
||
|
||
SlangResult find_result = module->findEntryPointByName(options.entry_point.c_str(), entry_point.writeRef());
|
||
if (SLANG_FAILED(find_result)) {
|
||
// Try to find and check entry point with stage
|
||
if (options.stage) {
|
||
find_result = module->findAndCheckEntryPoint(
|
||
options.entry_point.c_str(),
|
||
slang_stage,
|
||
entry_point.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
);
|
||
if (SLANG_FAILED(find_result)) {
|
||
result.error_message = "Entry point '" + options.entry_point + "' not found";
|
||
if (diagnostics_blob) {
|
||
result.error_message += ": ";
|
||
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
|
||
}
|
||
return result;
|
||
}
|
||
} else {
|
||
result.error_message = "Entry point '" + options.entry_point + "' not found";
|
||
return result;
|
||
}
|
||
}
|
||
|
||
// Create composite program
|
||
std::vector<slang::IComponentType*> components = {module, entry_point.get()};
|
||
Slang::ComPtr<slang::IComponentType> program;
|
||
|
||
if (SLANG_FAILED(session->createCompositeComponentType(
|
||
components.data(),
|
||
static_cast<SlangInt>(components.size()),
|
||
program.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
))) {
|
||
result.error_message = "Failed to create composite component type";
|
||
if (diagnostics_blob) {
|
||
result.error_message += ": ";
|
||
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
|
||
}
|
||
return result;
|
||
}
|
||
|
||
// Link program
|
||
Slang::ComPtr<slang::IComponentType> linked_program;
|
||
if (SLANG_FAILED(program->link(linked_program.writeRef(), diagnostics_blob.writeRef()))) {
|
||
result.error_message = "Failed to link shader program";
|
||
if (diagnostics_blob) {
|
||
result.error_message += ": ";
|
||
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
|
||
}
|
||
return result;
|
||
}
|
||
|
||
// Generate SPIR-V code
|
||
if (options.emit_spirv) {
|
||
Slang::ComPtr<slang::IBlob> spirv_blob;
|
||
if (SLANG_FAILED(linked_program->getEntryPointCode(
|
||
0, // entry point index
|
||
0, // target index
|
||
spirv_blob.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
))) {
|
||
result.error_message = "Failed to generate SPIR-V code";
|
||
if (diagnostics_blob) {
|
||
result.error_message += ": ";
|
||
result.error_message += static_cast<const char*>(diagnostics_blob->getBufferPointer());
|
||
}
|
||
return result;
|
||
}
|
||
|
||
// Copy SPIR-V data
|
||
const uint32_t* spirv_data = static_cast<const uint32_t*>(spirv_blob->getBufferPointer());
|
||
size_t spirv_size = spirv_blob->getBufferSize() / sizeof(uint32_t);
|
||
result.spirv.assign(spirv_data, spirv_data + spirv_size);
|
||
}
|
||
|
||
// Generate reflection data
|
||
if (options.emit_reflection) {
|
||
result.reflection_json = generate_reflection_json(linked_program.get(), options.entry_point);
|
||
}
|
||
|
||
result.success = true;
|
||
return result;
|
||
}
|
||
|
||
std::string generate_reflection_json(slang::IComponentType* program, const std::string& entry_point_name) {
|
||
slang::ProgramLayout* layout = program->getLayout();
|
||
if (!layout) {
|
||
return "{}";
|
||
}
|
||
|
||
nlohmann::json root;
|
||
root["entryPoint"] = entry_point_name;
|
||
|
||
// Get parameter count
|
||
unsigned param_count = layout->getParameterCount();
|
||
|
||
// Uniform Buffers
|
||
nlohmann::json uniform_buffers = nlohmann::json::array();
|
||
|
||
for (unsigned i = 0; i < param_count; ++i) {
|
||
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
|
||
if (!param) continue;
|
||
|
||
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
|
||
if (!type_layout) continue;
|
||
|
||
// Check if this is a Uniform Buffer
|
||
slang::TypeReflection::Kind kind = type_layout->getKind();
|
||
if (kind == slang::TypeReflection::Kind::ConstantBuffer ||
|
||
kind == slang::TypeReflection::Kind::ParameterBlock) {
|
||
|
||
nlohmann::json ub;
|
||
ub["name"] = param->getName() ? param->getName() : "unnamed";
|
||
ub["set"] = param->getBindingSpace();
|
||
ub["binding"] = param->getBindingIndex();
|
||
|
||
// Calculate size with fallback
|
||
size_t ub_size = type_layout->getSize();
|
||
if (ub_size == 0) {
|
||
// Try to get size from element type layout
|
||
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
|
||
if (element_type) {
|
||
ub_size = element_type->getSize();
|
||
}
|
||
}
|
||
if (ub_size == 0) {
|
||
// Fallback: calculate size from members (with STD140 padding)
|
||
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
|
||
if (element_type) {
|
||
int field_count = element_type->getFieldCount();
|
||
size_t max_offset = 0;
|
||
for (int j = 0; j < field_count; ++j) {
|
||
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
|
||
if (!field) continue;
|
||
size_t offset = field->getOffset();
|
||
size_t size = field->getTypeLayout()->getSize();
|
||
// Round up to 16-byte alignment for struct (STD140 rule)
|
||
size_t field_end = offset + size;
|
||
if (field_end > max_offset) {
|
||
max_offset = field_end;
|
||
}
|
||
}
|
||
// Round up to multiple of 16 for uniform buffer alignment
|
||
ub_size = (max_offset + 15) & ~15ULL;
|
||
}
|
||
}
|
||
ub["size"] = ub_size;
|
||
|
||
// Get members
|
||
nlohmann::json members = nlohmann::json::array();
|
||
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
|
||
if (element_type) {
|
||
int field_count = element_type->getFieldCount();
|
||
for (int j = 0; j < field_count; ++j) {
|
||
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
|
||
if (!field) continue;
|
||
|
||
nlohmann::json member;
|
||
member["name"] = field->getName() ? field->getName() : "unnamed";
|
||
member["type"] = get_type_name(field->getTypeLayout());
|
||
member["offset"] = field->getOffset();
|
||
member["size"] = field->getTypeLayout()->getSize();
|
||
members.push_back(member);
|
||
}
|
||
}
|
||
ub["members"] = members;
|
||
uniform_buffers.push_back(ub);
|
||
}
|
||
}
|
||
root["uniformBuffers"] = uniform_buffers;
|
||
|
||
// Samplers
|
||
nlohmann::json samplers = nlohmann::json::array();
|
||
|
||
for (unsigned i = 0; i < param_count; ++i) {
|
||
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
|
||
if (!param) continue;
|
||
|
||
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
|
||
if (!type_layout) continue;
|
||
|
||
slang::TypeReflection::Kind kind = type_layout->getKind();
|
||
if (kind == slang::TypeReflection::Kind::Resource) {
|
||
// Get resource shape using SlangResourceShape
|
||
SlangResourceShape shape = type_layout->getResourceShape();
|
||
SlangResourceShape base_shape = static_cast<SlangResourceShape>(shape & SLANG_RESOURCE_BASE_SHAPE_MASK);
|
||
|
||
if (base_shape == SLANG_TEXTURE_1D ||
|
||
base_shape == SLANG_TEXTURE_2D ||
|
||
base_shape == SLANG_TEXTURE_3D ||
|
||
base_shape == SLANG_TEXTURE_CUBE) {
|
||
|
||
nlohmann::json sampler;
|
||
sampler["name"] = param->getName() ? param->getName() : "unnamed";
|
||
sampler["set"] = param->getBindingSpace();
|
||
sampler["binding"] = param->getBindingIndex();
|
||
sampler["dimension"] = get_texture_dimension(base_shape);
|
||
samplers.push_back(sampler);
|
||
}
|
||
}
|
||
}
|
||
root["samplers"] = samplers;
|
||
|
||
// Debug: Print all parameters for analysis
|
||
for (unsigned i = 0; i < param_count; ++i) {
|
||
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
|
||
if (!param) continue;
|
||
|
||
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
|
||
if (!type_layout) continue;
|
||
|
||
slang::TypeReflection::Kind kind = type_layout->getKind();
|
||
slang::TypeReflection* type = type_layout->getType();
|
||
const char* param_name = param->getName();
|
||
const char* type_name = type ? type->getName() : "null";
|
||
unsigned binding_space = param->getBindingSpace();
|
||
unsigned binding_index = param->getBindingIndex();
|
||
|
||
// Debug output - can be enabled for troubleshooting
|
||
// printf("DEBUG: param[%u] name=%s type=%s kind=%d space=%u index=%u\n",
|
||
// i, param_name ? param_name : "null", type_name, (int)kind, binding_space, binding_index);
|
||
}
|
||
|
||
// Push Constants - try to find parameters that look like push constants
|
||
nlohmann::json push_constants = nlohmann::json::array();
|
||
|
||
for (unsigned i = 0; i < param_count; ++i) {
|
||
slang::VariableLayoutReflection* param = layout->getParameterByIndex(i);
|
||
if (!param) continue;
|
||
|
||
slang::TypeLayoutReflection* type_layout = param->getTypeLayout();
|
||
if (!type_layout) continue;
|
||
|
||
slang::TypeReflection::Kind kind = type_layout->getKind();
|
||
slang::TypeReflection* type = type_layout->getType();
|
||
|
||
// Method 1: Check for ConstantBuffer with struct element (likely push constant)
|
||
if (kind == slang::TypeReflection::Kind::ConstantBuffer) {
|
||
slang::TypeLayoutReflection* element_type = type_layout->getElementTypeLayout();
|
||
if (element_type) {
|
||
slang::TypeReflection* element_type_ref = element_type->getType();
|
||
if (element_type_ref && element_type_ref->getKind() == slang::TypeReflection::Kind::Struct) {
|
||
// Check if this could be a push constant
|
||
// Push constants typically have:
|
||
// - No binding index (or very small)
|
||
// - Element type is a struct
|
||
// - Name might contain "pc" or "push" or match the expected struct name
|
||
|
||
const char* type_name = type ? type->getName() : "";
|
||
const char* param_name = param->getName();
|
||
|
||
// Check if name contains "pc" or "push" or matches "PushConstants"
|
||
bool is_likely_push_constant = false;
|
||
if (param_name && (strstr(param_name, "pc") || strstr(param_name, "push") || strstr(param_name, "PushConstant"))) {
|
||
is_likely_push_constant = true;
|
||
}
|
||
if (type_name && (strstr(type_name, "pc") || strstr(type_name, "push") || strstr(type_name, "PushConstant"))) {
|
||
is_likely_push_constant = true;
|
||
}
|
||
|
||
// Also check if binding space is 0 (no descriptor set assigned)
|
||
// and this is the only ConstantBuffer without a proper binding
|
||
unsigned binding_space = param->getBindingSpace();
|
||
unsigned binding_index = param->getBindingIndex();
|
||
|
||
// If binding space is 0 but binding index is 0, and there's a struct element,
|
||
// this might be a push constant (descriptors usually have higher binding indices)
|
||
if (binding_space == 0 && binding_index == 0) {
|
||
// This could be a push constant - let's also check it's NOT in uniform buffers
|
||
// by checking if we've already processed it as a regular uniform buffer
|
||
is_likely_push_constant = true;
|
||
}
|
||
|
||
if (is_likely_push_constant) {
|
||
nlohmann::json pc;
|
||
pc["name"] = param_name ? param_name : (type_name ? type_name : "PushConstants");
|
||
|
||
// Calculate size
|
||
size_t pc_size = type_layout->getSize();
|
||
if (pc_size == 0) {
|
||
pc_size = element_type->getSize();
|
||
}
|
||
if (pc_size == 0) {
|
||
// Fallback: calculate size from members
|
||
int field_count = element_type->getFieldCount();
|
||
size_t max_offset = 0;
|
||
for (int j = 0; j < field_count; ++j) {
|
||
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
|
||
if (!field) continue;
|
||
size_t offset = field->getOffset();
|
||
size_t size = field->getTypeLayout()->getSize();
|
||
size_t field_end = offset + size;
|
||
if (field_end > max_offset) {
|
||
max_offset = field_end;
|
||
}
|
||
}
|
||
pc_size = (max_offset + 15) & ~15ULL;
|
||
}
|
||
pc["size"] = pc_size;
|
||
|
||
// Get members
|
||
nlohmann::json members = nlohmann::json::array();
|
||
int field_count = element_type->getFieldCount();
|
||
for (int j = 0; j < field_count; ++j) {
|
||
slang::VariableLayoutReflection* field = element_type->getFieldByIndex(static_cast<unsigned>(j));
|
||
if (!field) continue;
|
||
|
||
nlohmann::json member;
|
||
member["name"] = field->getName() ? field->getName() : "unnamed";
|
||
member["type"] = get_type_name(field->getTypeLayout());
|
||
member["offset"] = field->getOffset();
|
||
member["size"] = field->getTypeLayout()->getSize();
|
||
members.push_back(member);
|
||
}
|
||
pc["members"] = members;
|
||
push_constants.push_back(pc);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
root["pushConstants"] = push_constants;
|
||
|
||
return root.dump(2);
|
||
}
|
||
|
||
const char* get_type_name(slang::TypeLayoutReflection* type_layout) {
|
||
if (!type_layout) return "unknown";
|
||
|
||
slang::TypeReflection* type = type_layout->getType();
|
||
if (!type) return "unknown";
|
||
|
||
switch (type->getKind()) {
|
||
case slang::TypeReflection::Kind::Scalar: {
|
||
switch (type->getScalarType()) {
|
||
case slang::TypeReflection::ScalarType::Float32: return "float";
|
||
case slang::TypeReflection::ScalarType::Int32: return "int";
|
||
case slang::TypeReflection::ScalarType::UInt32: return "uint";
|
||
case slang::TypeReflection::ScalarType::Bool: return "bool";
|
||
default: return "scalar";
|
||
}
|
||
}
|
||
case slang::TypeReflection::Kind::Vector: {
|
||
unsigned count = type->getElementCount();
|
||
switch (type->getScalarType()) {
|
||
case slang::TypeReflection::ScalarType::Float32:
|
||
switch (count) {
|
||
case 2: return "float2";
|
||
case 3: return "float3";
|
||
case 4: return "float4";
|
||
default: return "floatN";
|
||
}
|
||
case slang::TypeReflection::ScalarType::Int32:
|
||
switch (count) {
|
||
case 2: return "int2";
|
||
case 3: return "int3";
|
||
case 4: return "int4";
|
||
default: return "intN";
|
||
}
|
||
case slang::TypeReflection::ScalarType::UInt32:
|
||
switch (count) {
|
||
case 2: return "uint2";
|
||
case 3: return "uint3";
|
||
case 4: return "uint4";
|
||
default: return "uintN";
|
||
}
|
||
default: return "vector";
|
||
}
|
||
}
|
||
case slang::TypeReflection::Kind::Matrix: {
|
||
unsigned rows = type->getRowCount();
|
||
unsigned cols = type->getColumnCount();
|
||
if (rows == 4 && cols == 4) return "float4x4";
|
||
if (rows == 3 && cols == 3) return "float3x3";
|
||
if (rows == 2 && cols == 2) return "float2x2";
|
||
return "matrix";
|
||
}
|
||
case slang::TypeReflection::Kind::Struct:
|
||
return type->getName() ? type->getName() : "struct";
|
||
default:
|
||
return "unknown";
|
||
}
|
||
}
|
||
|
||
const char* get_texture_dimension(SlangResourceShape base_shape) {
|
||
switch (base_shape) {
|
||
case SLANG_TEXTURE_1D: return "1D";
|
||
case SLANG_TEXTURE_2D: return "2D";
|
||
case SLANG_TEXTURE_3D: return "3D";
|
||
case SLANG_TEXTURE_CUBE: return "Cube";
|
||
default: return "2D";
|
||
}
|
||
}
|
||
|
||
// 检查是否存在入口函数
|
||
std::optional<bool> has_entry_point(
|
||
const std::string& source,
|
||
const std::string& filename,
|
||
const compile_options& options
|
||
) {
|
||
auto entry_points = get_entry_points(source, filename, options);
|
||
return !entry_points.empty();
|
||
}
|
||
|
||
// 获取所有入口函数
|
||
std::vector<std::string> get_entry_points(
|
||
const std::string& source,
|
||
const std::string& filename,
|
||
const compile_options& options
|
||
) {
|
||
std::vector<std::string> result;
|
||
|
||
if (!global_session) {
|
||
return result;
|
||
}
|
||
|
||
// 创建 session
|
||
slang::SessionDesc session_desc = {};
|
||
slang::TargetDesc target_desc = {};
|
||
target_desc.format = SLANG_SPIRV;
|
||
target_desc.profile = global_session->findProfile("glsl_450");
|
||
|
||
session_desc.targets = &target_desc;
|
||
session_desc.targetCount = 1;
|
||
|
||
std::vector<std::string> include_path_strs;
|
||
std::vector<const char*> search_paths;
|
||
for (const auto& path : options.include_paths) {
|
||
include_path_strs.push_back(path.string());
|
||
search_paths.push_back(include_path_strs.back().c_str());
|
||
}
|
||
session_desc.searchPaths = search_paths.data();
|
||
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
|
||
|
||
std::vector<std::pair<std::string, std::string>> macro_strs;
|
||
std::vector<slang::PreprocessorMacroDesc> macros;
|
||
for (const auto& [name, value] : options.defines) {
|
||
macro_strs.emplace_back(name, value);
|
||
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
|
||
}
|
||
session_desc.preprocessorMacros = macros.data();
|
||
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
|
||
|
||
Slang::ComPtr<slang::ISession> session;
|
||
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
|
||
return result;
|
||
}
|
||
|
||
// 加载模块
|
||
Slang::ComPtr<slang::IBlob> diagnostics_blob;
|
||
slang::IModule* module = session->loadModuleFromSourceString(
|
||
filename.c_str(),
|
||
filename.c_str(),
|
||
source.c_str(),
|
||
diagnostics_blob.writeRef()
|
||
);
|
||
|
||
if (!module) {
|
||
return result;
|
||
}
|
||
|
||
// 获取入口点列表
|
||
SlangInt32 entry_point_count = module->getDefinedEntryPointCount();
|
||
|
||
// 常见入口点名称模式
|
||
const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain",
|
||
"vertexShader", "fragmentShader", "computeShader"};
|
||
|
||
const SlangStage stages[] = {
|
||
SLANG_STAGE_VERTEX,
|
||
SLANG_STAGE_FRAGMENT,
|
||
SLANG_STAGE_COMPUTE,
|
||
SLANG_STAGE_GEOMETRY,
|
||
SLANG_STAGE_HULL,
|
||
SLANG_STAGE_DOMAIN
|
||
};
|
||
|
||
// 尝试每个可能的入口点名称和阶段的组合
|
||
for (const char* name : common_names) {
|
||
for (SlangStage stage : stages) {
|
||
Slang::ComPtr<slang::IEntryPoint> checked_entry;
|
||
SlangResult find_result = module->findAndCheckEntryPoint(
|
||
name,
|
||
stage,
|
||
checked_entry.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
);
|
||
|
||
if (SLANG_FAILED(find_result) || !checked_entry) {
|
||
continue;
|
||
}
|
||
|
||
// 验证并编译这个入口点来获取名称
|
||
std::vector<slang::IComponentType*> comps = {module, checked_entry.get()};
|
||
Slang::ComPtr<slang::IComponentType> composite;
|
||
|
||
if (SLANG_FAILED(session->createCompositeComponentType(
|
||
comps.data(),
|
||
static_cast<SlangInt>(comps.size()),
|
||
composite.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
))) {
|
||
continue;
|
||
}
|
||
|
||
Slang::ComPtr<slang::IComponentType> linked;
|
||
if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) {
|
||
continue;
|
||
}
|
||
|
||
// 生成代码来触发 layout 生成
|
||
Slang::ComPtr<slang::IBlob> spirv_blob;
|
||
if (SLANG_FAILED(linked->getEntryPointCode(
|
||
0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) {
|
||
continue;
|
||
}
|
||
|
||
// 获取 layout
|
||
slang::ProgramLayout* layout = linked->getLayout();
|
||
if (!layout || layout->getEntryPointCount() == 0) {
|
||
continue;
|
||
}
|
||
|
||
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
|
||
if (!ep_ref) {
|
||
continue;
|
||
}
|
||
|
||
const char* entry_name = ep_ref->getName();
|
||
if (entry_name && entry_name[0] != '\0') {
|
||
// 检查是否已经添加
|
||
bool exists = false;
|
||
for (const auto& existing : result) {
|
||
if (existing == entry_name) {
|
||
exists = true;
|
||
break;
|
||
}
|
||
}
|
||
if (!exists) {
|
||
result.push_back(entry_name);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return result;
|
||
}
|
||
|
||
// 获取所有入口函数及其阶段
|
||
std::vector<entry_point_info> get_entry_points_with_stages(
|
||
const std::string& source,
|
||
const std::string& filename,
|
||
const compile_options& options
|
||
) {
|
||
std::vector<entry_point_info> result;
|
||
|
||
if (!global_session) {
|
||
return result;
|
||
}
|
||
|
||
// 创建 session
|
||
slang::SessionDesc session_desc = {};
|
||
slang::TargetDesc target_desc = {};
|
||
target_desc.format = SLANG_SPIRV;
|
||
target_desc.profile = global_session->findProfile("glsl_450");
|
||
|
||
session_desc.targets = &target_desc;
|
||
session_desc.targetCount = 1;
|
||
|
||
std::vector<std::string> include_path_strs;
|
||
std::vector<const char*> search_paths;
|
||
for (const auto& path : options.include_paths) {
|
||
include_path_strs.push_back(path.string());
|
||
search_paths.push_back(include_path_strs.back().c_str());
|
||
}
|
||
session_desc.searchPaths = search_paths.data();
|
||
session_desc.searchPathCount = static_cast<SlangInt>(search_paths.size());
|
||
|
||
std::vector<std::pair<std::string, std::string>> macro_strs;
|
||
std::vector<slang::PreprocessorMacroDesc> macros;
|
||
for (const auto& [name, value] : options.defines) {
|
||
macro_strs.emplace_back(name, value);
|
||
macros.push_back({macro_strs.back().first.c_str(), macro_strs.back().second.c_str()});
|
||
}
|
||
session_desc.preprocessorMacros = macros.data();
|
||
session_desc.preprocessorMacroCount = static_cast<SlangInt>(macros.size());
|
||
|
||
Slang::ComPtr<slang::ISession> session;
|
||
if (SLANG_FAILED(global_session->createSession(session_desc, session.writeRef()))) {
|
||
return result;
|
||
}
|
||
|
||
// 加载模块
|
||
Slang::ComPtr<slang::IBlob> diagnostics_blob;
|
||
slang::IModule* module = session->loadModuleFromSourceString(
|
||
filename.c_str(),
|
||
filename.c_str(),
|
||
source.c_str(),
|
||
diagnostics_blob.writeRef()
|
||
);
|
||
|
||
if (!module) {
|
||
return result;
|
||
}
|
||
|
||
// 获取所有定义的入口点数量
|
||
SlangInt32 entry_point_count = module->getDefinedEntryPointCount();
|
||
|
||
if (entry_point_count == 0) {
|
||
return result;
|
||
}
|
||
|
||
// 常见入口点名称模式
|
||
const char* common_names[] = {"main", "vertexMain", "fragmentMain", "computeMain",
|
||
"vertexShader", "fragmentShader", "computeShader"};
|
||
|
||
const SlangStage stages[] = {
|
||
SLANG_STAGE_VERTEX,
|
||
SLANG_STAGE_FRAGMENT,
|
||
SLANG_STAGE_COMPUTE,
|
||
SLANG_STAGE_GEOMETRY,
|
||
SLANG_STAGE_HULL,
|
||
SLANG_STAGE_DOMAIN
|
||
};
|
||
|
||
// 尝试每个可能的入口点名称和阶段的组合
|
||
for (const char* name : common_names) {
|
||
for (SlangStage stage : stages) {
|
||
// 如果已经找到所有入口点,提前退出
|
||
if (static_cast<SlangInt32>(result.size()) >= entry_point_count) {
|
||
break;
|
||
}
|
||
|
||
Slang::ComPtr<slang::IEntryPoint> checked_entry;
|
||
SlangResult find_result = module->findAndCheckEntryPoint(
|
||
name,
|
||
stage,
|
||
checked_entry.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
);
|
||
|
||
if (SLANG_FAILED(find_result) || !checked_entry) {
|
||
continue;
|
||
}
|
||
|
||
// 验证并编译这个入口点来获取名称
|
||
std::vector<slang::IComponentType*> comps = {module, checked_entry.get()};
|
||
Slang::ComPtr<slang::IComponentType> composite;
|
||
|
||
if (SLANG_FAILED(session->createCompositeComponentType(
|
||
comps.data(),
|
||
static_cast<SlangInt>(comps.size()),
|
||
composite.writeRef(),
|
||
diagnostics_blob.writeRef()
|
||
))) {
|
||
continue;
|
||
}
|
||
|
||
Slang::ComPtr<slang::IComponentType> linked;
|
||
if (SLANG_FAILED(composite->link(linked.writeRef(), diagnostics_blob.writeRef()))) {
|
||
continue;
|
||
}
|
||
|
||
// 生成代码来触发 layout 生成
|
||
Slang::ComPtr<slang::IBlob> spirv_blob;
|
||
if (SLANG_FAILED(linked->getEntryPointCode(
|
||
0, 0, spirv_blob.writeRef(), diagnostics_blob.writeRef()))) {
|
||
continue;
|
||
}
|
||
|
||
// 获取 layout
|
||
slang::ProgramLayout* layout = linked->getLayout();
|
||
if (!layout || layout->getEntryPointCount() == 0) {
|
||
continue;
|
||
}
|
||
|
||
slang::EntryPointReflection* ep_ref = layout->getEntryPointByIndex(0);
|
||
if (!ep_ref) {
|
||
continue;
|
||
}
|
||
|
||
const char* entry_name = ep_ref->getName();
|
||
SlangStage slang_stage = ep_ref->getStage();
|
||
|
||
if (entry_name && entry_name[0] != '\0') {
|
||
// 检查是否已经添加
|
||
bool exists = false;
|
||
for (const auto& existing : result) {
|
||
if (existing.name == entry_name) {
|
||
exists = true;
|
||
break;
|
||
}
|
||
}
|
||
|
||
if (!exists) {
|
||
auto shader_stage_opt = slang_stage_to_shader_stage(slang_stage);
|
||
if (shader_stage_opt) {
|
||
entry_point_info info;
|
||
info.name = entry_name;
|
||
info.stage = *shader_stage_opt;
|
||
result.push_back(info);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return result;
|
||
}
|
||
|
||
// 编译所有入口点
|
||
std::vector<compile_result> compile_all_entries(
|
||
const std::string& source,
|
||
const std::string& filename,
|
||
const compile_options& options
|
||
) {
|
||
std::vector<compile_result> results;
|
||
|
||
// 获取所有入口点及其阶段
|
||
auto entry_points = get_entry_points_with_stages(source, filename, options);
|
||
|
||
for (const auto& ep_info : entry_points) {
|
||
compile_options ep_options = options;
|
||
ep_options.entry_point = ep_info.name;
|
||
ep_options.stage = ep_info.stage;
|
||
|
||
compile_result result = compile(source, filename, ep_options);
|
||
result.entry_point = ep_info.name;
|
||
result.stage = ep_info.stage;
|
||
results.push_back(result);
|
||
}
|
||
|
||
return results;
|
||
}
|
||
};
|
||
|
||
// ============================================================================
|
||
// shader_compiler 实现
|
||
// ============================================================================
|
||
|
||
shader_compiler::shader_compiler()
|
||
: impl_(std::make_unique<impl>()) {
|
||
}
|
||
|
||
shader_compiler::~shader_compiler() = default;
|
||
|
||
shader_compiler::shader_compiler(shader_compiler&&) noexcept = default;
|
||
shader_compiler& shader_compiler::operator=(shader_compiler&&) noexcept = default;
|
||
|
||
compile_result shader_compiler::compile_file(
|
||
const std::filesystem::path& path,
|
||
const compile_options& options
|
||
) {
|
||
compile_result result;
|
||
|
||
// 读取文件内容
|
||
std::ifstream file(path);
|
||
if (!file) {
|
||
result.error_message = "Failed to open file: " + path.string();
|
||
return result;
|
||
}
|
||
|
||
std::ostringstream ss;
|
||
ss << file.rdbuf();
|
||
std::string source = ss.str();
|
||
|
||
// 添加文件所在目录到包含路径
|
||
compile_options modified_options = options;
|
||
modified_options.include_paths.insert(
|
||
modified_options.include_paths.begin(),
|
||
path.parent_path()
|
||
);
|
||
|
||
return impl_->compile(source, path.filename().string(), modified_options);
|
||
}
|
||
|
||
compile_result shader_compiler::compile_source(
|
||
const std::string& source,
|
||
const std::string& filename,
|
||
const compile_options& options
|
||
) {
|
||
return impl_->compile(source, filename, options);
|
||
}
|
||
|
||
std::optional<std::string> shader_compiler::generate_reflection(
|
||
const std::filesystem::path& path,
|
||
const compile_options& options
|
||
) {
|
||
compile_options reflection_options = options;
|
||
reflection_options.emit_spirv = false;
|
||
reflection_options.emit_reflection = true;
|
||
|
||
auto result = compile_file(path, reflection_options);
|
||
if (result.success) {
|
||
return result.reflection_json;
|
||
}
|
||
return std::nullopt;
|
||
}
|
||
|
||
std::optional<bool> shader_compiler::has_entry_point(
|
||
const std::filesystem::path& path,
|
||
const compile_options& options
|
||
) {
|
||
// 首先检查文件是否存在
|
||
if (!std::filesystem::exists(path)) {
|
||
return std::nullopt; // 无法检查
|
||
}
|
||
|
||
// 读取文件内容
|
||
std::ifstream file(path);
|
||
if (!file) {
|
||
return std::nullopt; // 无法检查
|
||
}
|
||
|
||
std::ostringstream ss;
|
||
ss << file.rdbuf();
|
||
std::string source = ss.str();
|
||
|
||
return impl_->has_entry_point(source, path.filename().string(), options);
|
||
}
|
||
|
||
std::vector<std::string> shader_compiler::get_entry_points(
|
||
const std::filesystem::path& path,
|
||
const compile_options& options
|
||
) {
|
||
// 首先检查文件是否存在
|
||
if (!std::filesystem::exists(path)) {
|
||
return {};
|
||
}
|
||
|
||
// 读取文件内容
|
||
std::ifstream file(path);
|
||
if (!file) {
|
||
return {};
|
||
}
|
||
|
||
std::ostringstream ss;
|
||
ss << file.rdbuf();
|
||
std::string source = ss.str();
|
||
|
||
return impl_->get_entry_points(source, path.filename().string(), options);
|
||
}
|
||
|
||
std::vector<entry_point_info> shader_compiler::get_entry_points_with_stages(
|
||
const std::filesystem::path& path,
|
||
const compile_options& options
|
||
) {
|
||
// 首先检查文件是否存在
|
||
if (!std::filesystem::exists(path)) {
|
||
return {};
|
||
}
|
||
|
||
// 读取文件内容
|
||
std::ifstream file(path);
|
||
if (!file) {
|
||
return {};
|
||
}
|
||
|
||
std::ostringstream ss;
|
||
ss << file.rdbuf();
|
||
std::string source = ss.str();
|
||
|
||
return impl_->get_entry_points_with_stages(source, path.filename().string(), options);
|
||
}
|
||
|
||
std::vector<compile_result> shader_compiler::compile_file_all_entries(
|
||
const std::filesystem::path& path,
|
||
const compile_options& options
|
||
) {
|
||
// 首先检查文件是否存在
|
||
if (!std::filesystem::exists(path)) {
|
||
return {};
|
||
}
|
||
|
||
// 读取文件内容
|
||
std::ifstream file(path);
|
||
if (!file) {
|
||
return {};
|
||
}
|
||
|
||
std::ostringstream ss;
|
||
ss << file.rdbuf();
|
||
std::string source = ss.str();
|
||
|
||
// 添加文件所在目录到包含路径
|
||
compile_options modified_options = options;
|
||
modified_options.include_paths.insert(
|
||
modified_options.include_paths.begin(),
|
||
path.parent_path()
|
||
);
|
||
|
||
return impl_->compile_all_entries(source, path.filename().string(), modified_options);
|
||
}
|
||
|
||
bool shader_compiler::is_available() const noexcept {
|
||
return impl_ && impl_->global_session;
|
||
}
|
||
|
||
std::string shader_compiler::get_version() const {
|
||
if (!impl_ || !impl_->global_session) {
|
||
return "unavailable";
|
||
}
|
||
// Slang 没有直接的版本 API,返回固定字符串
|
||
return "Slang Shader Compiler";
|
||
}
|
||
|
||
} // namespace mirai::tools
|