diff --git a/packages/pyright-internal/src/analyzer/constraintSolver.ts b/packages/pyright-internal/src/analyzer/constraintSolver.ts index 25dcd0804775..5a9e0d1cf9cd 100644 --- a/packages/pyright-internal/src/analyzer/constraintSolver.ts +++ b/packages/pyright-internal/src/analyzer/constraintSolver.ts @@ -642,7 +642,7 @@ export function updateTypeVarType( } } - typeVarContext.setTypeVarType(destType, lowerBound, lowerBoundNoLiterals, upperBound); //, tupleTypes); + typeVarContext.setTypeVarType(destType, lowerBound, lowerBoundNoLiterals, upperBound); } function assignTypeToConstrainedTypeVar( diff --git a/packages/pyright-internal/src/analyzer/typeEvaluator.ts b/packages/pyright-internal/src/analyzer/typeEvaluator.ts index 58b9e8721158..24855cbd8c8f 100644 --- a/packages/pyright-internal/src/analyzer/typeEvaluator.ts +++ b/packages/pyright-internal/src/analyzer/typeEvaluator.ts @@ -418,7 +418,6 @@ interface MatchedOverloadInfo { interface ValidateArgTypeOptions { skipUnknownArgCheck?: boolean; - skipOverloadArg?: boolean; isArgFirstPass?: boolean; conditionFilter?: TypeCondition[]; skipReportError?: boolean; @@ -11453,7 +11452,6 @@ export function createTypeEvaluator( { type, isIncomplete: matchResults.isTypeIncomplete }, { skipUnknownArgCheck, - skipOverloadArg: i === 0, isArgFirstPass: passCount > 1 && i === 0, conditionFilter: typeCondition, skipReportError: true, @@ -11955,7 +11953,16 @@ export function createTypeEvaluator( // Assign the argument type back to the expected type to assign // values to any unification variables. const typeVarContextClone = typeVarContext.clone(); - if (assignType(expectedType, argType, /* diag */ undefined, typeVarContextClone)) { + if ( + assignType( + expectedType, + argType, + /* diag */ undefined, + typeVarContextClone, + /* srcTypeVarContext */ undefined, + options?.isArgFirstPass ? AssignTypeFlags.ArgAssignmentFirstPass : AssignTypeFlags.Default + ) + ) { typeVarContext.copyFromClone(typeVarContextClone); } else { isCompatible = false; @@ -12029,56 +12036,6 @@ export function createTypeEvaluator( } } - // If we are asked to skip overload arguments, determine whether the argument - // is an explicit overload type, an overloaded class constructor, or a - // an overloaded callback protocol. - if (options.skipOverloadArg) { - if (isOverloadedFunction(argType)) { - return { - isCompatible, - argType, - isTypeIncomplete, - skippedOverloadArg: true, - skippedBareTypeVarExpectedType, - condition, - }; - } - - const concreteParamType = makeTopLevelTypeVarsConcrete(argParam.paramType); - if (isFunction(concreteParamType) || isOverloadedFunction(concreteParamType)) { - if (isInstantiableClass(argType)) { - const constructor = createFunctionFromConstructor(evaluatorInterface, argType); - if (constructor) { - return { - isCompatible, - argType, - isTypeIncomplete, - skippedOverloadArg: true, - skippedBareTypeVarExpectedType, - condition, - }; - } - } - - if (isClassInstance(argType)) { - const callMember = lookUpObjectMember(argType, '__call__', MemberAccessFlags.SkipInstanceMembers); - if (callMember) { - const memberType = getTypeOfMember(callMember); - if (isOverloadedFunction(memberType)) { - return { - isCompatible, - argType, - isTypeIncomplete, - skippedOverloadArg: true, - skippedBareTypeVarExpectedType, - condition, - }; - } - } - } - } - } - let assignTypeFlags = AssignTypeFlags.Default; if (argParam.isinstanceParam) { @@ -17617,7 +17574,7 @@ export function createTypeEvaluator( argParam, new TypeVarContext(), { type: newMethodType }, - { skipUnknownArgCheck: true, skipOverloadArg: true } + { skipUnknownArgCheck: true } ); paramMap.delete(arg.name.d.value); } else { @@ -24043,12 +24000,14 @@ export function createTypeEvaluator( return false; } - if (destTypeVarContext) { - destTypeVarContext.addSolutionSets(destTypeVarSignatures); - } + if (filteredOverloads.length === 1 || (flags & AssignTypeFlags.ArgAssignmentFirstPass) === 0) { + if (destTypeVarContext) { + destTypeVarContext.addSolutionSets(destTypeVarSignatures); + } - if (srcTypeVarContext) { - srcTypeVarContext.addSolutionSets(srcTypeVarSignatures); + if (srcTypeVarContext) { + srcTypeVarContext.addSolutionSets(srcTypeVarSignatures); + } } return true; diff --git a/packages/pyright-internal/src/analyzer/typeUtils.ts b/packages/pyright-internal/src/analyzer/typeUtils.ts index 0d14fd85c7d7..ccc642df65ef 100644 --- a/packages/pyright-internal/src/analyzer/typeUtils.ts +++ b/packages/pyright-internal/src/analyzer/typeUtils.ts @@ -4538,11 +4538,7 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer { return false; } - if (this._options.useDefaultForUnsolved || this._options.useUnknownForUnsolved) { - return true; - } - - return this._typeVarContext.hasSolveForScope(typeVar.priv.scopeId); + return true; } private _shouldReplaceUnsolvedTypeVar(typeVar: TypeVarType): boolean { diff --git a/packages/pyright-internal/src/tests/samples/solver37.py b/packages/pyright-internal/src/tests/samples/solver37.py new file mode 100644 index 000000000000..70508f6b6afd --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/solver37.py @@ -0,0 +1,26 @@ +# This sample tests a complex TypeVar unification scenario. + +from typing import Callable, Generic, TypeVar + +A = TypeVar("A") +B = TypeVar("B") + + +class Gen(Generic[A]): + ... + + +def func1(x: A) -> A: + ... + + +def func2(x: Gen[A], y: A) -> Gen[Gen[A]]: + ... + + +def func3(x: Gen[Gen[A]]) -> Gen[A]: + return func4(x, func1, func2) + + +def func4(x: Gen[A], id_: Callable[[B], B], step: Callable[[A, B], Gen[A]]) -> A: + ... diff --git a/packages/pyright-internal/src/tests/samples/solver38.py b/packages/pyright-internal/src/tests/samples/solver38.py new file mode 100644 index 000000000000..92a515af7795 --- /dev/null +++ b/packages/pyright-internal/src/tests/samples/solver38.py @@ -0,0 +1,42 @@ +# This sample tests a complex TypeVar unification scenario. + +from typing import Protocol, TypeVar + +A = TypeVar("A", contravariant=True) +B = TypeVar("B", covariant=True) +T = TypeVar("T") +U = TypeVar("U") +V = TypeVar("V") + + +class Getter(Protocol[A, B]): + def __call__(self, x: A, /) -> B: + ... + + +class PolymorphicListItemGetter(Protocol): + def __call__(self, l: list[T], /) -> T: + ... + + +def compose(get1: Getter[T, U], get2: Getter[U, V]) -> Getter[T, V]: + ... + + +class HasMethod(Protocol): + @property + def method(self) -> int: + ... + + +def get_value(x: HasMethod) -> int: + ... + + +def upcast(x: PolymorphicListItemGetter) -> Getter[list[HasMethod], HasMethod]: + return x + + +def test(poly_getter: PolymorphicListItemGetter): + compose(poly_getter, get_value) + compose(upcast(poly_getter), get_value) diff --git a/packages/pyright-internal/src/tests/samples/solverHigherOrder3.py b/packages/pyright-internal/src/tests/samples/solverHigherOrder3.py index 71920bb0debe..27647565a21f 100644 --- a/packages/pyright-internal/src/tests/samples/solverHigherOrder3.py +++ b/packages/pyright-internal/src/tests/samples/solverHigherOrder3.py @@ -18,7 +18,7 @@ def func2(x: T, y: T) -> T: reveal_type( - func2(func1, func2), expected_text="(x: U@func1, y: T@func1) -> (U@func1 | T@func1)" + func2(func1, func2), expected_text="(x: T(1)@func2, y: T(1)@func2) -> T(1)@func2" ) diff --git a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts index c73e15ec39d7..695433bf27be 100644 --- a/packages/pyright-internal/src/tests/typeEvaluator2.test.ts +++ b/packages/pyright-internal/src/tests/typeEvaluator2.test.ts @@ -783,6 +783,18 @@ test('Solver36', () => { TestUtils.validateResults(analysisResults, 1); }); +test('Solver37', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solver37.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + +test('Solver38', () => { + const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solver38.py']); + + TestUtils.validateResults(analysisResults, 0); +}); + test('SolverScoring1', () => { const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solverScoring1.py']);