diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index d6beacadbb674..1b8afc1511081 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst.plans.logical +import scala.collection.mutable + import org.apache.spark.{SparkException, SparkIllegalArgumentException, SparkUnsupportedOperationException} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis.{AnalysisContext, AssignmentUtils, EliminateSubqueryAliases, FieldName, NamedRelation, PartitionSpec, ResolvedIdentifier, ResolvedProcedure, ResolveSchemaEvolution, TypeCheckResult, UnresolvedAttribute, UnresolvedException, UnresolvedProcedure, ViewSchemaMode} @@ -38,9 +40,10 @@ import org.apache.spark.sql.connector.expressions.filter.Predicate import org.apache.spark.sql.connector.write.{DeltaWrite, RowLevelOperation, RowLevelOperationTable, SupportsDelta, Write} import org.apache.spark.sql.connector.write.RowLevelOperation.Command.{DELETE, MERGE, UPDATE} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType +import org.apache.spark.sql.errors.QueryExecutionErrors import org.apache.spark.sql.execution.datasources.v2.{DataSourceV2Relation, ExtractV2Table} -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructType} +import org.apache.spark.sql.types.{ArrayType, AtomicType, BooleanType, DataType, IntegerType, MapType, MetadataBuilder, StringType, StructField, StructType} +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.ArrayImplicits._ import org.apache.spark.util.Utils @@ -1004,11 +1007,15 @@ case class MergeIntoTable( case _ => false } + /** + * Catalog changes for MERGE auto schema evolution, produced from UPDATE/INSERT assignments. + * + * Unlike INSERT evolution (struct diff of table vs query), MERGE uses assignment-driven + * [[TableChange]]s from [[MergeIntoTable.computePendingSchemaChanges]]. + */ override lazy val pendingSchemaChanges: Seq[TableChange] = { if (schemaEvolutionEnabled && schemaEvolutionReady) { - val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this) - ResolveSchemaEvolution.computeSchemaChanges( - targetTable.schema, referencedSourceSchema, isByName = true).toSeq + MergeIntoTable.computePendingSchemaChanges(this) } else { Seq.empty } @@ -1062,52 +1069,168 @@ object MergeIntoTable { .toSet } - // A pruned version of source schema that only contains columns/nested fields - // explicitly and directly assigned to a target counterpart in MERGE INTO actions, - // which are relevant for schema evolution. - // Examples: - // * UPDATE SET target.a = source.a - // * UPDATE SET nested.a = source.nested.a - // * INSERT (a, nested.b) VALUES (source.a, source.nested.b) - // New columns/nested fields in this schema that are not existing in target schema - // will be added for schema evolution. - def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { + /** + * Builds the list of catalog changes for `MERGE ... WITH SCHEMA EVOLUTION` from the explicit + * `SET` / `VALUES` assignments in `WHEN MATCHED` (UPDATE) and `WHEN NOT MATCHED` (INSERT) + * clauses. + * + * `UPDATE *` and `INSERT *` return no changes; those branches need to be expanded by other + * analysis steps before this logic applies. + * + * Only assignments that copy from a source column (or nested field) into the same path on the + * target are considered, including new target columns that do not exist on the table yet but name + * the same path as that source field. + * + * From those assignments we may produce: + * - `addColumn` when the assignment targets a new column and the table does not already have it + * at that name/path. + * - `updateColumnType` when an existing target column and the matching source column disagree on + * a simple (non-struct) type (for example widening `INT` to `BIGINT`). + * - Extra nested `addColumn` steps when the source side has struct fields (including inside + * arrays or maps) that the target table row does not yet store at the same path. + * - Nothing extra when the types already line up for that assignment. + * + * @param merge analyzed MERGE command (must satisfy `schemaEvolutionEnabled` and + * `schemaEvolutionReady` on the caller side) + * @return catalog edits to apply to the target table, deduplicated and ordered by assignment + * then stable set iteration + */ + private def computePendingSchemaChanges(merge: MergeIntoTable): Seq[TableChange] = { val actions = merge.matchedActions ++ merge.notMatchedActions - val assignments = actions.collect { + val originalTarget = merge.targetTable.schema + val originalSource = merge.sourceTable.schema + + val schemaEvolutionAssignments = actions.flatMap { case a: UpdateAction => a.assignments case a: InsertAction => a.assignments - }.flatten - - val containsStarAction = actions.exists { - case _: UpdateStarAction => true - case _: InsertStarAction => true - case _ => false + case _: UpdateStarAction | _: InsertStarAction => Seq.empty + case _ => Seq.empty + }.filter(isSchemaEvolutionCandidate(_, merge.sourceTable)) + + val changes = mutable.LinkedHashSet.empty[TableChange] + val failIncompatible: () => Nothing = () => + throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( + originalTarget, originalSource, null) + + schemaEvolutionAssignments.foreach { + case a if !a.key.resolved => + val fieldPath = extractFieldPath(a.key, allowUnresolved = true) + if (fieldPath.nonEmpty && + !SchemaUtils.fieldExistsAtPath(originalTarget, fieldPath)) { + changes += TableChange.addColumn(fieldPath.toArray, a.value.dataType.asNullable) + } + case a if a.key.dataType != a.value.dataType => + computeTypeSchemaChanges( + a.key.dataType, + a.value.dataType, + changes, + fieldPath = extractFieldPath(a.key, allowUnresolved = false), + targetTypeAtPath = originalTarget, + failIncompatible) + case _ => } - def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = - StructType(sourceSchema.flatMap { field => - val fieldPath = basePath :+ field.name - - field.dataType match { - // Specifically assigned to in one clause: - // always keep, including all nested attributes - case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) - // If this is a struct and one of the children is being assigned to in a merge clause, - // keep it and continue filtering children. - case struct: StructType if assignments.exists(assign => - isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* - // clause. Check if it should be kept with any * action. - case struct: StructType if containsStarAction => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - case _ if containsStarAction => Some(field) - // The field and its children are not assigned to in any * or non-* action, drop it. - case _ => None + changes.toSeq + } + + /** + * Recursively compares assignment key vs value types at `fieldPath` and appends matching + * `addColumn` / `updateColumnType` entries to `changes`. + * + * `keyType` and `valueType` come from the assignment expression types (MERGE target side vs + * source side). `targetTypeAtPath` is the type of the same path in the current + * MERGE target table. + * + * @param keyType type of the assignment key at this path (MERGE target column expression) + * @param valueType type of the assignment value at this path (typically source column) + * @param changes accumulator for [[TableChange]] instances + * @param fieldPath qualified path segments for nested columns (`element` / `key` / `value` + * under arrays and mapss) + * @param targetTypeAtPath type of the loaded MERGE target table at `fieldPath` + * @param failIncompatible error handling when assignment types cannot be reconciled + */ + private def computeTypeSchemaChanges( + keyType: DataType, + valueType: DataType, + changes: mutable.LinkedHashSet[TableChange], + fieldPath: Seq[String], + targetTypeAtPath: DataType, + failIncompatible: () => Nothing): Unit = { + (keyType, valueType) match { + case (StructType(keyFields), StructType(valueFields)) => + val keyFieldMap = SchemaUtils.toFieldMap(keyFields) + val valueFieldMap = SchemaUtils.toFieldMap(valueFields) + val targetFieldMap = targetTypeAtPath match { + case st: StructType => SchemaUtils.toFieldMap(st.fields) + case _ => Map.empty[String, StructField] + } + + keyFields + .filter(f => valueFieldMap.contains(f.name)) + .foreach { f => + val nextTargetType = + targetFieldMap.get(f.name).map(_.dataType).getOrElse(f.dataType) + computeTypeSchemaChanges( + f.dataType, + valueFieldMap(f.name).dataType, + changes, + fieldPath ++ Seq(f.name), + nextTargetType, + failIncompatible) + } + + valueFields + .filterNot(f => keyFieldMap.contains(f.name)) + .foreach { f => + if (!targetFieldMap.contains(f.name)) { + changes += TableChange.addColumn( + (fieldPath :+ f.name).toArray, + f.dataType.asNullable) + } + } + + case (ArrayType(keyElemType, _), ArrayType(valueElemType, _)) => + val nextTargetType = targetTypeAtPath match { + case ArrayType(elementType, _) => elementType + case _ => keyElemType + } + computeTypeSchemaChanges( + keyElemType, + valueElemType, + changes, + fieldPath :+ "element", + nextTargetType, + failIncompatible) + + case (MapType(keySideMapKeyType, keySideMapValueType, _), + MapType(valueSideMapKeyType, valueSideMapValueType, _)) => + val (nextMapKeyTargetType, nextMapValueTargetType) = targetTypeAtPath match { + case MapType(kt, vt, _) => (kt, vt) + case _ => (keySideMapKeyType, keySideMapValueType) } - }) + computeTypeSchemaChanges( + keySideMapKeyType, + valueSideMapKeyType, + changes, + fieldPath :+ "key", + nextMapKeyTargetType, + failIncompatible) + computeTypeSchemaChanges( + keySideMapValueType, + valueSideMapValueType, + changes, + fieldPath :+ "value", + nextMapValueTargetType, + failIncompatible) + + case (kt: AtomicType, vt: AtomicType) if kt != vt => + changes += TableChange.updateColumnType(fieldPath.toArray, vt) + + case (kt, vt) if kt == vt => - filterSchema(merge.sourceTable.schema, Seq.empty) + case _ => + failIncompatible() + } } // Helper method to extract field path from an Expression. @@ -1121,24 +1244,6 @@ object MergeIntoTable { } } - // Helper method to check if a given field path is a prefix of another path. - private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = - prefix.length <= path.length && prefix.zip(path).forall { - case (prefixNamePart, pathNamePart) => - SQLConf.get.resolver(prefixNamePart, pathNamePart) - } - - // Helper method to check if an assignment key is equal to a source column - // and if the assignment value is that same source column. - // Example: UPDATE SET target.a = source.a - private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = { - // key must be a non-qualified field path that may be added to target schema via evolution - val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true) - // value should always be resolved (from source) - val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false) - assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath - } - private def areSchemaEvolutionReady( assignments: Seq[Assignment], source: LogicalPlan): Boolean = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 58ababa04739f..385a1534d4cc5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -25,9 +25,11 @@ import scala.collection.mutable import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, NamedExpression} +import org.apache.spark.sql.catalyst.util.CaseInsensitiveMap import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.MultipartIdentifierHelper import org.apache.spark.sql.connector.expressions.{BucketTransform, FieldReference, NamedTransform, Transform} import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} +import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_TOP_LEVEL_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ @@ -125,6 +127,70 @@ private[spark] object SchemaUtils { } } + /** + * Returns whether a column path exists in the schema's nested structure. + * + * @param root type schema + * @param path name segments + */ + def fieldExistsAtPath( + root: StructType, + path: Seq[String]): Boolean = { + if (path.isEmpty) { + false + } else { + fieldExistsAtPathInternal(root, path) + } + } + + private def fieldExistsAtPathInternal( + dt: DataType, + parts: Seq[String]): Boolean = { + def checkAndRecurse( + nextType: DataType, + remaining: Seq[String]): Boolean = { + if (remaining.isEmpty) { + true + } else { + fieldExistsAtPathInternal(nextType, remaining) + } + } + + if (parts.isEmpty) { + true + } else { + dt match { + case st: StructType => + toFieldMap(st.fields).get(parts.head) match { + case Some(f) => checkAndRecurse(f.dataType, parts.tail) + case None => false + } + case ArrayType(elementType, _) if parts.head == "element" => + checkAndRecurse(elementType, parts.tail) + case MapType(keyType, _, _) if parts.head == "key" => + checkAndRecurse(keyType, parts.tail) + case MapType(_, valueType, _) if parts.head == "value" => + checkAndRecurse(valueType, parts.tail) + case _ => false + } + } + } + + /** + * Returns a map of field name to StructField for the given fields. + * @param fields the fields to create the map for + * @return a map of field name to StructField for the given fields + */ + def toFieldMap( + fields: Array[StructField]): Map[String, StructField] = { + val fieldMap = fields.map(f => f.name -> f).toMap + if (SQLConf.get.caseSensitiveAnalysis) { + fieldMap + } else { + CaseInsensitiveMap(fieldMap) + } + } + /** * Checks if input column names have duplicate identifiers. This throws an exception if * the duplication exists. diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala index a277bb021c3f6..91b4edb618a6e 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/util/SchemaUtilsSuite.scala @@ -21,10 +21,12 @@ import java.util.Locale import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis._ -import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StructType} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{ArrayType, LongType, MapType, StringType, StructType} -class SchemaUtilsSuite extends SparkFunSuite { +class SchemaUtilsSuite extends SparkFunSuite with SQLConfHelper { private def resolver(caseSensitiveAnalysis: Boolean): Resolver = { if (caseSensitiveAnalysis) { @@ -110,4 +112,27 @@ class SchemaUtilsSuite extends SparkFunSuite { parameters = Map("columnName" -> "`camelcase`")) } } + + test("fieldExistsAtPath: structs, arrays, maps, and name case rules") { + val nested = new StructType().add("y", LongType) + val root = new StructType() + .add("a", LongType) + .add("S", nested) + .add("arr", ArrayType(LongType)) + .add("m", MapType(StringType, LongType)) + + assert(!SchemaUtils.fieldExistsAtPath(root, Seq.empty)) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("a"))) + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "true") { + assert(!SchemaUtils.fieldExistsAtPath(root, Seq("A"))) + } + withSQLConf(SQLConf.CASE_SENSITIVE.key -> "false") { + assert(SchemaUtils.fieldExistsAtPath(root, Seq("A"))) + } + assert(SchemaUtils.fieldExistsAtPath(root, Seq("S", "y"))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("arr", "element"))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "key"))) + assert(SchemaUtils.fieldExistsAtPath(root, Seq("m", "value"))) + assert(!SchemaUtils.fieldExistsAtPath(root, Seq("missing"))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala index 8565c0b31c0c1..f4308162277c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionExtraSQLTests.scala @@ -142,7 +142,7 @@ trait MergeIntoSchemaEvolutionExtraSQLTests extends RowLevelOperationSuiteBase { s"Error message should mention table name: ${ex.getMessage}") val msg = ex.getMessage - val expectedChanges = "ALTER COLUMN pk TYPE BIGINT; ADD COLUMN active BOOLEAN" + val expectedChanges = "ADD COLUMN active BOOLEAN; ALTER COLUMN pk TYPE BIGINT" assert(msg.contains(expectedChanges), s"Error message should contain exact changes '$expectedChanges': $msg") }