diff --git a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala index f43571e44..464b6f742 100644 --- a/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala +++ b/codepropertygraph/src/main/scala/io/shiftleft/passes/CpgPass.scala @@ -7,6 +7,7 @@ import org.slf4j.{Logger, LoggerFactory, MDC} import java.util.function.{BiConsumer, Supplier} import scala.annotation.nowarn +import scala.compiletime.uninitialized import scala.concurrent.duration.DurationLong import scala.util.{Failure, Success, Try} @@ -137,6 +138,147 @@ abstract class ForkJoinParallelCpgPass[T <: AnyRef](cpg: Cpg, @nowarn outName: S } +/** A [[ForkJoinParallelCpgPass]] that additionally maintains a thread-local accumulator of type [[R]] which is merged + * across all threads after processing completes. This enables map-reduce style aggregation alongside the usual + * DiffGraph-based graph modifications. + * + * Each thread gets its own accumulator instance (via [[newAccumulator]]). After all parts are processed, the + * accumulators are merged using [[mergeAccumulators]] and the result is passed to [[onAccumulatorComplete]]. + * + * This variant uses the `stream.collect` / `BiConsumer` API (just like [[ForkJoinParallelCpgPass]]) with a combined + * container that holds both a [[DiffGraphBuilder]] and an accumulator per fork, so no `ThreadLocal` or + * `ConcurrentLinkedQueue` is needed. + * + * @tparam T + * the part type (same as in [[ForkJoinParallelCpgPass]]) + * @tparam R + * the accumulator type + */ +abstract class ForkJoinParallelCpgPassWithAccumulator[T <: AnyRef, R](cpg: Cpg, @nowarn outName: String = "") + extends CpgPassBase { + type DiffGraphBuilder = io.shiftleft.codepropertygraph.generated.DiffGraphBuilder + + /** Generate Array of parts that can be processed in parallel. */ + def generateParts(): Array[? <: AnyRef] + + /** Setup large data structures, acquire external resources. */ + def init(): Unit = {} + + /** Override this to disable parallelism of passes. Useful for debugging. */ + def isParallel: Boolean = true + + /** Create a fresh, empty accumulator. Called once per fork (thread). */ + protected def newAccumulator(): R + + /** Merge two accumulators. Must be associative. The result may reuse either argument. */ + protected def mergeAccumulators(left: R, right: R): R + + /** Process a single part, writing graph changes to `builder` and aggregated data to `acc`. */ + protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: R, part: T): Unit + + /** Called after all parts are processed with the fully merged accumulator. Override `finish()` if you need to release + * resources; `onAccumulatorComplete` is invoked from within the default `finish()` implementation. + */ + protected def onAccumulatorComplete(acc: R): Unit = {} + + /** Container pairing a per-fork DiffGraphBuilder with a per-fork accumulator. */ + private class BuilderWithAccumulator(val builder: DiffGraphBuilder, var acc: R) + + @volatile private var _accResult: R = uninitialized + @volatile private var _hasResult: Boolean = false + + /** Release large data structures and external resources. The default implementation calls [[onAccumulatorComplete]] + * with the merged accumulator (or a fresh one if processing failed). Subclasses that override this method must call + * `super.finish()` to ensure the accumulator callback fires. + */ + def finish(): Unit = { + val acc = if (_hasResult) _accResult else newAccumulator() + onAccumulatorComplete(acc) + _hasResult = false + } + + override def createAndApply(): Unit = { + baseLogger.info(s"Start of pass: $name") + val nanosStart = System.nanoTime() + var nParts = 0 + var nanosBuilt = -1L + var nDiff = -1 + var nDiffT = -1 + try { + val diffGraph = Cpg.newDiffGraphBuilder + nParts = runWithBuilder(diffGraph) + nanosBuilt = System.nanoTime() + nDiff = diffGraph.size + + nDiffT = flatgraph.DiffGraphApplier.applyDiff(cpg.graph, diffGraph) + } catch { + case exc: Exception => + baseLogger.error(s"Pass ${name} failed", exc) + throw exc + } finally { + val nanosStop = System.nanoTime() + val fracRun = if (nanosBuilt == -1) 0.0 else (nanosStop - nanosBuilt) * 100.0 / (nanosStop - nanosStart + 1) + baseLogger.info( + f"Pass $name completed in ${(nanosStop - nanosStart) * 1e-6}%.0f ms ($fracRun%.0f%% on mutations). $nDiff%d + ${nDiffT - nDiff}%d changes committed from $nParts%d parts." + ) + } + } + + override def runWithBuilder(externalBuilder: DiffGraphBuilder): Int = { + _hasResult = false + try { + init() + val parts = generateParts() + val nParts = parts.size + _accResult = nParts match { + case 0 => + newAccumulator() + case 1 => + val acc = newAccumulator() + runOnPartWithAccumulator(externalBuilder, acc, parts(0).asInstanceOf[T]) + acc + case _ => + val stream = + if (!isParallel) + java.util.Arrays + .stream(parts) + .sequential() + else + java.util.Arrays + .stream(parts) + .parallel() + val result = stream.collect( + new Supplier[BuilderWithAccumulator] { + override def get(): BuilderWithAccumulator = + new BuilderWithAccumulator(Cpg.newDiffGraphBuilder, newAccumulator()) + }, + new BiConsumer[BuilderWithAccumulator, AnyRef] { + override def accept(bwa: BuilderWithAccumulator, part: AnyRef): Unit = + runOnPartWithAccumulator(bwa.builder, bwa.acc, part.asInstanceOf[T]) + }, + new BiConsumer[BuilderWithAccumulator, BuilderWithAccumulator] { + override def accept(left: BuilderWithAccumulator, right: BuilderWithAccumulator): Unit = { + left.builder.absorb(right.builder) + left.acc = mergeAccumulators(left.acc, right.acc) + } + } + ) + externalBuilder.absorb(result.builder) + result.acc + } + _hasResult = true + nParts + } finally { + finish() + } + } + + @deprecated("Please use createAndApply") + override def createApplySerializeAndStore(serializedCpg: SerializedCpg, prefix: String = ""): Unit = { + createAndApply() + } +} + trait CpgPassBase { protected def baseLogger: Logger = LoggerFactory.getLogger(getClass) diff --git a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala index dc32d7975..467e786ba 100644 --- a/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala +++ b/codepropertygraph/src/test/scala/io/shiftleft/passes/CpgPassNewTests.scala @@ -90,4 +90,114 @@ class CpgPassNewTests extends AnyWordSpec with Matchers { } } + "ForkJoinParallelCpgPassWithAccumulator" should { + "merge accumulators and invoke completion callback once" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-pass") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] = + left ++= right + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = acc += part.length + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array("a", "bb", "ccc") + override def isParallel: Boolean = false + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(6) + } + + "use a fresh accumulator when there are no parts" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-empty") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer(42) + override protected def mergeAccumulators(left: ArrayBuffer[Int], right: ArrayBuffer[Int]): ArrayBuffer[Int] = + left ++= right + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = () + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array.empty + } + + pass.createAndApply() + + completed.toSeq shouldBe Seq(42) + } + + "clear accumulator state between runs" in { + val cpg = Cpg.empty + val completed = ArrayBuffer.empty[Int] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]] = + new ForkJoinParallelCpgPassWithAccumulator[String, ArrayBuffer[Int]](cpg, "acc-rerun") { + override protected def newAccumulator(): ArrayBuffer[Int] = ArrayBuffer.empty[Int] + override protected def mergeAccumulators( + left: ArrayBuffer[Int], + right: ArrayBuffer[Int] + ): ArrayBuffer[Int] = { + left ++= right + } + override protected def runOnPartWithAccumulator( + builder: DiffGraphBuilder, + acc: ArrayBuffer[Int], + part: String + ): Unit = acc += part.toInt + override protected def onAccumulatorComplete(acc: ArrayBuffer[Int]): Unit = completed += acc.sum + override def generateParts(): Array[String] = Array("1", "2", "3") + override def isParallel: Boolean = false + } + + pass.createAndApply() + pass.createAndApply() + + completed.toSeq shouldBe Seq(6, 6) + } + + "invoke completion callback once when a part fails" in { + val cpg = Cpg.empty + val events = ArrayBuffer.empty[String] + + val pass: ForkJoinParallelCpgPassWithAccumulator[String, Int] = + new ForkJoinParallelCpgPassWithAccumulator[String, Int](cpg, "acc-fail") { + override protected def newAccumulator(): Int = 0 + override protected def mergeAccumulators(left: Int, right: Int): Int = left + right + override protected def runOnPartWithAccumulator(builder: DiffGraphBuilder, acc: Int, part: String): Unit = { + events += "run" + throw new RuntimeException("boom") + } + override protected def onAccumulatorComplete(acc: Int): Unit = events += s"complete:$acc" + override def generateParts(): Array[String] = Array("p1") + override def isParallel: Boolean = false + override def init(): Unit = { + events += "init" + super.init() + } + override def finish(): Unit = { + events += "finish" + super.finish() + } + } + + intercept[RuntimeException] { + pass.createAndApply() + } + + events.toSeq shouldBe Seq("init", "run", "finish", "complete:0") + } + } + }