如何在Python中使用Java类? - codecentric

21-11-21 banq

让 Java 和 Python 携手合作非常容易,这在开发原型时尤其有价值。

我们从一个实现 Snake 游戏逻辑的 Java 程序开始:场上总有一块食物。每当蛇到达食物时,它就会生长并出现新的食物。如果蛇咬自己或咬墙,游戏结束。

我们的目标是训练一个神经网络来控制蛇,让蛇在犯错和游戏结束之前吃掉尽可能多的食物。首先,我们需要一个代表游戏当前状态的张量。它充当我们神经网络的输入,以便网络可以使用它来预测下一步要采取的最佳步骤。为了让这个例子简单,我们的张量只是一个包含七个元素的向量,可以是 1 或 0:前四个表示食物是在蛇的右边、左边、前面还是后面,接下来的三个条目表示如果蛇头的左边、前面和右边的田地都被一堵墙或蛇的尾巴挡住了。

我们示例的完整源代码可在 GitHub 上找到。

使用JPype导入Java类即可:

import jpype
import jpype.imports
from jpype.types import *
 
# launch the JVM
jpype.startJVM(classpath=['../target/autosnake-1.0-SNAPSHOT.jar'])
 
# import the Java module
from me.schawe.autosnake import SnakeLogic
 
# construct an object of the `SnakeLogic` class ...
width, height = 10, 10
snake_logic = SnakeLogic(width, height)
 
# ... and call a method on it
print(snake_logic.trainingState())

JPype 在与 Python 解释器相同的进程中启动 JVM,并让它们使用 Java 本机接口 (JNI) 进行通信。

其他选项:

  • Jython直接在 JVM 中执行 Python 解释器,这样 Python 和 Java 就可以非常高效地使用相同的数据结构。但这对使用原生 Python 库有一些缺点——因为我们将使用numpy和tensorflow,这对我们来说不是一个选择。
  • Py4J处于频谱的另一侧。它在 Java 代码中启动一个套接字,它可以通过它与 Python 程序进行通信。优点是任意数量的 Python 进程可以连接到一个长时间运行的 Java 进程——或者相反,一个 Python 进程可以连接到多个 JVM,甚至通过网络。缺点是套接字通信的开销较大。

 

在 Java 中加载模型

使用deeplearning4j将训练好的模型加载到 Java 中……

// https://deeplearning4j.konduit.ai/deeplearning4j/how-to-guides/keras-import
public class Autopilot {
    ComputationGraph model;
 
    public Autopilot(String pathToModel) {
        try {
            model = KerasModelImport.importKerasModelAndWeights(pathToModel, false);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }
 
    // infer the next move from the given state
    public int nextMove(boolean[] state) {
        INDArray input = Nd4j.create(state).reshape(1, state.length);
        INDArray output = model.output(input)[0];
 
        int action = output.ravel().argMax().getInt(0);
 
        return action;
    }
}
调用:
public class SnakeLogic {
    Autopilot autopilot = new Autopilot("path/to/model.h5");
 
    public void update() {
        int action = autopilot.nextMove(trainingState());
        turnRelative(action);
 
        // rest of the update omitted
    }
 
    // further methods omitted
}

 

1
猜你喜欢