在本文中,我们了解了 Tribuo 及其功能。然后,我们概述了 Tribuo 支持的一些机器学习算法。最后,我们训练了一个模型,使用回归算法来预测葡萄酒的质量。
机器学习 (ML) 和人工智能 (AI) 正在通过使系统能够从数据中学习并做出智能预测来重塑软件开发。
Tribuo是由 Oracle 开发的可立即投入生产的开源机器学习库。它简化了构建和部署强大的机器学习模型。与 Weka 和Deeplearning4j一样,Tribuo 支持各种机器学习任务,并且可以轻松地与 Java 应用程序集成。
在本教程中,我们将学习Tribuo中可用的不同机器学习算法。此外,我们将使用UCI红酒质量数据集构建一个回归模型来预测葡萄酒质量。
什么是Tribo?
Tribuo 是一个以 Java 为中心的机器学习库,支持:
此外,它是强类型的,这意味着它强制正确的输入和输出类型,这有助于防止运行时错误并确保一致的模型开发。
它支持以开放神经网络交换 (ONNX) 格式导入和导出模型,允许与其他 ML 框架(如 TensorFlow 和 PyTorch)集成。
Tribuo 的另一个突出特点是来源追踪。此功能记录有关数据集、模型参数和训练配置的元数据,从而提高透明度和可重复性。
随着人工智能在企业 Java 应用程序中不断找到自己的位置,Tribuo 提供了一个实用的工具包,可以将智能行为直接嵌入到基于 Java 的系统中。
支持的机器学习算法
Tribuo 支持各种 ML 任务,包括:
- 分类:预测离散的类别或标签。例如,预测一支足球队的胜负,或者根据质量阈值将葡萄酒分为好酒和坏酒。
- 回归:它预测连续值,例如葡萄酒质量评分或患者的胆固醇水平。
- 聚类:识别未标记数据中的组。例如,它可以根据酸度和酒精含量等化学特性对葡萄酒进行分组,而无需知道其质量评分。
建立Tribuo项目
让我们通过建立回归模型来预测葡萄酒质量,看看 Tribuo 的实际作用。
首先,让我们将Tribuo 依赖项添加到我们的pom.xml中:
<dependency> <groupId>org.tribuo</groupId> <artifactId>tribuo-all</artifactId> <version>4.3.2</version> </dependency>
|
tribuo -all依赖项提供使用特定训练算法加载和训练数据集的类。
另外,我们下载UCI 葡萄酒质量数据集 Red并将其放在src/main/resources/dataset目录中。该数据集包含 11 个物理化学特征,例如酸度和酒精含量:
UCI红酒数据集
质量列是适合回归的连续数值。
最后,让我们创建一个名为WineQualityRegression 的类:
public class WineQualityRegression { }
|
在后续章节中,课程将介绍不同的逻辑来训练和保存我们的模型以供将来使用。
类级变量
接下来,让我们定义以下类级变量:
public static final String DATASET_PATH = "src/main/resources/dataset/winequality-red.csv"; public static final String MODEL_PATH = "src/main/resources/model/winequality-red-regressor.ser"; public Model<Regressor> model; public Trainer<Regressor> trainer; public Dataset<Regressor> trainSet; public Dataset<Regressor> testSet;
|
在上面的代码中,我们定义了数据集的路径以及训练模型的保存或加载位置。
接下来,我们定义四个变量,分别代表以下内容:
- 模型——存储预测模型的类
- 训练器——可以训练预测模型的界面
- 数据集——包含用于训练的一组数据的类
此外,我们明确指定模型输出类型为Regressor。
加载和分割数据集
让我们定义一个方法来加载和分割数据集:
void createDatasets() throws Exception { RegressionFactory regressionFactory = new RegressionFactory(); CSVLoader<Regressor> csvLoader = new CSVLoader<>(';', CSVIterator.QUOTE, regressionFactory); DataSource<Regressor> dataSource = csvLoader.loadDataSource(Paths.get(DATASET_PATH), "quality"); TrainTestSplitter<Regressor> dataSplitter = new TrainTestSplitter<>(dataSource, 0.7, 1L); trainSet = new MutableDataset<>(dataSplitter.getTrain()); testSet = new MutableDataset<>(dataSplitter.getTest()); }
|
在这里,我们使用CSVLoader解析以分号分隔的 CSV 文件,并准备进行回归分析。RegressionFactory创建回归输出,并指定目标变量quality为连续变量。DataSource 保存解析后的数据。
然后,为了评估泛化能力和模型性能,我们使用TrainTestSplitter将数据集分成 70% 的训练集和 30% 的测试集。
训练回归模型
由于葡萄酒质量得分是一个数值,让我们使用分类和回归树(CART)作为基础学习器来训练模型,以预测葡萄酒质量:
void createTrainer() { CARTRegressionTrainer subsamplingTree = new CARTRegressionTrainer( Integer.MAX_VALUE, AbstractCARTTrainer.MIN_EXAMPLES, 0.001f, 0.7f, new MeanSquaredError(), Trainer.DEFAULT_SEED ); trainer = new RandomForestTrainer<>(subsamplingTree, new AveragingCombiner(), 10); model = trainer.train(trainSet); }
|
在上述方法中,CARTRegressionTrainer配置了一个不设最大深度、每次分割至少包含 6 个样本、并以均方误差 (MSE) 作为分割标准的决策树。然后,RandomForestTrainer会组合 10 棵 CART 决策树,并使用AveragingCombiner对其预测结果求平均值。
train ()方法在trainSet数据集上训练模型,生成用于预测葡萄酒质量分数的Model 。
评估
接下来,让我们使用RegressionEvaluator计算与模型相关的数据集的指标来评估回归模型的性能:
void evaluate(Model<Regressor> model, String datasetName, Dataset<Regressor> dataset) { RegressionEvaluator evaluator = new RegressionEvaluator(); RegressionEvaluation evaluation = evaluator.evaluate(model, dataset); Regressor dimension0 = new Regressor("DIM-0", Double.NaN); log.info("MAE: " + evaluation.mae(dimension0)); log.info("RMSE: " + evaluation.rmse(dimension0)); log.info("R^2: " + evaluation.r2(dimension0)); }
|
RegressionEvaluator评估模型在数据集上的性能。然后,我们将 MAE(平均绝对误差)、RMSE(均方根误差)和 R^2(判定系数)记录到控制台。
接下来,让我们使用evaluate()方法来评估我们的模型和数据集:
void evaluateModels() throws Exception { log.info("Training model"); evaluate(model, "trainSet", trainSet); log.info("Testing model"); evaluate(model, "testSet", testSet); }
|
这是我们执行程序时针对模型的训练和测试集的评估:
07:10:14.405 [main] INFO tribuo.WineQualityRegression - Training model 07:10:14.406 [main] INFO tribuo.WineQualityRegression - Results for trainSet--------------------- 07:10:14.537 [main] INFO tribuo.WineQualityRegression - MAE: 0.25025410332970005 07:10:14.537 [main] INFO tribuo.WineQualityRegression - RMSE: 0.3422557198486092 07:10:14.538 [main] INFO tribuo.WineQualityRegression - R^2: 0.8190947891297661 07:10:14.538 [main] INFO tribuo.WineQualityRegression - Testing model 07:10:14.540 [main] INFO tribuo.WineQualityRegression - Results for testSet--------------------- 07:10:14.565 [main] INFO tribuo.WineQualityRegression - MAE: 0.48711029366796743 07:10:14.565 [main] INFO tribuo.WineQualityRegression - RMSE: 0.6584973595553575 07:10:14.565 [main] INFO tribuo.WineQualityRegression - R^2: 0.3444460580874339
|
MAE表示训练集和测试集的预测值与实际值之间的绝对差。RMSE 度量 的是预测值与实际值平方差平均值的平方根。此外,R^2 表示模型对训练集和测试集方差的解释能力。
MAE和RMSE值越低, R^2值越高,表明预测性能越好。
保存模型
最后,让我们将模型保存为文件以供重复使用:
void saveModel() throws Exception { File modelFile = new File(MODEL_PATH); try (ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(modelFile))) { objectOutputStream.writeObject(model); } }
|
在上面的代码中,我们使用ObjectOutputStream类将训练好的模型序列化为文件。通过将模型保存到文件中,我们可以重复使用该模型进行后续预测,而无需重新训练。
调用方法
现在,让我们在main()方法中调用之前创建的方法:
public static void main(String[] args) throws Exception { WineQualityRegression wineQualityRegression = new WineQualityRegression(); wineQualityRegression.createDatasets(); wineQualityRegression.createTrainer(); wineQualityRegression.evaluateModels(); wineQualityRegression.saveModel(); }
|
这将编译代码并将模型保存在指定的目录中。
使用模型
让我们创建一个名为WinePredictor的新类,并在main()方法中加载保存的模型:
class WineQualityPredictor { private static final Logger log = LoggerFactory.getLogger(WineQualityPredictor.class); public static void main(String[] args) throws IOException, ClassNotFoundException { File modelFile = new File("src/main/resources/model/winequality-red-regressor.ser"); Model<Regressor> loadedModel = null; try (ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(modelFile))) { loadedModel = (Model<Regressor>) objectInputStream.readObject(); } }
|
我们知道,Tribuo 对类型敏感,因此我们指定了模型的类型,在本例中为回归器。
另外,我们通过创建ObjectInputStream并将模型路径作为参数传递来加载模型。
然后,让我们创建一个ArrayExample对象来表示单个葡萄酒样本:
ArrayExample<Regressor> wineAttribute = new ArrayExample<Regressor>(new Regressor("quality", Double.NaN)); wineAttribute.add("fixed acidity", 7.4f); wineAttribute.add("volatile acidity", 0.7f); wineAttribute.add("citric acid", 0.47f); wineAttribute.add("residual sugar", 1.9f); wineAttribute.add("chlorides", 0.076f); wineAttribute.add("free sulfur dioxide", 11.0f); wineAttribute.add("total sulfur dioxide", 34.0f); wineAttribute.add("density", 0.9978f); wineAttribute.add("pH", 3.51f); wineAttribute.add("sulphates", 0.56f); wineAttribute.add("alcohol", 9.4f);
|
最后,让我们使用Prediction类进行预测:
Prediction<Regressor> prediction = loadedModel.predict(wineAttribute); double predictQuality = prediction.getOutput().getValues()[0]; log.info("Predicted wine quality: " + predictQuality);
|
以下是预测的葡萄酒品质:
07:31:05.772 [main] INFO tribuo.WineQualityPredictor - Predicted wine quality: 5.028163673540464