From 674f699052ae0444a95b5fb817f039cc1854b92d Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 25 Apr 2023 21:47:02 -0700 Subject: [PATCH 1/3] feat: more flexible "in" operator and filter expressions --- .../validator/expression-validator.ts | 17 +- .../access-policy/expression-writer.ts | 41 ++++- .../typescript-expression-transformer.ts | 55 +++++- packages/schema/src/res/stdlib.zmodel | 5 +- .../tests/generator/expression-writer.test.ts | 162 +++++++++++++++++- .../validation/attribute-validation.test.ts | 11 -- 6 files changed, 248 insertions(+), 43 deletions(-) diff --git a/packages/schema/src/language-server/validator/expression-validator.ts b/packages/schema/src/language-server/validator/expression-validator.ts index 9a17414e8..18c826158 100644 --- a/packages/schema/src/language-server/validator/expression-validator.ts +++ b/packages/schema/src/language-server/validator/expression-validator.ts @@ -1,6 +1,6 @@ -import { BinaryExpr, Expression, isArrayExpr, isBinaryExpr, isEnum, isLiteralExpr } from '@zenstackhq/language/ast'; +import { BinaryExpr, Expression, isBinaryExpr, isEnum } from '@zenstackhq/language/ast'; import { ValidationAcceptor } from 'langium'; -import { getDataModelFieldReference, isAuthInvocation, isEnumFieldReference } from '../../utils/ast-utils'; +import { isAuthInvocation } from '../../utils/ast-utils'; import { AstValidator } from '../types'; /** @@ -37,21 +37,12 @@ export default class ExpressionValidator implements AstValidator { private validateBinaryExpr(expr: BinaryExpr, accept: ValidationAcceptor) { switch (expr.operator) { case 'in': { - if (!getDataModelFieldReference(expr.left)) { - accept('error', 'left operand of "in" must be a field reference', { node: expr.left }); - } - if (typeof expr.left.$resolvedType?.decl !== 'string' && !isEnum(expr.left.$resolvedType?.decl)) { accept('error', 'left operand of "in" must be of scalar type', { node: expr.left }); } - if ( - !( - isArrayExpr(expr.right) && - expr.right.items.every((item) => isLiteralExpr(item) || isEnumFieldReference(item)) - ) - ) { - accept('error', 'right operand of "in" must be an array of literals or enum values', { + if (!expr.right.$resolvedType?.array) { + accept('error', 'right operand of "in" must be an array', { node: expr.right, }); } diff --git a/packages/schema/src/plugins/access-policy/expression-writer.ts b/packages/schema/src/plugins/access-policy/expression-writer.ts index 03fe11fb1..238eed6c1 100644 --- a/packages/schema/src/plugins/access-policy/expression-writer.ts +++ b/packages/schema/src/plugins/access-policy/expression-writer.ts @@ -153,14 +153,35 @@ export class ExpressionWriter { } private writeIn(expr: BinaryExpr) { + const leftIsFieldAccess = this.isFieldAccess(expr.left); + const rightIsFieldAccess = this.isFieldAccess(expr.right); + this.block(() => { - this.writeFieldCondition( - expr.left, - () => { - this.plain(expr.right); - }, - 'in' - ); + if (!leftIsFieldAccess && !rightIsFieldAccess) { + // 'in' without referencing fields + this.guard(() => this.plain(expr)); + } else if (leftIsFieldAccess && !rightIsFieldAccess) { + // 'in' with left referencing a field, right is an array literal + this.writeFieldCondition( + expr.left, + () => { + this.plain(expr.right); + }, + 'in' + ); + } else if (!leftIsFieldAccess && rightIsFieldAccess) { + // 'in' with right referencing an array field, left is a literal + // transform it into a 'has' filter + this.writeFieldCondition( + expr.right, + () => { + this.plain(expr.left); + }, + 'has' + ); + } else { + throw new PluginError('"in" operator cannot be used with field references on both sides'); + } }); } @@ -520,6 +541,12 @@ export class ExpressionWriter { } if (FILTER_OPERATOR_FUNCTIONS.includes(funcDecl.name)) { + if (!expr.args.some((arg) => this.isFieldAccess(arg.value))) { + // filter functions without referencing fields + this.block(() => this.guard(() => this.plain(expr))); + return; + } + let valueArg = expr.args[1]?.value; // isEmpty function is zero arity, it's mapped to a boolean literal diff --git a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts index cb6dfba5e..af71c798f 100644 --- a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts +++ b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts @@ -12,7 +12,8 @@ import { ThisExpr, UnaryExpr, } from '@zenstackhq/language/ast'; -import { PluginError } from '@zenstackhq/sdk'; +import { getLiteral, PluginError } from '@zenstackhq/sdk'; +import { FILTER_OPERATOR_FUNCTIONS } from '../../language-server/constants'; import { isAuthInvocation } from '../../utils/ast-utils'; import { isFutureExpr } from './utils'; @@ -91,8 +92,54 @@ export default class TypeScriptExpressionTransformer { } private invocation(expr: InvocationExpr) { + if (!expr.function.ref) { + throw new PluginError(`Unresolved InvocationExpr`); + } + if (isAuthInvocation(expr)) { return 'user'; + } else if (FILTER_OPERATOR_FUNCTIONS.includes(expr.function.ref.name)) { + // arguments are already type-checked + + const arg0 = this.transform(expr.args[0].value); + let result: string; + switch (expr.function.ref.name) { + case 'contains': { + const caseInsensitive = getLiteral(expr.args[2]?.value) === true; + if (caseInsensitive) { + result = `${arg0}?.toLowerCase().includes(${this.transform( + expr.args[1].value + )}?.toLowerCase())`; + } else { + result = `${arg0}?.includes(${this.transform(expr.args[1].value)})`; + } + break; + } + case 'search': + throw new PluginError('"search" function must be used against a field'); + case 'startsWith': + result = `${arg0}?.startsWith(${this.transform(expr.args[1].value)})`; + break; + case 'endsWith': + result = `${arg0}?.endsWith(${this.transform(expr.args[1].value)})`; + break; + case 'has': + result = `${arg0}?.includes(${this.transform(expr.args[1].value)})`; + break; + case 'hasEvery': + result = `${this.transform(expr.args[1].value)}?.every((item) => ${arg0}?.includes(item))`; + break; + case 'hasSome': + result = `${this.transform(expr.args[1].value)}?.some((item) => ${arg0}?.includes(item))`; + break; + case 'isEmpty': + result = `${arg0}?.length === 0`; + break; + default: + throw new PluginError(`Function invocation is not supported: ${expr.function.ref?.name}`); + } + + return `(${result} ?? false)`; } else { throw new PluginError(`Function invocation is not supported: ${expr.function.ref?.name}`); } @@ -138,6 +185,10 @@ export default class TypeScriptExpressionTransformer { } private binary(expr: BinaryExpr): string { - return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right)})`; + if (expr.operator === 'in') { + return `(${this.transform(expr.right)}?.includes(${this.transform(expr.left)}) ?? false)`; + } else { + return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right)})`; + } } } diff --git a/packages/schema/src/res/stdlib.zmodel b/packages/schema/src/res/stdlib.zmodel index f76ef7251..8787f6e57 100644 --- a/packages/schema/src/res/stdlib.zmodel +++ b/packages/schema/src/res/stdlib.zmodel @@ -99,9 +99,10 @@ function future(): Any { } /* - * If the field value contains the search string + * If the field value contains the search string. By default, the search is case-sensitive, + * but you can override the behavior with the "caseInSensitive" argument. */ -function contains(field: String, search: String, caseSensitive: Boolean?): Boolean { +function contains(field: String, search: String, caseInSensitive: Boolean?): Boolean { } /* diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index e3296a720..35584f1fa 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -943,7 +943,7 @@ describe('Expression Writer Tests', () => { ); }); - it('filter operators', async () => { + it('filter operators field access', async () => { await check( ` enum Role { @@ -1134,6 +1134,153 @@ describe('Expression Writer Tests', () => { }); }); +it('filter operators non-field access', async () => { + const userInit = `{ id: 'user1', email: 'test@zenstack.dev', roles: [Role.ADMIN] }`; + const prelude = ` + enum Role { + USER + ADMIN + } + + model User { + id String @id + email String + roles Role[] + } + `; + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', ADMIN in auth().roles) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.roles??null)?.includes(Role.ADMIN)??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + roles Role[] + @@allow('all', ADMIN in roles) + } + `, + (model) => model.attributes[0].args[1].value, + `{roles:{has:Role.ADMIN}}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', contains(auth().email, 'test')) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.email??null)?.includes('test')??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', contains(auth().email, 'test', true)) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.email??null)?.toLowerCase().includes('test'?.toLowerCase())??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', startsWith(auth().email, 'test')) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.email??null)?.startsWith('test')??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', endsWith(auth().email, 'test')) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.email??null)?.endsWith('test')??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', has(auth().roles, ADMIN)) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.roles??null)?.includes(Role.ADMIN)??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', hasEvery(auth().roles, [ADMIN, USER])) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:([Role.ADMIN,Role.USER]?.every((item)=>(user?.roles??null)?.includes(item))??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', hasSome(auth().roles, [USER, ADMIN])) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:([Role.USER,Role.ADMIN]?.some((item)=>(user?.roles??null)?.includes(item))??false)}`, + userInit + ); + + await check( + ` + ${prelude} + model Test { + id String @id + @@allow('all', isEmpty(auth().roles)) + } + `, + (model) => model.attributes[0].args[1].value, + `{zenstack_guard:((user?.roles??null)?.length===0??false)}`, + userInit + ); +}); + async function check(schema: string, getExpr: (model: DataModel) => Expression, expected: string, userInit?: string) { if (!schema.includes('datasource ')) { schema = @@ -1155,12 +1302,6 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, overwrite: true, }); - // inject user variable - sf.addVariableStatement({ - declarationKind: VariableDeclarationKind.Const, - declarations: [{ name: 'user', initializer: userInit ?? '{ id: "user1" }' }], - }); - // inject enums model.declarations .filter((d) => isEnum(d)) @@ -1180,6 +1321,12 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, }); }); + // inject user variable + sf.addVariableStatement({ + declarationKind: VariableDeclarationKind.Const, + declarations: [{ name: 'user', initializer: userInit ?? '{ id: "user1" }' }], + }); + sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [ @@ -1197,7 +1344,6 @@ async function check(schema: string, getExpr: (model: DataModel) => Expression, for (const d of project.getPreEmitDiagnostics()) { console.warn(`${d.getLineNumber()}: ${d.getMessageText()}`); } - console.log(`Generated source: ${sourcePath}`); throw new Error('Compilation errors occurred'); } diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index d541d43a9..97ee001c6 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -461,17 +461,6 @@ describe('Attribute tests', () => { `) ).toContain('argument is not assignable to parameter'); - expect( - await loadModelWithError(` - ${prelude} - model M { - id String @id - i Int[] - @@allow('all', 1 in i) - } - `) - ).toContain('left operand of "in" must be a field reference'); - expect( await loadModelWithError(` ${prelude} From dd63e525e3661db60651ca7f1d125684d3594761 Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 25 Apr 2023 22:01:14 -0700 Subject: [PATCH 2/3] fix tests --- .../schema/tests/schema/validation/attribute-validation.test.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/schema/tests/schema/validation/attribute-validation.test.ts b/packages/schema/tests/schema/validation/attribute-validation.test.ts index 97ee001c6..faf88eb9f 100644 --- a/packages/schema/tests/schema/validation/attribute-validation.test.ts +++ b/packages/schema/tests/schema/validation/attribute-validation.test.ts @@ -470,7 +470,7 @@ describe('Attribute tests', () => { @@allow('all', i in 1) } `) - ).toContain('right operand of "in" must be an array of literals or enum values'); + ).toContain('right operand of "in" must be an array'); expect( await loadModelWithError(` From a32baa7afc3b16f1006f47af505aa403f3becfcf Mon Sep 17 00:00:00 2001 From: ymc9 <104139426+ymc9@users.noreply.github.com> Date: Tue, 25 Apr 2023 23:36:52 -0700 Subject: [PATCH 3/3] more tests & optimization --- .../access-policy/policy-guard-generator.ts | 3 +- .../typescript-expression-transformer.ts | 68 ++++++++------ .../schema/src/plugins/model-meta/index.ts | 1 + .../tests/generator/expression-writer.test.ts | 18 ++-- .../e2e/filter-function-coverage.test.ts | 88 ++++++++++++++++++- 5 files changed, 137 insertions(+), 41 deletions(-) diff --git a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts index 7d04d1691..62079cef0 100644 --- a/packages/schema/src/plugins/access-policy/policy-guard-generator.ts +++ b/packages/schema/src/plugins/access-policy/policy-guard-generator.ts @@ -52,6 +52,7 @@ export default class PolicyGenerator { const project = createProject(); const sf = project.createSourceFile(path.join(output, 'policy.ts'), undefined, { overwrite: true }); + sf.addStatements('/* eslint-disable */'); sf.addImportDeclaration({ namedImports: [{ name: 'type QueryContext' }, { name: 'hasAllFields' }], @@ -361,7 +362,7 @@ export default class PolicyGenerator { func.addStatements( `const user = hasAllFields(context.user, [${userIdFields .map((f) => "'" + f.name + "'") - .join(', ')}]) ? context.user : null;` + .join(', ')}]) ? context.user as any : null;` ); } diff --git a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts index af71c798f..98dde9004 100644 --- a/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts +++ b/packages/schema/src/plugins/access-policy/typescript-expression-transformer.ts @@ -29,17 +29,17 @@ export default class TypeScriptExpressionTransformer { constructor(private readonly isPostGuard = false) {} /** - * - * @param expr + * Transforms the given expression to a TypeScript expression. + * @param normalizeUndefined if undefined values should be normalized to null * @returns */ - transform(expr: Expression): string { + transform(expr: Expression, normalizeUndefined = true): string { switch (expr.$type) { case LiteralExpr: return this.literal(expr as LiteralExpr); case ArrayExpr: - return this.array(expr as ArrayExpr); + return this.array(expr as ArrayExpr, normalizeUndefined); case NullExpr: return this.null(); @@ -51,16 +51,16 @@ export default class TypeScriptExpressionTransformer { return this.reference(expr as ReferenceExpr); case InvocationExpr: - return this.invocation(expr as InvocationExpr); + return this.invocation(expr as InvocationExpr, normalizeUndefined); case MemberAccessExpr: - return this.memberAccess(expr as MemberAccessExpr); + return this.memberAccess(expr as MemberAccessExpr, normalizeUndefined); case UnaryExpr: - return this.unary(expr as UnaryExpr); + return this.unary(expr as UnaryExpr, normalizeUndefined); case BinaryExpr: - return this.binary(expr as BinaryExpr); + return this.binary(expr as BinaryExpr, normalizeUndefined); default: throw new PluginError(`Unsupported expression type: ${expr.$type}`); @@ -73,7 +73,7 @@ export default class TypeScriptExpressionTransformer { return 'id'; } - private memberAccess(expr: MemberAccessExpr) { + private memberAccess(expr: MemberAccessExpr, normalizeUndefined: boolean) { if (!expr.member.ref) { throw new PluginError(`Unresolved MemberAccessExpr`); } @@ -86,12 +86,16 @@ export default class TypeScriptExpressionTransformer { } return expr.member.ref.name; } else { - // normalize field access to null instead of undefined to avoid accidentally use undefined in filter - return `(${this.transform(expr.operand)}?.${expr.member.ref.name} ?? null)`; + if (normalizeUndefined) { + // normalize field access to null instead of undefined to avoid accidentally use undefined in filter + return `(${this.transform(expr.operand, normalizeUndefined)}?.${expr.member.ref.name} ?? null)`; + } else { + return `${this.transform(expr.operand, normalizeUndefined)}?.${expr.member.ref.name}`; + } } } - private invocation(expr: InvocationExpr) { + private invocation(expr: InvocationExpr, normalizeUndefined: boolean) { if (!expr.function.ref) { throw new PluginError(`Unresolved InvocationExpr`); } @@ -101,36 +105,43 @@ export default class TypeScriptExpressionTransformer { } else if (FILTER_OPERATOR_FUNCTIONS.includes(expr.function.ref.name)) { // arguments are already type-checked - const arg0 = this.transform(expr.args[0].value); + const arg0 = this.transform(expr.args[0].value, false); let result: string; switch (expr.function.ref.name) { case 'contains': { const caseInsensitive = getLiteral(expr.args[2]?.value) === true; if (caseInsensitive) { result = `${arg0}?.toLowerCase().includes(${this.transform( - expr.args[1].value + expr.args[1].value, + normalizeUndefined )}?.toLowerCase())`; } else { - result = `${arg0}?.includes(${this.transform(expr.args[1].value)})`; + result = `${arg0}?.includes(${this.transform(expr.args[1].value, normalizeUndefined)})`; } break; } case 'search': throw new PluginError('"search" function must be used against a field'); case 'startsWith': - result = `${arg0}?.startsWith(${this.transform(expr.args[1].value)})`; + result = `${arg0}?.startsWith(${this.transform(expr.args[1].value, normalizeUndefined)})`; break; case 'endsWith': - result = `${arg0}?.endsWith(${this.transform(expr.args[1].value)})`; + result = `${arg0}?.endsWith(${this.transform(expr.args[1].value, normalizeUndefined)})`; break; case 'has': - result = `${arg0}?.includes(${this.transform(expr.args[1].value)})`; + result = `${arg0}?.includes(${this.transform(expr.args[1].value, normalizeUndefined)})`; break; case 'hasEvery': - result = `${this.transform(expr.args[1].value)}?.every((item) => ${arg0}?.includes(item))`; + result = `${this.transform( + expr.args[1].value, + normalizeUndefined + )}?.every((item) => ${arg0}?.includes(item))`; break; case 'hasSome': - result = `${this.transform(expr.args[1].value)}?.some((item) => ${arg0}?.includes(item))`; + result = `${this.transform( + expr.args[1].value, + normalizeUndefined + )}?.some((item) => ${arg0}?.includes(item))`; break; case 'isEmpty': result = `${arg0}?.length === 0`; @@ -168,8 +179,8 @@ export default class TypeScriptExpressionTransformer { return 'null'; } - private array(expr: ArrayExpr) { - return `[${expr.items.map((item) => this.transform(item)).join(', ')}]`; + private array(expr: ArrayExpr, normalizeUndefined: boolean) { + return `[${expr.items.map((item) => this.transform(item, normalizeUndefined)).join(', ')}]`; } private literal(expr: LiteralExpr) { @@ -180,15 +191,18 @@ export default class TypeScriptExpressionTransformer { } } - private unary(expr: UnaryExpr): string { - return `(${expr.operator} ${this.transform(expr.operand)})`; + private unary(expr: UnaryExpr, normalizeUndefined: boolean): string { + return `(${expr.operator} ${this.transform(expr.operand, normalizeUndefined)})`; } - private binary(expr: BinaryExpr): string { + private binary(expr: BinaryExpr, normalizeUndefined: boolean): string { if (expr.operator === 'in') { - return `(${this.transform(expr.right)}?.includes(${this.transform(expr.left)}) ?? false)`; + return `(${this.transform(expr.right, false)}?.includes(${this.transform( + expr.left, + normalizeUndefined + )}) ?? false)`; } else { - return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right)})`; + return `(${this.transform(expr.left)} ${expr.operator} ${this.transform(expr.right, normalizeUndefined)})`; } } } diff --git a/packages/schema/src/plugins/model-meta/index.ts b/packages/schema/src/plugins/model-meta/index.ts index 271dcde64..61a539e78 100644 --- a/packages/schema/src/plugins/model-meta/index.ts +++ b/packages/schema/src/plugins/model-meta/index.ts @@ -43,6 +43,7 @@ export default async function run(model: Model, options: PluginOptions) { } const sf = project.createSourceFile(path.join(output, 'model-meta.ts'), undefined, { overwrite: true }); + sf.addStatements('/* eslint-disable */'); sf.addVariableStatement({ declarationKind: VariableDeclarationKind.Const, declarations: [{ name: 'metadata', initializer: (writer) => generateModelMetadata(dataModels, writer) }], diff --git a/packages/schema/tests/generator/expression-writer.test.ts b/packages/schema/tests/generator/expression-writer.test.ts index 35584f1fa..178abe78e 100644 --- a/packages/schema/tests/generator/expression-writer.test.ts +++ b/packages/schema/tests/generator/expression-writer.test.ts @@ -1158,7 +1158,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.roles??null)?.includes(Role.ADMIN)??false)}`, + `{zenstack_guard:(user?.roles?.includes(Role.ADMIN)??false)}`, userInit ); @@ -1185,7 +1185,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.email??null)?.includes('test')??false)}`, + `{zenstack_guard:(user?.email?.includes('test')??false)}`, userInit ); @@ -1198,7 +1198,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.email??null)?.toLowerCase().includes('test'?.toLowerCase())??false)}`, + `{zenstack_guard:(user?.email?.toLowerCase().includes('test'?.toLowerCase())??false)}`, userInit ); @@ -1211,7 +1211,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.email??null)?.startsWith('test')??false)}`, + `{zenstack_guard:(user?.email?.startsWith('test')??false)}`, userInit ); @@ -1224,7 +1224,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.email??null)?.endsWith('test')??false)}`, + `{zenstack_guard:(user?.email?.endsWith('test')??false)}`, userInit ); @@ -1237,7 +1237,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.roles??null)?.includes(Role.ADMIN)??false)}`, + `{zenstack_guard:(user?.roles?.includes(Role.ADMIN)??false)}`, userInit ); @@ -1250,7 +1250,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:([Role.ADMIN,Role.USER]?.every((item)=>(user?.roles??null)?.includes(item))??false)}`, + `{zenstack_guard:([Role.ADMIN,Role.USER]?.every((item)=>user?.roles?.includes(item))??false)}`, userInit ); @@ -1263,7 +1263,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:([Role.USER,Role.ADMIN]?.some((item)=>(user?.roles??null)?.includes(item))??false)}`, + `{zenstack_guard:([Role.USER,Role.ADMIN]?.some((item)=>user?.roles?.includes(item))??false)}`, userInit ); @@ -1276,7 +1276,7 @@ it('filter operators non-field access', async () => { } `, (model) => model.attributes[0].args[1].value, - `{zenstack_guard:((user?.roles??null)?.length===0??false)}`, + `{zenstack_guard:(user?.roles?.length===0??false)}`, userInit ); }); diff --git a/tests/integration/tests/e2e/filter-function-coverage.test.ts b/tests/integration/tests/e2e/filter-function-coverage.test.ts index daedbbef3..a7d6088c4 100644 --- a/tests/integration/tests/e2e/filter-function-coverage.test.ts +++ b/tests/integration/tests/e2e/filter-function-coverage.test.ts @@ -1,7 +1,7 @@ import { loadSchema } from '@zenstackhq/testtools'; describe('Filter Function Coverage Tests', () => { - it('contains case-sensitive', async () => { + it('contains case-sensitive field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -16,7 +16,27 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'bac' } })).toResolveTruthy(); }); - it('startsWith', async () => { + it('contains case-sensitive non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', contains(auth().name, 'a')) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bcd' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toResolveTruthy(); + }); + + it('startsWith field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -31,7 +51,27 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'abc' } })).toResolveTruthy(); }); - it('endsWith', async () => { + it('startsWith non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', startsWith(auth().name, 'a')) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'abc' }).foo.create({ data: {} })).toResolveTruthy(); + }); + + it('endsWith field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -46,7 +86,27 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'bca' } })).toResolveTruthy(); }); - it('in', async () => { + it('endsWith non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', endsWith(auth().name, 'a')) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bac' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'bca' }).foo.create({ data: {} })).toResolveTruthy(); + }); + + it('in left field', async () => { const { withPresets } = await loadSchema( ` model Foo { @@ -60,4 +120,24 @@ describe('Filter Function Coverage Tests', () => { await expect(withPresets().foo.create({ data: { string: 'c' } })).toBeRejectedByPolicy(); await expect(withPresets().foo.create({ data: { string: 'b' } })).toResolveTruthy(); }); + + it('in non-field', async () => { + const { withPresets } = await loadSchema( + ` + model User { + id String @id + name String + } + + model Foo { + id String @id @default(cuid()) + @@allow('all', auth().name in ['abc', 'bcd']) + } + ` + ); + + await expect(withPresets().foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'abd' }).foo.create({ data: {} })).toBeRejectedByPolicy(); + await expect(withPresets({ id: 'user1', name: 'abc' }).foo.create({ data: {} })).toResolveTruthy(); + }); });