@@ -306,6 +306,29 @@ def test_filter(df):
306306 assert result .column (2 ) == pa .array ([5 ])
307307
308308
309+ def test_filter_string_predicates (df ):
310+ df_str = df .filter ("a > 2" )
311+ result = df_str .collect ()[0 ]
312+
313+ assert result .column (0 ) == pa .array ([3 ])
314+ assert result .column (1 ) == pa .array ([6 ])
315+ assert result .column (2 ) == pa .array ([8 ])
316+
317+ df_mixed = df .filter ("a > 1" , column ("b" ) != literal (6 ))
318+ result_mixed = df_mixed .collect ()[0 ]
319+
320+ assert result_mixed .column (0 ) == pa .array ([2 ])
321+ assert result_mixed .column (1 ) == pa .array ([5 ])
322+ assert result_mixed .column (2 ) == pa .array ([5 ])
323+
324+ df_strings = df .filter ("a > 1" , "b < 6" )
325+ result_strings = df_strings .collect ()[0 ]
326+
327+ assert result_strings .column (0 ) == pa .array ([2 ])
328+ assert result_strings .column (1 ) == pa .array ([5 ])
329+ assert result_strings .column (2 ) == pa .array ([5 ])
330+
331+
309332def test_parse_sql_expr (df ):
310333 plan1 = df .filter (df .parse_sql_expr ("a > 2" )).logical_plan ()
311334 plan2 = df .filter (column ("a" ) > literal (2 )).logical_plan ()
@@ -388,9 +411,16 @@ def test_aggregate_tuple_aggs(df):
388411 assert result_tuple == result_list
389412
390413
391- def test_filter_string_unsupported (df ):
392- with pytest .raises (TypeError , match = re .escape (EXPR_TYPE_ERROR )):
393- df .filter ("a > 1" )
414+ def test_filter_string_equivalent (df ):
415+ df1 = df .filter ("a > 1" ).to_pydict ()
416+ df2 = df .filter (column ("a" ) > literal (1 )).to_pydict ()
417+ assert df1 == df2
418+
419+
420+ def test_filter_string_invalid (df ):
421+ with pytest .raises (Exception ) as excinfo :
422+ df .filter ("this is not valid sql" ).collect ()
423+ assert "Expected Expr" not in str (excinfo .value )
394424
395425
396426def test_drop (df ):
0 commit comments