diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala index d18116ed36ede..be1210409c3b4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSchemaEvolution.scala @@ -98,64 +98,57 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { } /** - * Computes the set of table changes needed to evolve `originalTarget` schema - * to accommodate `originalSource` schema. When `isByName` is true, fields are matched + * Computes the set of table changes needed to evolve `target` schema + * to accommodate `source` schema. When `isByName` is true, fields are matched * by name. When false, fields are matched by position. */ def computeSchemaChanges( - originalTarget: StructType, - originalSource: StructType, + target: StructType, + source: StructType, isByName: Boolean): Array[TableChange] = computeSchemaChanges( - originalTarget, - originalSource, - originalTarget, - originalSource, + target, + source, fieldPath = Nil, - isByName) + isByName, + error = throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( + target, source, null)) private def computeSchemaChanges( currentType: DataType, newType: DataType, - originalTarget: StructType, - originalSource: StructType, fieldPath: List[String], - isByName: Boolean): Array[TableChange] = { + isByName: Boolean, + error: => Nothing): Array[TableChange] = { (currentType, newType) match { case (StructType(currentFields), StructType(newFields)) => if (isByName) { - computeSchemaChangesByName( - currentFields, newFields, originalTarget, originalSource, fieldPath) + computeSchemaChangesByName(currentFields, newFields, fieldPath, error) } else { - computeSchemaChangesByPosition( - currentFields, newFields, originalTarget, originalSource, fieldPath) + computeSchemaChangesByPosition(currentFields, newFields, fieldPath, error) } case (ArrayType(currentElementType, _), ArrayType(newElementType, _)) => computeSchemaChanges( currentElementType, newElementType, - originalTarget, - originalSource, fieldPath :+ "element", - isByName) + isByName, + error) - case (MapType(currentKeyType, currentValueType, _), - MapType(newKeyType, newValueType, _)) => + case (MapType(currentKeyType, currentValueType, _), MapType(newKeyType, newValueType, _)) => val keyChanges = computeSchemaChanges( currentKeyType, newKeyType, - originalTarget, - originalSource, fieldPath :+ "key", - isByName) + isByName, + error) val valueChanges = computeSchemaChanges( currentValueType, newValueType, - originalTarget, - originalSource, fieldPath :+ "value", - isByName) + isByName, + error) keyChanges ++ valueChanges case (currentType: AtomicType, newType: AtomicType) if currentType != newType => @@ -167,8 +160,7 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { case _ => // Do not support change between atomic and complex types for now - throw QueryExecutionErrors.failedToMergeIncompatibleSchemasError( - originalTarget, originalSource, null) + error } } @@ -179,9 +171,8 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { private def computeSchemaChangesByName( currentFields: Array[StructField], newFields: Array[StructField], - originalTarget: StructType, - originalSource: StructType, - fieldPath: List[String]): Array[TableChange] = { + fieldPath: List[String], + error: => Nothing): Array[TableChange] = { val currentFieldMap = toFieldMap(currentFields) val newFieldMap = toFieldMap(newFields) @@ -192,10 +183,9 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { computeSchemaChanges( f.dataType, newFieldMap(f.name).dataType, - originalTarget, - originalSource, fieldPath :+ f.name, - isByName = true) + isByName = true, + error) } // Collect newly added fields @@ -213,18 +203,16 @@ object ResolveSchemaEvolution extends Rule[LogicalPlan] { private def computeSchemaChangesByPosition( currentFields: Array[StructField], newFields: Array[StructField], - originalTarget: StructType, - originalSource: StructType, - fieldPath: List[String]): Array[TableChange] = { + fieldPath: List[String], + error: => Nothing): Array[TableChange] = { // Update existing field types by pairing fields at the same position. val updates = currentFields.zip(newFields).flatMap { case (currentField, newField) => computeSchemaChanges( currentField.dataType, newField.dataType, - originalTarget, - originalSource, fieldPath :+ currentField.name, - isByName = false) + isByName = false, + error) } // Extra source fields beyond the target's field count are new additions. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala index fc95b29d6546e..a3fe0617387ba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/v2Commands.scala @@ -1041,9 +1041,10 @@ case class MergeIntoTable( override lazy val pendingSchemaChanges: Seq[TableChange] = { if (schemaEvolutionEnabled && schemaEvolutionReady) { - val referencedSourceSchema = MergeIntoTable.sourceSchemaForSchemaEvolution(this) - ResolveSchemaEvolution.computeSchemaChanges( - targetTable.schema, referencedSourceSchema, isByName = true).toSeq + val allChanges = ResolveSchemaEvolution.computeSchemaChanges( + targetTable.schema, sourceTable.schema, isByName = true) + MergeIntoTable.filterValidSchemaEvolution( + allChanges, matchedActions ++ notMatchedActions, sourceTable) } else { Seq.empty } @@ -1097,52 +1098,36 @@ object MergeIntoTable { .toSet } - // A pruned version of source schema that only contains columns/nested fields - // explicitly and directly assigned to a target counterpart in MERGE INTO actions, - // which are relevant for schema evolution. - // Examples: - // * UPDATE SET target.a = source.a - // * UPDATE SET nested.a = source.nested.a - // * INSERT (a, nested.b) VALUES (source.a, source.nested.b) - // New columns/nested fields in this schema that are not existing in target schema - // will be added for schema evolution. - def sourceSchemaForSchemaEvolution(merge: MergeIntoTable): StructType = { - val actions = merge.matchedActions ++ merge.notMatchedActions + /** + * Filters schema changes to only those relevant to identity assignments + * (e.g. `target.x = source.x`) in the MERGE actions. Only identity assignments can + * introduce new columns or type changes via schema evolution. + * + * A schema change is kept if its field path is equal to or nested under the key path + * of an identity assignment. + */ + private def filterValidSchemaEvolution( + changes: Array[TableChange], + actions: Seq[MergeAction], + source: LogicalPlan): Seq[TableChange] = { val assignments = actions.collect { case a: UpdateAction => a.assignments case a: InsertAction => a.assignments }.flatten - val containsStarAction = actions.exists { - case _: UpdateStarAction => true - case _: InsertStarAction => true - case _ => false - } - - def filterSchema(sourceSchema: StructType, basePath: Seq[String]): StructType = - StructType(sourceSchema.flatMap { field => - val fieldPath = basePath :+ field.name - - field.dataType match { - // Specifically assigned to in one clause: - // always keep, including all nested attributes - case _ if assignments.exists(isEqual(_, fieldPath)) => Some(field) - // If this is a struct and one of the children is being assigned to in a merge clause, - // keep it and continue filtering children. - case struct: StructType if assignments.exists(assign => - isPrefix(fieldPath, extractFieldPath(assign.key, allowUnresolved = true))) => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - // The field isn't assigned to directly or indirectly (i.e. its children) in any non-* - // clause. Check if it should be kept with any * action. - case struct: StructType if containsStarAction => - Some(field.copy(dataType = filterSchema(struct, fieldPath))) - case _ if containsStarAction => Some(field) - // The field and its children are not assigned to in any * or non-* action, drop it. - case _ => None - } - }) - - filterSchema(merge.sourceTable.schema, Seq.empty) + val evolutionPaths = assignments + .filter(isSameColumnAssignment(_, source)) + .map(a => extractFieldPath(a.key, allowUnresolved = true)) + .filter(_.nonEmpty) + + val resolver = SQLConf.get.resolver + changes.filter { case change: TableChange.ColumnChange => + val changePath = change.fieldNames().toSeq + evolutionPaths.exists { ep => + ep.length <= changePath.length && + ep.zip(changePath).forall { case (a, b) => resolver(a, b) } + } + }.toSeq } // Helper method to extract field path from an Expression. @@ -1156,24 +1141,6 @@ object MergeIntoTable { } } - // Helper method to check if a given field path is a prefix of another path. - private def isPrefix(prefix: Seq[String], path: Seq[String]): Boolean = - prefix.length <= path.length && prefix.zip(path).forall { - case (prefixNamePart, pathNamePart) => - SQLConf.get.resolver(prefixNamePart, pathNamePart) - } - - // Helper method to check if an assignment key is equal to a source column - // and if the assignment value is that same source column. - // Example: UPDATE SET target.a = source.a - private def isEqual(assignment: Assignment, sourceFieldPath: Seq[String]): Boolean = { - // key must be a non-qualified field path that may be added to target schema via evolution - val assignmentKeyExpr = extractFieldPath(assignment.key, allowUnresolved = true) - // value should always be resolved (from source) - val assignmentValueExpr = extractFieldPath(assignment.value, allowUnresolved = false) - assignmentKeyExpr == assignmentValueExpr && assignmentKeyExpr == sourceFieldPath - } - private def areSchemaEvolutionReady( assignments: Seq[Assignment], source: LogicalPlan): Boolean = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionTypeWideningAndExtraFieldTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionTypeWideningAndExtraFieldTests.scala index cd151f6e1e0ea..62c5e34f22006 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionTypeWideningAndExtraFieldTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoSchemaEvolutionTypeWideningAndExtraFieldTests.scala @@ -142,6 +142,94 @@ trait MergeIntoSchemaEvolutionTypeWideningAndExtraFieldTests (3, 75, "newdep")).toDF("pk", "salary", "dep") ) + // When assigning s.bonus to existing t.salary and source.salary has a wider type (long) than + // target.salary (int), no evolution should occur because the assignment uses s.bonus, not + // s.salary. The type mismatch on the same-named column should be irrelevant. + testEvolution("source has extra column with type mismatch on existing column -" + + "should not evolve when assigning from differently named source column")( + targetData = { + val schema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(1, 100, "hr"), + Row(2, 200, "software") + )), schema) + }, + sourceData = { + val schema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", LongType), + StructField("dep", StringType), + StructField("bonus", LongType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(2, 150L, "dummy", 50L), + Row(3, 250L, "dummy", 75L) + )), schema) + }, + clauses = Seq( + update(set = "salary = s.bonus"), + insert(values = "(pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep')") + ), + expected = Seq( + (1, 100, "hr"), + (2, 50, "software"), + (3, 75, "newdep")).toDF("pk", "salary", "dep"), + expectedWithoutEvolution = Seq( + (1, 100, "hr"), + (2, 50, "software"), + (3, 75, "newdep")).toDF("pk", "salary", "dep"), + expectedSchema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )), + expectedSchemaWithoutEvolution = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + ) + + // When assigning s.bonus (StringType) to target salary (IntegerType), the types are + // incompatible. This should fail both with and without schema evolution because the explicit + // assignment has mismatched types regardless of evolution. + testEvolution("source has extra column with type mismatch on existing column -" + + "should fail when assigning from incompatible source column")( + targetData = { + val schema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", IntegerType), + StructField("dep", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(1, 100, "hr"), + Row(2, 200, "software") + )), schema) + }, + sourceData = { + val schema = StructType(Seq( + StructField("pk", IntegerType, nullable = false), + StructField("salary", LongType), + StructField("dep", StringType), + StructField("bonus", StringType) + )) + spark.createDataFrame(spark.sparkContext.parallelize(Seq( + Row(2, 150L, "dummy", "fifty"), + Row(3, 250L, "dummy", "seventy-five") + )), schema) + }, + clauses = Seq( + update(set = "salary = s.bonus"), + insert(values = "(pk, salary, dep) VALUES (s.pk, s.bonus, 'newdep')") + ), + expectErrorContains = "Cannot safely cast", + expectErrorWithoutEvolutionContains = "Cannot safely cast" + ) + // No evolution when using named_struct to construct value without referencing new field testNestedStructsEvolution("source has extra struct field -" + "no evolution when not directly referencing new field - INSERT")(