diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java index dfaef55b56c964..4a1821abe6dc6b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnion.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSetOperation; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; @@ -60,7 +61,12 @@ public static boolean canPushProject(List projects, LogicalSetO if (e instanceof SlotReference) { return true; } else { - Expression expr = ExpressionUtils.getExpressionCoveredByCast(e.child(0)); + Expression expr; + if (logicalSetOperation.getQualifier().equals(Qualifier.ALL)) { + expr = ExpressionUtils.getExpressionCoveredByCast(e.child(0)); + } else { + expr = ExpressionUtils.getExpressionCoveredBySafetyCast(e.child(0)); + } return expr instanceof SlotReference; } }); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java index 084756303b64fc..015690e64348ce 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/ArrayType.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.types; import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.ComplexDataType; import java.util.Objects; @@ -57,6 +58,14 @@ public DataType getItemType() { return itemType; } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof ArrayType) { + return itemType.isInjectiveCastTo(((ArrayType) target).itemType); + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { // Catalog ArrayType defaults containsNull to true via single-arg constructor diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java index 49b2a6e72d7c1d..708801a883f32b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/BooleanType.java @@ -31,6 +31,14 @@ public class BooleanType extends PrimitiveType { private BooleanType() { } + @Override + public boolean isInjectiveCastTo(DataType target) { + return target instanceof BooleanType || target.isIntegralType() || target.isFloatLikeType() + || (target instanceof DecimalV2Type && ((DecimalV2Type) target).getRange() >= 1) + || (target instanceof DecimalV3Type && ((DecimalV3Type) target).getRange() >= 1) + || target.isStringLikeType(); + } + @Override public Type toCatalogDataType() { return Type.BOOLEAN; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java index 8ded09c4b63c75..14295979dae637 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DataType.java @@ -813,6 +813,10 @@ public List getAllPromotions() { public abstract int width(); + public boolean isInjectiveCastTo(DataType target) { + return this.equals(target); + } + public static List trivialTypes() { return Type.getTrivialTypes() .stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java index 8a0250d7b447c9..a93bfda0364477 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeType.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.Config; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.DateLikeType; import java.time.DateTimeException; @@ -45,6 +46,14 @@ private DateTimeType(boolean shouldConversion) { this.shouldConversion = shouldConversion; } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof DateTimeType || target instanceof DateTimeV2Type || target instanceof CharacterType) { + return true; + } + return false; + } + @Override public DataType conversion() { if (Config.enable_date_conversion && shouldConversion) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java index f56b4662f8ba9c..13097339554e3e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateTimeV2Type.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.format.DateTimeChecker; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.DateLikeType; import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.ScaleTimeType; @@ -128,6 +129,18 @@ public String toSql() { return super.toSql() + "(" + scale + ")"; } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof DateTimeV2Type) { + DateTimeV2Type t2 = (DateTimeV2Type) target; + return this.scale <= t2.scale; + } + if (target instanceof DateTimeType) { + return this.scale == 0; + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return ScalarType.createDatetimeV2Type(scale); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java index d127ab16069a16..c6ce702ebe75f5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DateType.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.common.Config; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.DateLikeType; import java.time.DateTimeException; @@ -45,6 +46,11 @@ private DateType(boolean shouldConversion) { this.shouldConversion = shouldConversion; } + @Override + public boolean isInjectiveCastTo(DataType target) { + return target instanceof DateType || target instanceof DateV2Type || target instanceof CharacterType; + } + @Override public DataType conversion() { if (Config.enable_date_conversion && shouldConversion) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java index b601aaa9f133ff..b055172f2629fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV2Type.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.ScalarType; import org.apache.doris.catalog.Type; import org.apache.doris.common.Config; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.FractionalType; import com.google.common.base.Preconditions; @@ -159,6 +160,19 @@ private static DecimalV2Type widerDecimalV2Type( return DecimalV2Type.createDecimalV2Type(range + scale, scale); } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof DecimalV2Type) { + DecimalV2Type decimalV2Type = (DecimalV2Type) target; + return decimalV2Type.getRange() >= this.getRange() && decimalV2Type.getScale() >= this.getScale(); + } + if (target instanceof DecimalV3Type) { + DecimalV3Type decimalV3Type = (DecimalV3Type) target; + return decimalV3Type.getRange() >= this.getRange() && decimalV3Type.getScale() >= this.getScale(); + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return ScalarType.createDecimalType(PrimitiveType.DECIMALV2, precision, scale); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java index 3c0a83e95c4b56..b366568cb35933 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/DecimalV3Type.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.exceptions.NotSupportedException; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.FractionalType; import org.apache.doris.qe.ConnectContext; import org.apache.doris.qe.SessionVariable; @@ -214,6 +215,19 @@ private static DataType widerDecimalV3Type( } } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof DecimalV2Type) { + DecimalV2Type decimalV2Type = (DecimalV2Type) target; + return decimalV2Type.getRange() >= this.getRange() && decimalV2Type.getScale() >= this.getScale(); + } + if (target instanceof DecimalV3Type) { + DecimalV3Type decimalV3Type = (DecimalV3Type) target; + return decimalV3Type.getRange() >= this.getRange() && decimalV3Type.getScale() >= this.getScale(); + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return ScalarType.createDecimalV3Type(precision, scale); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java index 176c1db1d0d6b1..fc6e9ba2f94060 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/MapType.java @@ -19,6 +19,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.ComplexDataType; import java.util.Objects; @@ -63,6 +64,15 @@ public DataType conversion() { return MapType.of(keyType.conversion(), valueType.conversion()); } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof MapType) { + MapType mapType = (MapType) target; + return keyType.isInjectiveCastTo(mapType.keyType) && valueType.isInjectiveCastTo(mapType.valueType); + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return new org.apache.doris.catalog.MapType(keyType.toCatalogDataType(), valueType.toCatalogDataType()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java index 0c33a6d2decb1a..13f28c2e06e986 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/StructType.java @@ -20,6 +20,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.ComplexDataType; import com.google.common.collect.ImmutableList; @@ -84,6 +85,23 @@ public DataType conversion() { return new StructType(fields.stream().map(StructField::conversion).collect(Collectors.toList())); } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof StructType) { + StructType structType = (StructType) target; + if (this.fields.size() != structType.fields.size()) { + return false; + } + for (int i = 0; i < fields.size(); i++) { + if (!this.fields.get(i).getDataType().isInjectiveCastTo(structType.fields.get(i).getDataType())) { + return false; + } + } + return true; + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return new org.apache.doris.catalog.StructType(fields.stream() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java index c3c99cad6fccf9..4f9c09b5f7f805 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeStampTzType.java @@ -21,6 +21,7 @@ import org.apache.doris.catalog.Type; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.DateLikeType; import org.apache.doris.nereids.types.coercion.ScaleTimeType; @@ -46,6 +47,15 @@ private TimeStampTzType(int scale) { this.scale = scale; } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof TimeStampTzType) { + TimeStampTzType timeStampTzType = (TimeStampTzType) target; + return timeStampTzType.getScale() >= this.scale; + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return ScalarType.createTimeStampTzType(scale); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java index 39f420e6931032..af758c02cb522b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/TimeV2Type.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.TimeV2Literal; +import org.apache.doris.nereids.types.coercion.CharacterType; import org.apache.doris.nereids.types.coercion.IntegralType; import org.apache.doris.nereids.types.coercion.PrimitiveType; import org.apache.doris.nereids.types.coercion.RangeScalable; @@ -48,6 +49,15 @@ private TimeV2Type() { scale = 0; } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof TimeV2Type) { + TimeV2Type timeV2Type = (TimeV2Type) target; + return timeV2Type.scale >= scale; + } + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { return ScalarType.createTimeV2Type(scale); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java index 0a745d9cb7f5bb..1ba69fb4f50c9d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/VariantType.java @@ -116,6 +116,11 @@ public VariantType(List fields, int variantMaxSubcolumnsCount, this.enableNestedGroup = enableNestedGroup; } + @Override + public boolean isInjectiveCastTo(DataType target) { + return target.equals(this) || target instanceof VariantType; + } + @Override public DataType conversion() { return new VariantType(predefinedFields.stream().map(VariantField::conversion) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java index 781b125702837c..3d8590534f53b2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/CharacterType.java @@ -42,6 +42,11 @@ public int getLen() { return len; } + @Override + public boolean isInjectiveCastTo(DataType target) { + return target instanceof CharacterType; + } + @Override public Type toCatalogDataType() { throw new RuntimeException("CharacterType is only used for implicit cast."); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java index b1e588053881eb..fe625fa34bc589 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/types/coercion/IntegralType.java @@ -19,6 +19,8 @@ import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; +import org.apache.doris.nereids.types.LargeIntType; import org.apache.commons.lang3.NotImplementedException; @@ -44,6 +46,19 @@ public String simpleString() { return "integral"; } + @Override + public boolean isInjectiveCastTo(DataType target) { + if (target instanceof IntegralType) { + return this.equals(target) || ((IntegralType) target).widerThan(this); + } + if (target instanceof DecimalV3Type && !(this instanceof LargeIntType)) { + DecimalV3Type other = (DecimalV3Type) target; + DecimalV3Type self = DecimalV3Type.forType(this); + return other.getRange() >= self.getRange(); + } + return target instanceof CharacterType; + } + public boolean widerThan(IntegralType other) { return this.width() > other.width(); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 5dcd65d69c659a..f2898d43e973cc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -1133,6 +1133,20 @@ public static Expression getExpressionCoveredByCast(Expression expression) { return expression; } + /** + * Strip only casts that preserve distinctness of the child expression. + */ + public static Expression getExpressionCoveredBySafetyCast(Expression expression) { + while (expression instanceof Cast) { + if (((Cast) expression).child().getDataType().isInjectiveCastTo(expression.getDataType())) { + expression = ((Cast) expression).child(); + } else { + break; + } + } + return expression; + } + /** * the expressions can be used as runtime filter targets */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java index 328c390d52fc7d..157b7827a249f8 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushProjectThroughUnionTest.java @@ -28,6 +28,8 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalUnion; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.util.MemoTestUtils; import org.apache.doris.nereids.util.PlanChecker; @@ -107,6 +109,37 @@ public void testConstantExprIdsDistinctFromUnionOutput() { } } + @Test + public void testCastProjectPushThroughUnionByQualifierAndSafety() { + SlotReference unionOutput = new SlotReference(new ExprId(10), "s", + IntegerType.INSTANCE, true, ImmutableList.of()); + Alias castProject = new Alias(new ExprId(100), + new Cast(unionOutput, BigIntType.INSTANCE), "n"); + ImmutableList projects = ImmutableList.of(castProject); + + LogicalUnion unionAll = new LogicalUnion(Qualifier.ALL, + ImmutableList.of(unionOutput), ImmutableList.of(), ImmutableList.of(), false, ImmutableList.of()); + Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects, unionAll)); + + LogicalUnion unionDistinct = new LogicalUnion(Qualifier.DISTINCT, + ImmutableList.of(unionOutput), ImmutableList.of(), ImmutableList.of(), false, ImmutableList.of()); + Assertions.assertTrue(PushProjectThroughUnion.canPushProject(projects, unionDistinct)); + + SlotReference dateTimeOutput = new SlotReference(new ExprId(11), "dt", + DateTimeType.INSTANCE, true, ImmutableList.of()); + Alias unsafeCastProject = new Alias(new ExprId(101), + new Cast(dateTimeOutput, DateType.INSTANCE), "d"); + ImmutableList unsafeProjects = ImmutableList.of(unsafeCastProject); + + LogicalUnion unionAllWithUnsafeCast = new LogicalUnion(Qualifier.ALL, + ImmutableList.of(dateTimeOutput), ImmutableList.of(), ImmutableList.of(), false, ImmutableList.of()); + Assertions.assertTrue(PushProjectThroughUnion.canPushProject(unsafeProjects, unionAllWithUnsafeCast)); + + LogicalUnion unionDistinctWithUnsafeCast = new LogicalUnion(Qualifier.DISTINCT, + ImmutableList.of(dateTimeOutput), ImmutableList.of(), ImmutableList.of(), false, ImmutableList.of()); + Assertions.assertFalse(PushProjectThroughUnion.canPushProject(unsafeProjects, unionDistinctWithUnsafeCast)); + } + private LogicalUnion findUnion(Plan p) { if (p instanceof LogicalUnion) { return (LogicalUnion) p; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java index 59509fce805848..9720fa1c5502a5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/types/DataTypeTest.java @@ -138,6 +138,62 @@ void testConvertFromString() { } + @Test + public void testIsInjectiveCastToForPrimitiveTypes() { + assertSafeCast(IntegerType.INSTANCE, IntegerType.INSTANCE); + assertSafeCast(IntegerType.INSTANCE, BigIntType.INSTANCE); + assertUnsafeCast(BigIntType.INSTANCE, IntegerType.INSTANCE); + assertSafeCast(IntegerType.INSTANCE, DecimalV3Type.createDecimalV3Type(10, 0)); + assertUnsafeCast(IntegerType.INSTANCE, DecimalV3Type.createDecimalV3Type(9, 0)); + assertUnsafeCast(LargeIntType.INSTANCE, DecimalV3Type.createDecimalV3Type(38, 0)); + + assertSafeCast(BooleanType.INSTANCE, DecimalV3Type.createDecimalV3Type(1, 0)); + assertUnsafeCast(BooleanType.INSTANCE, DecimalV3Type.createDecimalV3Type(1, 1)); + + assertSafeCast(DecimalV3Type.createDecimalV3Type(6, 2), DecimalV3Type.createDecimalV3Type(8, 3)); + assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2), DecimalV3Type.createDecimalV3Type(6, 1)); + assertUnsafeCast(DecimalV3Type.createDecimalV3Type(6, 2), DecimalV3Type.createDecimalV3Type(5, 2)); + + assertSafeCast(DateTimeType.INSTANCE, DateTimeV2Type.of(0)); + assertSafeCast(DateTimeV2Type.of(0), DateTimeType.INSTANCE); + assertSafeCast(DateTimeV2Type.of(3), DateTimeV2Type.of(6)); + assertUnsafeCast(DateTimeV2Type.of(3), DateTimeType.INSTANCE); + assertUnsafeCast(DateTimeType.INSTANCE, DateType.INSTANCE); + + assertSafeCast(VarcharType.createVarcharType(10), VarcharType.createVarcharType(20)); + assertSafeCast(VarcharType.createVarcharType(10), StringType.INSTANCE); + assertSafeCast(VarcharType.createVarcharType(20), VarcharType.createVarcharType(10)); + assertSafeCast(StringType.INSTANCE, VarcharType.createVarcharType(10)); + } + + @Test + public void testIsInjectiveCastToForComplexTypes() { + assertSafeCast(ArrayType.of(IntegerType.INSTANCE), ArrayType.of(BigIntType.INSTANCE)); + assertUnsafeCast(ArrayType.of(BigIntType.INSTANCE), ArrayType.of(IntegerType.INSTANCE)); + + assertSafeCast(MapType.of(IntegerType.INSTANCE, VarcharType.createVarcharType(10)), + MapType.of(BigIntType.INSTANCE, StringType.INSTANCE)); + assertUnsafeCast(MapType.of(BigIntType.INSTANCE, VarcharType.createVarcharType(10)), + MapType.of(IntegerType.INSTANCE, StringType.INSTANCE)); + + StructType intStringStruct = new StructType(ImmutableList.of( + new StructField("a", IntegerType.INSTANCE, true, ""), + new StructField("b", VarcharType.createVarcharType(10), true, ""))); + StructType bigintStringStruct = new StructType(ImmutableList.of( + new StructField("a", BigIntType.INSTANCE, true, ""), + new StructField("b", StringType.INSTANCE, true, ""))); + StructType intOnlyStruct = new StructType(ImmutableList.of( + new StructField("a", IntegerType.INSTANCE, true, ""))); + + assertSafeCast(intStringStruct, bigintStringStruct); + assertUnsafeCast(bigintStringStruct, intStringStruct); + assertUnsafeCast(intOnlyStruct, intStringStruct); + + assertSafeCast(ArrayType.of(IntegerType.INSTANCE), StringType.INSTANCE); + assertSafeCast(MapType.of(IntegerType.INSTANCE, StringType.INSTANCE), StringType.INSTANCE); + assertSafeCast(intStringStruct, StringType.INSTANCE); + } + @Test public void testAnyAccept() { AnyDataType dateType = AnyDataType.INSTANCE_WITHOUT_INDEX; @@ -654,4 +710,14 @@ public void testDeepValidComplexNestingAccepted() { DataType type = ArrayType.of(MapType.of(VarcharType.SYSTEM_DEFAULT, IntegerType.INSTANCE)); Assertions.assertDoesNotThrow(type::validateDataType); } + + private void assertSafeCast(DataType source, DataType target) { + Assertions.assertTrue(source.isInjectiveCastTo(target), source.toSql() + " should safely cast to " + + target.toSql()); + } + + private void assertUnsafeCast(DataType source, DataType target) { + Assertions.assertFalse(source.isInjectiveCastTo(target), source.toSql() + " should not safely cast to " + + target.toSql()); + } } diff --git a/regression-test/data/nereids_syntax_p0/set_operation.out b/regression-test/data/nereids_syntax_p0/set_operation.out index 5afc4fac2adec7..c649e9e3b05a20 100644 --- a/regression-test/data/nereids_syntax_p0/set_operation.out +++ b/regression-test/data/nereids_syntax_p0/set_operation.out @@ -592,6 +592,10 @@ hell0 -- !union45 -- 2 +-- !union46 -- +2020-01-01 +2020-01-01 + -- !check_child_col_order -- 205548764.21875 3601 53950855.65625 3602 @@ -599,4 +603,3 @@ hell0 -- !intersect_case -- 0 1 - diff --git a/regression-test/suites/nereids_syntax_p0/set_operation.groovy b/regression-test/suites/nereids_syntax_p0/set_operation.groovy index 360dfd55d937c6..09282213b842d4 100644 --- a/regression-test/suites/nereids_syntax_p0/set_operation.groovy +++ b/regression-test/suites/nereids_syntax_p0/set_operation.groovy @@ -291,6 +291,17 @@ suite("set_operation") { select count(*) from (select 1, 2 union select 1,1 ) a; """ + // do not push non-injective cast project below UNION DISTINCT. + // The two datetime values are distinct before the outer cast, but become + // equal after casting to date. The correct result keeps both rows. + order_qt_union46 """ + select cast(dt as date) from ( + select cast('2020-01-01 00:00:00' as datetime) dt + union + select cast('2020-01-01 01:00:00' as datetime) dt + ) t + """ + def tables = [ "dwd_daytable", ]