写在前面

准备近期将微软的machinelearning-samples翻译成中文,水平有限,如有错漏,请大家多多指正。

如果有朋友对此感兴趣,可以加入我:https://github.com/feiyun0112/machinelearning-samples.zh-cn

鸢尾花分类

ML.NET 版本 API 类型 状态 应用程序类型 数据类型 场景 机器学习任务 算法
v0.7 动态 API 最新版本 控制台应用程序 .txt 文件 鸢尾花分类 多类分类 Sdca Multi-class

在这个介绍性示例中,您将看到如何使用ML.NET来预测鸢尾花的类型。 在机器学习领域,这种类型的预测被称为多类分类

问题

这个问题集中在根据花瓣长度,花瓣宽度等花的参数预测鸢尾花(setosa,versicolor或virginica)的类型。

为了解决这个问题,我们将建立一个ML模型,它有4个输入参数:

  • petal length
  • petal width
  • sepal length
  • sepal width

并预测该花属于哪种鸢尾花类型:

  • setosa
  • versicolor
  • virginica

确切地说,模型将返回花属于每个类型的概率。

ML 任务 - 多类分类

多类分类的广义问题是将项目分类为三个或更多类别中的一个。 (将项目分类为两个类别之一称为二元分类)。

多类分类的其他例子包括:

  • 手写数字识别:预测图像中包含10个数字(0~9)。
  • 问题标记:预测问题属于哪个类别(UI,后端,文档)。
  • 根据患者的测试结果预测疾病阶段。

所有这些例子的共同特点是我们要预测的参数可以取几个(超过两个)值中的一个。换句话说,这个值由enum表示,而不是由integerfloatdoubleboolean类型表示。

解决方案

为了解决这个问题,首先我们将建立一个ML模型。然后,我们将在现有数据上训练模型,评估其有多好,最后我们将使用该模型来预测鸢尾花类型。

1. 建立模型

