diff --git a/packages/cubejs-api-gateway/src/gateway.ts b/packages/cubejs-api-gateway/src/gateway.ts index 08c462c00d178..12bf993d5b9f9 100644 --- a/packages/cubejs-api-gateway/src/gateway.ts +++ b/packages/cubejs-api-gateway/src/gateway.ts @@ -1343,6 +1343,10 @@ class ApiGateway { currentQuery = this.parseMemberExpressionsInQuery(currentQuery); } + if ((currentQuery as any).maskedMembers) { + throw new UserError('maskedMembers cannot be provided in the query'); + } + return { normalizedQuery: (normalizeQuery(currentQuery, persistent, cacheMode)), hasExpressionsInQuery @@ -1372,6 +1376,8 @@ class ApiGateway { context ) : queryWithRlsFilters; + rewrittenQuery.maskedMembers = queryWithRlsFilters.maskedMembers; + // applyRowLevelSecurity may add new filters which may contain raw member expressions // if that's the case, we should run an extra pass of parsing here to make sure // nothing breaks down the road diff --git a/packages/cubejs-api-gateway/src/query.js b/packages/cubejs-api-gateway/src/query.js index af9f52c798b36..88d5132648848 100644 --- a/packages/cubejs-api-gateway/src/query.js +++ b/packages/cubejs-api-gateway/src/query.js @@ -195,7 +195,10 @@ const querySchema = Joi.object().keys({ responseFormat: Joi.valid('default', 'compact', 'columnar'), subqueryJoins: Joi.array().items(subqueryJoin), joinHints: Joi.array().items(joinHint), - maskedMembers: Joi.array().items(Joi.string()), + maskedMembers: Joi.array().items(Joi.object().keys({ + member: Joi.string().required(), + filter: Joi.object(), + })), }); const normalizeQueryOrder = order => { diff --git a/packages/cubejs-api-gateway/src/types/query.ts b/packages/cubejs-api-gateway/src/types/query.ts index 8224edb0f5266..26b0e26e263da 100644 --- a/packages/cubejs-api-gateway/src/types/query.ts +++ b/packages/cubejs-api-gateway/src/types/query.ts @@ -166,7 +166,7 @@ interface NormalizedQuery extends Query { filters?: NormalizedQueryFilter[]; rowLimit?: null | number; order?: { id: string; desc: boolean }[]; - maskedMembers?: string[]; + maskedMembers?: { member: string; filter?: any }[]; } export { diff --git a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js index 30528a028ca41..4107babf4709f 100644 --- a/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js +++ b/packages/cubejs-schema-compiler/src/adapter/BaseQuery.js @@ -253,7 +253,14 @@ export class BaseQuery { securityContext: {}, ...this.options.contextSymbols, }; - this.maskedMembers = new Set(this.options.maskedMembers || []); + this.maskedMembers = new Set(); + this.memberMaskFilters = {}; + for (const item of this.options.maskedMembers || []) { + this.maskedMembers.add(item.member); + if (item.filter) { + this.memberMaskFilters[item.member] = item.filter; + } + } this.compilerCache = this.compilers.compiler.compilerCache; this.queryCache = this.compilerCache.getQueryCache({ measures: this.options.measures, @@ -3299,13 +3306,18 @@ export class BaseQuery { this.safeEvaluateSymbolContext().currentMember = memberPath; try { - if (this.maskedMembers && this.maskedMembers.has(memberPath) && !memberExpressionType) { + if (this.maskedMembers && this.maskedMembers.has(memberPath) && !memberExpressionType && + !this.safeEvaluateSymbolContext().skipMasking) { // In ungrouped queries, only apply static masks to measures. // SQL masks (mask.sql) reference columns that don't apply per-row. const isMeasure = type === 'measure'; const isUngrouped = this.options.ungrouped; const hasSqlMask = symbol.mask && typeof symbol.mask === 'object' && symbol.mask.sql; if (!isMeasure || !isUngrouped || !hasSqlMask) { + const maskFilter = this.memberMaskFilters && this.memberMaskFilters[memberPath]; + if (maskFilter) { + return this.conditionalMemberMaskSql(cubeName, name, symbol, maskFilter); + } return this.memberMaskSql(cubeName, name, symbol); } } @@ -3486,6 +3498,37 @@ export class BaseQuery { return this.defaultMaskSql(symbol.type); } + conditionalMemberMaskSql(cubeName, name, symbol, maskFilter) { + const maskedSql = this.memberMaskSql(cubeName, name, symbol); + const result = this.evaluateSymbolSqlWithContext( + () => { + const filterSql = this.maskFilterToSql(maskFilter); + if (!filterSql) { + return maskedSql; + } + const originalSql = this.autoPrefixAndEvaluateSql(cubeName, symbol.sql); + return this.caseWhenStatement([{ sql: filterSql, label: originalSql }], maskedSql); + }, + { skipMasking: true, currentMember: null } + ); + return result; + } + + maskFilterToSql(filter) { + if (!filter) return null; + const filterItems = this.extractFiltersAsTree([filter]); + if (!filterItems.length) return null; + const initialized = filterItems.map(this.initFilter.bind(this)); + if (initialized.length === 1) { + return initialized[0].filterToWhere(); + } + const groupFilter = this.newGroupFilter({ + operator: 'and', + values: initialized, + }); + return groupFilter.filterToWhere(); + } + defaultMaskSql(memberType) { const envMasks = { string: getEnv('accessPolicyMaskString'), diff --git a/packages/cubejs-schema-compiler/test/unit/transpilers.test.ts b/packages/cubejs-schema-compiler/test/unit/transpilers.test.ts index c2ec0c71a8e39..5854317cc0ee2 100644 --- a/packages/cubejs-schema-compiler/test/unit/transpilers.test.ts +++ b/packages/cubejs-schema-compiler/test/unit/transpilers.test.ts @@ -410,7 +410,7 @@ describe('Transpilers', () => { { measures: ['Test.count'], dimensions: ['Test.secret'], - maskedMembers: ['Test.secret'], + maskedMembers: [{ member: 'Test.secret' }], } ); const sql = query.buildSqlAndParams(); diff --git a/packages/cubejs-schema-compiler/test/unit/yaml-schema.test.ts b/packages/cubejs-schema-compiler/test/unit/yaml-schema.test.ts index bef4e2b69b76d..c4501ef926463 100644 --- a/packages/cubejs-schema-compiler/test/unit/yaml-schema.test.ts +++ b/packages/cubejs-schema-compiler/test/unit/yaml-schema.test.ts @@ -1665,7 +1665,7 @@ cubes: { measures: ['orders.count'], dimensions: ['orders.status'], - maskedMembers: ['orders.status'], + maskedMembers: [{ member: 'orders.status' }], contextSymbols: { securityContext: { cubeCloud: { userAttributes: { hasStatusAccess: true } } } } @@ -1806,7 +1806,7 @@ views: const query = new PostgresQuery(compilers, { measures: ['users_secure_view.users_count'], dimensions: ['users_secure_view.users_city_sensitive_masked'], - maskedMembers: ['users_secure_view.users_city_sensitive_masked'], + maskedMembers: [{ member: 'users_secure_view.users_city_sensitive_masked' }], contextSymbols: { securityContext: { cubeCloud: { groups } } }, @@ -1845,4 +1845,214 @@ views: }); }); }); + + describe('Conditional masking with row-level filters (memberMaskFilters)', () => { + it('generates CASE WHEN with row filter for masked members that have conditional full access', async () => { + const compilers = prepareYamlCompiler(` +cubes: + - name: users + sql_table: public.users + dimensions: + - name: id + sql: id + type: number + primary_key: true + - name: city + sql: city + type: string + - name: data_region + sql: data_region + type: string + measures: + - name: count + type: count + `); + + await compilers.compiler.compile(); + + const query = new PostgresQuery(compilers, { + measures: ['users.count'], + dimensions: ['users.city'], + maskedMembers: [{ + member: 'users.city', + filter: { + member: 'users.data_region', + operator: 'equals', + values: ['RESEARCH', 'DEMO'], + } + }], + }); + const [sql] = query.buildSqlAndParams(); + expect(sql).toMatch(/CASE\s+WHEN/); + expect(sql).toMatch(/WHEN.*data_region.*THEN.*city.*ELSE.*NULL.*END/s); + }); + + it('generates CASE WHEN with AND row filter for multiple filter conditions', async () => { + const compilers = prepareYamlCompiler(` +cubes: + - name: users + sql_table: public.users + dimensions: + - name: id + sql: id + type: number + primary_key: true + - name: city + sql: city + type: string + - name: data_region + sql: data_region + type: string + - name: region_lock + sql: region_lock + type: number + measures: + - name: count + type: count + `); + + await compilers.compiler.compile(); + + const query = new PostgresQuery(compilers, { + measures: ['users.count'], + dimensions: ['users.city'], + maskedMembers: [{ + member: 'users.city', + filter: { + and: [ + { + member: 'users.data_region', + operator: 'equals', + values: ['RESEARCH'], + }, + { + member: 'users.region_lock', + operator: 'equals', + values: ['0'], + } + ] + } + }], + }); + const [sql] = query.buildSqlAndParams(); + expect(sql).toMatch(/CASE\s+WHEN/); + expect(sql).toMatch(/WHEN.*AND.*THEN.*city.*ELSE.*NULL.*END/s); + }); + + it('uses mask.sql as the ELSE branch when dimension has a custom mask', async () => { + const compilers = prepareYamlCompiler(` +cubes: + - name: users + sql_table: public.users + dimensions: + - name: id + sql: id + type: number + primary_key: true + - name: city + sql: city + type: string + mask: + sql: "'***MASKED***'" + - name: data_region + sql: data_region + type: string + measures: + - name: count + type: count + `); + + await compilers.compiler.compile(); + + const query = new PostgresQuery(compilers, { + measures: ['users.count'], + dimensions: ['users.city'], + maskedMembers: [{ + member: 'users.city', + filter: { + member: 'users.data_region', + operator: 'equals', + values: ['RESEARCH'], + } + }], + }); + const [sql] = query.buildSqlAndParams(); + expect(sql).toMatch(/CASE\s+WHEN/); + expect(sql).toMatch(/WHEN.*data_region.*THEN.*city.*ELSE.*MASKED.*END/s); + }); + + it('applies regular masking (no CASE WHEN) when no memberMaskFilters', async () => { + const compilers = prepareYamlCompiler(` +cubes: + - name: users + sql_table: public.users + dimensions: + - name: id + sql: id + type: number + primary_key: true + - name: city + sql: city + type: string + measures: + - name: count + type: count + `); + + await compilers.compiler.compile(); + + const query = new PostgresQuery(compilers, { + measures: ['users.count'], + dimensions: ['users.city'], + maskedMembers: [{ member: 'users.city' }], + }); + const [sql] = query.buildSqlAndParams(); + expect(sql).not.toMatch(/CASE\s+WHEN/); + expect(sql).toContain('NULL'); + }); + + it('does not recurse when filter member is also masked', async () => { + const compilers = prepareYamlCompiler(` +cubes: + - name: items + sql_table: public.items + dimensions: + - name: id + sql: id + type: number + primary_key: true + - name: product_id + sql: product_id + type: number + - name: price + sql: price + type: number + mask: -1 + measures: + - name: count + type: count + `); + + await compilers.compiler.compile(); + + const query = new PostgresQuery(compilers, { + measures: ['items.count'], + dimensions: ['items.product_id', 'items.price'], + maskedMembers: [ + { + member: 'items.product_id', + filter: { member: 'items.product_id', operator: 'lte', values: ['3'] } + }, + { + member: 'items.price', + filter: { member: 'items.product_id', operator: 'lte', values: ['3'] } + }, + ], + }); + const [sql] = query.buildSqlAndParams(); + expect(sql).toMatch(/CASE\s+WHEN/); + expect(sql).toMatch(/product_id/); + expect(sql).not.toMatch(/Maximum call stack/); + }); + }); }); diff --git a/packages/cubejs-server-core/src/core/CompilerApi.ts b/packages/cubejs-server-core/src/core/CompilerApi.ts index 0750a360b5cc9..afca38ddc8cce 100644 --- a/packages/cubejs-server-core/src/core/CompilerApi.ts +++ b/packages/cubejs-server-core/src/core/CompilerApi.ts @@ -548,6 +548,7 @@ export class CompilerApi { const viewFiltersPerCubePerRole: Record> = {}; const hasAllowAllForCube: Record = {}; const maskedMembersSet = new Set(); + const memberMaskFiltersMap: Record = {}; for (const cubeName of queryCubes) { const cube = cubeEvaluator.cubeFromPath(cubeName); @@ -680,26 +681,58 @@ export class CompilerApi { }); // Determine which members need masking: a member is masked if no covering - // policy grants it full access via memberLevel AND at least one covering - // policy defines memberMasking that includes the member. - // Masking follows the same pattern as row-level security: it is applied - // at both cube and view levels. When a cube is accessed through a view, - // both the cube's and the view's masking policies are evaluated. + // policy grants it unconditional full access via memberLevel AND at least + // one covering policy defines memberMasking that includes the member. + // + // When a policy grants full memberLevel access but also has row_level filters, + // the full access is conditional on the row filter. In that case, we track + // the row filters so that the generated SQL uses: + // CASE WHEN {rowFilter} THEN {originalValue} ELSE {maskedValue} END + // This ensures that rows outside the filter range see masked values. const cubeMembersInQuery = Array.from(queryMemberNames).filter( memberName => memberName.startsWith(`${cubeName}.`) ); for (const memberName of cubeMembersInQuery) { - const hasFullAccessInAnyPolicy = policiesWithMemberAccess.some(policy => { - if (!policy.memberLevel) return true; - return policy.memberLevel.includesMembers.includes(memberName) && + const hasUnconditionalFullAccess = policiesWithMemberAccess.some(policy => { + if (!policy.memberLevel) { + return !policy.rowLevel || policy.rowLevel.allowAll; + } + const inIncludes = policy.memberLevel.includesMembers.includes(memberName) && !policy.memberLevel.excludesMembers.includes(memberName); + return inIncludes && (!policy.rowLevel || policy.rowLevel.allowAll); }); - if (!hasFullAccessInAnyPolicy && policiesWithMemberAccess.length > 0) { - const isMaskedByAnyPolicy = policiesWithMemberAccess.some( - (policy) => policy.memberMasking && policy.memberMasking.includesMembers.includes(memberName) && !policy.memberMasking.excludesMembers.includes(memberName) + + if (!hasUnconditionalFullAccess) { + const hasMaskingPolicy = policiesWithMemberAccess.some( + (policy) => policy.memberMasking && + policy.memberMasking.includesMembers.includes(memberName) && + !policy.memberMasking.excludesMembers.includes(memberName) ); - if (isMaskedByAnyPolicy) { + + if (hasMaskingPolicy) { + const conditionalFullAccessPolicies = policiesWithMemberAccess.filter(policy => { + const hasFullMemberAccess = !policy.memberLevel || + (policy.memberLevel.includesMembers.includes(memberName) && + !policy.memberLevel.excludesMembers.includes(memberName)); + return hasFullMemberAccess && + policy.rowLevel && !policy.rowLevel.allowAll && + policy.rowLevel.filters?.length > 0; + }); + maskedMembersSet.add(memberName); + if (conditionalFullAccessPolicies.length > 0) { + const policyFilters = conditionalFullAccessPolicies.map(policy => { + const filters = (policy.rowLevel.filters || []).map( + (filter: any) => this.evaluateNestedFilter(filter, cube, context, cubeEvaluator) + ); + return filters.length === 1 ? filters[0] : { and: filters }; + }); + if (policyFilters.length > 0) { + memberMaskFiltersMap[memberName] = policyFilters.length === 1 + ? policyFilters[0] + : { or: policyFilters }; + } + } } } } @@ -746,7 +779,10 @@ export class CompilerApi { query.filters.push(rlsFilter); } if (maskedMembersSet.size > 0) { - query.maskedMembers = Array.from(maskedMembersSet); + query.maskedMembers = Array.from(maskedMembersSet).map(member => ({ + member, + filter: memberMaskFiltersMap[member], + })); } return { query, denied: false }; } diff --git a/packages/cubejs-testing/birdbox-fixtures/rbac/cube.js b/packages/cubejs-testing/birdbox-fixtures/rbac/cube.js index 6a4d8ad267b99..192d11c9cac7c 100644 --- a/packages/cubejs-testing/birdbox-fixtures/rbac/cube.js +++ b/packages/cubejs-testing/birdbox-fixtures/rbac/cube.js @@ -245,6 +245,40 @@ module.exports = { }, }; } + if (user === 'conditional_mask_user') { + if (password && password !== 'conditional_mask_password') { + throw new Error(`Password doesn't match for ${user}`); + } + return { + password, + superuser: false, + securityContext: { + auth: { + username: 'conditional_mask_user', + userAttributes: {}, + roles: ['conditional_mask_role'], + groups: [], + }, + }, + }; + } + if (user === 'conditional_mask_multi_user') { + if (password && password !== 'conditional_mask_multi_password') { + throw new Error(`Password doesn't match for ${user}`); + } + return { + password, + superuser: false, + securityContext: { + auth: { + username: 'conditional_mask_multi_user', + userAttributes: {}, + roles: ['conditional_mask_role', 'conditional_mask_role_extra'], + groups: [], + }, + }, + }; + } throw new Error(`User "${user}" doesn't exist`); } }; diff --git a/packages/cubejs-testing/birdbox-fixtures/rbac/model/cubes/conditional_masking_test.yaml b/packages/cubejs-testing/birdbox-fixtures/rbac/model/cubes/conditional_masking_test.yaml new file mode 100644 index 0000000000000..3d778e72b493d --- /dev/null +++ b/packages/cubejs-testing/birdbox-fixtures/rbac/model/cubes/conditional_masking_test.yaml @@ -0,0 +1,53 @@ +cubes: + - name: conditional_masking_test + sql_table: public.line_items + + dimensions: + - name: id + sql: id + type: number + primary_key: true + + - name: product_id + sql: product_id + type: number + + - name: price + sql: price + type: number + mask: -1 + + measures: + - name: count + type: count + + - name: total_price + sql: price + type: sum + + access_policy: + - role: "*" + member_level: + includes: [] + member_masking: + includes: "*" + + - role: "conditional_mask_role" + member_level: + includes: "*" + row_level: + filters: + - member: product_id + operator: lte + values: + - "3" + + - role: "conditional_mask_role_extra" + member_level: + includes: "*" + row_level: + filters: + - member: product_id + operator: equals + values: + - "5" diff --git a/packages/cubejs-testing/test/smoke-rbac.test.ts b/packages/cubejs-testing/test/smoke-rbac.test.ts index 7599c570cc099..fd810cf39d65f 100644 --- a/packages/cubejs-testing/test/smoke-rbac.test.ts +++ b/packages/cubejs-testing/test/smoke-rbac.test.ts @@ -491,6 +491,83 @@ describe('Cube RBAC Engine', () => { }); }); + /** + * Conditional masking: when a policy grants full member_level access WITH + * row_level filters, the masking is conditional on the row filter. + * Rows matching the filter see unmasked values; other rows see masked values. + * + * conditional_masking_test cube: + * - role "*": member_level includes=[], member_masking includes="*" + * - role "conditional_mask_role": member_level includes="*", row_level filter product_id <= 3 + * + * For conditional_mask_user (role: conditional_mask_role): + * - product_id dimension: rows with product_id <= 3 show real value, others show masked (-1 for price) + */ + describe('RBAC conditional masking with row-level filters via SQL API', () => { + let connection: PgClient; + + beforeAll(async () => { + connection = await createPostgresClient('conditional_mask_user', 'conditional_mask_password'); + }); + + afterAll(async () => { + await connection.end(); + }, JEST_AFTER_ALL_DEFAULT_TIMEOUT); + + test('conditional_masking_test returns CASE WHEN masked values based on row filter', async () => { + const res = await connection.query( + 'SELECT product_id, price FROM conditional_masking_test ORDER BY product_id LIMIT 10' + ); + expect(res.rows.length).toBeGreaterThan(0); + for (const row of res.rows) { + if (Number(row.product_id) <= 3) { + // Rows matching the row filter should have real (unmasked) price + expect(Number(row.price)).not.toBe(-1); + } else { + // Rows NOT matching the row filter should have masked price (-1) + expect(Number(row.price)).toBe(-1); + } + } + }); + }); + + /** + * Multiple conditional policies use OR across policies: + * - role "conditional_mask_role": product_id <= 3 + * - role "conditional_mask_role_extra": product_id = 5 + * + * For conditional_mask_multi_user (both roles): + * Unmasked when product_id <= 3 OR product_id = 5, masked otherwise. + */ + describe('RBAC conditional masking with multiple policies (OR across policies)', () => { + let connection: PgClient; + + beforeAll(async () => { + connection = await createPostgresClient('conditional_mask_multi_user', 'conditional_mask_multi_password'); + }); + + afterAll(async () => { + await connection.end(); + }, JEST_AFTER_ALL_DEFAULT_TIMEOUT); + + test('unmasked when any policy filter matches (OR across policies)', async () => { + const res = await connection.query( + 'SELECT product_id, price FROM conditional_masking_test ORDER BY product_id LIMIT 10' + ); + expect(res.rows.length).toBeGreaterThan(0); + for (const row of res.rows) { + const pid = Number(row.product_id); + if (pid <= 3 || pid === 5) { + // Matches either policy filter → unmasked + expect(Number(row.price)).not.toBe(-1); + } else { + // Matches neither policy filter → masked + expect(Number(row.price)).toBe(-1); + } + } + }); + }); + /** * View masking tests — masking follows the RLS pattern and is applied at * both cube and view levels. If a cube masks a member, it stays masked diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/cube_bridge/base_query_options.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/cube_bridge/base_query_options.rs index 92b6759a52d4a..c164ec45cf5b6 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/cube_bridge/base_query_options.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/cube_bridge/base_query_options.rs @@ -14,6 +14,12 @@ use std::any::Any; use std::collections::HashMap; use std::rc::Rc; +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct MaskedMemberItem { + pub member: String, + pub filter: Option, +} + #[derive(Serialize, Deserialize, Debug, Clone)] pub struct TimeDimension { pub dimension: String, @@ -77,7 +83,7 @@ pub struct BaseQueryOptionsStatic { #[serde(rename = "convertTzForRawTimeDimension")] pub convert_tz_for_raw_time_dimension: Option, #[serde(rename = "maskedMembers")] - pub masked_members: Option>, + pub masked_members: Option>, #[serde(rename = "memberToAlias", default)] pub member_to_alias: Option>, } diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/query_tools.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/query_tools.rs index 377a3b78a068e..79d69376998af 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/query_tools.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/query_tools.rs @@ -1,5 +1,6 @@ use super::sql_evaluator::Compiler; use super::ParamsAllocator; +use crate::cube_bridge::base_query_options::{FilterItem, MaskedMemberItem}; use crate::cube_bridge::base_tools::BaseTools; use crate::cube_bridge::evaluator::CubeEvaluator; use crate::cube_bridge::join_definition::JoinDefinition; @@ -32,6 +33,7 @@ pub struct QueryTools { timezone: Tz, convert_tz_for_raw_time_dimension: bool, masked_members: HashSet, + member_mask_filters: HashMap, } impl QueryTools { @@ -43,7 +45,7 @@ impl QueryTools { timezone_name: Option, export_annotated_sql: bool, convert_tz_for_raw_time_dimension: bool, - masked_members: Option>, + masked_members: Option>, member_to_alias: Option>, ) -> Result, CubeError> { let templates_render = base_tools.sql_templates()?; @@ -61,6 +63,16 @@ impl QueryTools { timezone.clone(), member_to_alias, ))); + let mut masked_set = HashSet::new(); + let mut mask_filters = HashMap::new(); + if let Some(items) = masked_members { + for item in items { + masked_set.insert(item.member.clone()); + if let Some(filter) = item.filter { + mask_filters.insert(item.member, filter); + } + } + } Ok(Rc::new(Self { cube_evaluator, base_tools, @@ -70,7 +82,8 @@ impl QueryTools { evaluator_compiler, timezone, convert_tz_for_raw_time_dimension, - masked_members: masked_members.unwrap_or_default().into_iter().collect(), + masked_members: masked_set, + member_mask_filters: mask_filters, })) } @@ -78,6 +91,10 @@ impl QueryTools { self.masked_members.contains(member_path) } + pub fn member_mask_filter(&self, member_path: &str) -> Option<&FilterItem> { + self.member_mask_filters.get(member_path) + } + pub fn cube_evaluator(&self) -> &Rc { &self.cube_evaluator } diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/masked.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/masked.rs index b4d768d85954c..781151f89a71f 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/masked.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/masked.rs @@ -1,8 +1,12 @@ use super::SqlNode; +use crate::cube_bridge::base_query_options::FilterItem as NativeFilterItem; +use crate::plan::filter::FilterItem; +use crate::planner::filter::compiler::FilterCompiler; use crate::planner::query_tools::QueryTools; use crate::planner::sql_evaluator::MemberSymbol; use crate::planner::sql_evaluator::SqlEvaluatorVisitor; use crate::planner::sql_templates::PlanSqlTemplates; +use crate::planner::VisitorContext; use cubenativeutils::CubeError; use std::any::Any; use std::rc::Rc; @@ -39,9 +43,10 @@ impl MaskedSqlNode { if !query_tools.is_member_masked(&full_name) { return Ok(None); } - if let Some(mask_call) = node.mask_sql() { - // In ungrouped mode, skip SQL masks (has deps) on measures - // since they reference aggregated columns not meaningful per-row. + + let mask_filter = query_tools.member_mask_filter(&full_name).cloned(); + + let masked_sql = if let Some(mask_call) = node.mask_sql() { if self.ungrouped { if let MemberSymbol::Measure(_) = node.as_ref() { if mask_call.dependencies_count() > 0 { @@ -49,14 +54,74 @@ impl MaskedSqlNode { } } } - Ok(Some(mask_call.eval( + mask_call.eval( visitor, + node_processor.clone(), + query_tools.clone(), + templates, + )? + } else { + "(NULL)".to_string() + }; + + if let Some(filter_item) = mask_filter { + let original_sql = self.input.to_sql( + visitor, + node, + query_tools.clone(), node_processor, - query_tools, templates, - )?)) + )?; + let filter_sql = + self.compile_filter_to_sql(&filter_item, query_tools.clone(), templates)?; + if let Some(filter_sql) = filter_sql { + Ok(Some(templates.case( + None, + vec![(filter_sql, original_sql)], + Some(masked_sql), + )?)) + } else { + Ok(Some(masked_sql)) + } + } else { + Ok(Some(masked_sql)) + } + } + + fn compile_filter_to_sql( + &self, + native_filter: &NativeFilterItem, + query_tools: Rc, + templates: &PlanSqlTemplates, + ) -> Result, CubeError> { + let filter_item = { + let mut compiler = query_tools.evaluator_compiler().borrow_mut(); + let mut filter_compiler = FilterCompiler::new(&mut compiler, query_tools.clone()); + filter_compiler.add_item(native_filter)?; + let (dimension_filters, _, _) = filter_compiler.extract_result(); + if dimension_filters.is_empty() { + return Ok(None); + } + if dimension_filters.len() == 1 { + dimension_filters.into_iter().next().unwrap() + } else { + FilterItem::Group(Rc::new(crate::plan::filter::FilterGroup::new( + crate::plan::filter::FilterGroupOperator::And, + dimension_filters, + ))) + } + }; + // TODO: support FILTER_PARAMS in mask filter SQL by passing + // proper FiltersContext with filter_params_columns + let context = Rc::new(VisitorContext::new_with_node_processor( + query_tools.clone(), + self.input.clone(), + )); + let sql = filter_item.to_sql(templates, context)?; + if sql.is_empty() { + Ok(None) } else { - Ok(Some("(NULL)".to_string())) + Ok(Some(sql)) } } } diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/visitor_context.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/visitor_context.rs index 793321ec6ac51..91a007e140797 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/visitor_context.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/planner/visitor_context.rs @@ -60,6 +60,19 @@ impl VisitorContext { } } + pub fn new_with_node_processor( + query_tools: Rc, + node_processor: Rc, + ) -> Self { + Self { + query_tools, + node_processor, + cube_ref_evaluator: Rc::new(CubeRefEvaluator::new(HashMap::new(), HashMap::new())), + all_filters: None, + filters_context: FiltersContext::default(), + } + } + pub fn make_visitor(&self, query_tools: Rc) -> SqlEvaluatorVisitor { SqlEvaluatorVisitor::new( query_tools, diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/base_query_options.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/base_query_options.rs index 96a1ce29cf8da..4a390a5af082b 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/base_query_options.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/base_query_options.rs @@ -8,7 +8,8 @@ use typed_builder::TypedBuilder; use crate::{ cube_bridge::{ base_query_options::{ - BaseQueryOptions, BaseQueryOptionsStatic, FilterItem, OrderByItem, TimeDimension, + BaseQueryOptions, BaseQueryOptionsStatic, FilterItem, MaskedMemberItem, OrderByItem, + TimeDimension, }, base_tools::BaseTools, evaluator::CubeEvaluator, @@ -71,7 +72,7 @@ pub struct MockBaseQueryOptions { #[builder(default)] convert_tz_for_raw_time_dimension: Option, #[builder(default)] - masked_members: Option>, + masked_members: Option>, #[builder(default)] member_to_alias: Option>, } diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/yaml/base_query_options.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/yaml/base_query_options.rs index 0de5d60d4efb1..f00e3e2abe179 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/yaml/base_query_options.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/cube_bridge/yaml/base_query_options.rs @@ -1,4 +1,6 @@ -use crate::cube_bridge::base_query_options::{FilterItem, OrderByItem, TimeDimension}; +use crate::cube_bridge::base_query_options::{ + FilterItem, MaskedMemberItem, OrderByItem, TimeDimension, +}; use serde::de; use serde::{Deserialize, Deserializer}; use std::collections::HashMap; @@ -44,7 +46,7 @@ pub struct YamlBaseQueryOptions { #[serde(default)] pub timezone: Option, #[serde(default, rename = "maskedMembers")] - pub masked_members: Option>, + pub masked_members: Option>, } #[derive(Debug, Deserialize)] diff --git a/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/test_utils/test_context.rs b/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/test_utils/test_context.rs index e3cc23bb9a678..d3be16a4ae35c 100644 --- a/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/test_utils/test_context.rs +++ b/rust/cube/cubesqlplanner/cubesqlplanner/src/test_fixtures/test_utils/test_context.rs @@ -1,4 +1,4 @@ -use crate::cube_bridge::base_query_options::BaseQueryOptions; +use crate::cube_bridge::base_query_options::{BaseQueryOptions, MaskedMemberItem}; use crate::cube_bridge::join_hints::JoinHintItem; use crate::logical_plan::PreAggregationUsage; #[cfg(feature = "integration-postgres")] @@ -78,7 +78,14 @@ impl TestContext { schema: MockSchema, masked_members: Vec, ) -> Result { - Self::new_with_options(schema, Tz::UTC, Some(masked_members), None, false, false) + let items: Vec = masked_members + .into_iter() + .map(|member| MaskedMemberItem { + member, + filter: None, + }) + .collect(); + Self::new_with_options(schema, Tz::UTC, Some(items), None, false, false) } fn for_options(&self, options: &dyn BaseQueryOptions) -> Result { @@ -104,7 +111,7 @@ impl TestContext { fn new_with_options( schema: MockSchema, timezone: Tz, - masked_members: Option>, + masked_members: Option>, member_to_alias: Option>, export_annotated_sql: bool, convert_tz_for_raw_time_dimension: bool,