summaryrefslogtreecommitdiffstats
path: root/_blog/neo4j-a-star-search.md
diff options
context:
space:
mode:
Diffstat (limited to '_blog/neo4j-a-star-search.md')
-rw-r--r--_blog/neo4j-a-star-search.md318
1 files changed, 318 insertions, 0 deletions
diff --git a/_blog/neo4j-a-star-search.md b/_blog/neo4j-a-star-search.md
new file mode 100644
index 0000000..117931b
--- /dev/null
+++ b/_blog/neo4j-a-star-search.md
@@ -0,0 +1,318 @@
+---
+title: Neo4J A* search
+date: 2025-09-14
+layout: post
+---
+
+Back in 2018, we used <a href="https://neo4j.com/" class="external"
+target="_blank" rel="noopener noreferrer">Neo4J</a> graph database to track the
+movement of marine vessels. We were interested in the shortest path a ship
+could take through a network of about 13,000 route points. Algorithms based on
+graph theory, such as A* search, provide optimal solutions to such problems.
+In other words, the set of route points lends itself well to a model based on
+graphs.
+
+A graph is a finite set of vertices, and a subset of vertex pairs (edges).
+Edges can have weights. In the case of vessel tracking, the route points form
+the vertices of a graph; the routes between them, the edges; and the distances
+between them are the weights. For different reasons, people are interested in
+minimizing (or maximizing) the weight of a path through a set of vertices. For
+instance, we may want to find the shortest path between two ports.
+
+Given such a graph, an algorithm like Dijkstra's search could compute the
+shortest path between two vertices. In fact, this was the algorithm Neo4J
+shipped with at the time. One drawback of Dijkstra's algorithm is that it
+computes all the shortest paths from the source to all other vertices before
+terminating at the destination vertex. The exhaustive nature of this search
+limited our search to about 4,000 route points.
+
+The following enhancement to Dijkstra's search, also known as the A* search,
+employs a heuristic to steer the search in the direction of the destination
+more quickly. In the case of our network of vessels, which are on the earth's
+surface, spherical distance is a good candidate for a heuristic:
+
+```
+package org.neo4j.graphalgo.impl;
+
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
+
+import org.neo4j.graphalgo.api.Graph;
+import org.neo4j.graphalgo.core.utils.ProgressLogger;
+import org.neo4j.graphalgo.core.utils.queue.IntPriorityQueue;
+import org.neo4j.graphalgo.core.utils.queue.SharedIntPriorityQueue;
+import org.neo4j.graphalgo.core.utils.traverse.SimpleBitSet;
+import org.neo4j.graphdb.Direction;
+import org.neo4j.graphdb.Node;
+import org.neo4j.kernel.internal.GraphDatabaseAPI;
+
+import com.carrotsearch.hppc.IntArrayDeque;
+import com.carrotsearch.hppc.IntDoubleMap;
+import com.carrotsearch.hppc.IntDoubleScatterMap;
+import com.carrotsearch.hppc.IntIntMap;
+import com.carrotsearch.hppc.IntIntScatterMap;
+
+public class ShortestPathAStar extends Algorithm<ShortestPathAStar> {
+
+ private final GraphDatabaseAPI dbService;
+ private static final int PATH_END = -1;
+
+ private Graph graph;
+ private final int nodeCount;
+ private IntDoubleMap gCosts;
+ private IntDoubleMap fCosts;
+ private double totalCost;
+ private IntPriorityQueue openNodes;
+ private IntIntMap path;
+ private IntArrayDeque shortestPath;
+ private SimpleBitSet closedNodes;
+ private final ProgressLogger progressLogger;
+
+ public static final double NO_PATH_FOUND = -1.0;
+
+ public ShortestPathAStar(
+ final Graph graph,
+ final GraphDatabaseAPI dbService) {
+
+ this.graph = graph;
+ this.dbService = dbService;
+
+ nodeCount = Math.toIntExact(graph.nodeCount());
+ gCosts = new IntDoubleScatterMap(nodeCount);
+ fCosts = new IntDoubleScatterMap(nodeCount);
+ openNodes = SharedIntPriorityQueue.min(
+ nodeCount,
+ fCosts,
+ Double.MAX_VALUE);
+ path = new IntIntScatterMap(nodeCount);
+ closedNodes = new SimpleBitSet(nodeCount);
+ shortestPath = new IntArrayDeque();
+ progressLogger = getProgressLogger();
+ }
+
+ public ShortestPathAStar compute(
+ final long startNode,
+ final long goalNode,
+ final String propertyKeyLat,
+ final String propertyKeyLon,
+ final Direction direction) {
+
+ reset();
+
+ final int startNodeInternal =
+ graph.toMappedNodeId(startNode);
+ final double startNodeLat =
+ getNodeCoordinate(startNodeInternal, propertyKeyLat);
+ final double startNodeLon =
+ getNodeCoordinate(startNodeInternal, propertyKeyLon);
+
+ final int goalNodeInternal =
+ graph.toMappedNodeId(goalNode);
+ final double goalNodeLat =
+ getNodeCoordinate(goalNodeInternal, propertyKeyLat);
+ final double goalNodeLon =
+ getNodeCoordinate(goalNodeInternal, propertyKeyLon);
+
+ final double initialHeuristic =
+ computeHeuristic(startNodeLat,
+ startNodeLon,
+ goalNodeLat,
+ goalNodeLon);
+
+ gCosts.put(startNodeInternal, 0.0);
+ fCosts.put(startNodeInternal, initialHeuristic);
+ openNodes.add(startNodeInternal, 0.0);
+
+ run(goalNodeInternal,
+ propertyKeyLat,
+ propertyKeyLon,
+ direction);
+
+ if (path.containsKey(goalNodeInternal)) {
+ totalCost = gCosts.get(goalNodeInternal);
+ int node = goalNodeInternal;
+ while (node != PATH_END) {
+ shortestPath.addFirst(node);
+ node = path.getOrDefault(node, PATH_END);
+ }
+ }
+ return this;
+ }
+
+ private void run(
+ final int goalNodeId,
+ final String propertyKeyLat,
+ final String propertyKeyLon,
+ final Direction direction) {
+
+ final double goalLat =
+ getNodeCoordinate(goalNodeId, propertyKeyLat);
+ final double goalLon =
+ getNodeCoordinate(goalNodeId, propertyKeyLon);
+
+ while (!openNodes.isEmpty() && running()) {
+ int currentNodeId = openNodes.pop();
+ if (currentNodeId == goalNodeId) {
+ return;
+ }
+
+ closedNodes.put(currentNodeId);
+
+ double currentNodeCost =
+ this.gCosts.getOrDefault(
+ currentNodeId,
+ Double.MAX_VALUE);
+
+ graph.forEachRelationship(
+ currentNodeId,
+ direction,
+ (source, target, relationshipId, weight) -> {
+ double neighbourLat =
+ getNodeCoordinate(target, propertyKeyLat);
+ double neighbourLon =
+ getNodeCoordinate(target, propertyKeyLon);
+ double heuristic =
+ computeHeuristic(
+ neighbourLat,
+ neighbourLon,
+ goalLat,
+ goalLon);
+
+ updateCosts(
+ source,
+ target,
+ weight + currentNodeCost,
+ heuristic);
+
+ if (!closedNodes.contains(target)) {
+ openNodes.add(target, 0);
+ }
+ return true;
+ });
+
+ progressLogger.logProgress(
+ (double) currentNodeId / (nodeCount - 1));
+ }
+ }
+
+ private double computeHeuristic(
+ final double lat1,
+ final double lon1,
+ final double lat2,
+ final double lon2) {
+
+ final int earthRadius = 6371;
+ final double kmToNM = 0.539957;
+ final double latDistance = Math.toRadians(lat2 - lat1);
+ final double lonDistance = Math.toRadians(lon2 - lon1);
+ final double a = Math.sin(latDistance / 2)
+ * Math.sin(latDistance / 2)
+ + Math.cos(Math.toRadians(lat1))
+ * Math.cos(Math.toRadians(lat2))
+ * Math.sin(lonDistance / 2)
+ * Math.sin(lonDistance / 2);
+ final double c = 2
+ * Math.atan2(Math.sqrt(a), Math.sqrt(1 - a));
+ final double distance = earthRadius * c * kmToNM;
+ return distance;
+ }
+
+ private double getNodeCoordinate(
+ final int nodeId,
+ final String coordinateType) {
+
+ final long neo4jId = graph.toOriginalNodeId(nodeId);
+ final Node node = dbService.getNodeById(neo4jId);
+ return (double) node.getProperty(coordinateType);
+ }
+
+ private void updateCosts(
+ final int source,
+ final int target,
+ final double newCost,
+ final double heuristic) {
+
+ final double oldCost =
+ gCosts.getOrDefault(target, Double.MAX_VALUE);
+
+ if (newCost < oldCost) {
+ gCosts.put(target, newCost);
+ fCosts.put(target, newCost + heuristic);
+ path.put(target, source);
+ }
+ }
+
+ private void reset() {
+ closedNodes.clear();
+ openNodes.clear();
+ gCosts.clear();
+ fCosts.clear();
+ path.clear();
+ shortestPath.clear();
+ totalCost = NO_PATH_FOUND;
+ }
+
+ public Stream<Result> resultStream() {
+ return StreamSupport.stream(
+ shortestPath.spliterator(), false)
+ .map(cursor -> new Result(
+ graph.toOriginalNodeId(cursor.value),
+ gCosts.get(cursor.value)));
+ }
+
+ public IntArrayDeque getFinalPath() {
+ return shortestPath;
+ }
+
+ public double getTotalCost() {
+ return totalCost;
+ }
+
+ public int getPathLength() {
+ return shortestPath.size();
+ }
+
+ @Override
+ public ShortestPathAStar me() {
+ return this;
+ }
+
+ @Override
+ public ShortestPathAStar release() {
+ graph = null;
+ gCosts = null;
+ fCosts = null;
+ openNodes = null;
+ path = null;
+ shortestPath = null;
+ closedNodes = null;
+ return this;
+ }
+
+ public static class Result {
+
+ /**
+ * the neo4j node id
+ */
+ public final Long nodeId;
+
+ /**
+ * cost to reach the node from startNode
+ */
+ public final Double cost;
+
+ public Result(Long nodeId, Double cost) {
+ this.nodeId = nodeId;
+ this.cost = cost;
+ }
+ }
+}
+```
+
+The heuristic function is domain-specific. If chosen wisely, it can
+significantly speed up the search. In our case, we achieved a 300x speedup,
+enabling us to expand our search from 4,000 to 13,000 route points. The <a
+href="https://github.com/neo4j-contrib/neo4j-graph-algorithms/releases/tag/3.4.0.0"
+class="external" target="_blank" rel="noopener noreferrer">v3.4.0</a> of the
+Neo4J graph algorithms shipped with the A* search algorithm.
+