Java中使用Stream实现6种算法教程

在算法问题解决领域,效率和优雅常常齐头并进。 Java 作为最广泛使用的编程语言之一,提供了各种工具和库来应对此类挑战。 Java 8 中引入的 Stream API 就是这样一个强大的工具,它提供了一种处理元素集合的功能方法。

1、Java Stream API:矩阵乘法
矩阵乘法是线性代数中的基本运算,在计算机图形学、物理模拟和机器学习等各个领域都有广泛的应用。在 Java 中,高效执行矩阵乘法对于优化处理大型数据集的应用程序的性能至关重要。幸运的是,随着 Java 8 中 Stream API 的引入,可以简化矩阵乘法的过程,使代码简洁,并且具有潜在的可扩展性。

在深入研究实现细节之前,我们先简单回顾一下矩阵乘法的过程。给定两个矩阵 A 和 B,其中 A 的维度为 mxn,B 的维度为 nxp,所得矩阵 C 的维度为 mx p。

结果矩阵 C 的第 (i, j) 个元素计算为矩阵 A 第 i 行和矩阵 B 第 j 列的点积:

C[j] = A[0] * B[0][j] + A[1] * B[1][j] + ... + A[n -1] * B[n-1][j]

矩阵乘法的传统方法:
传统的矩阵乘法方法涉及嵌套循环来迭代矩阵的元素并执行必要的计算。以下是该算法的基本概要:

public static int[][] multiplyMatrices(int[][] A, int[][] B) {
    int m = A.length;
    int n = A[0].length;
    int p = B[0].length;
    int[][] C = new int[m][p];

    for (int i = 0; i < m; i++) {
        for (int j = 0; j < p; j++) {
            int sum = 0;
            for (int k = 0; k < n; k++) {
                sum += A[i][k] * B[k][j];
            }
            C[i][j] = sum;
        }
    }
    return C;
}

虽然这种方法效果很好,但它的表达能力不是特别强,而且可能很冗长,尤其是对于较大的矩阵。

使用 Stream API 进行矩阵乘法:
随着 Java 8 中 Stream API 的引入,我们可以利用其函数式编程特性来简化矩阵乘法过程。我们可以利用流来并行计算并更简洁地表达乘法逻辑。

以下是我们如何使用 Stream API 实现矩阵乘法:

import java.util.Arrays;

public static int[][] multiplyMatrices(int[][] A, int[][] B) {
    int m = A.length;
    int n = A[0].length;
    int p = B[0].length;

    return Arrays.stream(A)
            .parallel()
            .map(row -> Arrays.stream(transpose(B))
                    .mapToInt(col -> dotProduct(row, col))
                    .toArray())
            .toArray(int[][]::new);
}

private static int[][] transpose(int[][] matrix) {
    int m = matrix.length;
    int n = matrix[0].length;
    int[][] transposed = new int[n][m];
    for (int i = 0; i < m; i++) {
        for (int j = 0; j < n; j++) {
            transposed[j][i] = matrix[i][j];
        }
    }
    return transposed;
}

private static int dotProduct(int[] a, int[] b) {
    return IntStream.range(0, a.length)
            .map(i -> a[i] * b[i])
            .sum();
}

在此实现中,我们使用流来并行计算结果矩阵的每一行。我们将矩阵 A 的每一行与转置矩阵 B 的列映射到点积流。最后,我们将结果收集到表示结果矩阵 C 的二维数组中。


2、Java Stream API:查找具有最大总和的子数组
利用 Java Stream API 的强大功能:查找总和最大的子数组.

在深入研究解决方案之前,让我们首先了解当前的问题。给定一个整数数组,我们的目标是找到总和最大的连续子数组(连续元素的子序列)。这个问题也称为“最大子数组问题”,有多种解决方案,包括暴力破解、动态规划和分而治之算法。在这里,我们将重点使用 Java Stream API 来解决这个问题,它提供了一种简洁而富有表现力的方法来操作集合。

方法
使用 Stream API 解决此问题的关键思想是将数组分解为所有可能的子数组,计算它们的总和,然后找到其中的最大总和。这可以通过使用诸如map、reduce和max之类的流操作来实现。

执行
让我们深入了解一下实现:

import java.util.Arrays;

