https://github.com/shader-slang/slang
Tip revision: ec530b300524635dfe0fd86949b0a4fc5c19a984 authored by Yong He on 27 April 2022, 20:58:55 UTC
gfx: Add interop API to control descriptor heap binding. (#2211)
gfx: Add interop API to control descriptor heap binding. (#2211)
Tip revision: ec530b3
slang-ir-dll-import.cpp
// slang-ir-dll-import.cpp
#include "slang-ir-dll-import.h"
#include "slang-ir.h"
#include "slang-ir-insts.h"
namespace Slang
{
struct DllImportContext
{
IRModule* module;
DiagnosticSink* diagnosticSink;
SharedIRBuilder sharedBuilder;
IRFunc* loadDllFunc = nullptr;
IRFunc* loadFuncPtrFunc = nullptr;
IRFunc* stringGetBufferFunc = nullptr;
IRFunc* createBuiltinIntrinsicFunc(UInt paramCount, IRType** paramTypes, IRType* resultType, UnownedStringSlice targetIntrinsic)
{
IRBuilder builder(sharedBuilder);
builder.setInsertInto(module->getModuleInst());
IRFunc* result = builder.createFunc();
builder.setInsertInto(result);
auto funcType = builder.getFuncType(paramCount, paramTypes, resultType);
builder.setDataType(result, funcType);
builder.addTargetIntrinsicDecoration(
result, CapabilitySet(CapabilityAtom::CPP), targetIntrinsic);
return result;
}
IRFunc* getLoadDllFunc()
{
if (!loadDllFunc)
{
IRBuilder builder(sharedBuilder);
builder.setInsertInto(module->getModuleInst());
IRType* stringType = builder.getStringType();
loadDllFunc = createBuiltinIntrinsicFunc(
1,
&stringType,
builder.getPtrType(builder.getVoidType()),
UnownedStringSlice("_slang_rt_load_dll($0)"));
}
return loadDllFunc;
}
IRFunc* getLoadFuncPtrFunc()
{
if (!loadFuncPtrFunc)
{
IRBuilder builder(sharedBuilder);
builder.setInsertInto(module->getModuleInst());
IRType* stringType = builder.getStringType();
IRType* paramTypes[] = {builder.getPtrType(builder.getVoidType()), stringType};
loadFuncPtrFunc = createBuiltinIntrinsicFunc(
2,
paramTypes,
builder.getPtrType(builder.getVoidType()),
UnownedStringSlice("_slang_rt_load_dll_func($0, $1)"));
}
return loadFuncPtrFunc;
}
IRFunc* getStringGetBufferFunc()
{
if (!stringGetBufferFunc)
{
IRBuilder builder(sharedBuilder);
builder.setInsertInto(module->getModuleInst());
IRType* stringType = builder.getStringType();
IRType* paramTypes[] = {stringType};
stringGetBufferFunc = createBuiltinIntrinsicFunc(
1,
paramTypes,
builder.getPtrType(builder.getCharType()),
UnownedStringSlice("const_cast<char*>($0.getBuffer())"));
}
return stringGetBufferFunc;
}
IRType* getNativeType(IRBuilder& builder, IRType* type)
{
switch (type->getOp())
{
case kIROp_StringType:
return builder.getPtrType(builder.getCharType());
default:
return type;
}
}
IRType* getNativeFuncType(IRBuilder& builder, IRFunc* func)
{
List<IRInst*> nativeParamTypes;
auto declaredFuncType = func->getDataType();
assert(declaredFuncType->getOp() == kIROp_FuncType);
for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i)
{
auto paramType = declaredFuncType->getParamType(i);
nativeParamTypes.add(getNativeType(builder, as<IRType>(paramType)));
}
IRType* returnType = getNativeType(builder, func->getResultType());
auto funcType = builder.getFuncType(
nativeParamTypes.getCount(), (IRType**)nativeParamTypes.getBuffer(), returnType);
return funcType;
}
void marshalImportRefParameter(IRBuilder& builder, IRParam* param, List<IRInst*>& args)
{
SLANG_UNUSED(builder);
auto innerType = as<IRPtrTypeBase>(param->getDataType())->getValueType();
switch (innerType->getOp())
{
case kIROp_StringType:
{
diagnosticSink->diagnose(
param->sourceLoc,
Diagnostics::invalidTypeMarshallingForImportedDLLSymbol,
param->getParent()->getParent());
}
break;
default:
args.add(param);
break;
}
}
void marshalImportParameter(IRBuilder& builder, IRParam* param, List<IRInst*>& args)
{
switch (param->getDataType()->getOp())
{
case kIROp_InOutType:
case kIROp_RefType:
return marshalImportRefParameter(builder, param, args);
case kIROp_StringType:
{
auto getStringBufferFunc = getStringGetBufferFunc();
args.add(builder.emitCallInst(
builder.getPtrType(builder.getCharType()), getStringBufferFunc, 1, (IRInst**)¶m));
}
break;
default:
args.add(param);
break;
}
}
void processFunc(IRFunc* func, IRDllImportDecoration* dllImportDecoration)
{
assert(func->getFirstBlock() == nullptr);
IRBuilder builder(sharedBuilder);
auto nativeType = getNativeFuncType(builder, func);
builder.setInsertInto(module->getModuleInst());
auto funcPtr = builder.createGlobalVar(nativeType);
builder.setInsertInto(funcPtr);
builder.emitBlock();
builder.emitReturn(builder.getPtrValue(nullptr));
builder.setInsertInto(func);
auto block = builder.emitBlock();
builder.setInsertInto(block);
// Emit parameters.
auto declaredFuncType = func->getDataType();
List<IRParam*> params;
for (UInt i = 0; i < declaredFuncType->getParamCount(); ++i)
{
auto paramType = declaredFuncType->getParamType(i);
params.add(builder.emitParam((IRType*)paramType));
}
// Marshal parameters to arguments into native func.
List<IRInst*> args;
for (auto param : params)
{
marshalImportParameter(builder, param, args);
}
IRInst* cmpArgs[] = {builder.emitLoad(nativeType, funcPtr), builder.getPtrValue(nullptr)};
auto isUninitialized =
builder.emitIntrinsicInst(builder.getBoolType(), kIROp_Eql, 2, cmpArgs);
auto trueBlock = builder.emitBlock();
auto afterBlock = builder.emitBlock();
builder.setInsertInto(block);
builder.emitIf(isUninitialized, trueBlock, afterBlock);
builder.setInsertInto(trueBlock);
auto modulePtr = builder.emitCallInst(
builder.getPtrType(builder.getVoidType()),
getLoadDllFunc(),
builder.getStringValue(dllImportDecoration->getLibraryName()));
IRInst* loadDllFuncArgs[] = {
modulePtr, builder.getStringValue(dllImportDecoration->getFunctionName())};
auto loadedNativeFuncPtr = builder.emitCallInst(
builder.getPtrType(builder.getVoidType()), getLoadFuncPtrFunc(), 2, loadDllFuncArgs);
builder.emitStore(
funcPtr, builder.emitBitCast(nativeType, loadedNativeFuncPtr));
builder.emitBranch(afterBlock);
builder.setInsertInto(afterBlock);
IRType* nativeReturnType = getNativeType(builder, func->getResultType());
auto nativeFunc = builder.emitLoad(funcPtr);
auto call = builder.emitCallInst(nativeReturnType, nativeFunc, args);
if (declaredFuncType->getResultType()->getOp() != kIROp_VoidType)
{
builder.emitReturn(call);
}
}
void processModule()
{
for (auto childFunc : module->getGlobalInsts())
{
switch(childFunc->getOp())
{
case kIROp_Func:
if (auto dllImportDecoration = childFunc->findDecoration<IRDllImportDecoration>())
{
processFunc(as<IRFunc>(childFunc), dllImportDecoration);
}
break;
default:
break;
}
}
}
};
void generateDllImportFuncs(IRModule* module, DiagnosticSink* sink)
{
DllImportContext context;
context.module = module;
context.diagnosticSink = sink;
context.sharedBuilder.init(module);
return context.processModule();
}
}