Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Type guards using discriminant properties of string literal types
  • Loading branch information
ahejlsberg committed Jun 10, 2016
commit 4a8f94a5535548a3b9cc8eaf4c02c0830fba088f
56 changes: 41 additions & 15 deletions src/compiler/binder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -577,20 +577,14 @@ namespace ts {
}
}

function isNarrowableReference(expr: Expression): boolean {
return expr.kind === SyntaxKind.Identifier ||
expr.kind === SyntaxKind.ThisKeyword ||
expr.kind === SyntaxKind.PropertyAccessExpression && isNarrowableReference((<PropertyAccessExpression>expr).expression);
}

function isNarrowingExpression(expr: Expression): boolean {
switch (expr.kind) {
case SyntaxKind.Identifier:
case SyntaxKind.ThisKeyword:
case SyntaxKind.PropertyAccessExpression:
return isNarrowableReference(expr);
case SyntaxKind.CallExpression:
return true;
return hasNarrowableArgument(<CallExpression>expr);
case SyntaxKind.ParenthesizedExpression:
return isNarrowingExpression((<ParenthesizedExpression>expr).expression);
case SyntaxKind.BinaryExpression:
Expand All @@ -601,6 +595,27 @@ namespace ts {
return false;
}

function isNarrowableReference(expr: Expression): boolean {
return expr.kind === SyntaxKind.Identifier ||
expr.kind === SyntaxKind.ThisKeyword ||
expr.kind === SyntaxKind.PropertyAccessExpression && isNarrowableReference((<PropertyAccessExpression>expr).expression);
}

function hasNarrowableArgument(expr: CallExpression) {
if (expr.arguments) {
for (const argument of expr.arguments) {
if (isNarrowableReference(argument)) {
return true;
}
}
}
if (expr.expression.kind === SyntaxKind.PropertyAccessExpression &&
isNarrowableReference((<PropertyAccessExpression>expr.expression).expression)) {
return true;
}
return false;
}

function isNarrowingBinaryExpression(expr: BinaryExpression) {
switch (expr.operatorToken.kind) {
case SyntaxKind.EqualsToken:
Expand All @@ -609,21 +624,32 @@ namespace ts {
case SyntaxKind.ExclamationEqualsToken:
case SyntaxKind.EqualsEqualsEqualsToken:
case SyntaxKind.ExclamationEqualsEqualsToken:
if (isNarrowingExpression(expr.left) && (expr.right.kind === SyntaxKind.NullKeyword || expr.right.kind === SyntaxKind.Identifier)) {
return true;
}
if (expr.left.kind === SyntaxKind.TypeOfExpression && isNarrowingExpression((<TypeOfExpression>expr.left).expression) && expr.right.kind === SyntaxKind.StringLiteral) {
return true;
}
return false;
return (expr.right.kind === SyntaxKind.NullKeyword || expr.right.kind === SyntaxKind.Identifier && (<Identifier>expr.right).text === "undefined") && isNarrowableOperand(expr.left) ||
expr.left.kind === SyntaxKind.PropertyAccessExpression && isNarrowableReference((<PropertyAccessExpression>expr.left).expression) ||
expr.left.kind === SyntaxKind.TypeOfExpression && isNarrowableOperand((<TypeOfExpression>expr.left).expression) && expr.right.kind === SyntaxKind.StringLiteral;
case SyntaxKind.InstanceOfKeyword:
return isNarrowingExpression(expr.left);
return isNarrowableOperand(expr.left);
case SyntaxKind.CommaToken:
return isNarrowingExpression(expr.right);
}
return false;
}

function isNarrowableOperand(expr: Expression): boolean {
switch (expr.kind) {
case SyntaxKind.ParenthesizedExpression:
return isNarrowableOperand((<ParenthesizedExpression>expr).expression);
case SyntaxKind.BinaryExpression:
switch ((<BinaryExpression>expr).operatorToken.kind) {
case SyntaxKind.EqualsToken:
return isNarrowableOperand((<BinaryExpression>expr).left);
case SyntaxKind.CommaToken:
return isNarrowableOperand((<BinaryExpression>expr).right);
}
}
return isNarrowableReference(expr);
}

function createBranchLabel(): FlowLabel {
return {
flags: FlowFlags.BranchLabel,
Expand Down
47 changes: 39 additions & 8 deletions src/compiler/checker.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5160,7 +5160,6 @@ namespace ts {
if (hasProperty(stringLiteralTypes, text)) {
return stringLiteralTypes[text];
}

const type = stringLiteralTypes[text] = <StringLiteralType>createType(TypeFlags.StringLiteral);
type.text = text;
return type;
Expand Down Expand Up @@ -5625,6 +5624,10 @@ namespace ts {
return checkTypeComparableTo(source, target, /*errorNode*/ undefined);
}

function areTypesComparable(type1: Type, type2: Type): boolean {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

Add a comment that this relationship is bidirectional and doesn't report errors.

return isTypeComparableTo(type1, type2) || isTypeComparableTo(type2, type1);
}

function checkTypeSubtypeOf(source: Type, target: Type, errorNode: Node, headMessage?: DiagnosticMessage, containingMessageChain?: DiagnosticMessageChain): boolean {
return checkTypeRelatedTo(source, target, subtypeRelation, errorNode, headMessage, containingMessageChain);
}
Expand Down Expand Up @@ -6805,8 +6808,10 @@ namespace ts {
return !!getPropertyOfType(type, "0");
}

function isStringLiteralType(type: Type) {
return type.flags & TypeFlags.StringLiteral;
function isStringLiteralUnionType(type: Type): boolean {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should really be isUnionWithStringLiterals

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It can be either a single string literal type or a union of string literal types. I'm not sure your suggestion better covers that.

return type.flags & TypeFlags.StringLiteral ? true :
type.flags & TypeFlags.Union ? forEach((<UnionType>type).types, isStringLiteralUnionType) :
false;
}

/**
Expand Down Expand Up @@ -7873,6 +7878,9 @@ namespace ts {
if (isNullOrUndefinedLiteral(expr.right)) {
return narrowTypeByNullCheck(type, expr, assumeTrue);
}
if (expr.left.kind === SyntaxKind.PropertyAccessExpression) {
return narrowTypeByDiscriminant(type, expr, assumeTrue);
}
if (expr.left.kind === SyntaxKind.TypeOfExpression && expr.right.kind === SyntaxKind.StringLiteral) {
return narrowTypeByTypeof(type, expr, assumeTrue);
}
Expand Down Expand Up @@ -7903,6 +7911,33 @@ namespace ts {
return getTypeWithFacts(type, facts);
}

function narrowTypeByDiscriminant(type: Type, expr: BinaryExpression, assumeTrue: boolean): Type {
// We have '==', '!=', '===', or '!==' operator with property access on left
if (!(type.flags & TypeFlags.Union) || !isMatchingReference(reference, (<PropertyAccessExpression>expr.left).expression)) {
return type;
}
const propName = (<PropertyAccessExpression>expr.left).name.text;
const propType = getTypeOfPropertyOfType(type, propName);
if (!propType || !isStringLiteralUnionType(propType)) {
return type;
}
const discriminantType = expr.right.kind === SyntaxKind.StringLiteral ? getStringLiteralTypeForText((<StringLiteral>expr.right).text) : checkExpression(expr.right);
if (!isStringLiteralUnionType(discriminantType)) {
return type;
}
if (expr.operatorToken.kind === SyntaxKind.ExclamationEqualsToken ||
expr.operatorToken.kind === SyntaxKind.ExclamationEqualsEqualsToken) {
assumeTrue = !assumeTrue;
}
if (assumeTrue) {
return getUnionType(filter((<UnionType>type).types, t => areTypesComparable(getTypeOfPropertyOfType(t, propName), discriminantType)));
}
if (discriminantType.flags & TypeFlags.StringLiteral) {
return getUnionType(filter((<UnionType>type).types, t => getTypeOfPropertyOfType(t, propName) !== discriminantType));
}
return type;
}

function narrowTypeByTypeof(type: Type, expr: BinaryExpression, assumeTrue: boolean): Type {
// We have '==', '!=', '====', or !==' operator with 'typeof xxx' on the left
// and string literal on the right
Expand Down Expand Up @@ -8892,10 +8927,6 @@ namespace ts {
return applyToContextualType(type, t => getIndexTypeOfStructuredType(t, kind));
}

function contextualTypeIsStringLiteralType(type: Type): boolean {
return !!(type.flags & TypeFlags.Union ? forEach((<UnionType>type).types, isStringLiteralType) : isStringLiteralType(type));
}

// Return true if the given contextual type is a tuple-like type
function contextualTypeIsTupleLikeType(type: Type): boolean {
return !!(type.flags & TypeFlags.Union ? forEach((<UnionType>type).types, isTupleLikeType) : isTupleLikeType(type));
Expand Down Expand Up @@ -12557,7 +12588,7 @@ namespace ts {

function checkStringLiteralExpression(node: StringLiteral): Type {
const contextualType = getContextualType(node);
if (contextualType && contextualTypeIsStringLiteralType(contextualType)) {
if (contextualType && isStringLiteralUnionType(contextualType)) {
return getStringLiteralTypeForText(node.text);
}

Expand Down