public class MaximumSubarrayUsingStream {

    public static void main(String[] args) {
        int[] array = { -2, 1, -3, 4, -1, 2, 1, -5, 4 }; // Example array
        int[] maxSubarray = findMaxSubarray(array);
        System.out.println(
"Maximum subarray: " + Arrays.toString(maxSubarray));
    }

    public static int[] findMaxSubarray(int[] array) {
        return Arrays.stream(array)
                .mapToObj(start -> Arrays.stream(array)
                        .skip(start)
                        .mapToObj(end -> Arrays.copyOfRange(array, start, end + 1)))
                .flatMap(subArrays -> subArrays)
                .max((arr1, arr2) -> Arrays.stream(arr1).sum() - Arrays.stream(arr2).sum())
                .orElse(new int[0]);
    }
}


解释

  • - 我们首先使用 Arrays.stream(array) 迭代数组中的每个元素。
  • - 对于每个元素,我们使用 mapToObj 生成从该元素开始的子数组流。
  • - 在嵌套流中,我们使用skip(start)跳过当前元素之前的元素,并生成以每个后续元素结尾的子数组。
  • - 使用 flatMap 将生成的子数组展平为单个流。
  • - 最后,我们通过使用 max 比较每个子数组的和来找到最大子数组。


3、Java Stream API:查找最长公共子序列
使用 Java Stream API 查找两个字符串之间的最长公共子序列

给定两个字符串,我们需要找到最长公共子序列的长度和实际子序列本身。例如,对于字符串“ABCBDAB”和“BDCAB”,最长公共子序列是“BCAB”,长度为4。

方法:
我们将使用动态规划方法来解决这个问题。我们将创建一个二维数组 dp[][],其中 dp[j] 将存储第一个字符串的前 i 个字符和第二个字符串的前 j 个字符之间的最长公共子序列的长度。

然后我们将迭代两个字符串的字符。如果字符相等,我们将最长公共子序列的长度加1。否则,我们将取前一个字符的最长公共子序列长度的最大值。

最后,我们将通过dp[][]数组回溯构造最长公共子序列。

使用 Stream API 的 Java 代码:

import java.util.stream.IntStream;

public class LongestCommonSubsequence {

    public static String findLCS(String s1, String s2) {
        int m = s1.length();
        int n = s2.length();

        int[][] dp = new int[m + 1][n + 1];

        IntStream.rangeClosed(1, m).forEach(i ->
                IntStream.rangeClosed(1, n).forEach(j -> {
                    if (s1.charAt(i - 1) == s2.charAt(j - 1)) {
                        dp[i][j] = dp[i - 1][j - 1] + 1;
                    } else {
                        dp[i][j] = Math.max(dp[i - 1][j], dp[i][j - 1]);
                    }
                }));

        StringBuilder lcs = new StringBuilder();
        int i = m, j = n;
        while (i > 0 && j > 0) {
            if (s1.charAt(i - 1) == s2.charAt(j - 1)) {
                lcs.insert(0, s1.charAt(i - 1));
                i--;
                j--;
            } else if (dp[i - 1][j] > dp[i][j - 1]) {
                i--;
            } else {
                j--;
            }
        }

        return lcs.toString();
    }

    public static void main(String[] args) {
        String s1 = "ABCBDAB";
        String s2 =
"BDCAB";

        System.out.println(
"Longest Common Subsequence: " + findLCS(s1, s2));
    }
}


解释:

  • - 我们首先创建一个二维数组 dp[][] 来存储最长公共子序列的长度。
  • - 然后我们使用嵌套的 IntStream 迭代两个字符串的每个字符并计算最长公共子序列的长度。
  • - 最后,我们回溯 dp[][] 数组来构造最长公共子序列。

4、Java Stream API:实现 Prim 算法
Prim 算法是一种贪心算法,用于为带权无向图找到最小生成树。它从任意节点开始,并通过添加将树连接到新节点的最便宜的边来增长生成树。重复此过程,直到所有节点都包含在树中。

在深入实现之前,我们先简单回顾一下 Prim 算法所涉及的步骤:
1. 选择任意一个节点作为起点。
2. 初始化一个优先级队列来存储边,键为边的权重。
3. 虽然仍有节点需要包含在树中:
   A。找到将树中的节点连接到树外的节点的权重最小的边。
   b.将此边添加到树中。
   C。将新添加的节点标记为包含在树中。


