diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 66f85059096a5..ee96d6d83f90e 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -5652,6 +5652,12 @@ ], "sqlState" : "42617" }, + "PARSE_INPUT_NOT_STRING_TYPE" : { + "message" : [ + "Input DataFrame column must be StringType, but got ." + ], + "sqlState" : "42K09" + }, "PARSE_MODE_UNSUPPORTED" : { "message" : [ "The function doesn't support the mode. Acceptable modes are PERMISSIVE and FAILFAST." diff --git a/python/pyspark/sql/connect/plan.py b/python/pyspark/sql/connect/plan.py index 3dac4fc47ee70..6aff0ef636820 100644 --- a/python/pyspark/sql/connect/plan.py +++ b/python/pyspark/sql/connect/plan.py @@ -43,7 +43,7 @@ from pyspark.serializers import CloudPickleSerializer from pyspark.storagelevel import StorageLevel -from pyspark.sql.types import DataType +from pyspark.sql.types import DataType, StructType import pyspark.sql.connect.proto as proto from pyspark.sql.column import Column @@ -383,6 +383,40 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation: return plan +class Parse(LogicalPlan): + """Parse a DataFrame with a single string column into a structured DataFrame.""" + + def __init__( + self, + child: "LogicalPlan", + format: "proto.Parse.ParseFormat.ValueType", + schema: Optional[str] = None, + options: Optional[Mapping[str, str]] = None, + ) -> None: + super().__init__(child) + self._format = format + self._schema = schema + self._options = options + + def plan(self, session: "SparkConnectClient") -> proto.Relation: + assert self._child is not None + plan = self._create_proto_relation() + plan.parse.input.CopyFrom(self._child.plan(session)) + plan.parse.format = self._format + if self._schema is not None and len(self._schema) > 0: + plan.parse.schema.CopyFrom( + pyspark_types_to_proto_types( + StructType.fromDDL(self._schema) + if not self._schema.startswith("{") + else StructType.fromJson(json.loads(self._schema)) + ) + ) + if self._options is not None: + for k, v in self._options.items(): + plan.parse.options[k] = v + return plan + + class Read(LogicalPlan): def __init__( self, diff --git a/python/pyspark/sql/connect/readwriter.py b/python/pyspark/sql/connect/readwriter.py index c951a9caf6a56..027da28a31cb8 100644 --- a/python/pyspark/sql/connect/readwriter.py +++ b/python/pyspark/sql/connect/readwriter.py @@ -25,7 +25,9 @@ LogicalPlan, WriteOperation, WriteOperationV2, + Parse, ) +import pyspark.sql.connect.proto as proto from pyspark.sql.types import StructType from pyspark.sql.utils import to_str from pyspark.sql.readwriter import ( @@ -165,7 +167,7 @@ def changes(self, tableName: str) -> "DataFrame": def json( self, - path: PathOrPaths, + path: Union[PathOrPaths, "DataFrame"], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -220,7 +222,32 @@ def json( ) if isinstance(path, str): path = [path] - return self.load(path=path, format="json", schema=schema) + if isinstance(path, list): + return self.load(path=path, format="json", schema=schema) + + from pyspark.sql.connect.dataframe import DataFrame + + if isinstance(path, DataFrame): + # Schema must be set explicitly here because the DataFrame path + # bypasses load(), which normally calls self.schema(schema). + if schema is not None: + self.schema(schema) + return self._df( + Parse( + child=path._plan, + format=proto.Parse.ParseFormat.PARSE_FORMAT_JSON, + schema=self._schema, + options=self._options, + ) + ) + raise PySparkTypeError( + errorClass="NOT_EXPECTED_TYPE", + messageParameters={ + "arg_name": "path", + "expected_type": "str, list, or DataFrame", + "arg_type": type(path).__name__, + }, + ) json.__doc__ = PySparkDataFrameReader.json.__doc__ diff --git a/python/pyspark/sql/readwriter.py b/python/pyspark/sql/readwriter.py index bed87788d2c11..7ada41a71655d 100644 --- a/python/pyspark/sql/readwriter.py +++ b/python/pyspark/sql/readwriter.py @@ -320,7 +320,7 @@ def load( def json( self, - path: Union[str, List[str], "RDD[str]"], + path: Union[str, List[str], "RDD[str]", "DataFrame"], schema: Optional[Union[StructType, str]] = None, primitivesAsString: Optional[Union[bool, str]] = None, prefersDecimal: Optional[Union[bool, str]] = None, @@ -361,11 +361,15 @@ def json( .. versionchanged:: 3.4.0 Supports Spark Connect. + .. versionchanged:: 4.2.0 + Supports DataFrame input. + Parameters ---------- - path : str, list or :class:`RDD` + path : str, list, :class:`RDD`, or :class:`DataFrame` string represents path to the JSON dataset, or a list of paths, - or RDD of Strings storing JSON objects. + or RDD of Strings storing JSON objects, + or a DataFrame with a single string column containing JSON strings. schema : :class:`pyspark.sql.types.StructType` or str, optional an optional :class:`pyspark.sql.types.StructType` for the input schema or a DDL-formatted string (For example ``col0 INT, col1 DOUBLE``). @@ -434,6 +438,20 @@ def json( +----+---+ | Bob| 30| +----+---+ + + Example 4: Parse JSON from a DataFrame with a single string column. + + >>> json_df = spark.createDataFrame( + ... [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + ... schema="value STRING", + ... ) + >>> spark.read.json(json_df).sort("name").show() + +---+-----+ + |age| name| + +---+-----+ + | 25|Alice| + | 30| Bob| + +---+-----+ """ self._set_opts( schema=schema, @@ -486,12 +504,20 @@ def func(iterator: Iterable) -> Iterable: assert self._spark._jvm is not None jrdd = keyed._jrdd.map(self._spark._jvm.BytesToString()) return self._df(self._jreader.json(jrdd)) + + from pyspark.sql.dataframe import DataFrame + + if isinstance(path, DataFrame): + assert self._spark._jvm is not None + return self._df( + self._spark._jvm.PythonSQLUtils.jsonFromDataFrame(self._jreader, path._jdf) + ) else: raise PySparkTypeError( errorClass="NOT_EXPECTED_TYPE", messageParameters={ "arg_name": "path", - "expected_type": "str or list[RDD]", + "expected_type": "str, list, RDD, or DataFrame", "arg_type": type(path).__name__, }, ) diff --git a/python/pyspark/sql/tests/connect/test_connect_readwriter.py b/python/pyspark/sql/tests/connect/test_connect_readwriter.py index fc27771fff74d..9e8986b3c8623 100644 --- a/python/pyspark/sql/tests/connect/test_connect_readwriter.py +++ b/python/pyspark/sql/tests/connect/test_connect_readwriter.py @@ -177,6 +177,43 @@ def test_csv(self): # Read the text file as a DataFrame. self.assert_eq(self.connect.read.csv(d).toPandas(), self.spark.read.csv(d).toPandas()) + def test_json_with_dataframe_input(self): + json_df = self.connect.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.connect.read.json(json_df) + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_and_schema(self): + json_df = self.connect.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.connect.read.json(json_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_non_string_column(self): + int_df = self.connect.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.connect.read.json(int_df).collect() + + def test_json_with_dataframe_input_multiple_columns(self): + multi_df = self.connect.createDataFrame( + [('{"name": "Alice"}', "extra"), ('{"name": "Bob"}', "extra")], + schema="value STRING, other STRING", + ) + result = self.connect.read.json(multi_df) + expected = [Row(name="Alice"), Row(name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_zero_columns(self): + empty_schema_df = self.connect.range(1).select() + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.connect.read.json(empty_schema_df).collect() + def test_multi_paths(self): # SPARK-42041: DataFrameReader should support list of paths diff --git a/python/pyspark/sql/tests/test_datasources.py b/python/pyspark/sql/tests/test_datasources.py index 1ceb74c1d907c..d742a96ed5f2e 100644 --- a/python/pyspark/sql/tests/test_datasources.py +++ b/python/pyspark/sql/tests/test_datasources.py @@ -93,6 +93,43 @@ def test_linesep_json(self): finally: shutil.rmtree(tpath) + def test_json_with_dataframe_input(self): + json_df = self.spark.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.spark.read.json(json_df) + expected = [Row(age=25, name="Alice"), Row(age=30, name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_and_schema(self): + json_df = self.spark.createDataFrame( + [('{"name": "Alice", "age": 25}',), ('{"name": "Bob", "age": 30}',)], + schema="value STRING", + ) + result = self.spark.read.json(json_df, schema="name STRING, age INT") + expected = [Row(name="Alice", age=25), Row(name="Bob", age=30)] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_non_string_column(self): + int_df = self.spark.createDataFrame([(1,), (2,)], schema="value INT") + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.spark.read.json(int_df).collect() + + def test_json_with_dataframe_input_multiple_columns(self): + multi_df = self.spark.createDataFrame( + [('{"name": "Alice"}', "extra"), ('{"name": "Bob"}', "extra")], + schema="value STRING, other STRING", + ) + result = self.spark.read.json(multi_df) + expected = [Row(name="Alice"), Row(name="Bob")] + self.assertEqual(sorted(result.collect(), key=lambda r: r.name), expected) + + def test_json_with_dataframe_input_zero_columns(self): + empty_schema_df = self.spark.range(1).select() + with self.assertRaisesRegex(Exception, "PARSE_INPUT_NOT_STRING_TYPE"): + self.spark.read.json(empty_schema_df).collect() + def test_multiline_csv(self): ages_newlines = self.spark.read.csv( "python/test_support/sql/ages_newlines.csv", multiLine=True diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index b6fb05f3f1b1a..60ed6b74cf288 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -3495,6 +3495,12 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat ) } + def parseInputNotStringTypeError(dataType: DataType): Throwable = { + new AnalysisException( + errorClass = "PARSE_INPUT_NOT_STRING_TYPE", + messageParameters = Map("dataType" -> toSQLType(dataType))) + } + def textDataSourceWithMultiColumnsError(schema: StructType): Throwable = { new AnalysisException( errorClass = "_LEGACY_ERROR_TEMP_1290", diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala index 37bcf995ee16d..3a8a8f5a766c5 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala @@ -1760,7 +1760,19 @@ class SparkConnectPlanner( localMap.foreach { case (key, value) => reader.option(key, value) } reader } - def ds: Dataset[String] = Dataset(session, transformRelation(rel.getInput))(Encoders.STRING) + def ds: Dataset[String] = { + val input = transformRelation(rel.getInput) + val df = Dataset.ofRows(session, input) + val fields = df.schema.fields + if (fields.isEmpty) { + throw QueryCompilationErrors.parseInputNotStringTypeError( + org.apache.spark.sql.types.NullType) + } + if (fields.head.dataType != org.apache.spark.sql.types.StringType) { + throw QueryCompilationErrors.parseInputNotStringTypeError(fields.head.dataType) + } + df.select(df.columns.head).as(Encoders.STRING) + } rel.getFormat match { case ParseFormat.PARSE_FORMAT_CSV => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala index 5607c98bf29e5..aa941c81e9806 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/python/PythonSQLUtils.scala @@ -27,14 +27,16 @@ import org.apache.spark.api.python.DechunkedInputStream import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.CLASS_LOADER import org.apache.spark.security.SocketAuthServer -import org.apache.spark.sql.{internal, Column, DataFrame, Row, SparkSession, TableArg} +import org.apache.spark.sql.{internal, Column, DataFrame, DataFrameReader, Encoders, Row, SparkSession, TableArg} import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow} import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TableFunctionRegistry} import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.classic.{DataFrameReader => ClassicDataFrameReader} import org.apache.spark.sql.classic.ClassicConversions._ import org.apache.spark.sql.classic.ExpressionUtils.expression +import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.{ExplainMode, QueryExecution} import org.apache.spark.sql.execution.arrow.ArrowConverters import org.apache.spark.sql.execution.python.EvaluatePython @@ -193,6 +195,26 @@ private[sql] object PythonSQLUtils extends Logging { @scala.annotation.varargs def internalFn(name: String, inputs: Column*): Column = Column.internalFn(name, inputs: _*) + /** + * Parses a [[DataFrame]] containing JSON strings into a structured [[DataFrame]]. + * The input DataFrame must have exactly one column of StringType. + * This is used by PySpark to avoid manual Dataset[String] conversion on the Python side. + */ + def jsonFromDataFrame( + reader: DataFrameReader, + df: DataFrame): DataFrame = { + val classicReader = reader.asInstanceOf[ClassicDataFrameReader] + val fields = df.schema.fields + if (fields.isEmpty) { + throw QueryCompilationErrors.parseInputNotStringTypeError( + org.apache.spark.sql.types.NullType) + } + if (fields.head.dataType != org.apache.spark.sql.types.StringType) { + throw QueryCompilationErrors.parseInputNotStringTypeError(fields.head.dataType) + } + classicReader.json(df.select(df.columns.head).as(Encoders.STRING)) + } + def cleanupPythonWorkerLogs(sessionUUID: String, sparkContext: SparkContext): Unit = { if (!sparkContext.isStopped) { try {