Skip to content

Commit

Permalink
Changed inference logic for exception groups to more closely match th…
Browse files Browse the repository at this point in the history
…e runtime. If a non-base exception is targeted, the inferred type is now `ExceptionGroup` rather than `BaseExceptionGroup`. This addresses #9466.
  • Loading branch information
erictraut committed Nov 14, 2024
1 parent db2c9b0 commit f820bc4
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 4 deletions.
12 changes: 10 additions & 2 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19446,6 +19446,7 @@ export function createTypeEvaluator(

const exceptionTypeResult = getTypeOfExpression(node.d.typeExpr!);
const exceptionTypes = exceptionTypeResult.type;
let includesBaseException = false;

function getExceptionType(exceptionType: Type, errorNode: ExpressionNode) {
exceptionType = makeTopLevelTypeVarsConcrete(exceptionType);
Expand All @@ -19455,6 +19456,9 @@ export function createTypeEvaluator(
}

if (isInstantiableClass(exceptionType)) {
if (ClassType.isBuiltIn(exceptionType, 'BaseException')) {
includesBaseException = true;
}
return ClassType.cloneAsInstance(exceptionType);
}

Expand Down Expand Up @@ -19492,9 +19496,13 @@ export function createTypeEvaluator(
return getExceptionType(subType, node.d.typeExpr!);
});

// If this is an except group, wrap the exception type in an BaseExceptionGroup.
// If this is an except group, wrap the exception type in an ExceptionGroup
// or BaseExceptionGroup depending on whether the target exception is
// a BaseException.
if (node.d.isExceptGroup) {
targetType = getBuiltInObject(node, 'BaseExceptionGroup', [targetType]);
targetType = getBuiltInObject(node, includesBaseException ? 'BaseExceptionGroup' : 'ExceptionGroup', [
targetType,
]);
}

if (node.d.name) {
Expand Down
18 changes: 17 additions & 1 deletion packages/pyright-internal/src/tests/samples/exceptionGroup1.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def func1():

# This should generate an error if using Python 3.10 or earlier.
except* ValueError as e:
reveal_type(e, expected_text="BaseExceptionGroup[ValueError]")
reveal_type(e, expected_text="ExceptionGroup[ValueError]")
pass

# This should generate an error if using Python 3.10 or earlier.
Expand Down Expand Up @@ -105,3 +105,19 @@ def inner():
# return is not allowed in an except* block.
return



def func8():

try:
pass

# This should generate an error if using Python 3.10 or earlier.
except* (ValueError, FloatingPointError) as e:
reveal_type(e, expected_text="ExceptionGroup[ValueError | FloatingPointError]")
pass

# This should generate an error if using Python 3.10 or earlier.
except* BaseException as e:
reveal_type(e, expected_text="BaseExceptionGroup[BaseException]")
pass
2 changes: 1 addition & 1 deletion packages/pyright-internal/src/tests/typeEvaluator7.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -971,7 +971,7 @@ test('exceptionGroup1', () => {

configOptions.defaultPythonVersion = pythonVersion3_10;
const analysisResults1 = TestUtils.typeAnalyzeSampleFiles(['exceptionGroup1.py'], configOptions);
TestUtils.validateResults(analysisResults1, 28);
TestUtils.validateResults(analysisResults1, 34);

configOptions.defaultPythonVersion = pythonVersion3_11;
const analysisResults2 = TestUtils.typeAnalyzeSampleFiles(['exceptionGroup1.py'], configOptions);
Expand Down

0 comments on commit f820bc4

Please sign in to comment.