Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import org.apache.spark.sql.types._
import org.apache.spark.sql.util.SchemaUtils
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.{CalendarInterval, TimestampNanosVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -430,6 +430,11 @@ abstract class HashExpression[E] extends Expression {
s"$result = $hasherClassName.hashInt($input.months, $microsecondsHash);"
}

protected def genHashTimestampNanos(input: String, result: String): String = {
val epochMicrosHash = s"$hasherClassName.hashLong($input.epochMicros, $result)"
s"$result = $hasherClassName.hashInt($input.nanosWithinMicro, $epochMicrosHash);"
}

protected def genHashString(
ctx: CodegenContext, stringType: StringType, input: String, result: String): String = {
if (stringType.supportsBinaryEquality) {
Expand Down Expand Up @@ -549,6 +554,8 @@ abstract class HashExpression[E] extends Expression {
case ByteType | ShortType | IntegerType | DateType => genHashInt(input, result)
case LongType | _: TimeType => genHashLong(input, result)
case TimestampType | TimestampNTZType => genHashTimestamp(input, result)
case _: TimestampNTZNanosType | _: TimestampLTZNanosType =>
genHashTimestampNanos(input, result)
case FloatType => genHashFloat(input, result)
case DoubleType => genHashDouble(input, result)
case d: DecimalType => genHashDecimal(ctx, d, input, result)
Expand Down Expand Up @@ -636,6 +643,7 @@ abstract class InterpretedHashFunction {
hashUnsafeBytes(bytes, Platform.BYTE_ARRAY_OFFSET, bytes.length, seed)
}
case c: CalendarInterval => hashInt(c.months, hashInt(c.days, hashLong(c.microseconds, seed)))
case t: TimestampNanosVal => hashInt(t.nanosWithinMicro, hashLong(t.epochMicros, seed))
case a: Array[Byte] =>
hashUnsafeBytes(a, Platform.BYTE_ARRAY_OFFSET, a.length, seed)
case s: UTF8String =>
Expand Down Expand Up @@ -977,6 +985,12 @@ case class HiveHash(children: Seq[Expression]) extends HashExpression[Int] {
$result = (int) ${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestamp($input);
"""

override protected def genHashTimestampNanos(input: String, result: String): String =
s"""
$result = (int)
${HiveHashFunction.getClass.getName.stripSuffix("$")}.hashTimestampNanos($input);
"""

override protected def genHashString(
ctx: CodegenContext, stringType: StringType, input: String, result: String): String = {
if (stringType.supportsBinaryEquality || !isCollationAware) {
Expand Down Expand Up @@ -1144,6 +1158,17 @@ object HiveHashFunction extends InterpretedHashFunction {
((result >>> 32) ^ result).toInt
}

/**
* Extends [[hashTimestamp]] with the sub-microsecond nanoseconds carried by a
* [[TimestampNanosVal]], folding the extra field in with the same `* 37 + field` idiom used by
* [[hashCalendarInterval]]. Hive has no nanosecond-precision timestamp type, so this is a
* Spark-defined, self-consistent hash (equal values hash equally) rather than a Hive-compatible
* one.
*/
def hashTimestampNanos(t: TimestampNanosVal): Long = {
(hashTimestamp(t.epochMicros) * 37) + t.nanosWithinMicro
}

/**
* Hive allows input intervals to be defined using units below but the intervals
* have to be from the same category:
Expand Down Expand Up @@ -1242,6 +1267,7 @@ object HiveHashFunction extends InterpretedHashFunction {

case d: Decimal => normalizeDecimal(d.toJavaBigDecimal).hashCode()
case timestamp: Long if dataType.isInstanceOf[TimestampType] => hashTimestamp(timestamp)
case timestampNanos: TimestampNanosVal => hashTimestampNanos(timestampNanos)
case calendarInterval: CalendarInterval => hashCalendarInterval(calendarInterval)
case _ => super.hash(value, dataType, 0, isCollationAware, legacyCollationAwareHashing)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,8 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjecti
import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, CollationFactory, DateTimeUtils, GenericArrayData, IntervalUtils}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, StructType, _}
import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.unsafe.hash.Murmur3_x86_32
import org.apache.spark.unsafe.types.{TimestampNanosVal, UTF8String}
import org.apache.spark.util.ArrayImplicits._

class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
Expand Down Expand Up @@ -885,6 +886,86 @@ class HashExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper {
checkEvaluation(HiveHash(Seq(time)), -1567775210)
}

test("HashExpression supports nanosecond timestamp types") {
// (epochMicros, nanosWithinMicro) pairs covering zero/mid/max nanos, negative micros, and
// the Long epoch-micro boundaries.
val values = Seq(
TimestampNanosVal.fromParts(0L, 0.toShort),
TimestampNanosVal.fromParts(1L, 1.toShort),
TimestampNanosVal.fromParts(1234567890L, 999.toShort),
TimestampNanosVal.fromParts(-1L, 500.toShort),
TimestampNanosVal.fromParts(Long.MinValue, 0.toShort),
TimestampNanosVal.fromParts(Long.MaxValue, 999.toShort))

Seq(TimestampNTZNanosType(9), TimestampLTZNanosType(9),
TimestampNTZNanosType(7), TimestampLTZNanosType(7)).foreach { dt =>
(values :+ null).foreach { v =>
// 1) Literal child: the value is embedded as a constant, so this asserts that the
// interpreted and codegen paths agree. (The unsafe projection here only round-trips the
// scalar hash result, not the nanos input -- that path is covered below.)
val lit = Literal.create(v, dt)
checkEvaluation(Murmur3Hash(Seq(lit), 42), Murmur3Hash(Seq(lit), 42).eval())
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The expected value is Murmur3Hash(Seq(lit), 42).eval(), i.e. computed by the same expression under test, so this only proves the eval paths agree with each other (and, via the second test, that both fields contribute). A bug shared across all paths (e.g. a wrong constant, or a symmetric field swap) wouldn't be caught. The existing tests in this suite pin literals (e.g. checkEvaluation(HiveHash(Seq(time)), -1567775210)). Could we pin at least one golden constant per algorithm for a fixed (epochMicros, nanosWithinMicro) pair?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a golden test. Expecteds are composed independently of the expression under test.

checkEvaluation(XxHash64(Seq(lit), 42L), XxHash64(Seq(lit), 42L).eval())
checkEvaluation(HiveHash(Seq(lit)), HiveHash(Seq(lit)).eval())

// 2) BoundReference over a row: drives the ordinal row-read (getTimestampNTZNanos /
// getTimestampLTZNanos) and the UnsafeRow round-trip of the nanos value itself -- the
// real GROUP BY / shuffle / join input path that the literal case above skips.
val row = InternalRow(v)
val ref = BoundReference(0, dt, nullable = true)
checkEvaluation(Murmur3Hash(Seq(ref), 42), Murmur3Hash(Seq(ref), 42).eval(row), row)
checkEvaluation(XxHash64(Seq(ref), 42L), XxHash64(Seq(ref), 42L).eval(row), row)
checkEvaluation(HiveHash(Seq(ref)), HiveHash(Seq(ref)).eval(row), row)
}
}
}

test("nanosecond timestamp hash is consistent with equality") {
val dt = TimestampNTZNanosType(9)
def lit(micros: Long, nanos: Short): Literal =
Literal.create(TimestampNanosVal.fromParts(micros, nanos), dt)

val a = lit(1234567890L, 123)
val aCopy = lit(1234567890L, 123)
val diffNanos = lit(1234567890L, 124) // same micros, different sub-micro nanos
val diffMicros = lit(1234567891L, 123) // different micros, same nanos

Seq[Expression => Any](
e => Murmur3Hash(Seq(e), 42).eval(),
e => XxHash64(Seq(e), 42L).eval(),
e => HiveHash(Seq(e)).eval()).foreach { hash =>
// Equal values hash equally.
assert(hash(a) === hash(aCopy))
// Both fields contribute to the hash (guards against a dropped epochMicros/nanos field).
assert(hash(a) !== hash(diffNanos))
assert(hash(a) !== hash(diffMicros))
}
}

test("nanosecond timestamp hash matches expected golden values") {
// The expected values are composed independently of the expression under test -- directly
// from the primitive hashers (and the separate hashTimestamp for Hive) with an explicit
// epochMicros-then-nanosWithinMicro folding order. So a wrong seed/constant or a swapped
// field order in the dispatch is caught, rather than masked by comparing the expression
// against itself.
val micros = 1234567890L
val nanos: Short = 789
val v = TimestampNanosVal.fromParts(micros, nanos)
val seed = 42
Seq(TimestampNTZNanosType(9), TimestampLTZNanosType(9)).foreach { dt =>
val lit = Literal.create(v, dt)
checkEvaluation(
Murmur3Hash(Seq(lit), seed),
Murmur3_x86_32.hashInt(nanos, Murmur3_x86_32.hashLong(micros, seed)))
checkEvaluation(
XxHash64(Seq(lit), seed.toLong),
XXH64.hashInt(nanos, XXH64.hashLong(micros, seed.toLong)))
checkEvaluation(
HiveHash(Seq(lit)),
((HiveHashFunction.hashTimestamp(micros) * 37) + nanos).toInt)
}
}

private def testHash(inputSchema: StructType): Unit = {
val inputGenerator = RandomDataGenerator.forType(inputSchema, nullable = false).get
val toRow = ExpressionEncoder(inputSchema).createSerializer()
Expand Down