Skip to content

Commit

Permalink
Further simplification of the logic involved in type var solving.
Browse files Browse the repository at this point in the history
This addresses #8301 and #5855.
  • Loading branch information
erictraut committed Jul 30, 2024
1 parent b629011 commit cae2162
Show file tree
Hide file tree
Showing 7 changed files with 101 additions and 66 deletions.
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/analyzer/constraintSolver.ts
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ export function updateTypeVarType(
}
}

typeVarContext.setTypeVarType(destType, lowerBound, lowerBoundNoLiterals, upperBound); //, tupleTypes);
typeVarContext.setTypeVarType(destType, lowerBound, lowerBoundNoLiterals, upperBound);
}

function assignTypeToConstrainedTypeVar(
Expand Down
77 changes: 18 additions & 59 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -418,7 +418,6 @@ interface MatchedOverloadInfo {

interface ValidateArgTypeOptions {
skipUnknownArgCheck?: boolean;
skipOverloadArg?: boolean;
isArgFirstPass?: boolean;
conditionFilter?: TypeCondition[];
skipReportError?: boolean;
Expand Down Expand Up @@ -11453,7 +11452,6 @@ export function createTypeEvaluator(
{ type, isIncomplete: matchResults.isTypeIncomplete },
{
skipUnknownArgCheck,
skipOverloadArg: i === 0,
isArgFirstPass: passCount > 1 && i === 0,
conditionFilter: typeCondition,
skipReportError: true,
Expand Down Expand Up @@ -11955,7 +11953,16 @@ export function createTypeEvaluator(
// Assign the argument type back to the expected type to assign
// values to any unification variables.
const typeVarContextClone = typeVarContext.clone();
if (assignType(expectedType, argType, /* diag */ undefined, typeVarContextClone)) {
if (
assignType(
expectedType,
argType,
/* diag */ undefined,
typeVarContextClone,
/* srcTypeVarContext */ undefined,
options?.isArgFirstPass ? AssignTypeFlags.ArgAssignmentFirstPass : AssignTypeFlags.Default
)
) {
typeVarContext.copyFromClone(typeVarContextClone);
} else {
isCompatible = false;
Expand Down Expand Up @@ -12029,56 +12036,6 @@ export function createTypeEvaluator(
}
}

// If we are asked to skip overload arguments, determine whether the argument
// is an explicit overload type, an overloaded class constructor, or a
// an overloaded callback protocol.
if (options.skipOverloadArg) {
if (isOverloadedFunction(argType)) {
return {
isCompatible,
argType,
isTypeIncomplete,
skippedOverloadArg: true,
skippedBareTypeVarExpectedType,
condition,
};
}

const concreteParamType = makeTopLevelTypeVarsConcrete(argParam.paramType);
if (isFunction(concreteParamType) || isOverloadedFunction(concreteParamType)) {
if (isInstantiableClass(argType)) {
const constructor = createFunctionFromConstructor(evaluatorInterface, argType);
if (constructor) {
return {
isCompatible,
argType,
isTypeIncomplete,
skippedOverloadArg: true,
skippedBareTypeVarExpectedType,
condition,
};
}
}

if (isClassInstance(argType)) {
const callMember = lookUpObjectMember(argType, '__call__', MemberAccessFlags.SkipInstanceMembers);
if (callMember) {
const memberType = getTypeOfMember(callMember);
if (isOverloadedFunction(memberType)) {
return {
isCompatible,
argType,
isTypeIncomplete,
skippedOverloadArg: true,
skippedBareTypeVarExpectedType,
condition,
};
}
}
}
}
}

let assignTypeFlags = AssignTypeFlags.Default;

if (argParam.isinstanceParam) {
Expand Down Expand Up @@ -17617,7 +17574,7 @@ export function createTypeEvaluator(
argParam,
new TypeVarContext(),
{ type: newMethodType },
{ skipUnknownArgCheck: true, skipOverloadArg: true }
{ skipUnknownArgCheck: true }
);
paramMap.delete(arg.name.d.value);
} else {
Expand Down Expand Up @@ -24043,12 +24000,14 @@ export function createTypeEvaluator(
return false;
}

if (destTypeVarContext) {
destTypeVarContext.addSolutionSets(destTypeVarSignatures);
}
if (filteredOverloads.length === 1 || (flags & AssignTypeFlags.ArgAssignmentFirstPass) === 0) {
if (destTypeVarContext) {
destTypeVarContext.addSolutionSets(destTypeVarSignatures);
}

if (srcTypeVarContext) {
srcTypeVarContext.addSolutionSets(srcTypeVarSignatures);
if (srcTypeVarContext) {
srcTypeVarContext.addSolutionSets(srcTypeVarSignatures);
}
}

return true;
Expand Down
6 changes: 1 addition & 5 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4538,11 +4538,7 @@ class ApplySolvedTypeVarsTransformer extends TypeVarTransformer {
return false;
}

if (this._options.useDefaultForUnsolved || this._options.useUnknownForUnsolved) {
return true;
}

return this._typeVarContext.hasSolveForScope(typeVar.priv.scopeId);
return true;
}

private _shouldReplaceUnsolvedTypeVar(typeVar: TypeVarType): boolean {
Expand Down
26 changes: 26 additions & 0 deletions packages/pyright-internal/src/tests/samples/solver37.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# This sample tests a complex TypeVar unification scenario.

from typing import Callable, Generic, TypeVar

A = TypeVar("A")
B = TypeVar("B")


class Gen(Generic[A]):
...


def func1(x: A) -> A:
...


def func2(x: Gen[A], y: A) -> Gen[Gen[A]]:
...


def func3(x: Gen[Gen[A]]) -> Gen[A]:
return func4(x, func1, func2)


def func4(x: Gen[A], id_: Callable[[B], B], step: Callable[[A, B], Gen[A]]) -> A:
...
42 changes: 42 additions & 0 deletions packages/pyright-internal/src/tests/samples/solver38.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# This sample tests a complex TypeVar unification scenario.

from typing import Protocol, TypeVar

A = TypeVar("A", contravariant=True)
B = TypeVar("B", covariant=True)
T = TypeVar("T")
U = TypeVar("U")
V = TypeVar("V")


class Getter(Protocol[A, B]):
def __call__(self, x: A, /) -> B:
...


class PolymorphicListItemGetter(Protocol):
def __call__(self, l: list[T], /) -> T:
...


def compose(get1: Getter[T, U], get2: Getter[U, V]) -> Getter[T, V]:
...


class HasMethod(Protocol):
@property
def method(self) -> int:
...


def get_value(x: HasMethod) -> int:
...


def upcast(x: PolymorphicListItemGetter) -> Getter[list[HasMethod], HasMethod]:
return x


def test(poly_getter: PolymorphicListItemGetter):
compose(poly_getter, get_value)
compose(upcast(poly_getter), get_value)
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def func2(x: T, y: T) -> T:


reveal_type(
func2(func1, func2), expected_text="(x: U@func1, y: T@func1) -> (U@func1 | T@func1)"
func2(func1, func2), expected_text="(x: T(1)@func2, y: T(1)@func2) -> T(1)@func2"
)


Expand Down
12 changes: 12 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator2.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,18 @@ test('Solver36', () => {
TestUtils.validateResults(analysisResults, 1);
});

test('Solver37', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solver37.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('Solver38', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solver38.py']);

TestUtils.validateResults(analysisResults, 0);
});

test('SolverScoring1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['solverScoring1.py']);

Expand Down

0 comments on commit cae2162

Please sign in to comment.