#include "slang-ir-pytorch-cpp-binding.h" #include "slang-ir.h" #include "slang-ir-insts.h" #include "slang-diagnostics.h" namespace Slang { // Convert a type to a target tuple type. static IRType* translateToTupleType(IRBuilder& builder, IRType* type) { if (as(type)) return type; if (as(type)) return type; else if (as(type)) return type; else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); if (!count) { return nullptr; } List elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { elementTypes.addRange(vectorType->getElementType()); } return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } else if (auto arrayType = as(type)) { auto arraySize = as(arrayType->getElementCount()); if (!arraySize) { return nullptr; } List subElementTypes; auto subElementType = translateToTupleType(builder, arrayType->getElementType()); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { subElementTypes.addRange(subElementType); } return builder.getTargetTupleType((UInt)subElementTypes.getCount(), subElementTypes.getBuffer()); } else if (auto structType = as(type)) { List elementTypes; for (auto field : structType->getFields()) { auto fieldType = translateToTupleType(builder, field->getFieldType()); if (!fieldType) { return nullptr; } elementTypes.addRange(fieldType); } return builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); } else { return nullptr; } } // Convert a value to a target tuple type. static IRInst* makeTargetTuple(IRBuilder& builder, IRInst* val) { auto type = val->getDataType(); if (as(type)) return val; if (as(type)) return val; else if (as(type)) return val; else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); if (!count) { return nullptr; } List resultElements; List elementTypes; for (IRIntegerValue i = 0; i < count->getValue(); i++) { auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); auto tupleElement = makeTargetTuple(builder, elementVal); if (!tupleElement) return nullptr; resultElements.add(tupleElement); elementTypes.add(tupleElement->getFullType()); } auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto arrayType = as(type)) { auto arraySize = as(arrayType->getElementCount()); if (!arraySize) { return nullptr; } List resultElements; List elementTypes; for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto elementVal = builder.emitElementExtract(val, builder.getIntValue(builder.getIntType(), i)); auto tupleElement = makeTargetTuple(builder, elementVal); if (!tupleElement) return nullptr; resultElements.add(tupleElement); elementTypes.add(tupleElement->getFullType()); } auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto structType = as(type)) { List resultElements; List elementTypes; for (auto field : structType->getFields()) { auto elementVal = builder.emitFieldExtract(field->getFieldType(), val, field->getKey()); auto tupleElement = makeTargetTuple(builder, elementVal); if (!tupleElement) return nullptr; resultElements.add(tupleElement); elementTypes.add(tupleElement->getFullType()); } auto resultType = builder.getTargetTupleType((UInt)elementTypes.getCount(), elementTypes.getBuffer()); return builder.emitMakeTargetTuple(resultType, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else { return nullptr; } } // Convert a target tuple type to a value. static IRInst* makeValueFromTargetTuple(IRBuilder& builder, IRType* type, IRInst* val) { if (as(type)) return val; if (as(type)) return val; else if (as(type)) return val; else if (auto vectorType = as(type)) { auto count = as(vectorType->getElementCount()); if (!count) { return nullptr; } List resultElements; auto elementType = vectorType->getElementType(); for (IRIntegerValue i = 0; i < count->getValue(); i++) { auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; resultElements.add(convertedElement); } return builder.emitMakeVector(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto arrayType = as(type)) { auto arraySize = as(arrayType->getElementCount()); if (!arraySize) { return nullptr; } List resultElements; auto elementType = arrayType->getElementType(); for (IRIntegerValue i = 0; i < arraySize->getValue(); i++) { auto tupleElement = builder.emitTargetTupleGetElement(elementType, val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, elementType, tupleElement); if (!convertedElement) return nullptr; resultElements.add(convertedElement); } return builder.emitMakeArray(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else if (auto structType = as(type)) { List resultElements; IRIntegerValue i = 0; for (auto field : structType->getFields()) { auto tupleElement = builder.emitTargetTupleGetElement(field->getFieldType(), val, builder.getIntValue(builder.getIntType(), i)); auto convertedElement = makeValueFromTargetTuple(builder, field->getFieldType(), tupleElement); if (!convertedElement) return nullptr; resultElements.add(convertedElement); i++; } return builder.emitMakeStruct(type, (UInt)resultElements.getCount(), resultElements.getBuffer()); } else { return nullptr; } } static void generateCppBindingForFunc(IRFunc* func, DiagnosticSink* sink) { IRBuilder builder(func); builder.setInsertBefore(func); auto hostReturnType = translateToTupleType(builder, func->getResultType()); if (!hostReturnType) { sink->diagnose(func->sourceLoc, Diagnostics::invalidTorchKernelReturnType, func->getResultType()); return; } List hostParamTypes; auto funcType = as(func->getDataType()); for (UInt i = 0; i < funcType->getParamCount(); i++) { hostParamTypes.add(translateToTupleType(builder, funcType->getParamType(i))); } auto bindingFuncType = builder.getFuncType(hostParamTypes, hostReturnType); func->setFullType(bindingFuncType); builder.setInsertBefore(func->getFirstBlock()->getFirstOrdinaryInst()); List instsToRemove; List oldParams; for (auto param : func->getFirstBlock()->getParams()) { oldParams.add(param); } List newParams; for (auto param : oldParams) { auto paramType = param->getFullType(); auto newParamType = translateToTupleType(builder, paramType); if (!newParamType) { sink->diagnose(param->sourceLoc, Diagnostics::invalidTorchKernelParamType, paramType); return; } auto newParam = builder.emitParam(newParamType); param->transferDecorationsTo(newParam); newParams.add(newParam); } // Convert all new parameters from tuples to their original types. for (Index i = 0; i < newParams.getCount(); i++) { auto oldParam = oldParams[i]; auto newParam = newParams[i]; auto convertedParam = makeValueFromTargetTuple(builder, oldParam->getFullType(), newParam); if (!convertedParam) { return; } oldParam->replaceUsesWith(convertedParam); oldParam->removeAndDeallocate(); } for (auto block : func->getBlocks()) { for (auto inst : block->getChildren()) { if (auto kernelDispatch = as(inst)) { builder.setInsertBefore(kernelDispatch); List kernelArgs; auto kernelArgCount = kernelDispatch->getArgCount(); auto argArrayType = builder.getArrayType(builder.getPtrType(builder.getVoidType()), builder.getIntValue(builder.getIntType(), kernelArgCount)); auto argArrayVar = builder.emitVar(argArrayType); for (UInt i = 0; i < kernelArgCount; i++) { auto arg = kernelDispatch->getArg(i); auto argVar = builder.emitVar(arg->getFullType()); builder.emitStore(argVar, arg); auto addr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), i)); builder.emitStore(addr, argVar); } auto argArrayPtr = builder.emitElementAddress(argArrayVar, builder.getIntValue(builder.getIntType(), 0)); builder.emitCudaKernelLaunch( kernelDispatch->getBaseFn(), kernelDispatch->getDispatchSize(), kernelDispatch->getThreadGroupSize(), argArrayPtr, builder.emitGetTorchCudaStream()); instsToRemove.add(inst); } else if (auto getView = as(inst)) { builder.setInsertBefore(getView); auto makeView = builder.emitMakeTensorView(getView->getFullType(), inst->getOperand(0)); getView->replaceUsesWith(makeView); instsToRemove.add(getView); } else if (auto ret = as(inst)) { builder.setInsertBefore(ret); auto retVal = makeTargetTuple(builder, ret->getVal()); ret->setOperand(0, retVal); } } } for (auto inst : instsToRemove) inst->removeAndDeallocate(); } void generatePyTorchCppBinding(IRModule* module, DiagnosticSink* sink) { List workList; List cudaKernels; for (auto globalInst : module->getGlobalInsts()) { auto func = as(globalInst); if (!func) continue; if (func->findDecoration()) { workList.add(func); } else if (func->findDecoration()) { cudaKernels.add(func); } else { // Remove all other export decorations if this is not a cuda host func. if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); if (auto decor = func->findDecoration()) decor->removeAndDeallocate(); } } for (auto func : workList) generateCppBindingForFunc(func, sink); for (auto func : cudaKernels) { for (auto block = func->getFirstBlock(); block;) { auto nextBlock = block->getNextBlock(); block->removeAndDeallocate(); block = nextBlock; } } } // Remove all [TorchEntryPoint] functions when emitting CUDA source. void removeTorchKernels(IRModule* module) { List toRemove; for (auto globalInst : module->getGlobalInsts()) { if (!as(globalInst)) continue; if (globalInst->findDecoration()) toRemove.add(globalInst); } for (auto inst : toRemove) inst->removeAndDeallocate(); } }