diff --git a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java index c26edafb4..7c36368cb 100644 --- a/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java +++ b/geaflow/geaflow-core/geaflow-api/src/main/java/org/apache/geaflow/api/graph/base/algo/AbstractIncVertexCentricComputeAlgo.java @@ -32,6 +32,10 @@ public AbstractIncVertexCentricComputeAlgo(long iterations, String name) { super(iterations, name); } + public String getPythonTransformClassName() { + return null; + } + public abstract FUNC getIncComputeFunction(); } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java index 7de8eca8d..2af30cafb 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/main/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOp.java @@ -20,8 +20,10 @@ package org.apache.geaflow.operator.impl.graph.compute.dynamic; import java.io.IOException; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import org.apache.geaflow.api.function.iterator.RichIteratorFunction; import org.apache.geaflow.api.graph.base.algo.AbstractIncVertexCentricComputeAlgo; @@ -164,7 +166,8 @@ class IncGraphInferComputeContextImpl extends IncGraphComputeContextImpl im public IncGraphInferComputeContextImpl() { if (clientLocal.get() == null) { try { - inferContext = new InferContext<>(runtimeContext.getConfiguration()); + inferContext = new InferContext<>(buildInferConfiguration(runtimeContext.getConfiguration(), + function.getPythonTransformClassName())); } catch (Exception e) { throw new GeaflowRuntimeException(e); } @@ -191,4 +194,16 @@ public void close() throws IOException { } } } + + static Configuration buildInferConfiguration(Configuration baseConfig, String pythonTransformClassName) { + if (pythonTransformClassName == null || pythonTransformClassName.trim().isEmpty()) { + return baseConfig; + } + Map configMap = new HashMap<>(baseConfig.getConfigMap()); + configMap.put(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME.getKey(), + pythonTransformClassName); + Configuration configuration = new Configuration(configMap); + configuration.setMasterId(baseConfig.getMasterId()); + return configuration; + } } diff --git a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java index 3e913f4f6..402669f65 100644 --- a/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java +++ b/geaflow/geaflow-core/geaflow-runtime/geaflow-operator/src/test/java/org/apache/geaflow/operator/impl/graph/compute/dynamic/DynamicGraphVertexCentricComputeOpTest.java @@ -34,6 +34,7 @@ import org.apache.geaflow.collector.ICollector; import org.apache.geaflow.common.config.Configuration; import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; +import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; import org.apache.geaflow.common.task.TaskArgs; import org.apache.geaflow.common.type.primitive.IntegerType; import org.apache.geaflow.common.utils.ReflectionUtil; @@ -105,6 +106,21 @@ public VertexCentricCombineFunction getCombineFunction() { Assert.assertEquals(3L, ((RuntimeContext) ReflectionUtil.getField(operator, "runtimeContext")).getWindowId()); } + @Test + public void testBuildInferConfigurationOverride() { + Configuration config = new Configuration(); + config.put(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME, "GlobalTransform"); + Configuration overridden = DynamicGraphVertexCentricComputeOp.buildInferConfiguration(config, + "AlgoTransform"); + Assert.assertEquals(config.getString(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME), + "GlobalTransform"); + Assert.assertEquals(overridden.getString(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME), + "AlgoTransform"); + + Configuration unchanged = DynamicGraphVertexCentricComputeOp.buildInferConfiguration(config, null); + Assert.assertSame(unchanged, config); + } + public class TestRuntimeContext extends AbstractRuntimeContext { public TestRuntimeContext() { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java index 5a73c8c1b..62e9ffd50 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-common/src/main/java/org/apache/geaflow/dsl/common/algo/AlgorithmRuntimeContext.java @@ -151,6 +151,13 @@ public interface AlgorithmRuntimeContext { */ Configuration getConfig(); + /** + * Invoke model inference when runtime infer support is enabled. + */ + default OUT infer(Object... modelInputs) { + throw new UnsupportedOperationException("Inference is not enabled. Set INFER_ENV_ENABLE=true to enable inference."); + } + /** * Sends a termination vote to the coordinator to signal algorithm completion. * This method allows vertices to vote for algorithm termination when they @@ -160,4 +167,4 @@ public interface AlgorithmRuntimeContext { * @param voteValue The vote value (typically 1 for termination vote) */ void voteToTerminate(String terminationReason, Object voteValue); -} \ No newline at end of file +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java index 47addc84a..5c08ecbeb 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/schema/function/BuildInSqlFunctionTable.java @@ -39,6 +39,7 @@ import org.apache.geaflow.dsl.udf.graph.ClusterCoefficient; import org.apache.geaflow.dsl.udf.graph.CommonNeighbors; import org.apache.geaflow.dsl.udf.graph.ConnectedComponents; +import org.apache.geaflow.dsl.udf.graph.GCN; import org.apache.geaflow.dsl.udf.graph.IncKHopAlgorithm; import org.apache.geaflow.dsl.udf.graph.IncMinimumSpanningTree; import org.apache.geaflow.dsl.udf.graph.IncWeakConnectedComponents; @@ -241,6 +242,7 @@ public class BuildInSqlFunctionTable extends ListSqlOperatorTable { .add(GeaFlowFunction.of(LabelPropagation.class)) .add(GeaFlowFunction.of(ConnectedComponents.class)) .add(GeaFlowFunction.of(Louvain.class)) + .add(GeaFlowFunction.of(GCN.class)) .build(); public BuildInSqlFunctionTable(GQLJavaTypeFactory typeFactory) { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCN.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCN.java new file mode 100644 index 000000000..de593a361 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCN.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import org.apache.geaflow.common.type.Types; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.algo.AlgorithmUserFunction; +import org.apache.geaflow.dsl.common.algo.IncrementalAlgorithmUserFunction; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.function.Description; +import org.apache.geaflow.dsl.common.types.ArrayType; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.StructType; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNConfig; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNFeatureCollector; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNInferPayload; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNInferResult; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNResultParser; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNSubgraphBuilder; +import org.apache.geaflow.model.graph.edge.EdgeDirection; + +@Description(name = "gcn", description = "built-in udga for GCN inference") +public class GCN implements AlgorithmUserFunction, IncrementalAlgorithmUserFunction { + + private AlgorithmRuntimeContext context; + private GCNConfig config; + private GCNSubgraphBuilder subgraphBuilder; + private final GCNFeatureCollector featureCollector = new GCNFeatureCollector(); + private final GCNResultParser resultParser = new GCNResultParser(); + + @Override + public void init(AlgorithmRuntimeContext context, Object[] params) { + this.context = context; + this.config = parseConfig(params); + this.subgraphBuilder = new GCNSubgraphBuilder(config); + } + + @Override + public void process(RowVertex vertex, Optional updatedValues, Iterator messages) { + if (context.getCurrentIterationId() > 1) { + return; + } + GCNInferPayload payload = subgraphBuilder.build(vertex.getId(), new DynamicGraphAdapter(vertex)); + Object rawResult = context.infer(payload); + GCNInferResult result = resultParser.parse(vertex.getId(), rawResult); + context.take(ObjectRow.create(result.toRowValues())); + } + + @Override + public void finish(RowVertex graphVertex, Optional updatedValues) { + } + + @Override + public StructType getOutputType(GraphSchema graphSchema) { + return new StructType( + new TableField("id", graphSchema.getIdType(), false), + new TableField("embedding", new ArrayType(Types.DOUBLE), false), + new TableField("predicted_class", Types.LONG, false), + new TableField("confidence", Types.DOUBLE, false) + ); + } + + private GCNConfig parseConfig(Object[] params) { + if (params.length == 0) { + return new GCNConfig(); + } + if (params.length != 2 && params.length != 3) { + throw new IllegalArgumentException("GCN accepts 0, 2 or 3 parameters"); + } + int numHops = Integer.parseInt(String.valueOf(params[0])); + int numSamplesPerHop = Integer.parseInt(String.valueOf(params[1])); + String className = params.length == 3 ? String.valueOf(params[2]) + : GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS; + return new GCNConfig(numHops, numSamplesPerHop, className); + } + + private class DynamicGraphAdapter implements GCNSubgraphBuilder.GraphAdapter { + + private final GraphSchema graphSchema = context.getGraphSchema(); + private final RowVertex rootVertex; + + DynamicGraphAdapter(RowVertex rootVertex) { + this.rootVertex = rootVertex; + } + + @Override + public List loadNeighbors(Object nodeId) { + RowVertex current = switchVertex(nodeId); + if (current == null) { + restoreVertex(rootVertex.getId()); + return Collections.emptyList(); + } + List edges = context.loadEdges(EdgeDirection.BOTH); + List neighbors = new ArrayList<>(edges.size()); + for (RowEdge edge : edges) { + Object neighborId = nodeId.equals(edge.getSrcId()) ? edge.getTargetId() + : edge.getSrcId(); + neighbors.add(neighborId); + } + restoreVertex(rootVertex.getId()); + return neighbors; + } + + @Override + public double[] loadFeatures(Object nodeId) { + RowVertex vertex = nodeId.equals(rootVertex.getId()) ? rootVertex : switchVertex(nodeId); + try { + return vertex == null ? new double[config.getFeatureDimLimit()] + : featureCollector.collectFromRowVertex(vertex, graphSchema, + config.getFeatureDimLimit()); + } finally { + restoreVertex(rootVertex.getId()); + } + } + + private RowVertex switchVertex(Object nodeId) { + try { + Method setVertexId = context.getClass().getMethod("setVertexId", Object.class); + setVertexId.invoke(context, nodeId); + Method loadVertex = context.getClass().getMethod("loadVertex"); + return (RowVertex) loadVertex.invoke(context); + } catch (NoSuchMethodException e) { + throw new IllegalStateException("GCN requires dynamic runtime context with loadVertex support", e); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException("Failed to switch runtime context vertex", e); + } + } + + private void restoreVertex(Object rootId) { + try { + Method setVertexId = context.getClass().getMethod("setVertexId", Object.class); + setVertexId.invoke(context, rootId); + } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException("Failed to restore runtime context vertex", e); + } + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCNCompute.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCNCompute.java new file mode 100644 index 000000000..0308d7fcf --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/GCNCompute.java @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import org.apache.geaflow.api.graph.compute.IncVertexCentricCompute; +import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction; +import org.apache.geaflow.api.graph.function.vc.VertexCentricCombineFunction; +import org.apache.geaflow.api.graph.function.vc.base.IncGraphInferContext; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNConfig; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNFeatureCollector; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNInferPayload; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNInferResult; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNResultParser; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNSubgraphBuilder; +import org.apache.geaflow.model.graph.edge.IEdge; +import org.apache.geaflow.model.graph.vertex.IVertex; +import org.apache.geaflow.model.graph.vertex.impl.ValueVertex; + +public class GCNCompute extends IncVertexCentricCompute, Object, Object> { + + private final GCNConfig config; + + public GCNCompute() { + this(GCNConfig.DEFAULT_NUM_HOPS, GCNConfig.DEFAULT_NUM_SAMPLES_PER_HOP, + GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS); + } + + public GCNCompute(int numHops, int numSamplesPerHop) { + this(numHops, numSamplesPerHop, GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS); + } + + public GCNCompute(int numHops, int numSamplesPerHop, String pythonTransformClassName) { + super(1L); + this.config = new GCNConfig(numHops, numSamplesPerHop, pythonTransformClassName); + } + + @Override + public String getPythonTransformClassName() { + return config.getPythonTransformClassName(); + } + + @Override + public IncVertexCentricComputeFunction, Object, Object> getIncComputeFunction() { + return new GCNComputeFunction(config); + } + + @Override + public VertexCentricCombineFunction getCombineFunction() { + return null; + } + + static class GCNComputeFunction implements IncVertexCentricComputeFunction, Object, Object> { + + private final GCNConfig config; + private final GCNFeatureCollector featureCollector = new GCNFeatureCollector(); + private final GCNResultParser resultParser = new GCNResultParser(); + private final GCNSubgraphBuilder subgraphBuilder; + private IncGraphComputeContext, Object, Object> graphContext; + private IncGraphInferContext inferContext; + + GCNComputeFunction(GCNConfig config) { + this.config = config; + this.subgraphBuilder = new GCNSubgraphBuilder(config); + } + + @Override + @SuppressWarnings("unchecked") + public void init(IncGraphComputeContext, Object, Object> incGraphContext) { + this.graphContext = incGraphContext; + if (incGraphContext instanceof IncGraphInferContext) { + this.inferContext = (IncGraphInferContext) incGraphContext; + } + } + + @Override + public void evolve(Object vertexId, TemporaryGraph, Object> temporaryGraph) { + inferAndCollect(vertexId, temporaryGraph.getVertex()); + } + + @Override + public void compute(Object vertexId, Iterator messageIterator) { + } + + @Override + public void finish(Object vertexId, MutableGraph, Object> mutableGraph) { + } + + private void inferAndCollect(Object rootId, IVertex> rootVertex) { + if (inferContext == null) { + throw new IllegalStateException("GCNCompute requires infer-enabled runtime context"); + } + GCNInferPayload payload = subgraphBuilder.build(rootId, new ApiGraphAdapter(rootId, rootVertex)); + Object rawResult = inferContext.infer(payload); + GCNInferResult result = resultParser.parse(rootId, rawResult); + List value = new ArrayList<>(); + value.add(result.getEmbedding()); + value.add(result.getPredictedClass()); + value.add(result.getConfidence()); + graphContext.collect(new ValueVertex<>(rootId, value)); + } + + private class ApiGraphAdapter implements GCNSubgraphBuilder.GraphAdapter { + + private final Object rootId; + private final IVertex> rootVertex; + + ApiGraphAdapter(Object rootId, IVertex> rootVertex) { + this.rootId = rootId; + this.rootVertex = rootVertex; + } + + @Override + public List loadNeighbors(Object nodeId) { + Object query = graphContext.getHistoricalGraph() + .getSnapShot(graphContext.getHistoricalGraph().getLatestVersionId()).edges(); + try { + Method withId = query.getClass().getMethod("withId", Object.class); + withId.invoke(query, nodeId); + Method getEdges = query.getClass().getMethod("getEdges"); + List> edges = (List>) getEdges.invoke(query); + List neighbors = new ArrayList<>(edges.size()); + for (IEdge edge : edges) { + Object neighborId = nodeId.equals(edge.getSrcId()) ? edge.getTargetId() + : edge.getSrcId(); + neighbors.add(neighborId); + } + return neighbors; + } catch (NoSuchMethodException e) { + return Collections.emptyList(); + } catch (IllegalAccessException | InvocationTargetException e) { + throw new IllegalStateException("Failed to load neighbors for GCNCompute", e); + } + } + + @Override + public double[] loadFeatures(Object nodeId) { + IVertex> vertex = rootId.equals(nodeId) ? rootVertex + : graphContext.getHistoricalGraph().getSnapShot( + graphContext.getHistoricalGraph().getLatestVersionId()) + .vertex().withId(nodeId).get(); + return vertex == null ? new double[config.getFeatureDimLimit()] + : featureCollector.collectFromValue(vertex.getValue(), config.getFeatureDimLimit()); + } + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNConfig.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNConfig.java new file mode 100644 index 000000000..cb4529b27 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNConfig.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; + +public class GCNConfig implements Serializable { + + public static final int DEFAULT_NUM_HOPS = 2; + public static final int DEFAULT_NUM_SAMPLES_PER_HOP = 25; + public static final boolean DEFAULT_WITH_SELF_LOOP = true; + public static final int DEFAULT_FEATURE_DIM_LIMIT = 64; + public static final long DEFAULT_RANDOM_SEED = 20260322L; + public static final String DEFAULT_PYTHON_TRANSFORM_CLASS = "GCNTransFormFunction"; + + private final int numHops; + private final int numSamplesPerHop; + private final String pythonTransformClassName; + private final boolean withSelfLoop; + private final int featureDimLimit; + private final long randomSeed; + + public GCNConfig() { + this(DEFAULT_NUM_HOPS, DEFAULT_NUM_SAMPLES_PER_HOP, DEFAULT_PYTHON_TRANSFORM_CLASS, + DEFAULT_WITH_SELF_LOOP, DEFAULT_FEATURE_DIM_LIMIT, DEFAULT_RANDOM_SEED); + } + + public GCNConfig(int numHops, int numSamplesPerHop, String pythonTransformClassName) { + this(numHops, numSamplesPerHop, pythonTransformClassName, DEFAULT_WITH_SELF_LOOP, + DEFAULT_FEATURE_DIM_LIMIT, DEFAULT_RANDOM_SEED); + } + + public GCNConfig(int numHops, int numSamplesPerHop, String pythonTransformClassName, + boolean withSelfLoop, int featureDimLimit, long randomSeed) { + if (numHops <= 0) { + throw new IllegalArgumentException("numHops must be positive"); + } + if (numSamplesPerHop <= 0) { + throw new IllegalArgumentException("numSamplesPerHop must be positive"); + } + if (featureDimLimit <= 0) { + throw new IllegalArgumentException("featureDimLimit must be positive"); + } + if (pythonTransformClassName == null || pythonTransformClassName.trim().isEmpty()) { + throw new IllegalArgumentException("pythonTransformClassName cannot be blank"); + } + this.numHops = numHops; + this.numSamplesPerHop = numSamplesPerHop; + this.pythonTransformClassName = pythonTransformClassName; + this.withSelfLoop = withSelfLoop; + this.featureDimLimit = featureDimLimit; + this.randomSeed = randomSeed; + } + + public int getNumHops() { + return numHops; + } + + public int getNumSamplesPerHop() { + return numSamplesPerHop; + } + + public String getPythonTransformClassName() { + return pythonTransformClassName; + } + + public boolean isWithSelfLoop() { + return withSelfLoop; + } + + public int getFeatureDimLimit() { + return featureDimLimit; + } + + public long getRandomSeed() { + return randomSeed; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNFeatureCollector.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNFeatureCollector.java new file mode 100644 index 000000000..ca3a96bf3 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNFeatureCollector.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.lang.reflect.Array; +import java.util.List; +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.common.type.primitive.DoubleType; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.VertexType; + +public class GCNFeatureCollector { + + public double[] collectFromRowVertex(RowVertex vertex, GraphSchema graphSchema, int featureDimLimit) { + VertexType vertexType = graphSchema.getVertices().get(0); + IType[] valueTypes = vertexType.getValueTypes(); + double[] features = new double[featureDimLimit]; + for (int i = 0; i < featureDimLimit && i < valueTypes.length; i++) { + Object value = vertex.getField(vertexType.getValueOffset() + i, valueTypes[i]); + features[i] = toDouble(value); + } + return features; + } + + public double[] collectFromValue(Object value, int featureDimLimit) { + double[] features = new double[featureDimLimit]; + if (value == null) { + return features; + } + if (value instanceof List) { + List list = (List) value; + for (int i = 0; i < featureDimLimit && i < list.size(); i++) { + features[i] = toDouble(list.get(i)); + } + return features; + } + if (value instanceof ObjectRow) { + Object[] fields = ((ObjectRow) value).getFields(); + for (int i = 0; i < featureDimLimit && i < fields.length; i++) { + features[i] = toDouble(fields[i]); + } + return features; + } + if (value instanceof Row) { + for (int i = 0; i < featureDimLimit; i++) { + try { + features[i] = toDouble(((Row) value).getField(i, DoubleType.INSTANCE)); + } catch (Exception e) { + break; + } + } + return features; + } + if (value.getClass().isArray()) { + int length = Array.getLength(value); + for (int i = 0; i < featureDimLimit && i < length; i++) { + features[i] = toDouble(Array.get(value, i)); + } + return features; + } + features[0] = toDouble(value); + return features; + } + + private double toDouble(Object value) { + if (value == null) { + return 0D; + } + if (value instanceof Number) { + return ((Number) value).doubleValue(); + } + return Double.parseDouble(String.valueOf(value)); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNInferPayload.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNInferPayload.java new file mode 100644 index 000000000..57144ecb3 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNInferPayload.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.List; + +public class GCNInferPayload implements Serializable { + + private final Object centerNodeId; + private final List sampledNodes; + private final List nodeFeatures; + private final int[][] edgeIndex; + private final double[] edgeWeight; + + public GCNInferPayload(Object centerNodeId, List sampledNodes, List nodeFeatures, + int[][] edgeIndex, double[] edgeWeight) { + this.centerNodeId = centerNodeId; + this.sampledNodes = sampledNodes; + this.nodeFeatures = nodeFeatures; + this.edgeIndex = edgeIndex; + this.edgeWeight = edgeWeight; + } + + public Object getCenter_node_id() { + return centerNodeId; + } + + public List getSampled_nodes() { + return sampledNodes; + } + + public List getNode_features() { + return nodeFeatures; + } + + public int[][] getEdge_index() { + return edgeIndex; + } + + public double[] getEdge_weight() { + return edgeWeight; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNInferResult.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNInferResult.java new file mode 100644 index 000000000..90502cb6f --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNInferResult.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; + +public class GCNInferResult implements Serializable { + + private final Object nodeId; + private final double[] embedding; + private final long predictedClass; + private final double confidence; + + public GCNInferResult(Object nodeId, double[] embedding, long predictedClass, double confidence) { + this.nodeId = nodeId; + this.embedding = embedding; + this.predictedClass = predictedClass; + this.confidence = confidence; + } + + public Object getNodeId() { + return nodeId; + } + + public double[] getEmbedding() { + return embedding; + } + + public long getPredictedClass() { + return predictedClass; + } + + public double getConfidence() { + return confidence; + } + + public Object[] toRowValues() { + Double[] boxedEmbedding = new Double[embedding.length]; + for (int i = 0; i < embedding.length; i++) { + boxedEmbedding[i] = embedding[i]; + } + return new Object[]{nodeId, boxedEmbedding, predictedClass, confidence}; + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNResultParser.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNResultParser.java new file mode 100644 index 000000000..960ed4959 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNResultParser.java @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.lang.reflect.Array; +import java.util.List; +import java.util.Map; + +public class GCNResultParser { + + public GCNInferResult parse(Object expectedNodeId, Object rawResult) { + if (rawResult == null) { + return new GCNInferResult(expectedNodeId, new double[0], -1L, 0D); + } + if (rawResult instanceof String && ((String) rawResult).startsWith("python_exception:")) { + throw new IllegalStateException((String) rawResult); + } + if (rawResult instanceof GCNInferResult) { + return (GCNInferResult) rawResult; + } + if (rawResult instanceof Map) { + return parseMap(expectedNodeId, (Map) rawResult); + } + if (rawResult instanceof List) { + return parseList(expectedNodeId, (List) rawResult); + } + if (rawResult.getClass().isArray()) { + return parseArray(expectedNodeId, rawResult); + } + throw new IllegalArgumentException("Unsupported GCN infer result type: " + + rawResult.getClass().getName()); + } + + private GCNInferResult parseMap(Object expectedNodeId, Map rawResult) { + Object nodeId = rawResult.containsKey("node_id") ? rawResult.get("node_id") + : (rawResult.containsKey("id") ? rawResult.get("id") : expectedNodeId); + Object embedding = rawResult.get("embedding"); + Object prediction = rawResult.containsKey("prediction") ? rawResult.get("prediction") + : rawResult.get("predicted_class"); + Object confidence = rawResult.get("confidence"); + return new GCNInferResult(nodeId, parseEmbedding(embedding), parseLong(prediction), + parseDouble(confidence)); + } + + private GCNInferResult parseList(Object expectedNodeId, List rawResult) { + if (rawResult.size() >= 4) { + return new GCNInferResult(rawResult.get(0), parseEmbedding(rawResult.get(1)), + parseLong(rawResult.get(2)), parseDouble(rawResult.get(3))); + } + if (rawResult.size() >= 3) { + return new GCNInferResult(expectedNodeId, parseEmbedding(rawResult.get(0)), + parseLong(rawResult.get(1)), parseDouble(rawResult.get(2))); + } + if (rawResult.size() == 1) { + return new GCNInferResult(expectedNodeId, parseEmbedding(rawResult.get(0)), -1L, 0D); + } + throw new IllegalArgumentException("GCN infer result list is empty"); + } + + private GCNInferResult parseArray(Object expectedNodeId, Object rawResult) { + int length = Array.getLength(rawResult); + if (length >= 4) { + return new GCNInferResult(Array.get(rawResult, 0), parseEmbedding(Array.get(rawResult, 1)), + parseLong(Array.get(rawResult, 2)), parseDouble(Array.get(rawResult, 3))); + } + if (length >= 3) { + return new GCNInferResult(expectedNodeId, parseEmbedding(Array.get(rawResult, 0)), + parseLong(Array.get(rawResult, 1)), parseDouble(Array.get(rawResult, 2))); + } + if (length == 1) { + return new GCNInferResult(expectedNodeId, parseEmbedding(Array.get(rawResult, 0)), -1L, 0D); + } + throw new IllegalArgumentException("GCN infer result array is empty"); + } + + private double[] parseEmbedding(Object embedding) { + if (embedding == null) { + return new double[0]; + } + if (embedding instanceof double[]) { + return (double[]) embedding; + } + if (embedding instanceof Double[]) { + Double[] boxed = (Double[]) embedding; + double[] values = new double[boxed.length]; + for (int i = 0; i < boxed.length; i++) { + values[i] = boxed[i] == null ? 0D : boxed[i]; + } + return values; + } + if (embedding instanceof List) { + List list = (List) embedding; + double[] values = new double[list.size()]; + for (int i = 0; i < list.size(); i++) { + values[i] = parseDouble(list.get(i)); + } + return values; + } + if (embedding.getClass().isArray()) { + int len = Array.getLength(embedding); + double[] values = new double[len]; + for (int i = 0; i < len; i++) { + values[i] = parseDouble(Array.get(embedding, i)); + } + return values; + } + return new double[]{parseDouble(embedding)}; + } + + private long parseLong(Object value) { + if (value == null) { + return -1L; + } + if (value instanceof Number) { + return ((Number) value).longValue(); + } + return Long.parseLong(String.valueOf(value)); + } + + private double parseDouble(Object value) { + if (value == null) { + return 0D; + } + if (value instanceof Number) { + return ((Number) value).doubleValue(); + } + return Double.parseDouble(String.valueOf(value)); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNSubgraphBuilder.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNSubgraphBuilder.java new file mode 100644 index 000000000..a3911b685 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNSubgraphBuilder.java @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Random; +import java.util.Set; + +public class GCNSubgraphBuilder { + + private final GCNConfig config; + + public GCNSubgraphBuilder(GCNConfig config) { + this.config = config; + } + + public GCNInferPayload build(Object rootId, GraphAdapter adapter) { + Map nodeFeatures = new LinkedHashMap<>(); + List edges = new ArrayList<>(); + Set currentFrontier = new LinkedHashSet<>(); + currentFrontier.add(rootId); + nodeFeatures.put(rootId, adapter.loadFeatures(rootId)); + Random random = new Random(config.getRandomSeed() ^ rootId.hashCode()); + for (int hop = 0; hop < config.getNumHops(); hop++) { + Set nextFrontier = new LinkedHashSet<>(); + for (Object nodeId : currentFrontier) { + List neighbors = new ArrayList<>(adapter.loadNeighbors(nodeId)); + if (neighbors.isEmpty()) { + continue; + } + Collections.shuffle(neighbors, random); + int sampleSize = Math.min(config.getNumSamplesPerHop(), neighbors.size()); + for (int i = 0; i < sampleSize; i++) { + Object neighborId = neighbors.get(i); + nextFrontier.add(neighborId); + if (!nodeFeatures.containsKey(neighborId)) { + nodeFeatures.put(neighborId, adapter.loadFeatures(neighborId)); + } + edges.add(new int[]{indexOf(nodeFeatures, nodeId), indexOf(nodeFeatures, neighborId)}); + } + } + currentFrontier = nextFrontier; + if (currentFrontier.isEmpty()) { + break; + } + } + if (config.isWithSelfLoop()) { + for (int i = 0; i < nodeFeatures.size(); i++) { + edges.add(new int[]{i, i}); + } + } + List sampledNodes = new ArrayList<>(nodeFeatures.keySet()); + List features = new ArrayList<>(nodeFeatures.values()); + int[][] edgeIndex = new int[2][edges.size()]; + for (int i = 0; i < edges.size(); i++) { + edgeIndex[0][i] = edges.get(i)[0]; + edgeIndex[1][i] = edges.get(i)[1]; + } + return new GCNInferPayload(rootId, sampledNodes, features, edgeIndex, null); + } + + private int indexOf(Map nodeFeatures, Object nodeId) { + int index = 0; + for (Object id : nodeFeatures.keySet()) { + if (id.equals(nodeId)) { + return index; + } + index++; + } + throw new IllegalArgumentException("Node not found in sampled subgraph: " + nodeId); + } + + public interface GraphAdapter extends Serializable { + + List loadNeighbors(Object nodeId); + + double[] loadFeatures(Object nodeId); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/resources/TransFormFunctionUDF.py b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/resources/TransFormFunctionUDF.py new file mode 100644 index 000000000..f03455adb --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/resources/TransFormFunctionUDF.py @@ -0,0 +1,233 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import abc +import os +import torch +import torch.nn as nn +import torch.nn.functional as F + +torch.set_num_threads(1) + + +class TransFormFunction(abc.ABC): + + def __init__(self, input_size): + self.input_size = input_size + + @abc.abstractmethod + def load_model(self, *args): + pass + + @abc.abstractmethod + def transform_pre(self, *args): + pass + + @abc.abstractmethod + def transform_post(self, *args): + pass + + +class EmptyGCNModel(nn.Module): + + def __init__(self, input_dim, hidden_dim, embedding_dim, num_classes): + super(EmptyGCNModel, self).__init__() + self.input_proj = nn.Linear(input_dim, hidden_dim) + self.embedding_proj = nn.Linear(hidden_dim, embedding_dim) + self.classifier = nn.Linear(embedding_dim, num_classes) + + def forward(self, node_features, edge_index): + hidden = self._aggregate(node_features, edge_index) + hidden = F.relu(self.input_proj(hidden)) + embedding = self.embedding_proj(self._aggregate(hidden, edge_index)) + logits = self.classifier(embedding) + return embedding, logits + + def _aggregate(self, features, edge_index): + if edge_index is None or edge_index.numel() == 0: + return features + num_nodes = features.size(0) + aggregated = torch.zeros_like(features) + degree = torch.zeros((num_nodes, 1), dtype=features.dtype, device=features.device) + src_index = edge_index[0] + dst_index = edge_index[1] + for idx in range(src_index.size(0)): + src = int(src_index[idx].item()) + dst = int(dst_index[idx].item()) + if src < 0 or src >= num_nodes or dst < 0 or dst >= num_nodes: + continue + aggregated[dst] = aggregated[dst] + features[src] + degree[dst] = degree[dst] + 1.0 + degree = torch.clamp(degree, min=1.0) + return aggregated / degree + + +class GCNTransFormFunction(TransFormFunction): + + def __init__(self): + super(GCNTransFormFunction, self).__init__(input_size=1) + if torch.cuda.is_available(): + self.device = torch.device("cuda") + else: + self.device = torch.device("cpu") + self.hidden_dim = 32 + self.embedding_dim = 16 + self.num_classes = 2 + self.model = None + self.model_path = os.path.join(os.getcwd(), "model.pt") + self.model_loaded = False + self.load_model(self.model_path) + + def load_model(self, model_path=None): + if model_path is None: + model_path = self.model_path + if not os.path.exists(model_path): + self.model = None + self.model_loaded = False + return + checkpoint = torch.load(model_path, map_location=self.device) + if isinstance(checkpoint, nn.Module): + self.model = checkpoint.to(self.device) + self.model.eval() + self.model_loaded = True + return + if not isinstance(checkpoint, dict): + raise ValueError("Unsupported model checkpoint type: {}".format(type(checkpoint))) + state_dict = checkpoint.get("state_dict", checkpoint) + input_dim = int(checkpoint.get("input_dim", state_dict["input_proj.weight"].shape[1])) + hidden_dim = int(checkpoint.get("hidden_dim", state_dict["input_proj.weight"].shape[0])) + embedding_dim = int(checkpoint.get("embedding_dim", state_dict["embedding_proj.weight"].shape[0])) + num_classes = int(checkpoint.get("num_classes", state_dict["classifier.weight"].shape[0])) + self.model = EmptyGCNModel(input_dim, hidden_dim, embedding_dim, num_classes).to(self.device) + self.model.load_state_dict(state_dict) + self.model.eval() + self.model_loaded = True + + def transform_pre(self, *args): + payload = args[0] + center_node_id = self._read_value(payload, ["center_node_id", "centerNodeId"], + ["getCenter_node_id", "getCenterNodeId"]) + sampled_nodes = self._read_value(payload, ["sampled_nodes", "sampledNodes"], + ["getSampled_nodes", "getSampledNodes"]) + node_features = self._read_value(payload, ["node_features", "nodeFeatures"], + ["getNode_features", "getNodeFeatures"]) + edge_index = self._read_value(payload, ["edge_index", "edgeIndex"], + ["getEdge_index", "getEdgeIndex"]) + + feature_tensor = self._build_feature_tensor(node_features) + edge_tensor = self._build_edge_tensor(edge_index, feature_tensor.size(0)) + self._ensure_model(feature_tensor.size(1)) + + with torch.no_grad(): + embedding_matrix, logits = self.model(feature_tensor, edge_tensor) + + center_index = self._find_center_index(sampled_nodes, center_node_id) + center_embedding = embedding_matrix[center_index] + center_logits = logits[center_index] + probs = F.softmax(center_logits, dim=0) + predicted_class = int(torch.argmax(probs).item()) + confidence = float(torch.max(probs).item()) + + result = { + "node_id": center_node_id, + "embedding": center_embedding.cpu().tolist(), + "predicted_class": predicted_class, + "confidence": confidence + } + return result, center_node_id + + def transform_post(self, *args): + if len(args) == 0: + return None + return args[0] + + def _ensure_model(self, input_dim): + if self.model is not None and self.model.input_proj.in_features == input_dim: + return + torch.manual_seed(20260322) + self.model = EmptyGCNModel(input_dim, self.hidden_dim, self.embedding_dim, + self.num_classes).to(self.device) + self.model.eval() + + def _build_feature_tensor(self, node_features): + if node_features is None or len(node_features) == 0: + return torch.zeros((1, 1), dtype=torch.float32, device=self.device) + rows = [] + for feature_row in node_features: + if feature_row is None: + rows.append([0.0]) + continue + row = [] + for value in feature_row: + row.append(float(0.0 if value is None else value)) + if len(row) == 0: + row.append(0.0) + rows.append(row) + max_dim = 1 + for row in rows: + if len(row) > max_dim: + max_dim = len(row) + normalized_rows = [] + for row in rows: + if len(row) < max_dim: + row = row + [0.0] * (max_dim - len(row)) + normalized_rows.append(row) + return torch.tensor(normalized_rows, dtype=torch.float32, device=self.device) + + def _build_edge_tensor(self, edge_index, num_nodes): + if edge_index is None: + return torch.zeros((2, 0), dtype=torch.long, device=self.device) + if len(edge_index) < 2: + return torch.zeros((2, 0), dtype=torch.long, device=self.device) + src_list = list(edge_index[0]) + dst_list = list(edge_index[1]) + edges = [] + for idx in range(min(len(src_list), len(dst_list))): + src = int(src_list[idx]) + dst = int(dst_list[idx]) + if src < 0 or dst < 0: + continue + if src >= num_nodes or dst >= num_nodes: + continue + edges.append([src, dst]) + if len(edges) == 0: + return torch.zeros((2, 0), dtype=torch.long, device=self.device) + edge_tensor = torch.tensor(edges, dtype=torch.long, device=self.device) + return edge_tensor.t().contiguous() + + def _find_center_index(self, sampled_nodes, center_node_id): + if sampled_nodes is None or len(sampled_nodes) == 0: + return 0 + for idx in range(len(sampled_nodes)): + if sampled_nodes[idx] == center_node_id: + return idx + return 0 + + def _read_value(self, obj, attr_names, method_names): + if isinstance(obj, dict): + for name in attr_names: + if name in obj: + return obj[name] + for name in attr_names: + if hasattr(obj, name): + return getattr(obj, name) + for name in method_names: + if hasattr(obj, name): + method = getattr(obj, name) + if callable(method): + return method() + return None diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/resources/requirements.txt b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/resources/requirements.txt new file mode 100644 index 000000000..8905293c9 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/main/resources/requirements.txt @@ -0,0 +1,18 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +torch diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java index a1de8505a..884b59767 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/GQLValidateGraphAlgorithmTest.java @@ -40,5 +40,33 @@ public void testGraphAlgorithm() { .validate() .expectValidateType( "RecordType(BIGINT vid, BIGINT distance)"); + + String script3 = "CALL GCN() YIELD (id, embedding, predicted_class, confidence)\n" + + "RETURN id, embedding, predicted_class, confidence"; + + PlanTester.build() + .gql(script3) + .validate() + .expectValidateType( + "RecordType(BIGINT id, DOUBLE ARRAY embedding, BIGINT predicted_class, DOUBLE confidence)"); + + String script4 = "CALL GCN(2, 25) YIELD (id, embedding, predicted_class, confidence)\n" + + "RETURN id, embedding, predicted_class, confidence"; + + PlanTester.build() + .gql(script4) + .validate() + .expectValidateType( + "RecordType(BIGINT id, DOUBLE ARRAY embedding, BIGINT predicted_class, DOUBLE confidence)"); + + String script5 = "CALL GCN(2, 25, 'GCNTransFormFunction') " + + "YIELD (id, embedding, predicted_class, confidence)\n" + + "RETURN id, embedding, predicted_class, confidence"; + + PlanTester.build() + .gql(script5) + .validate() + .expectValidateType( + "RecordType(BIGINT id, DOUBLE ARRAY embedding, BIGINT predicted_class, DOUBLE confidence)"); } } diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNAlgorithmTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNAlgorithmTest.java new file mode 100644 index 000000000..50340f89c --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/GCNAlgorithmTest.java @@ -0,0 +1,882 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import org.apache.geaflow.api.context.RuntimeContext; +import org.apache.geaflow.api.graph.function.vc.IncVertexCentricComputeFunction.IncGraphComputeContext; +import org.apache.geaflow.api.graph.function.vc.base.IncGraphInferContext; +import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.GraphSnapShot; +import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.HistoricalGraph; +import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.MutableGraph; +import org.apache.geaflow.api.graph.function.vc.base.IncVertexCentricFunction.TemporaryGraph; +import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.EdgeQuery; +import org.apache.geaflow.api.graph.function.vc.base.VertexCentricFunction.VertexQuery; +import org.apache.geaflow.common.binary.BinaryString; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.iterator.CloseableIterator; +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.common.type.Types; +import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowEdge; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.common.types.VertexType; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNConfig; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNInferPayload; +import org.apache.geaflow.model.graph.edge.EdgeDirection; +import org.apache.geaflow.model.graph.edge.IEdge; +import org.apache.geaflow.model.graph.vertex.IVertex; +import org.apache.geaflow.model.graph.vertex.impl.ValueVertex; +import org.apache.geaflow.state.pushdown.filter.IFilter; +import org.apache.geaflow.state.pushdown.filter.IVertexFilter; +import org.testng.annotations.Test; + +public class GCNAlgorithmTest { + + @Test + public void testGCNProcessCollectsInferenceResult() { + TestAlgorithmRuntimeContext context = new TestAlgorithmRuntimeContext(); + context.addVertex(1L, new TestRowVertex(1L, ObjectRow.create(1, 2D))); + context.addVertex(2L, new TestRowVertex(2L, ObjectRow.create(3, 4D))); + context.addVertex(3L, new TestRowVertex(3L, ObjectRow.create(5, 6D))); + context.addEdge(1L, new TestRowEdge(1L, 2L)); + context.addEdge(1L, new TestRowEdge(1L, 3L)); + Map inferResult = new HashMap<>(); + inferResult.put("node_id", 1L); + inferResult.put("embedding", Arrays.asList(0.1D, 0.2D)); + inferResult.put("prediction", 7L); + inferResult.put("confidence", 0.8D); + context.inferResult = inferResult; + + GCN gcn = new GCN(); + gcn.init(context, new Object[0]); + gcn.process(context.vertices.get(1L), java.util.Optional.empty(), Collections.emptyIterator()); + + ObjectRow row = (ObjectRow) context.takenRows.get(0); + assertEquals(row.getFields()[0], 1L); + assertEquals(row.getFields()[2], 7L); + assertEquals(row.getFields()[3], 0.8D); + assertNotNull(context.capturedPayload); + assertEquals(context.capturedPayload.getCenter_node_id(), 1L); + assertTrue(context.capturedPayload.getSampled_nodes().containsAll(Arrays.asList(1L, 2L, 3L))); + } + + @Test + public void testGCNProcessSkipsAfterFirstIteration() { + TestAlgorithmRuntimeContext context = new TestAlgorithmRuntimeContext(); + context.iterationId = 2L; + context.addVertex(1L, new TestRowVertex(1L, ObjectRow.create(1, 2D))); + + GCN gcn = new GCN(); + gcn.init(context, new Object[0]); + gcn.process(context.vertices.get(1L), java.util.Optional.empty(), Collections.emptyIterator()); + + assertTrue(context.takenRows.isEmpty()); + assertEquals(context.inferCallCount, 0); + } + + @Test(expectedExceptions = IllegalStateException.class) + public void testGCNProcessRequiresDynamicContextMethods() { + GCN gcn = new GCN(); + gcn.init(new MissingDynamicMethodsContext(buildGraphSchema()), new Object[0]); + gcn.process(new TestRowVertex(1L, ObjectRow.create(1, 2D)), + java.util.Optional.empty(), Collections.emptyIterator()); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testGCNInitRejectsInvalidParamCount() { + new GCN().init(new TestAlgorithmRuntimeContext(), new Object[]{1}); + } + + @Test + public void testGCNComputeFunctionCollectsVertexResult() { + GCNCompute.GCNComputeFunction function = new GCNCompute.GCNComputeFunction( + new GCNConfig(1, 2, GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS)); + TestIncGraphContext context = new TestIncGraphContext(); + context.addVertex(1L, new ValueVertex>(1L, Arrays.asList(1D, 2D))); + context.addVertex(2L, new ValueVertex>(2L, Arrays.asList(3D, 4D))); + context.addVertex(3L, new ValueVertex>(3L, Arrays.asList(5D, 6D))); + context.addEdges(1L, Arrays.>asList(new TestEdge(1L, 2L), new TestEdge(1L, 3L))); + Map inferResult = new HashMap<>(); + inferResult.put("embedding", Arrays.asList(0.3D, 0.4D)); + inferResult.put("predicted_class", 9L); + inferResult.put("confidence", 0.95D); + context.inferResult = inferResult; + + function.init(context); + function.evolve(1L, new TestTemporaryGraph(context.vertices.get(1L))); + + ValueVertex> collected = (ValueVertex>) context.collectedVertex; + assertEquals(collected.getId(), 1L); + assertEquals(((double[]) collected.getValue().get(0)), new double[]{0.3D, 0.4D}); + assertEquals(collected.getValue().get(1), 9L); + assertEquals(collected.getValue().get(2), 0.95D); + } + + @Test(expectedExceptions = IllegalStateException.class) + public void testGCNComputeFunctionRequiresInferContext() { + GCNCompute.GCNComputeFunction function = new GCNCompute.GCNComputeFunction( + new GCNConfig(1, 1, GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS)); + NonInferIncGraphContext context = new NonInferIncGraphContext(); + context.addVertex(1L, new ValueVertex>(1L, Arrays.asList(1D, 2D))); + + function.init(context); + function.evolve(1L, new TestTemporaryGraph(context.getVertex(1L))); + } + + private static GraphSchema buildGraphSchema() { + return new GraphSchema("g", Collections.singletonList( + new TableField("person", new VertexType(Arrays.asList( + new TableField("~id", Types.LONG, false), + new TableField("~label", Types.STRING, false), + new TableField("f0", Types.INTEGER, true), + new TableField("f1", Types.DOUBLE, true) + )), false) + )); + } + + private static class TestAlgorithmRuntimeContext implements AlgorithmRuntimeContext { + + private final Map vertices = new HashMap<>(); + private final Map> edges = new HashMap<>(); + private final List takenRows = new ArrayList<>(); + private final GraphSchema graphSchema = buildGraphSchema(); + private Object currentVertexId; + private long iterationId = 1L; + private int inferCallCount; + private Object inferResult; + private GCNInferPayload capturedPayload; + + void addVertex(Object id, TestRowVertex vertex) { + vertices.put(id, vertex); + currentVertexId = id; + } + + void addEdge(Object nodeId, RowEdge edge) { + List rowEdges = edges.get(nodeId); + if (rowEdges == null) { + rowEdges = new ArrayList<>(); + edges.put(nodeId, rowEdges); + } + rowEdges.add(edge); + } + + public void setVertexId(Object vertexId) { + this.currentVertexId = vertexId; + } + + public RowVertex loadVertex() { + return vertices.get(currentVertexId); + } + + @Override + public List loadEdges(EdgeDirection direction) { + return edges.getOrDefault(currentVertexId, Collections.emptyList()); + } + + @Override + public void take(Row value) { + takenRows.add(value); + } + + public Object infer(Object... modelInputs) { + inferCallCount++; + capturedPayload = (GCNInferPayload) modelInputs[0]; + return inferResult; + } + + @Override + public long getCurrentIterationId() { + return iterationId; + } + + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } + + @Override + public Configuration getConfig() { + return new Configuration(); + } + + @Override + public CloseableIterator loadEdgesIterator(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadEdgesIterator(IFilter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public List loadStaticEdges(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(IFilter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public List loadDynamicEdges(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public void sendMessage(Object vertexId, Object message) { + throw new UnsupportedOperationException(); + } + + @Override + public void updateVertexValue(Row value) { + throw new UnsupportedOperationException(); + } + + @Override + public void voteToTerminate(String terminationReason, Object voteValue) { + throw new UnsupportedOperationException(); + } + } + + private static final class MissingDynamicMethodsContext implements AlgorithmRuntimeContext { + + private final GraphSchema graphSchema; + + private MissingDynamicMethodsContext(GraphSchema graphSchema) { + this.graphSchema = graphSchema; + } + + @Override + public List loadEdges(EdgeDirection direction) { + return Collections.emptyList(); + } + + @Override + public CloseableIterator loadEdgesIterator(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadEdgesIterator(IFilter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public List loadStaticEdges(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadStaticEdgesIterator(IFilter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public List loadDynamicEdges(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(EdgeDirection direction) { + throw new UnsupportedOperationException(); + } + + @Override + public CloseableIterator loadDynamicEdgesIterator(IFilter filter) { + throw new UnsupportedOperationException(); + } + + @Override + public void sendMessage(Object vertexId, Object message) { + throw new UnsupportedOperationException(); + } + + @Override + public void updateVertexValue(Row value) { + throw new UnsupportedOperationException(); + } + + @Override + public void take(Row value) { + throw new UnsupportedOperationException(); + } + + @Override + public long getCurrentIterationId() { + return 1L; + } + + @Override + public GraphSchema getGraphSchema() { + return graphSchema; + } + + @Override + public Configuration getConfig() { + return new Configuration(); + } + + @Override + public void voteToTerminate(String terminationReason, Object voteValue) { + throw new UnsupportedOperationException(); + } + } + + private static class TestIncGraphContext implements IncGraphComputeContext, Object, Object>, + IncGraphInferContext { + + private final Map>> vertices = new HashMap<>(); + private final Map>> edges = new HashMap<>(); + private IVertex> collectedVertex; + private Object inferResult; + + void addVertex(Object id, IVertex> vertex) { + vertices.put(id, vertex); + } + + void addEdges(Object id, List> vertexEdges) { + edges.put(id, vertexEdges); + } + + IVertex> getVertex(Object id) { + return vertices.get(id); + } + + @Override + public void collect(IVertex> vertex) { + this.collectedVertex = vertex; + } + + @Override + @SuppressWarnings("unchecked") + public Object infer(Object... modelInputs) { + return inferResult; + } + + @Override + public HistoricalGraph, Object> getHistoricalGraph() { + return new HistoricalGraph, Object>() { + @Override + public Long getLatestVersionId() { + return 1L; + } + + @Override + public List getAllVersionIds() { + return Collections.singletonList(1L); + } + + @Override + public Map>> getAllVertex() { + throw new UnsupportedOperationException(); + } + + @Override + public Map>> getAllVertex(List versions) { + throw new UnsupportedOperationException(); + } + + @Override + public Map>> getAllVertex(List versions, + IVertexFilter> vertexFilter) { + throw new UnsupportedOperationException(); + } + + @Override + public GraphSnapShot, Object> getSnapShot(long version) { + return new GraphSnapShot, Object>() { + @Override + public long getVersion() { + return version; + } + + @Override + public VertexQuery> vertex() { + return new TestVertexQuery(vertices); + } + + @Override + public EdgeQuery edges() { + return new TestEdgeQuery(edges); + } + }; + } + }; + } + + @Override + public long getJobId() { + return 1L; + } + + @Override + public long getIterationId() { + return 1L; + } + + @Override + public RuntimeContext getRuntimeContext() { + return null; + } + + @Override + public MutableGraph, Object> getMutableGraph() { + return null; + } + + @Override + public TemporaryGraph, Object> getTemporaryGraph() { + return null; + } + + @Override + public void sendMessage(Object vertexId, Object message) { + } + + @Override + public void sendMessageToNeighbors(Object message) { + } + + @Override + public void close() throws IOException { + } + } + + private static final class NonInferIncGraphContext implements IncGraphComputeContext, Object, Object> { + + private final Map>> vertices = new HashMap<>(); + + void addVertex(Object id, IVertex> vertex) { + vertices.put(id, vertex); + } + + IVertex> getVertex(Object id) { + return vertices.get(id); + } + + @Override + public void collect(IVertex> vertex) { + } + + @Override + public long getJobId() { + return 1L; + } + + @Override + public long getIterationId() { + return 1L; + } + + @Override + public RuntimeContext getRuntimeContext() { + return null; + } + + @Override + public MutableGraph, Object> getMutableGraph() { + return null; + } + + @Override + public TemporaryGraph, Object> getTemporaryGraph() { + return null; + } + + @Override + public HistoricalGraph, Object> getHistoricalGraph() { + return new TestIncGraphContext().getHistoricalGraph(); + } + + @Override + public void sendMessage(Object vertexId, Object message) { + } + + @Override + public void sendMessageToNeighbors(Object message) { + } + } + + private static final class TestTemporaryGraph implements TemporaryGraph, Object> { + + private final IVertex> vertex; + + private TestTemporaryGraph(IVertex> vertex) { + this.vertex = vertex; + } + + @Override + public IVertex> getVertex() { + return vertex; + } + + @Override + public List> getEdges() { + return Collections.emptyList(); + } + + @Override + public void updateVertexValue(List value) { + vertex.withValue(value); + } + } + + private static final class TestVertexQuery implements VertexQuery> { + + private final Map>> vertices; + private Object currentId; + + private TestVertexQuery(Map>> vertices) { + this.vertices = vertices; + } + + @Override + public VertexQuery> withId(Object vertexId) { + this.currentId = vertexId; + return this; + } + + @Override + public IVertex> get() { + return vertices.get(currentId); + } + + @Override + public IVertex> get(IFilter vertexFilter) { + return get(); + } + } + + private static final class TestEdgeQuery implements EdgeQuery { + + private final Map>> edges; + private Object currentId; + + private TestEdgeQuery(Map>> edges) { + this.edges = edges; + } + + public TestEdgeQuery withId(Object vertexId) { + this.currentId = vertexId; + return this; + } + + @Override + public List> getEdges() { + return edges.getOrDefault(currentId, Collections.emptyList()); + } + + @Override + public List> getOutEdges() { + return getEdges(); + } + + @Override + public List> getInEdges() { + return getEdges(); + } + + @Override + public CloseableIterator> getEdges(IFilter edgeFilter) { + throw new UnsupportedOperationException(); + } + } + + private static class TestRowVertex implements RowVertex { + + private Object id; + private Row value; + private String label = "person"; + private BinaryString binaryLabel; + + private TestRowVertex(Object id, Row value) { + this.id = id; + this.value = value; + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public Object getId() { + return id; + } + + @Override + public void setId(Object id) { + this.id = id; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public IVertex withValue(Row value) { + this.value = value; + return this; + } + + @Override + public IVertex withLabel(String label) { + this.label = label; + return this; + } + + @Override + public IVertex withTime(long time) { + return this; + } + + @Override + public Object getField(int i, IType type) { + if (i == 0) { + return id; + } + if (i == 1) { + return label; + } + return value.getField(i - 2, type); + } + + @Override + public String getLabel() { + return label; + } + + @Override + public void setLabel(String label) { + this.label = label; + } + + @Override + public BinaryString getBinaryLabel() { + return binaryLabel; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.binaryLabel = label; + } + + @Override + public int compareTo(Object o) { + return String.valueOf(id).compareTo(String.valueOf(((IVertex) o).getId())); + } + } + + private static final class TestRowEdge implements RowEdge { + + private Object srcId; + private Object targetId; + private Row value = Row.EMPTY; + private String label = "knows"; + private BinaryString binaryLabel; + private EdgeDirection direction = EdgeDirection.OUT; + + private TestRowEdge(Object srcId, Object targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public RowEdge withDirection(EdgeDirection direction) { + this.direction = direction; + return this; + } + + @Override + public RowEdge identityReverse() { + return new TestRowEdge(targetId, srcId).withDirection(direction.reverse()); + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = targetId; + } + + @Override + public EdgeDirection getDirect() { + return direction; + } + + @Override + public void setDirect(EdgeDirection direction) { + this.direction = direction; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public IEdge withValue(Row value) { + this.value = value; + return this; + } + + @Override + public IEdge reverse() { + return new TestRowEdge(targetId, srcId); + } + + @Override + public Object getField(int i, IType type) { + return value.getField(i, type); + } + + @Override + public String getLabel() { + return label; + } + + @Override + public void setLabel(String label) { + this.label = label; + } + + @Override + public BinaryString getBinaryLabel() { + return binaryLabel; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.binaryLabel = label; + } + } + + private static final class TestEdge implements IEdge { + + private Object srcId; + private Object targetId; + + private TestEdge(Object srcId, Object targetId) { + this.srcId = srcId; + this.targetId = targetId; + } + + @Override + public Object getSrcId() { + return srcId; + } + + @Override + public void setSrcId(Object srcId) { + this.srcId = srcId; + } + + @Override + public Object getTargetId() { + return targetId; + } + + @Override + public void setTargetId(Object targetId) { + this.targetId = targetId; + } + + @Override + public EdgeDirection getDirect() { + return EdgeDirection.OUT; + } + + @Override + public void setDirect(EdgeDirection direction) { + } + + @Override + public Object getValue() { + return null; + } + + @Override + public IEdge withValue(Object value) { + return this; + } + + @Override + public IEdge reverse() { + return new TestEdge(targetId, srcId); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNComponentsTest.java b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNComponentsTest.java new file mode 100644 index 000000000..8e2b3bbf6 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-plan/src/test/java/org/apache/geaflow/dsl/udf/graph/gcn/GCNComponentsTest.java @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.udf.graph.gcn; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.apache.geaflow.common.binary.BinaryString; +import org.apache.geaflow.common.type.IType; +import org.apache.geaflow.common.type.Types; +import org.apache.geaflow.dsl.common.data.Row; +import org.apache.geaflow.dsl.common.data.RowVertex; +import org.apache.geaflow.dsl.common.data.impl.ObjectRow; +import org.apache.geaflow.dsl.common.types.GraphSchema; +import org.apache.geaflow.dsl.common.types.TableField; +import org.apache.geaflow.dsl.common.types.VertexType; +import org.testng.annotations.Test; + +public class GCNComponentsTest { + + @Test + public void testResultParserMap() { + GCNResultParser parser = new GCNResultParser(); + Map rawResult = new HashMap<>(); + rawResult.put("node_id", 1L); + rawResult.put("embedding", Arrays.asList(0.1D, 0.2D)); + rawResult.put("prediction", 2); + rawResult.put("confidence", 0.9D); + GCNInferResult result = parser.parse(1L, rawResult); + assertEquals(result.getNodeId(), 1L); + assertEquals(result.getPredictedClass(), 2L); + assertEquals(result.getConfidence(), 0.9D); + assertEquals(result.getEmbedding(), new double[]{0.1D, 0.2D}); + } + + @Test + public void testResultParserListArrayNullAndPythonException() { + GCNResultParser parser = new GCNResultParser(); + + GCNInferResult listResult = parser.parse(1L, Arrays.asList(Arrays.asList(1.0D, 2.0D), 3, 0.5D)); + assertEquals(listResult.getNodeId(), 1L); + assertEquals(listResult.getEmbedding(), new double[]{1.0D, 2.0D}); + assertEquals(listResult.getPredictedClass(), 3L); + assertEquals(listResult.getConfidence(), 0.5D); + + Object[] arrayResult = new Object[]{2L, new Double[]{3.0D, null, 4.0D}, 5L, 0.7D}; + GCNInferResult parsedArray = parser.parse(1L, arrayResult); + assertEquals(parsedArray.getNodeId(), 2L); + assertEquals(parsedArray.getEmbedding(), new double[]{3.0D, 0D, 4.0D}); + assertEquals(parsedArray.getPredictedClass(), 5L); + assertEquals(parsedArray.getConfidence(), 0.7D); + + GCNInferResult nullResult = parser.parse(7L, null); + assertEquals(nullResult.getNodeId(), 7L); + assertEquals(nullResult.getEmbedding(), new double[0]); + assertEquals(nullResult.getPredictedClass(), -1L); + assertEquals(nullResult.getConfidence(), 0D); + } + + @Test(expectedExceptions = IllegalStateException.class) + public void testResultParserThrowsOnPythonExceptionMarker() { + new GCNResultParser().parse(1L, "python_exception:test"); + } + + @Test + public void testFeatureCollectorSupportsMultipleInputShapes() { + GCNFeatureCollector collector = new GCNFeatureCollector(); + + assertEquals(collector.collectFromValue(Arrays.asList(1, 2, 3), 5), + new double[]{1D, 2D, 3D, 0D, 0D}); + assertEquals(collector.collectFromValue(new Object[]{1, "2.5"}, 4), + new double[]{1D, 2.5D, 0D, 0D}); + assertEquals(collector.collectFromValue(ObjectRow.create(4, 5.5D), 4), + new double[]{4D, 5.5D, 0D, 0D}); + assertEquals(collector.collectFromValue(6, 3), + new double[]{6D, 0D, 0D}); + assertEquals(collector.collectFromValue(null, 2), + new double[]{0D, 0D}); + } + + @Test + public void testFeatureCollectorFromRowVertexHonorsSchemaAndPads() { + GraphSchema graphSchema = new GraphSchema("g", Collections.singletonList( + new TableField("person", new VertexType(Arrays.asList( + new TableField("~id", Types.LONG, false), + new TableField("~label", Types.STRING, false), + new TableField("f0", Types.INTEGER, true), + new TableField("f1", Types.DOUBLE, true) + )), false) + )); + TestRowVertex vertex = new TestRowVertex(1L, ObjectRow.create(10, 20.5D)); + + assertEquals(new GCNFeatureCollector().collectFromRowVertex(vertex, graphSchema, 4), + new double[]{10D, 20.5D, 0D, 0D}); + } + + @Test + public void testSubgraphBuilderAddsSelfLoopAndSamplesDeterministically() { + GCNConfig config = new GCNConfig(2, 2, GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS); + GCNSubgraphBuilder builder = new GCNSubgraphBuilder(config); + GCNInferPayload payload = builder.build(1L, new GCNSubgraphBuilder.GraphAdapter() { + @Override + public List loadNeighbors(Object nodeId) { + if (nodeId.equals(1L)) { + return Arrays.asList(2L, 3L); + } + if (nodeId.equals(2L)) { + return Collections.singletonList(4L); + } + return Collections.emptyList(); + } + + @Override + public double[] loadFeatures(Object nodeId) { + return new double[]{((Number) nodeId).doubleValue()}; + } + }); + assertEquals(payload.getCenter_node_id(), 1L); + assertEquals(payload.getSampled_nodes().size(), 4); + assertTrue(payload.getSampled_nodes().containsAll(Arrays.asList(1L, 2L, 3L, 4L))); + assertEquals(payload.getNode_features().size(), 4); + assertEquals(payload.getEdge_index()[0].length, 7); + assertNotNull(payload.getEdge_index()); + } + + @Test + public void testSubgraphBuilderRespectsSampleLimitAndCanDisableSelfLoop() { + GCNConfig config = new GCNConfig(1, 1, GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS, + false, 4, 123L); + GCNSubgraphBuilder builder = new GCNSubgraphBuilder(config); + GCNInferPayload payload = builder.build(1L, new GCNSubgraphBuilder.GraphAdapter() { + @Override + public List loadNeighbors(Object nodeId) { + return Arrays.asList(2L, 3L, 4L); + } + + @Override + public double[] loadFeatures(Object nodeId) { + return new double[]{((Number) nodeId).doubleValue()}; + } + }); + + assertEquals(payload.getSampled_nodes().size(), 2); + assertEquals(payload.getEdge_index()[0].length, 1); + assertTrue(payload.getSampled_nodes().contains(1L)); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testConfigRejectsBlankTransformClass() { + new GCNConfig(1, 1, " "); + } + + @Test(expectedExceptions = IllegalArgumentException.class) + public void testConfigRejectsNonPositiveHopCount() { + new GCNConfig(0, 1, GCNConfig.DEFAULT_PYTHON_TRANSFORM_CLASS); + } + + private static class TestRowVertex implements RowVertex { + + private final Object id; + private Row value; + private String label = "person"; + private BinaryString binaryLabel; + + private TestRowVertex(Object id, Row value) { + this.id = id; + this.value = value; + } + + @Override + public void setValue(Row value) { + this.value = value; + } + + @Override + public Object getId() { + return id; + } + + @Override + public Row getValue() { + return value; + } + + @Override + public void setId(Object id) { + } + + @Override + public Object getField(int i, IType type) { + if (i == 0) { + return id; + } + if (i == 1) { + return label; + } + return value.getField(i - 2, type); + } + + @Override + public org.apache.geaflow.model.graph.vertex.IVertex withValue(Row value) { + this.value = value; + return this; + } + + @Override + public org.apache.geaflow.model.graph.vertex.IVertex withLabel(String label) { + this.label = label; + return this; + } + + @Override + public org.apache.geaflow.model.graph.vertex.IVertex withTime(long time) { + return this; + } + + @Override + public String getLabel() { + return label; + } + + @Override + public void setLabel(String label) { + this.label = label; + } + + @Override + public BinaryString getBinaryLabel() { + return binaryLabel; + } + + @Override + public void setBinaryLabel(BinaryString label) { + this.binaryLabel = label; + } + + @Override + public int compareTo(Object o) { + return String.valueOf(id).compareTo(String.valueOf(((RowVertex) o).getId())); + } + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml b/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml index e9863eb2e..8cdbb9932 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/pom.xml @@ -111,6 +111,11 @@ geaflow-pipeline + + org.apache.geaflow + geaflow-infer + + org.testng testng @@ -217,4 +222,4 @@ - \ No newline at end of file + diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java index daca980db..fe4c52c99 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmAggTraversalFunction.java @@ -33,6 +33,7 @@ import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; +import org.apache.geaflow.infer.InferContext; import org.apache.geaflow.model.traversal.ITraversalRequest; import org.apache.geaflow.state.KeyValueState; import org.apache.geaflow.state.StateFactory; @@ -73,7 +74,13 @@ public GeaFlowAlgorithmAggTraversalFunction(GraphSchema graphSchema, public void open( VertexCentricTraversalFuncContext vertexCentricFuncContext) { this.traversalContext = vertexCentricFuncContext; - this.algorithmCtx = new GeaFlowAlgorithmRuntimeContext(this, traversalContext, graphSchema); + InferContext inferContext = null; + if (traversalContext.getRuntimeContext().getConfiguration().getBoolean( + org.apache.geaflow.common.config.keys.FrameworkConfigKeys.INFER_ENV_ENABLE)) { + inferContext = new InferContext<>(traversalContext.getRuntimeContext().getConfiguration()); + } + this.algorithmCtx = new GeaFlowAlgorithmRuntimeContext(this, traversalContext, graphSchema, + inferContext); this.userFunction.init(algorithmCtx, params); this.invokeVIds = new HashSet<>(); String stateName = traversalContext.getTraversalOpName() + "_" + STATE_SUFFIX; diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java index 98c475b15..198c38589 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicAggTraversalFunction.java @@ -37,6 +37,7 @@ import org.apache.geaflow.dsl.common.data.RowVertex; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; +import org.apache.geaflow.infer.InferContext; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; import org.apache.geaflow.model.traversal.ITraversalRequest; @@ -90,8 +91,13 @@ public void open( IncVertexCentricTraversalFuncContext vertexCentricFuncContext) { this.traversalContext = vertexCentricFuncContext; this.materializeInFinish = traversalContext.getRuntimeContext().getConfiguration().getBoolean(FrameworkConfigKeys.UDF_MATERIALIZE_GRAPH_IN_FINISH); + InferContext inferContext = null; + if (traversalContext.getRuntimeContext().getConfiguration().getBoolean( + FrameworkConfigKeys.INFER_ENV_ENABLE)) { + inferContext = new InferContext<>(traversalContext.getRuntimeContext().getConfiguration()); + } this.algorithmCtx = new GeaFlowAlgorithmDynamicRuntimeContext(this, traversalContext, - graphSchema); + graphSchema, inferContext); this.initVertices = new HashSet<>(); this.userFunction.init(algorithmCtx, params); this.mutableGraph = traversalContext.getMutableGraph(); diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java index d929ae441..4947d3ce6 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmDynamicRuntimeContext.java @@ -21,13 +21,13 @@ import java.util.ArrayList; import java.util.List; -import java.util.Objects; import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction.VertexCentricAggContext; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.IncVertexCentricTraversalFuncContext; import org.apache.geaflow.api.graph.function.vc.IncVertexCentricTraversalFunction.TraversalGraphSnapShot; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.TraversalEdgeQuery; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.TraversalVertexQuery; import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.exception.GeaflowRuntimeException; import org.apache.geaflow.common.iterator.CloseableIterator; import org.apache.geaflow.dsl.common.algo.AlgorithmRuntimeContext; import org.apache.geaflow.dsl.common.data.Row; @@ -35,6 +35,7 @@ import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; +import org.apache.geaflow.infer.InferContext; import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.graph.edge.IEdge; import org.apache.geaflow.model.graph.vertex.IVertex; @@ -59,6 +60,7 @@ public class GeaFlowAlgorithmDynamicRuntimeContext implements AlgorithmRuntimeCo protected TraversalEdgeQuery edgeQuery; private final transient GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction; + private final InferContext inferContext; private Object vertexId; @@ -66,12 +68,20 @@ public class GeaFlowAlgorithmDynamicRuntimeContext implements AlgorithmRuntimeCo public GeaFlowAlgorithmDynamicRuntimeContext(GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction, IncVertexCentricTraversalFuncContext traversalContext, GraphSchema graphSchema) { + this(traversalFunction, traversalContext, graphSchema, null); + } + + public GeaFlowAlgorithmDynamicRuntimeContext(GeaFlowAlgorithmDynamicAggTraversalFunction traversalFunction, + IncVertexCentricTraversalFuncContext traversalContext, + GraphSchema graphSchema, + InferContext inferContext) { this.traversalFunction = traversalFunction; this.incVCTraversalCtx = traversalContext; this.graphSchema = graphSchema; TraversalGraphSnapShot graphSnapShot = incVCTraversalCtx.getHistoricalGraph().getSnapShot(0L); this.vertexQuery = graphSnapShot.vertex(); this.edgeQuery = graphSnapShot.edges(); + this.inferContext = inferContext; } public void setVertexId(Object vertexId) { @@ -248,7 +258,9 @@ public void finish() { } public void close() { - + if (inferContext != null) { + inferContext.close(); + } } @Override @@ -256,12 +268,28 @@ public GraphSchema getGraphSchema() { return graphSchema; } + @Override + @SuppressWarnings("unchecked") + public OUT infer(Object... modelInputs) { + if (inferContext == null) { + return AlgorithmRuntimeContext.super.infer(modelInputs); + } + try { + return (OUT) inferContext.infer(modelInputs); + } catch (Exception e) { + throw new GeaflowRuntimeException("model infer failed", e); + } + } + public VertexCentricAggContext getAggContext() { return aggContext; } public void setAggContext(VertexCentricAggContext aggContext) { - this.aggContext = Objects.requireNonNull(aggContext); + if (aggContext == null) { + throw new NullPointerException("aggContext"); + } + this.aggContext = aggContext; } public IncVertexCentricTraversalFuncContext getIncVCTraversalCtx() { diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java index 7696b4f10..b46e7296c 100644 --- a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/main/java/org/apache/geaflow/dsl/runtime/engine/GeaFlowAlgorithmRuntimeContext.java @@ -21,7 +21,6 @@ import java.util.ArrayList; import java.util.List; -import java.util.Objects; import org.apache.geaflow.api.graph.function.aggregate.VertexCentricAggContextFunction.VertexCentricAggContext; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.TraversalEdgeQuery; import org.apache.geaflow.api.graph.function.vc.VertexCentricTraversalFunction.VertexCentricTraversalFuncContext; @@ -34,6 +33,7 @@ import org.apache.geaflow.dsl.common.exception.GeaFlowDSLException; import org.apache.geaflow.dsl.common.types.GraphSchema; import org.apache.geaflow.dsl.runtime.traversal.message.ITraversalAgg; +import org.apache.geaflow.infer.InferContext; import org.apache.geaflow.model.graph.edge.EdgeDirection; import org.apache.geaflow.model.traversal.ITraversalResponse; import org.apache.geaflow.model.traversal.TraversalType.ResponseType; @@ -51,6 +51,7 @@ public class GeaFlowAlgorithmRuntimeContext implements AlgorithmRuntimeContext edgeQuery; private final transient GeaFlowAlgorithmAggTraversalFunction traversalFunction; + private final InferContext inferContext; private Object vertexId; private long lastSendAggMsgIterationId = -1L; @@ -59,11 +60,20 @@ public GeaFlowAlgorithmRuntimeContext( GeaFlowAlgorithmAggTraversalFunction traversalFunction, VertexCentricTraversalFuncContext traversalContext, GraphSchema graphSchema) { + this(traversalFunction, traversalContext, graphSchema, null); + } + + public GeaFlowAlgorithmRuntimeContext( + GeaFlowAlgorithmAggTraversalFunction traversalFunction, + VertexCentricTraversalFuncContext traversalContext, + GraphSchema graphSchema, + InferContext inferContext) { this.traversalFunction = traversalFunction; this.traversalContext = traversalContext; this.edgeQuery = traversalContext.edges(); this.graphSchema = graphSchema; this.aggContext = null; + this.inferContext = inferContext; } public void setVertexId(Object vertexId) { @@ -172,7 +182,9 @@ public void finish() { } public void close() { - + if (inferContext != null) { + inferContext.close(); + } } public long getCurrentIterationId() { @@ -189,12 +201,28 @@ public GraphSchema getGraphSchema() { return graphSchema; } + @Override + @SuppressWarnings("unchecked") + public OUT infer(Object... modelInputs) { + if (inferContext == null) { + return AlgorithmRuntimeContext.super.infer(modelInputs); + } + try { + return (OUT) inferContext.infer(modelInputs); + } catch (Exception e) { + throw new GeaflowRuntimeException("model infer failed", e); + } + } + public VertexCentricAggContext getAggContext() { return aggContext; } public void setAggContext(VertexCentricAggContext aggContext) { - this.aggContext = Objects.requireNonNull(aggContext); + if (aggContext == null) { + throw new NullPointerException("aggContext"); + } + this.aggContext = aggContext; } @Override diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GCNInferIntegrationTest.java b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GCNInferIntegrationTest.java new file mode 100644 index 000000000..fd062d1f3 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/java/org/apache/geaflow/dsl/runtime/query/GCNInferIntegrationTest.java @@ -0,0 +1,404 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.geaflow.dsl.runtime.query; + +import com.google.common.io.Resources; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStreamWriter; +import java.net.URL; +import java.nio.charset.StandardCharsets; +import java.nio.file.Paths; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.jar.JarEntry; +import java.util.jar.JarOutputStream; +import org.apache.commons.io.FileUtils; +import org.apache.geaflow.common.config.Configuration; +import org.apache.geaflow.common.config.keys.ExecutionConfigKeys; +import org.apache.geaflow.common.config.keys.FrameworkConfigKeys; +import org.apache.geaflow.dsl.udf.graph.gcn.GCNInferPayload; +import org.apache.geaflow.file.FileConfigKeys; +import org.apache.geaflow.infer.InferContext; +import org.apache.geaflow.infer.InferEnvironmentManager; +import org.testng.Assert; +import org.testng.SkipException; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.Test; + +public class GCNInferIntegrationTest { + + private static final String TEST_WORK_DIR = "/tmp/geaflow/gcn_infer_test"; + private static final String PYTHON_UDF_DIR = TEST_WORK_DIR + "/python_udf"; + private static final String TEST_JOB_JAR = "gcn-test-job.jar"; + private static final String GCN_TRANSFORM_CLASS = "GCNTransFormFunction"; + private static final String FAILING_TRANSFORM_CLASS = "GCNFailingTransformFunction"; + private static final String CONDA_URL_ENV = "GEAFLOW_GCN_INFER_CONDA_URL"; + private static final String GCN_QUERY_PATH = "/query/gql_algorithm_gcn_infer_001.sql"; + + @AfterMethod + public void tearDown() { + FileUtils.deleteQuietly(new File(TEST_WORK_DIR)); + FileUtils.deleteQuietly(new File(getClasspathRoot(), TEST_JOB_JAR)); + } + + @Test(timeOut = 30000) + public void testGCNPythonUDFDirectWithoutModelFile() throws Exception { + ensurePythonModuleAvailable("torch"); + File udfDir = new File(PYTHON_UDF_DIR); + FileUtils.forceMkdir(udfDir); + copyResourceToDirectory("TransFormFunctionUDF.py", udfDir); + + String testScript = String.join("\n", + "import os", + "import sys", + "os.chdir('" + escapeForPython(udfDir.getAbsolutePath()) + "')", + "sys.path.insert(0, os.getcwd())", + "from TransFormFunctionUDF import GCNTransFormFunction", + "payload = {", + " 'center_node_id': 1,", + " 'sampled_nodes': [1, 2, 3],", + " 'node_features': [[1.0, 0.0], [0.0, 1.0], [0.5, 0.5]],", + " 'edge_index': [[0, 0, 1, 2, 0, 1, 2], [1, 2, 0, 0, 0, 1, 2]]", + "}", + "func = GCNTransFormFunction()", + "assert func.model_loaded is False", + "result, center = func.transform_pre(payload)", + "output = func.transform_post(result)", + "assert center == 1", + "assert output['node_id'] == 1", + "assert len(output['embedding']) == 16", + "assert 'predicted_class' in output", + "assert 'confidence' in output", + "assert 0.0 <= float(output['confidence']) <= 1.0", + "print('GCN direct python test passed')" + ); + + File scriptFile = new File(udfDir, "test_gcn_udf.py"); + try (OutputStreamWriter writer = new OutputStreamWriter( + new FileOutputStream(scriptFile), StandardCharsets.UTF_8)) { + writer.write(testScript); + } + + Process process = new ProcessBuilder(findPythonExecutable(), scriptFile.getAbsolutePath()) + .directory(udfDir) + .redirectErrorStream(true) + .start(); + String output = readProcessOutput(process); + int exitCode = process.waitFor(); + + Assert.assertEquals(exitCode, 0, output); + Assert.assertTrue(output.contains("GCN direct python test passed"), output); + } + + @Test + public void testPythonModulesAvailable() throws Exception { + ensurePythonModuleAvailable("torch"); + } + + @Test(timeOut = 600000) + public void testGCNInferContextEndToEndWhenCondaConfigured() throws Exception { + String condaUrl = System.getenv(CONDA_URL_ENV); + if (condaUrl == null || condaUrl.trim().isEmpty()) { + throw new SkipException("Skip GCN infer end-to-end test: " + CONDA_URL_ENV + " is not set"); + } + + ensureInferEnvironmentReady(condaUrl); + + InferContext inferContext = new InferContext<>( + createInferConfiguration(condaUrl, GCN_TRANSFORM_CLASS)); + try { + assertInferResult(inferContext.infer(createPayload(1L)), 1L); + } finally { + inferContext.close(); + } + } + + @Test(timeOut = 600000) + public void testGCNInferContextMultipleCallsWhenCondaConfigured() throws Exception { + String condaUrl = System.getenv(CONDA_URL_ENV); + if (condaUrl == null || condaUrl.trim().isEmpty()) { + throw new SkipException("Skip repeated GCN infer test: " + CONDA_URL_ENV + " is not set"); + } + + ensureInferEnvironmentReady(condaUrl); + + InferContext inferContext = new InferContext<>( + createInferConfiguration(condaUrl, GCN_TRANSFORM_CLASS)); + try { + assertInferResult(inferContext.infer(createPayload(1L)), 1L); + assertInferResult(inferContext.infer(createPayload(2L)), 2L); + assertInferResult(inferContext.infer(createPayload(3L)), 3L); + } finally { + inferContext.close(); + } + } + + @Test(timeOut = 600000) + public void testGCNInferContextPropagatesPythonExceptionWhenCondaConfigured() throws Exception { + String condaUrl = System.getenv(CONDA_URL_ENV); + if (condaUrl == null || condaUrl.trim().isEmpty()) { + throw new SkipException("Skip GCN python exception test: " + CONDA_URL_ENV + " is not set"); + } + + ensureInferEnvironmentReady(condaUrl); + + InferContext inferContext = new InferContext<>( + createInferConfiguration(condaUrl, FAILING_TRANSFORM_CLASS)); + try { + Object rawResult = inferContext.infer(createPayload(1L)); + Assert.assertTrue(rawResult instanceof String, + "Expected python exception string, but got: " + rawResult); + String error = (String) rawResult; + Assert.assertTrue(error.startsWith("python_exception:"), + "Expected python exception marker, but got: " + error); + Assert.assertTrue(error.contains("gcn failing transform invoked"), error); + } finally { + inferContext.close(); + } + } + + @Test(timeOut = 600000) + public void testGCNQueryRuntimeEndToEndWhenCondaConfigured() throws Exception { + String condaUrl = System.getenv(CONDA_URL_ENV); + if (condaUrl == null || condaUrl.trim().isEmpty()) { + throw new SkipException("Skip GCN DSL infer query test: " + CONDA_URL_ENV + " is not set"); + } + + createTestJobJar(); + + QueryTester + .build() + .withConfig(FrameworkConfigKeys.INFER_ENV_ENABLE.getKey(), "true") + .withConfig(FrameworkConfigKeys.INFER_ENV_CONDA_URL.getKey(), condaUrl) + .withConfig(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME.getKey(), + GCN_TRANSFORM_CLASS) + .withConfig(FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC.getKey(), "300") + .withConfig(ExecutionConfigKeys.JOB_WORK_PATH.getKey(), TEST_WORK_DIR) + .withConfig(FileConfigKeys.USER_NAME.getKey(), "gcn_test_user") + .withQueryPath(GCN_QUERY_PATH) + .execute() + .checkSinkResult(); + } + + private void ensurePythonModuleAvailable(String moduleName) throws Exception { + String pythonExecutable = findPythonExecutable(); + Process process = new ProcessBuilder( + pythonExecutable, "-c", "import " + moduleName + "; print('ok')").start(); + String output = readProcessOutput(process); + int exitCode = process.waitFor(); + if (exitCode != 0) { + throw new SkipException("Skip python integration test, missing module " + moduleName + + ": " + output); + } + } + + private String findPythonExecutable() { + String[] candidates = new String[]{"python3", "python"}; + for (String candidate : candidates) { + try { + Process process = new ProcessBuilder(candidate, "--version") + .redirectErrorStream(true) + .start(); + int exitCode = process.waitFor(); + if (exitCode == 0) { + return candidate; + } + } catch (Exception e) { + // try next candidate + } + } + throw new SkipException("Skip python integration test, no python executable found"); + } + + private String readProcessOutput(Process process) throws IOException { + StringBuilder builder = new StringBuilder(); + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(process.getInputStream(), StandardCharsets.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + builder.append(line).append('\n'); + } + } + return builder.toString(); + } + + private void copyResourceToDirectory(String resourceName, File targetDirectory) throws IOException { + try (InputStream inputStream = getRequiredResource(resourceName).openStream()) { + File targetFile = new File(targetDirectory, resourceName); + FileUtils.copyInputStreamToFile(inputStream, targetFile); + } + } + + private URL getRequiredResource(String resourceName) { + URL resource = GCNInferIntegrationTest.class.getClassLoader().getResource(resourceName); + Assert.assertNotNull(resource, "Missing resource " + resourceName); + return resource; + } + + private void createTestJobJar() throws IOException { + File classpathRoot = getClasspathRoot(); + File jarFile = new File(classpathRoot, TEST_JOB_JAR); + if (jarFile.exists()) { + FileUtils.forceDelete(jarFile); + } + try (JarOutputStream jarOutputStream = new JarOutputStream(new FileOutputStream(jarFile))) { + writeStringEntry(jarOutputStream, "TransFormFunctionUDF.py", buildTestPythonUdfContent()); + writeResourceEntry(jarOutputStream, "requirements.txt"); + } + } + + private void ensureInferEnvironmentReady(String condaUrl) throws Exception { + createTestJobJar(); + InferEnvironmentManager.buildInferEnvironmentManager( + createInferConfiguration(condaUrl, GCN_TRANSFORM_CLASS)).createEnvironment(); + } + + private Configuration createInferConfiguration(String condaUrl, String transformClass) { + Configuration config = new Configuration(); + config.put(FrameworkConfigKeys.INFER_ENV_ENABLE.getKey(), "true"); + config.put(FrameworkConfigKeys.INFER_ENV_CONDA_URL.getKey(), condaUrl); + config.put(FrameworkConfigKeys.INFER_ENV_USER_TRANSFORM_CLASSNAME.getKey(), transformClass); + config.put(FrameworkConfigKeys.INFER_ENV_INIT_TIMEOUT_SEC.getKey(), "300"); + config.put(ExecutionConfigKeys.JOB_UNIQUE_ID.getKey(), + "gcn-infer-test-" + System.currentTimeMillis()); + config.put(ExecutionConfigKeys.JOB_WORK_PATH.getKey(), TEST_WORK_DIR); + config.put(FileConfigKeys.USER_NAME.getKey(), "gcn_test_user"); + return config; + } + + private void assertInferResult(Object rawResult, long expectedNodeId) { + Assert.assertTrue(rawResult instanceof Map, + "Expected Python dict result, but got: " + rawResult); + Map result = (Map) rawResult; + Assert.assertEquals(result.get("node_id"), expectedNodeId); + Assert.assertNotNull(result.get("embedding")); + Assert.assertNotNull(result.get("predicted_class")); + Assert.assertNotNull(result.get("confidence")); + } + + private GCNInferPayload createPayload(long centerNodeId) { + if (centerNodeId == 1L) { + return new GCNInferPayload( + 1L, + Arrays.asList(1L, 2L, 3L), + Arrays.asList( + featureVector(1.0D, 2.0D), + featureVector(0.0D, 1.0D), + featureVector(2.0D, 1.0D) + ), + new int[][]{{0, 0, 0, 1, 2}, {1, 2, 0, 1, 2}}, + null + ); + } + if (centerNodeId == 2L) { + return new GCNInferPayload( + 2L, + Arrays.asList(2L, 1L), + Arrays.asList( + featureVector(0.0D, 1.0D), + featureVector(1.0D, 2.0D) + ), + new int[][]{{0, 0, 1}, {1, 0, 1}}, + null + ); + } + if (centerNodeId == 3L) { + return new GCNInferPayload( + 3L, + Arrays.asList(3L, 1L), + Arrays.asList( + featureVector(2.0D, 1.0D), + featureVector(1.0D, 2.0D) + ), + new int[][]{{0, 0, 1}, {1, 0, 1}}, + null + ); + } + throw new IllegalArgumentException("Unsupported center node id " + centerNodeId); + } + + private double[] featureVector(double first, double second) { + double[] values = new double[64]; + values[0] = first; + values[1] = second; + return values; + } + + private void writeResourceEntry(JarOutputStream jarOutputStream, String resourceName) + throws IOException { + JarEntry entry = new JarEntry(resourceName); + jarOutputStream.putNextEntry(entry); + try (InputStream inputStream = getRequiredResource(resourceName).openStream()) { + byte[] buffer = new byte[4096]; + int bytesRead; + while ((bytesRead = inputStream.read(buffer)) != -1) { + jarOutputStream.write(buffer, 0, bytesRead); + } + } + jarOutputStream.closeEntry(); + } + + private void writeStringEntry(JarOutputStream jarOutputStream, String resourceName, String content) + throws IOException { + JarEntry entry = new JarEntry(resourceName); + jarOutputStream.putNextEntry(entry); + jarOutputStream.write(content.getBytes(StandardCharsets.UTF_8)); + jarOutputStream.closeEntry(); + } + + private String buildTestPythonUdfContent() throws IOException { + StringBuilder builder = new StringBuilder(); + try (InputStream inputStream = getRequiredResource("TransFormFunctionUDF.py").openStream(); + InputStreamReader reader = new InputStreamReader(inputStream, StandardCharsets.UTF_8); + BufferedReader bufferedReader = new BufferedReader(reader)) { + String line; + while ((line = bufferedReader.readLine()) != null) { + builder.append(line).append('\n'); + } + } + builder.append('\n'); + builder.append("class ").append(FAILING_TRANSFORM_CLASS).append("(object):\n"); + builder.append(" input_size = 1\n\n"); + builder.append(" def transform_pre(self, *inputs):\n"); + builder.append(" raise RuntimeError('gcn failing transform invoked')\n\n"); + builder.append(" def transform_post(self, *inputs):\n"); + builder.append(" return None\n"); + return builder.toString(); + } + + private File getClasspathRoot() { + try { + return Paths.get(Resources.getResource(".").toURI()).toFile(); + } catch (Exception e) { + throw new RuntimeException("Failed to locate classpath root", e); + } + } + + private String escapeForPython(String path) { + return path.replace("\\", "\\\\").replace("'", "\\'"); + } +} diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/gcn_infer_edge.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/gcn_infer_edge.txt new file mode 100644 index 000000000..ddd8ca16f --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/gcn_infer_edge.txt @@ -0,0 +1,2 @@ +1,2 +1,3 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/gcn_infer_vertex.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/gcn_infer_vertex.txt new file mode 100644 index 000000000..cf869735e --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/data/gcn_infer_vertex.txt @@ -0,0 +1,3 @@ +1.0,2.0,1 +0.0,1.0,2 +2.0,1.0,3 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_infer_001.txt b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_infer_001.txt new file mode 100644 index 000000000..493e9807d --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/expect/gql_algorithm_gcn_infer_001.txt @@ -0,0 +1,3 @@ +1,1 +2,1 +3,1 diff --git a/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_infer_001.sql b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_infer_001.sql new file mode 100644 index 000000000..274e20828 --- /dev/null +++ b/geaflow/geaflow-dsl/geaflow-dsl-runtime/src/test/resources/query/gql_algorithm_gcn_infer_001.sql @@ -0,0 +1,60 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +CREATE TABLE v_gcn ( + feature0 double, + feature1 double, + id bigint +) WITH ( + type='file', + geaflow.dsl.window.size = -1, + geaflow.dsl.file.path = 'resource:///data/gcn_infer_vertex.txt' +); + +CREATE TABLE e_gcn ( + srcId bigint, + targetId bigint +) WITH ( + type='file', + geaflow.dsl.window.size = -1, + geaflow.dsl.file.path = 'resource:///data/gcn_infer_edge.txt' +); + +CREATE GRAPH gcn_graph ( + Vertex feature using v_gcn WITH ID(id), + Edge relation using e_gcn WITH ID(srcId, targetId) +) WITH ( + storeType='memory', + shardCount = 1 +); + +CREATE TABLE tbl_result ( + id bigint, + predicted_class bigint +) WITH ( + type='file', + geaflow.dsl.file.path='${target}' +); + +USE GRAPH gcn_graph; + +INSERT INTO tbl_result +CALL GCN(1, 25, 'GCNTransFormFunction') YIELD (id, embedding, predicted_class, confidence) +RETURN cast(id as bigint), cast(predicted_class as bigint) +;