Raw File
addMissingAwait.ts
/* @internal */
namespace ts.codefix {
    type ContextualTrackChangesFunction = (cb: (changeTracker: textChanges.ChangeTracker) => void) => FileTextChanges[];
    const fixId = "addMissingAwait";
    const propertyAccessCode = Diagnostics.Property_0_does_not_exist_on_type_1.code;
    const callableConstructableErrorCodes = [
        Diagnostics.This_expression_is_not_callable.code,
        Diagnostics.This_expression_is_not_constructable.code,
    ];
    const errorCodes = [
        Diagnostics.An_arithmetic_operand_must_be_of_type_any_number_bigint_or_an_enum_type.code,
        Diagnostics.The_left_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code,
        Diagnostics.The_right_hand_side_of_an_arithmetic_operation_must_be_of_type_any_number_bigint_or_an_enum_type.code,
        Diagnostics.Operator_0_cannot_be_applied_to_type_1.code,
        Diagnostics.Operator_0_cannot_be_applied_to_types_1_and_2.code,
        Diagnostics.This_condition_will_always_return_0_since_the_types_1_and_2_have_no_overlap.code,
        Diagnostics.Type_0_is_not_an_array_type.code,
        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type.code,
        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_Use_compiler_option_downlevelIteration_to_allow_iterating_of_iterators.code,
        Diagnostics.Type_0_is_not_an_array_type_or_a_string_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
        Diagnostics.Type_0_is_not_an_array_type_or_does_not_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
        Diagnostics.Type_0_must_have_a_Symbol_iterator_method_that_returns_an_iterator.code,
        Diagnostics.Type_0_must_have_a_Symbol_asyncIterator_method_that_returns_an_async_iterator.code,
        Diagnostics.Argument_of_type_0_is_not_assignable_to_parameter_of_type_1.code,
        propertyAccessCode,
        ...callableConstructableErrorCodes,
    ];

    registerCodeFix({
        fixIds: [fixId],
        errorCodes,
        getCodeActions: context => {
            const { sourceFile, errorCode, span, cancellationToken, program } = context;
            const expression = getAwaitableExpression(sourceFile, errorCode, span, cancellationToken, program);
            if (!expression) {
                return;
            }

            const checker = context.program.getTypeChecker();
            const trackChanges: ContextualTrackChangesFunction = cb => textChanges.ChangeTracker.with(context, cb);
            return compact([
                getDeclarationSiteFix(context, expression, errorCode, checker, trackChanges),
                getUseSiteFix(context, expression, errorCode, checker, trackChanges)]);
        },
        getAllCodeActions: context => {
            const { sourceFile, program, cancellationToken } = context;
            const checker = context.program.getTypeChecker();
            return codeFixAll(context, errorCodes, (t, diagnostic) => {
                const expression = getAwaitableExpression(sourceFile, diagnostic.code, diagnostic, cancellationToken, program);
                if (!expression) {
                    return;
                }
                const trackChanges: ContextualTrackChangesFunction = cb => (cb(t), []);
                return getDeclarationSiteFix(context, expression, diagnostic.code, checker, trackChanges)
                    || getUseSiteFix(context, expression, diagnostic.code, checker, trackChanges);
            });
        },
    });