实施 Prim 算法
为了实现 Prim 算法,我们将图表示为边列表,其中每条边都是一个包含源节点、目标节点和边权重的元组。我们将使用 Set 来跟踪树中包含的节点,并使用 PriorityQueue 来存储按权重排序的边。

import java.util.*;
import java.util.stream.*;

class Edge {
    int src, dest, weight;

    Edge(int src, int dest, int weight) {
        this.src = src;
        this.dest = dest;
        this.weight = weight;
    }
}

public class PrimAlgorithm {

    public static List<Edge> primMST(List<Edge> edges, int numVertices) {
        Set<Integer> visited = new HashSet<>();
        PriorityQueue<Edge> pq = new PriorityQueue<>(Comparator.comparingInt(e -> e.weight));
        List<Edge> mst = new ArrayList<>();

        // Start from the first node
        int startNode = 0;
        visited.add(startNode);

       
// Add all edges from the starting node to the priority queue
        pq.addAll(edges.stream()
                       .filter(e -> e.src == startNode || e.dest == startNode)
                       .collect(Collectors.toList()));

        while (!pq.isEmpty() && visited.size() < numVertices) {
            Edge minEdge = pq.poll();
            int nextNode = visited.contains(minEdge.src) ? minEdge.dest : minEdge.src;

            if (!visited.contains(nextNode)) {
                visited.add(nextNode);
                mst.add(minEdge);

                pq.addAll(edges.stream()
                               .filter(e -> e.src == nextNode || e.dest == nextNode)
                               .collect(Collectors.toList()));
            }
        }

        return mst;
    }

    public static void main(String[] args) {
        List<Edge> edges = Arrays.asList(
                new Edge(0, 1, 2),
                new Edge(0, 2, 3),
                new Edge(1, 2, 1),
                new Edge(1, 3, 4),
                new Edge(2, 4, 5),
                new Edge(3, 4, 6)
        );

        List<Edge> mst = primMST(edges, 5);

        System.out.println(
"Minimum Spanning Tree:");
        for (Edge edge : mst) {
            System.out.println(edge.src +
" - " + edge.dest + ": " + edge.weight);
        }
    }
}


在这个实现中,我们从第一个节点开始,重复添加连接树中节点和树外节点的最小权重边,直到所有节点都包含在树中。 PriorityQueue 确保我们始终有效地选择最小权重边。


5、Java Stream API:实现 Dijkstra 算法
Dijkstra 算法是一种经典算法,用于查找图中从单个源节点到所有其他节点的最短路径。

Dijkstra 算法的工作原理是迭代选择距源节点最短距离的节点并更新到其邻居节点的最短距离。它维护一个优先级队列(或最小堆)以有效地选择下一个要访问的节点。

我们将使用邻接列表来表示该图,其中每个节点都与其相邻节点及其相应的边权重的列表相关联。

import java.util.*;

class Graph {
    private final Map<Integer, List<Edge>> adjacencyList = new HashMap<>();

    public void addEdge(int source, int destination, int weight) {
        adjacencyList.computeIfAbsent(source, k -> new ArrayList<>()).add(new Edge(destination, weight));
        adjacencyList.computeIfAbsent(destination, k -> new ArrayList<>()).add(new Edge(source, weight)); // for undirected graph
    }

    public List<Edge> getNeighbors(int node) {
        return adjacencyList.getOrDefault(node, Collections.emptyList());
    }

    static class Edge {
        int destination;
        int weight;

        public Edge(int destination, int weight) {
            this.destination = destination;
            this.weight = weight;
        }
    }
}


实施 Dijkstra 算法
我们将使用 PriorityQueue 来维护与源节点的最短距离尚未最终确定的节点集。我们还将使用 Map 来跟踪从源节点到图中每个节点的最短距离。

import java.util.*;