建立模型包括:

  • 使用DataReader上传数据(iris-train.txt
  • 创建一个评估器并将数据转换为一列,以便ML算法(使用Concatenate)可以有效地使用它。
  • 选择学习算法(StochasticDualCoordinateAscent)。

初始代码类似以下内容:

  1. // Create MLContext to be shared across the model creation workflow objects
  2. // Set a random seed for repeatable/deterministic results across multiple trainings.
  3. var mlContext = new MLContext(seed: 0);
  4. // STEP 1: Common data loading configuration
  5. var textLoader = IrisTextLoaderFactory.CreateTextLoader(mlContext);
  6. var trainingDataView = textLoader.Read(TrainDataPath);
  7. var testDataView = textLoader.Read(TestDataPath);
  8. // STEP 2: Common data process configuration with pipeline data transformations
  9. var dataProcessPipeline = mlContext.Transforms.Concatenate("Features", "SepalLength",
  10. "SepalWidth",
  11. "PetalLength",
  12. "PetalWidth" );
  13. // STEP 3: Set the training algorithm, then create and config the modelBuilder
  14. var modelBuilder = new Common.ModelBuilder<IrisData, IrisPrediction>(mlContext, dataProcessPipeline);
  15. // We apply our selected Trainer
  16. var trainer = mlContext.MulticlassClassification.Trainers.StochasticDualCoordinateAscent(labelColumn: "Label", featureColumn: "Features");
  17. modelBuilder.AddTrainer(trainer);

2. 训练

训练模型是在训练数据(已知鸢尾花类型)上运行所选算法以调整模型参数的过程。它在评估器对象中的Fit() 方法中实现。

为了执行训练,我们只需调用方法时传入在DataView对象中提供的训练数据集(iris-train.txt文件)。

  1. // STEP 4: Train the model fitting to the DataSet
  2. modelBuilder.Train(trainingDataView);
  3. [...]
  4. public ITransformer Train(IDataView trainingData)
  5. {
  6. TrainedModel = TrainingPipeline.Fit(trainingData);
  7. return TrainedModel;
  8. }

3. 评估模型

我们需要这一步来总结我们的模型对新数据的准确性。 为此,上一步中的模型针对另一个未在训练中使用的数据集(iris-test.txt)运行。 此数据集还包含已知的鸢尾花类型。

MulticlassClassification.Evaluate计算模型预测的值和已知类型之间差异的各种指标。

  1. var metrics = modelBuilder.EvaluateMultiClassClassificationModel(testDataView, "Label");
  2. Common.ConsoleHelper.PrintMultiClassClassificationMetrics(trainer.ToString(), metrics);
  3. [...]
  4. public MultiClassClassifierEvaluator.Result EvaluateMultiClassClassificationModel(IDataView testData, string label="Label", string score="Score")
  5. {
  6. CheckTrained();
  7. var predictions = TrainedModel.Transform(testData);
  8. var metrics = _mlcontext.MulticlassClassification.Evaluate(predictions, label: label, score: score);
  9. return metrics;
  10. }

要了解关于如何理解指标的更多信息,请参阅ML.NET指南 中的机器学习词汇表,或者使用任何有关数据科学和机器学习的可用材料.

如果您对模型的质量不满意,可以采用多种方法来改进,这将在examples类别中进行介绍。

4. 使用模型

在模型被训练之后,我们可以使用Predict() API来预测这种花属于每个鸢尾花类型的概率。

  1. var modelScorer = new Common.ModelScorer<IrisData, IrisPrediction>(mlContext);
  2. modelScorer.LoadModelFromZipFile(ModelPath);
  3. var prediction = modelScorer.PredictSingle(SampleIrisData.Iris1);
  4. Console.WriteLine($"Actual: setosa. Predicted probability: setosa: {prediction.Score[0]:0.####}");
  5. Console.WriteLine($" versicolor: {prediction.Score[1]:0.####}");
  6. Console.WriteLine($" virginica: {prediction.Score[2]:0.####}");
  7. [...]
  8. public TPrediction PredictSingle(TObservation input)
  9. {
  10. CheckTrainedModelIsLoaded();
  11. return PredictionFunction.Predict(input);
  12. }

TestIrisData.Iris1中存储有关我们想要预测类型的花的信息。

  1. internal class TestIrisData
  2. {
  3. internal static readonly IrisData Iris1 = new IrisData()
  4. {
  5. SepalLength = 3.3f,
  6. SepalWidth = 1.6f,
  7. PetalLength = 0.2f,
  8. PetalWidth= 5.1f,
  9. }
  10. (...)
  11. }

ML.NET 示例:多类分类之鸢尾花分类的更多相关文章

  1. ML.NET 示例:开篇

    写在前面 准备近期将微软的machinelearning-samples翻译成中文,水平有限,如有错漏,请大家多多指正. 如果有朋友对此感兴趣,可以加入我:https://github.com/fei ...

  2. ML.NET 示例:多类分类之问题分类

    写在前面 准备近期将微软的machinelearning-samples翻译成中文,水平有限,如有错漏,请大家多多指正. 如果有朋友对此感兴趣,可以加入我:https://github.com/fei ...

  3. ML.NET 示例:二元分类之信用卡欺诈检测

    写在前面 准备近期将微软的machinelearning-samples翻译成中文,水平有限,如有错漏,请大家多多指正. 如果有朋友对此感兴趣,可以加入我:https://github.com/fei ...

  4. ML.NET 示例:目录

    ML.NET 示例中文版:https://github.com/feiyun0112/machinelearning-samples.zh-cn 英文原版请访问:https://github.com/ ...

  5. [Python]基于K-Nearest Neighbors[K-NN]算法的鸢尾花分类问题解决方案

    看了原理,总觉得需要用具体问题实现一下机器学习算法的模型,才算学习深刻.而写此博文的目的是,网上关于K-NN解决此问题的博文很多,但大都是调用Python高级库实现,尤其不利于初级学习者本人对模型的理 ...

  6. 做一个logitic分类之鸢尾花数据集的分类

    做一个logitic分类之鸢尾花数据集的分类 Iris 鸢尾花数据集是一个经典数据集,在统计学习和机器学习领域都经常被用作示例.数据集内包含 3 类共 150 条记录,每类各 50 个数据,每条记录都 ...

  7. ipv4理论知识2-分类编址、ip分类、网络标识、主机标识、地址类、地址块

    分类编址 ipv4的体系结构中有分类编址和无分类编址(后续会介绍到),在分类编址时,ipv4地址分为A.B.C.D.E这5类.每类占用的IP比例和个数如下图: ipv4分类识别 计算机以二进制方式存储 ...

  8. OC语言类的本质和分类

    OC语言类的深入和分类 一.分类 (一)分类的基本知识  概念:Category  分类是OC特有的语言,依赖于类. 分类的作用:在不改变原来的类内容的基础上,为类增加一些方法. 添加一个分类: 文件 ...

  9. 李洪强iOS开发之OC语言类的深入和分类

    OC语言类的深入和分类 一.分类 (一)分类的基本知识  概念:Category  分类是OC特有的语言,依赖于类. 分类的作用:在不改变原来的类内容的基础上,为类增加一些方法. 添加一个分类: 文件 ...

随机推荐

  1. WPF:完美自定义MeaagseBox 2.0

    很久前做个一个MessageBox,原文链接:http://www.cnblogs.com/DoNetCoder/p/3843658.html. 不过对比MessageBox还有一些瑕疵.这些天有时间 ...

  2. canvas代替imgage,可以有效的提高大图片加载的速度!

    //加载zepto插件 <script> //定义图片的数量 var total = 17; //获取屏幕的宽度 var zWin = $(window); //定义渲染图片的方法 var ...

  3. L2-024. 部落

    在一个社区里,每个人都有自己的小圈子,还可能同时属于很多不同的朋友圈.我们认为朋友的朋友都算在一个部落里,于是要请你统计一下,在一个给定社区中,到底有多少个互不相交的部落?并且检查任意两个人是否属于同 ...

  4. ntohs, ntohl, htons,htonl对比

    ntohs =net to host short int 16位htons=host to net short int 16位ntohl =net to host long int 32位htonl= ...

  5. mssql sql server 系统更新,如何正确的增加表字段

    转自: http://www.maomao365.com/?p=5277摘要:下文主要讲述,如何对"已上线的系统"中的表,增加新的字段. 系统部署脚本,增加列的方法:在系统脚本发布 ...

  6. ASP.Net上传文件

    在做Web项目时,上传文件是经常会碰到的需求.ASP.Net的WebForm开发模式中,封装了FileUpload控件,可以方便的进行文件上传操作.但有时,你可能不希望使用ASP.Net中的服务器控件 ...

  7. ASP.NET系统对象

    一.ASP.NET 系统对象        Request:用来获取客户端在Web请求期间发送的值,如URL参数,表单参数        Response:用来负者返回到客户端的HTTP输出      ...

  8. Java中的生产消费者问题

    package day190109; import java.util.LinkedList; import java.util.Queue; import java.util.Random; pub ...

  9. MySQL注入与防御

    1.简介 1.1.含义 在一个应用中,数据的安全无疑是最重要的.数据的最终归宿都是数据库,因此如何保证数据库不被恶意攻击者入侵是一项重要且严肃的问题! SQL注入作为一种很流行的攻击手段,一直以来都受 ...

  10. 第六章 第一个Linux驱动程序: 统计单词个数

    一.编写Linux驱动程序的步骤 第1 步:建立Linux 驱动骨架(装载和卸载Linux 驱动) 骨架部分主要是Linux驱动的初始化和退出函数,代码如下: #include <linux/m ...