    function getDeclarationSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction) {
        const { sourceFile } = context;
        const awaitableInitializer = findAwaitableInitializer(expression, sourceFile, checker);
        if (awaitableInitializer) {
            const initializerChanges = trackChanges(t => makeChange(t, errorCode, sourceFile, checker, awaitableInitializer));
            return createCodeFixActionNoFixId(
                "addMissingAwaitToInitializer",
                initializerChanges,
                [Diagnostics.Add_await_to_initializer_for_0, expression.getText(sourceFile)]);
        }
    }

    function getUseSiteFix(context: CodeFixContext | CodeFixAllContext, expression: Expression, errorCode: number, checker: TypeChecker, trackChanges: ContextualTrackChangesFunction) {
        const changes = trackChanges(t => makeChange(t, errorCode, context.sourceFile, checker, expression));
        return createCodeFixAction(fixId, changes, Diagnostics.Add_await, fixId, Diagnostics.Fix_all_expressions_possibly_missing_await);
    }

    function isMissingAwaitError(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program) {
        const checker = program.getDiagnosticsProducingTypeChecker();
        const diagnostics = checker.getDiagnostics(sourceFile, cancellationToken);
        return some(diagnostics, ({ start, length, relatedInformation, code }) =>
            isNumber(start) && isNumber(length) && textSpansEqual({ start, length }, span) &&
            code === errorCode &&
            !!relatedInformation &&
            some(relatedInformation, related => related.code === Diagnostics.Did_you_forget_to_use_await.code));
    }

    function getAwaitableExpression(sourceFile: SourceFile, errorCode: number, span: TextSpan, cancellationToken: CancellationToken, program: Program): Expression | undefined {
        const token = getTokenAtPosition(sourceFile, span.start);
        // Checker has already done work to determine that await might be possible, and has attached
        // related info to the node, so start by finding the expression that exactly matches up
        // with the diagnostic range.
        const expression = findAncestor(token, node => {
            if (node.getStart(sourceFile) < span.start || node.getEnd() > textSpanEnd(span)) {
                return "quit";
            }
            return isExpression(node) && textSpansEqual(span, createTextSpanFromNode(node, sourceFile));
        }) as Expression | undefined;

        return expression
            && isMissingAwaitError(sourceFile, errorCode, span, cancellationToken, program)
            && isInsideAwaitableBody(expression)
                ? expression
                : undefined;
    }

    function findAwaitableInitializer(expression: Node, sourceFile: SourceFile, checker: TypeChecker): Expression | undefined {
        if (!isIdentifier(expression)) {
            return;
        }

        const symbol = checker.getSymbolAtLocation(expression);
        if (!symbol) {
            return;
        }

        const declaration = tryCast(symbol.valueDeclaration, isVariableDeclaration);
        const variableName = tryCast(declaration && declaration.name, isIdentifier);
        const variableStatement = getAncestor(declaration, SyntaxKind.VariableStatement);
        if (!declaration || !variableStatement ||
            declaration.type ||
            !declaration.initializer ||
            variableStatement.getSourceFile() !== sourceFile ||
            hasModifier(variableStatement, ModifierFlags.Export) ||
            !variableName ||
            !isInsideAwaitableBody(declaration.initializer)) {
            return;
        }

        const isUsedElsewhere = FindAllReferences.Core.eachSymbolReferenceInFile(variableName, checker, sourceFile, identifier => {
            return identifier !== expression;
        });

        if (isUsedElsewhere) {
            return;
        }

        return declaration.initializer;
    }

    function isInsideAwaitableBody(node: Node) {
        return node.kind & NodeFlags.AwaitContext || !!findAncestor(node, ancestor =>
            ancestor.parent && isArrowFunction(ancestor.parent) && ancestor.parent.body === ancestor ||
            isBlock(ancestor) && (
                ancestor.parent.kind === SyntaxKind.FunctionDeclaration ||
                ancestor.parent.kind === SyntaxKind.FunctionExpression ||
                ancestor.parent.kind === SyntaxKind.ArrowFunction ||
                ancestor.parent.kind === SyntaxKind.MethodDeclaration));
    }

    function makeChange(changeTracker: textChanges.ChangeTracker, errorCode: number, sourceFile: SourceFile, checker: TypeChecker, insertionSite: Expression) {
        if (isBinaryExpression(insertionSite)) {
            const { left, right } = insertionSite;
            const leftType = checker.getTypeAtLocation(left);
            const rightType = checker.getTypeAtLocation(right);
            const newLeft = checker.getPromisedTypeOfPromise(leftType) ? createAwait(left) : left;
            const newRight = checker.getPromisedTypeOfPromise(rightType) ? createAwait(right) : right;
            changeTracker.replaceNode(sourceFile, left, newLeft);
            changeTracker.replaceNode(sourceFile, right, newRight);
        }
        else if (errorCode === propertyAccessCode && isPropertyAccessExpression(insertionSite.parent)) {
            changeTracker.replaceNode(
                sourceFile,
                insertionSite.parent.expression,
                createParen(createAwait(insertionSite.parent.expression)));
        }
        else if (contains(callableConstructableErrorCodes, errorCode) && isCallOrNewExpression(insertionSite.parent)) {
            changeTracker.replaceNode(sourceFile, insertionSite, createParen(createAwait(insertionSite)));
        }
        else {
            changeTracker.replaceNode(sourceFile, insertionSite, createAwait(insertionSite));
        }
    }
}
back to top