diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index ca7cdb6aac8705..d55d47f4cb2a37 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -30,6 +30,7 @@ import org.apache.doris.nereids.rules.expression.rules.NestedCaseWhenCondToLiteral; import org.apache.doris.nereids.rules.expression.rules.NullSafeEqualToEqual; import org.apache.doris.nereids.rules.expression.rules.PushIntoCaseWhenBranch; +import org.apache.doris.nereids.rules.expression.rules.RegexpFunctionRewrite; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; import org.apache.doris.nereids.rules.expression.rules.SimplifyConflictCompound; import org.apache.doris.nereids.rules.expression.rules.SimplifyInPredicate; @@ -70,6 +71,7 @@ public class ExpressionOptimization extends ExpressionRewrite { PushIntoCaseWhenBranch.INSTANCE, NullSafeEqualToEqual.INSTANCE, LikeToEqualRewrite.INSTANCE, + RegexpFunctionRewrite.INSTANCE, BetweenToEqual.INSTANCE, StringEmptyToLengthRule.INSTANCE ) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java index 726026cccbf869..914cb56bacb09d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java @@ -51,6 +51,7 @@ public enum ExpressionRuleType { NULL_SAFE_EQUAL_TO_EQUAL, PUSH_INTO_CASE_WHEN_BRANCH, REPLACE_VARIABLE_BY_LITERAL, + REGEXP_FUNCTION_REWRITE, SIMPLIFY_ARITHMETIC_COMPARISON, SIMPLIFY_ARITHMETIC, SIMPLIFY_CAST, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RegexpFunctionRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RegexpFunctionRewrite.java new file mode 100644 index 00000000000000..5aa4b8d4d350a7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RegexpFunctionRewrite.java @@ -0,0 +1,217 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplace; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplaceOne; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * Rewrites regexp functions to cheaper equivalent forms when the regexp shape proves the rewrite is safe. + */ +public class RegexpFunctionRewrite implements ExpressionPatternRuleFactory { + public static final RegexpFunctionRewrite INSTANCE = new RegexpFunctionRewrite(); + + @Override + public List> buildRules() { + return ImmutableList.of( + matchesType(RegexpReplace.class) + .then(RegexpFunctionRewrite::rewriteRegexpReplace) + .toRule(ExpressionRuleType.REGEXP_FUNCTION_REWRITE), + matchesType(RegexpExtract.class) + .then(RegexpFunctionRewrite::rewriteRegexpExtract) + .toRule(ExpressionRuleType.REGEXP_FUNCTION_REWRITE) + ); + } + + private static Expression rewriteRegexpReplace(RegexpReplace regexpReplace) { + String pattern = getStringLiteral(regexpReplace.child(1)); + if (pattern == null || pattern.isEmpty()) { + return regexpReplace; + } + if (!startsWithUnescapedCaret(pattern) && !endsWithUnescapedDollar(pattern)) { + return regexpReplace; + } + if (hasUnescapedAlternation(pattern) || hasInlineRegexpFlag(pattern, 'm')) { + return regexpReplace; + } + + if (regexpReplace.arity() == 3) { + return new RegexpReplaceOne(regexpReplace.child(0), regexpReplace.child(1), regexpReplace.child(2)); + } + return new RegexpReplaceOne(regexpReplace.child(0), regexpReplace.child(1), regexpReplace.child(2), + regexpReplace.child(3)); + } + + private static Expression rewriteRegexpExtract(RegexpExtract regexpExtract) { + String pattern = getStringLiteral(regexpExtract.child(1)); + if (pattern == null || pattern.isEmpty() || !isPositiveGroupIndex(regexpExtract.child(2)) + || !hasCapturingGroup(pattern) || hasUnescapedAlternation(pattern) + || hasInlineRegexpFlag(pattern, 's')) { + return regexpExtract; + } + + String trimmedPattern = trimExtractPattern(pattern); + if (trimmedPattern.equals(pattern)) { + return regexpExtract; + } + return new RegexpExtract(regexpExtract.child(0), new VarcharLiteral(trimmedPattern), regexpExtract.child(2)); + } + + private static String trimExtractPattern(String pattern) { + String trimmed = pattern; + if (endsWithUnescapedDotStarDollar(trimmed)) { + trimmed = trimmed.substring(0, trimmed.length() - 3); + } + return trimmed; + } + + private static String getStringLiteral(Expression expression) { + if (!(expression instanceof Literal) || !expression.getDataType().isStringLikeType()) { + return null; + } + return ((Literal) expression).getStringValue(); + } + + private static boolean isPositiveGroupIndex(Expression expression) { + return expression instanceof IntegerLikeLiteral && ((IntegerLikeLiteral) expression).getLongValue() >= 1; + } + + private static boolean startsWithUnescapedCaret(String pattern) { + return !pattern.isEmpty() && pattern.charAt(0) == '^'; + } + + private static boolean endsWithUnescapedDollar(String pattern) { + int dollarPos = pattern.length() - 1; + return dollarPos >= 0 && pattern.charAt(dollarPos) == '$' + && isUnescaped(pattern, dollarPos) && !isInCharClass(pattern, dollarPos); + } + + private static boolean endsWithUnescapedDotStarDollar(String pattern) { + if (pattern.length() < 3 || !pattern.endsWith(".*$")) { + return false; + } + int dotPos = pattern.length() - 3; + return isUnescaped(pattern, dotPos) && !isInCharClass(pattern, dotPos); + } + + private static boolean isUnescaped(String pattern, int pos) { + int backslashCount = 0; + for (int i = pos - 1; i >= 0 && pattern.charAt(i) == '\\'; i--) { + backslashCount++; + } + return backslashCount % 2 == 0; + } + + private static boolean hasUnescapedAlternation(String pattern) { + boolean inCharClass = false; + for (int i = 0; i < pattern.length(); i++) { + char ch = pattern.charAt(i); + if (!isUnescaped(pattern, i)) { + continue; + } + if (ch == '[') { + inCharClass = true; + } else if (ch == ']' && inCharClass) { + inCharClass = false; + } else if (ch == '|' && !inCharClass) { + return true; + } + } + return false; + } + + private static boolean hasCapturingGroup(String pattern) { + boolean inCharClass = false; + for (int i = 0; i < pattern.length(); i++) { + char ch = pattern.charAt(i); + if (!isUnescaped(pattern, i)) { + continue; + } + if (ch == '[') { + inCharClass = true; + } else if (ch == ']' && inCharClass) { + inCharClass = false; + } else if (ch == '(' && !inCharClass && (i + 1 >= pattern.length() || pattern.charAt(i + 1) != '?')) { + return true; + } + } + return false; + } + + private static boolean hasInlineRegexpFlag(String pattern, char targetFlag) { + boolean inCharClass = false; + for (int i = 0; i < pattern.length() - 2; i++) { + char ch = pattern.charAt(i); + if (!isUnescaped(pattern, i)) { + continue; + } + if (ch == '[') { + inCharClass = true; + } else if (ch == ']' && inCharClass) { + inCharClass = false; + } else if (ch == '(' && !inCharClass && pattern.charAt(i + 1) == '?' + && isInlineFlagChar(pattern.charAt(i + 2))) { + for (int j = i + 2; j < pattern.length(); j++) { + char flag = pattern.charAt(j); + if (flag == ':' || flag == ')') { + break; + } + if (flag == targetFlag) { + return true; + } + if (!isInlineFlagChar(flag)) { + break; + } + } + } + } + return false; + } + + private static boolean isInlineFlagChar(char ch) { + return ch == 'i' || ch == 'm' || ch == 's' || ch == 'U' || ch == '-'; + } + + private static boolean isInCharClass(String pattern, int pos) { + boolean inCharClass = false; + for (int i = 0; i < pos; i++) { + char ch = pattern.charAt(i); + if (!isUnescaped(pattern, i)) { + continue; + } + if (ch == '[') { + inCharClass = true; + } else if (ch == ']' && inCharClass) { + inCharClass = false; + } + } + return inCharClass; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/RegexpFunctionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/RegexpFunctionRewriteTest.java new file mode 100644 index 00000000000000..86d758d59ceffd --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/RegexpFunctionRewriteTest.java @@ -0,0 +1,141 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionRewriteTestHelper; +import org.apache.doris.nereids.rules.expression.ExpressionRuleExecutor; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpExtract; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplace; +import org.apache.doris.nereids.trees.expressions.functions.scalar.RegexpReplaceOne; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import org.apache.doris.nereids.types.StringType; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class RegexpFunctionRewriteTest extends ExpressionRewriteTestHelper { + private SlotReference str; + + @BeforeEach + public void setup() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp(RegexpFunctionRewrite.INSTANCE) + )); + str = new SlotReference("str", StringType.INSTANCE, true); + } + + @Test + public void testRewriteAnchoredRegexpReplace() { + RegexpReplace before = new RegexpReplace(str, new VarcharLiteral("^https?://(?:www\\.)?([^/]+)/.*$"), + new VarcharLiteral("\\1")); + RegexpReplaceOne expected = new RegexpReplaceOne(str, + new VarcharLiteral("^https?://(?:www\\.)?([^/]+)/.*$"), new VarcharLiteral("\\1")); + assertRuleRewrite(before, expected); + } + + @Test + public void testRewriteAnchoredRegexpReplaceWithOptions() { + RegexpReplace before = new RegexpReplace(str, new VarcharLiteral("^abc"), new VarcharLiteral("x"), + new VarcharLiteral("ignore_invalid_escape")); + RegexpReplaceOne expected = new RegexpReplaceOne(str, new VarcharLiteral("^abc"), new VarcharLiteral("x"), + new VarcharLiteral("ignore_invalid_escape")); + assertRuleRewrite(before, expected); + } + + @Test + public void testDoNotRewriteRegexpReplaceWithAlternation() { + RegexpReplace before = new RegexpReplace(str, new VarcharLiteral("^a|b"), new VarcharLiteral("x")); + assertRuleNoRewrite(before); + } + + @Test + public void testDoNotRewriteRegexpReplaceWithEscapedDollar() { + RegexpReplace before = new RegexpReplace(str, new VarcharLiteral("a\\$"), new VarcharLiteral("x")); + assertRuleNoRewrite(before); + } + + @Test + public void testDoNotRewriteRegexpReplaceWithInlineMultilineFlag() { + RegexpReplace before = new RegexpReplace(str, new VarcharLiteral("(?m)a$"), new VarcharLiteral("x")); + assertRuleNoRewrite(before); + } + + @Test + public void testRewriteRegexpExtractTrimSuffix() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("^.*(abc).*$"), new BigIntLiteral(1)); + RegexpExtract expected = new RegexpExtract(str, new VarcharLiteral("^.*(abc)"), new BigIntLiteral(1)); + assertRuleRewrite(before, expected); + } + + @Test + public void testRewriteRegexpExtractTrimSuffixOnly() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("^([a-z]+).*$"), new BigIntLiteral(1)); + RegexpExtract expected = new RegexpExtract(str, new VarcharLiteral("^([a-z]+)"), new BigIntLiteral(1)); + assertRuleRewrite(before, expected); + } + + @Test + public void testDoNotRewriteRegexpExtractGroupZero() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("^.*(abc).*$"), new BigIntLiteral(0)); + assertRuleNoRewrite(before); + } + + @Test + public void testDoNotRewriteRegexpExtractWithoutCapture() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("^.*abc.*$"), new BigIntLiteral(1)); + assertRuleNoRewrite(before); + } + + @Test + public void testRewriteRegexpExtractLazyPrefixSuffixOnly() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("^.*?(abc).*$"), new BigIntLiteral(1)); + RegexpExtract expected = new RegexpExtract(str, new VarcharLiteral("^.*?(abc)"), new BigIntLiteral(1)); + assertRuleRewrite(before, expected); + } + + @Test + public void testDoNotRewriteRegexpExtractEscapedDotSuffix() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("^(a)\\.*$"), new BigIntLiteral(1)); + assertRuleNoRewrite(before); + } + + @Test + public void testDoNotRewriteRegexpExtractWithAlternation() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("(a)|(b).*$"), new BigIntLiteral(1)); + assertRuleNoRewrite(before); + } + + @Test + public void testDoNotRewriteRegexpExtractWithInlineDotFlag() { + RegexpExtract before = new RegexpExtract(str, new VarcharLiteral("(?-s)^(a).*$"), new BigIntLiteral(1)); + assertRuleNoRewrite(before); + } + + private void assertRuleRewrite(Expression before, Expression expected) { + Assertions.assertEquals(expected, executor.rewrite(before, context)); + } + + private void assertRuleNoRewrite(Expression before) { + Assertions.assertEquals(before, executor.rewrite(before, context)); + } +} diff --git a/regression-test/suites/nereids_rules_p0/expression/test_regexp_function_rewrite.groovy b/regression-test/suites/nereids_rules_p0/expression/test_regexp_function_rewrite.groovy new file mode 100644 index 00000000000000..7c5de3049a4ce8 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/expression/test_regexp_function_rewrite.groovy @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_regexp_function_rewrite") { + sql "DROP TABLE IF EXISTS test_regexp_function_rewrite" + sql """ + CREATE TABLE test_regexp_function_rewrite ( + id INT, + url STRING + ) + DUPLICATE KEY(id) + DISTRIBUTED BY HASH(id) BUCKETS 1 + PROPERTIES("replication_num" = "1") + """ + + sql """ + INSERT INTO test_regexp_function_rewrite VALUES + (1, 'https://www.example.com/path/index.html'), + (2, 'http://doris.apache.org/docs'), + (3, 'https://sub.domain.org/'), + (4, 'ftp://www.example.com/path'), + (5, 'https://www.example.com'), + (6, NULL) + """ + + def shapeEnabled = sql """ + EXPLAIN VERBOSE + SELECT regexp_replace(url, '^https?://(?:www\\\\.)?([^/]+)/.*\$', '\\\\1') + FROM test_regexp_function_rewrite + """ + assertTrue(shapeEnabled.toString().contains("regexp_replace_one")) + + def inlineFlagShapeEnabled = sql """ + EXPLAIN VERBOSE + SELECT + regexp_replace(concat(url, char(10), url), '(?m).*\$', 'x'), + regexp_extract(concat(url, char(10), 'tail'), '(?-s)^([^/]+).*\$', 1) + FROM test_regexp_function_rewrite + """ + assertTrue(!inlineFlagShapeEnabled.toString().contains("regexp_replace_one")) + + def resultEnabled = sql """ + SELECT + id, + regexp_replace(url, '^https?://(?:www\\\\.)?([^/]+)/.*\$', '\\\\1') AS host_by_replace, + regexp_replace(url, '.*\\\\.org\$', 'ORG') AS org_suffix, + regexp_extract(url, '^https?://(?:www\\\\.)?([^/]+).*\$', 1) AS host_by_extract, + regexp_replace(concat(url, char(10), url), '(?m).*\$', 'x') AS multiline_replace, + regexp_extract(concat(url, char(10), 'tail'), '(?-s)^([^/]+).*\$', 1) AS dot_mode_extract + FROM test_regexp_function_rewrite + ORDER BY id + """ + + sql "SET disable_nereids_expression_rules='REGEXP_FUNCTION_REWRITE'" + + def shapeDisabled = sql """ + EXPLAIN VERBOSE + SELECT regexp_replace(url, '^https?://(?:www\\\\.)?([^/]+)/.*\$', '\\\\1') + FROM test_regexp_function_rewrite + """ + assertTrue(shapeDisabled.toString().contains("regexp_replace")) + assertTrue(!shapeDisabled.toString().contains("regexp_replace_one")) + + def resultDisabled = sql """ + SELECT + id, + regexp_replace(url, '^https?://(?:www\\\\.)?([^/]+)/.*\$', '\\\\1') AS host_by_replace, + regexp_replace(url, '.*\\\\.org\$', 'ORG') AS org_suffix, + regexp_extract(url, '^https?://(?:www\\\\.)?([^/]+).*\$', 1) AS host_by_extract, + regexp_replace(concat(url, char(10), url), '(?m).*\$', 'x') AS multiline_replace, + regexp_extract(concat(url, char(10), 'tail'), '(?-s)^([^/]+).*\$', 1) AS dot_mode_extract + FROM test_regexp_function_rewrite + ORDER BY id + """ + assertEquals(resultEnabled, resultDisabled) +}