class Dijkstra {
    public static Map<Integer, Integer> shortestPaths(Graph graph, int source) {
        Map<Integer, Integer> distances = new HashMap<>();
        PriorityQueue<Integer> pq = new PriorityQueue<>(Comparator.comparingInt(distances::get));
        Set<Integer> visited = new HashSet<>();

        distances.put(source, 0);
        pq.add(source);

        while (!pq.isEmpty()) {
            int current = pq.poll();
            if (visited.contains(current)) {
                continue;
            }
            visited.add(current);

            for (Graph.Edge neighbor : graph.getNeighbors(current)) {
                int newDistance = distances.get(current) + neighbor.weight;
                if (!distances.containsKey(neighbor.destination) || newDistance < distances.get(neighbor.destination)) {
                    distances.put(neighbor.destination, newDistance);
                    pq.add(neighbor.destination);
                }
            }
        }

        return distances;
    }
}

使用 Java Stream API 实现简洁
现在,让我们使用Java Stream API 重写shortestPaths 方法,使其更加简洁和富有表现力。

import java.util.*;

class Dijkstra {
    public static Map<Integer, Integer> shortestPaths(Graph graph, int source) {
        Map<Integer, Integer> distances = new HashMap<>();
        PriorityQueue<Integer> pq = new PriorityQueue<>(Comparator.comparingInt(distances::get));
        Set<Integer> visited = new HashSet<>();

        distances.put(source, 0);
        pq.add(source);

        while (!pq.isEmpty()) {
            int current = pq.poll();
            if (visited.contains(current)) {
                continue;
            }
            visited.add(current);

            graph.getNeighbors(current).stream()
                    .filter(neighbor -> !visited.contains(neighbor.destination))
                    .forEach(neighbor -> {
                        int newDistance = distances.get(current) + neighbor.weight;
                        distances.put(neighbor.destination, Math.min(distances.getOrDefault(neighbor.destination, Integer.MAX_VALUE), newDistance));
                        pq.add(neighbor.destination);
                    });
        }

        return distances;
    }
}


使用 Stream API,我们消除了显式循环并将其替换为流操作,从而简化了循环内的代码。这使得代码更具可读性并保持了编程的函数式风格。


6、Java Stream API:实现深度优先搜索算法
使用 Java Stream API 实现图遍历的深度优先搜索算法

图遍历是一个基本的算法问题,涉及访问图的所有节点。深度优先搜索 (DFS) 是遍历图的方法之一,从特定节点开始,然后以深度运动递归访问其邻居,然后再移动到下一个邻居。

首先,让我们使用邻接列表定义一个简单的图形表示。我们将使用“Map”来表示图,其中键是节点,值是相邻节点的列表。

import java.util.*;

public class Graph {
    private Map<Integer, List<Integer>> adjacencyList;

    public Graph(int vertices) {
        adjacencyList = new HashMap<>();
        for (int i = 0; i < vertices; i++) {
            adjacencyList.put(i, new LinkedList<>());
        }
    }

    public void addEdge(int source, int destination) {
        adjacencyList.get(source).add(destination);
    }

    public List<Integer> getNeighbors(int vertex) {
        return adjacencyList.get(vertex);
    }

    public int getVertexCount() {
        return adjacencyList.size();
    }
}


深度优先搜索算法
现在,让我们使用 Java Stream API 来实现 DFS 算法。 DFS的基本思想是从给定的节点开始,将其标记为已访问,然后递归地访问其所有尚未访问过的邻居。

import java.util.*;

public class DepthFirstSearch {

    public static void main(String[] args) {
        Graph graph = new Graph(5);
        graph.addEdge(0, 1);
        graph.addEdge(0, 2);
        graph.addEdge(1, 3);
        graph.addEdge(1, 4);
        graph.addEdge(2, 4);

        dfs(graph, 0);
    }

    public static void dfs(Graph graph, int start) {
        Set<Integer> visited = new HashSet<>();
        dfsRecursive(graph, start, visited);
    }

    private static void dfsRecursive(Graph graph, int current, Set<Integer> visited) {
        visited.add(current);
        System.out.println("Visiting node: " + current);

        graph.getNeighbors(current).stream()
            .filter(neighbor -> !visited.contains(neighbor))
            .forEach(neighbor -> dfsRecursive(graph, neighbor, visited));
    }
}


在此实现中,我们从节点 0 开始 DFS 遍历。“dfsRecursive”方法将当前节点标记为已访问,打印其值,然后在当前节点的所有未访问邻居上递归调用自身。