理解问题

出租车的车费不仅与距离有关,还涉及乘客数量,是否使用信用卡等因素(这是的出租车是指纽约市的)。所以并不是一个简单的一元方程问题。

准备数据

建立一控制台应用程序工程,新建Data文件夹,在其目录下添加taxi-fare-train.csvtaxi-fare-test.csv文件,不要忘了把它们的Copy to Output Directory属性改为Copy if newer。之后,添加Microsoft.ML类库包。

加载数据

新建MLContext对象,及创建TextLoader对象。TextLoader对象可用于从文件中读取数据。

MLContext mlContext = new MLContext(seed: 0);

_textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
Separator = ",",
HasHeader = true,
Column = new[]
{
new TextLoader.Column("VendorId", DataKind.Text, 0),
new TextLoader.Column("RateCode", DataKind.Text, 1),
new TextLoader.Column("PassengerCount", DataKind.R4, 2),
new TextLoader.Column("TripTime", DataKind.R4, 3),
new TextLoader.Column("TripDistance", DataKind.R4, 4),
new TextLoader.Column("PaymentType", DataKind.Text, 5),
new TextLoader.Column("FareAmount", DataKind.R4, 6)
}
});

提取特征

数据集文件里共有七列,前六列做为特征数据,最后一列是标记数据。

public class TaxiTrip
{
[Column("0")]
public string VendorId; [Column("1")]
public string RateCode; [Column("2")]
public float PassengerCount; [Column("3")]
public float TripTime; [Column("4")]
public float TripDistance; [Column("5")]
public string PaymentType; [Column("6")]
public float FareAmount;
} public class TaxiTripFarePrediction
{
[ColumnName("Score")]
public float FareAmount;
}

训练模型

首先读取训练数据集,其次建立管道。管道中第一步是把FareAmount列复制到Label列,做为标记数据。第二步,通过OneHotEncoding方式将VendorIdRateCodePaymentType三个字符串类型列转换成数值类型列。第三步,合并六个数据列为一个特征数据列。最后一步,选择FastTreeRegressionTrainer算法做为训练方法。

完成管道后,开始训练模型。

IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
.Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
.Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
.Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);

评估模型

这里要使用测试数据集,并用回归问题的Evaluate方法进行评估。

IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Model quality metrics evaluation ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}");

保存模型

完成训练的模型可以被保存为zip文件以备之后使用。

using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(model, fileStream);

使用模型

首先加载已经保存的模型。接着建立预测函数对象,TaxiTrip为函数的输入类型,TaxiTripFarePrediction为输出类型。之后执行预测方法,传入待测数据。

ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
loadedModel = mlContext.Model.Load(stream);
} var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext); var taxiTripSample = new TaxiTrip()
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripTime = 1140,
TripDistance = 3.75f,
PaymentType = "CRD",
FareAmount = 0 // To predict. Actual/Observed = 15.5
}; var prediction = predictionFunction.Predict(taxiTripSample); Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");

完整示例代码

using Microsoft.ML;
using Microsoft.ML.Core.Data;
using Microsoft.ML.Runtime.Data;
using System;
using System.IO; namespace TexiFarePredictor
{
class Program
{
static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv");
static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv");
static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip");
static TextLoader _textLoader; static void Main(string[] args)
{
MLContext mlContext = new MLContext(seed: 0); _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments()
{
Separator = ",",
HasHeader = true,
Column = new[]
{
new TextLoader.Column("VendorId", DataKind.Text, 0),
new TextLoader.Column("RateCode", DataKind.Text, 1),
new TextLoader.Column("PassengerCount", DataKind.R4, 2),
new TextLoader.Column("TripTime", DataKind.R4, 3),
new TextLoader.Column("TripDistance", DataKind.R4, 4),
new TextLoader.Column("PaymentType", DataKind.Text, 5),
new TextLoader.Column("FareAmount", DataKind.R4, 6)
}
}); var model = Train(mlContext, _trainDataPath); Evaluate(mlContext, model); TestSinglePrediction(mlContext); Console.Read();
} public static ITransformer Train(MLContext mlContext, string dataPath)
{
IDataView dataView = _textLoader.Read(dataPath);
var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label")
.Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode"))
.Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType"))
.Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType"))
.Append(mlContext.Regression.Trainers.FastTree());
var model = pipeline.Fit(dataView);
SaveModelAsFile(mlContext, model);
return model;
} private static void SaveModelAsFile(MLContext mlContext, ITransformer model)
{
using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write))
mlContext.Model.Save(model, fileStream);
} private static void Evaluate(MLContext mlContext, ITransformer model)
{
IDataView dataView = _textLoader.Read(_testDataPath);
var predictions = model.Transform(dataView);
var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score");
Console.WriteLine();
Console.WriteLine($"*************************************************");
Console.WriteLine($"* Model quality metrics evaluation ");
Console.WriteLine($"*------------------------------------------------");
Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}");
Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}");
} private static void TestSinglePrediction(MLContext mlContext)
{
ITransformer loadedModel;
using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read))
{
loadedModel = mlContext.Model.Load(stream);
} var predictionFunction = loadedModel.MakePredictionFunction<TaxiTrip, TaxiTripFarePrediction>(mlContext); var taxiTripSample = new TaxiTrip()
{
VendorId = "VTS",
RateCode = "1",
PassengerCount = 1,
TripTime = 1140,
TripDistance = 3.75f,
PaymentType = "CRD",
FareAmount = 0 // To predict. Actual/Observed = 15.5
}; var prediction = predictionFunction.Predict(taxiTripSample); Console.WriteLine($"**********************************************************************");
Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5");
Console.WriteLine($"**********************************************************************");
}
}
}

程序运行后显示的结果:

*************************************************
* Model quality metrics evaluation
*------------------------------------------------
* R2 Score: 0.92
* RMS loss: 2.81
**********************************************************************
Predicted fare: 15.7855, actual fare: 15.5
**********************************************************************

最后的预测结果还是比较符合实际数值的。

ML.NET教程之出租车车费预测(回归问题)的更多相关文章

  1. ML.NET教程之情感分析(二元分类问题)

    机器学习的工作流程分为以下几个步骤: 理解问题 准备数据 加载数据 提取特征 构建与训练 训练模型 评估模型 运行 使用模型 理解问题 本教程需要解决的问题是根据网站内评论的意见采取合适的行动. 可用 ...

  2. ML 05、分类、标注与回归

    机器学习算法 原理.实现与实践 —— 分类.标注与回归 1. 分类问题 分类问题是监督学习的一个核心问题.在监督学习中,当输出变量$Y$取有限个离散值时,预测问题便成为分类问题. 监督学习从数据中学习 ...

  3. 学习ML.NET(2): 使用模型进行预测

    训练模型 在上一篇文章中,我们已经通过LearningPipeline训练好了一个“鸢尾花瓣预测”模型, var model = pipeline.Train<IrisData, IrisPre ...

  4. ML.NET教程之客户细分(聚类问题)

    理解问题 客户细分需要解决的问题是按照客户之间的相似特征区分不同客户群体.这个问题的先决条件中没有可供使用的客户分类列表,只有客户的人物画像. 数据集 已有的数据是公司的历史商业活动记录以及客户的购买 ...

  5. 教程 | Kaggle网站流量预测任务第一名解决方案:从模型到代码详解时序预测

    https://mp.weixin.qq.com/s/JwRXBNmXBaQM2GK6BDRqMw 选自GitHub 作者:Artur Suilin 机器之心编译 参与:蒋思源.路雪.黄小天 近日,A ...

  6. ML.NET 示例:回归之价格预测

    写在前面 准备近期将微软的machinelearning-samples翻译成中文,水平有限,如有错漏,请大家多多指正. 如果有朋友对此感兴趣,可以加入我:https://github.com/fei ...

  7. .NET开发人员如何开始使用ML.NET

    随着谷歌,Facebook发布他们的工具机器学习工具Tensorflow 2和PyTorch ,微软的CNTK 2.7之后不再继续更新(https://docs.microsoft.com/zh-cn ...

  8. ML.NET Model Builder 更新

    ML.NET是面向.NET开发人员的跨平台机器学习框架,而Model Builder是Visual Studio中的UI工具,它使用自动机器学习(AutoML)轻松地允许您训练和使用自定义ML.NET ...

  9. ML.NET相关资源整理

      在人工智能领域,无论是机器学习,还是深度学习等,Python编程语言都是绝对的主流,尽管底层都是C++实现的,似乎人工智能和C#/F#编程语言没什么关系.在人工智能的工程实现,通常都是将Pytho ...

随机推荐

  1. 基于Centos7.5搭建Docker环境

    docker很火,基于容器化技术,实现一次编译到运行.实现运行环境+服务的一键式打包! 00.部署环境 centos7.5(基于vmware搭建的测试环境,可以跟互联网交互,桥接方式联网) docke ...

  2. SQL语句调优三板斧

    改装有顺序------常开的爱车下手 你的系统中有成千上万的语句,那么优化语句从何入手呢 ? 当然是系统中运行最频繁,最核心的语句了.废话不多说,上例子: 这是一天的语句执行情况,里面柱状图表示的是对 ...

  3. swift4.0 对 afn 进行二次封装

    先将  afn 用pod导入到 工程中 创建一个类 ZHttpTools 继承自 AFHTTPSessionManager 一般我们不希望网络请求同时有多个存在,所以我们将这个工具类  设计成单例 代 ...

  4. SQL Server 数据库基础笔记分享(下)

    前言 本文是个人学习SQL Server 数据库时的以往笔记的整理,内容主要是对数据库的基本增删改查的SQL语句操作和约束,视图,存储过程,触发器的基本了解. 注:内容比较基础,适合入门者对SQL S ...

  5. OpenCV 学习笔记 06 SIFT使用中出现版权问题error: (-213:The function/feature is not implemented)

    1 错误原因 1.1 报错全部信息: cv2.error: OpenCV(4.0.1) D:\Build\OpenCV\opencv_contrib-4.0.1\modules\xfeatures2d ...

  6. mysqlcheck与myisamchk的区别

    mysqlcheck和myisamchk都可以用来检测和修复表.(主要是MyISAM表)但是也有以下不同点:1.都可以检查.分析和修复myisam表.但是mysqlcheck也可以检查.分析innod ...

  7. maven项目中利用jmeter-maven-plugin插件直接执行jmeter jmx脚本

    jmeter脚本需要执行脚本,先得下载jmeter并解压jmeter.如想在maven项目中通过mvn install 直接执行jmx文件,这样就能在测试服务器上通过一个命令就能执行下性能测试了,给自 ...

  8. linux每日命令(7):rmdir命令

    rmdir是常用的命令,该命令的功能是删除空目录,一个目录被删除之前必须是空的.(注意,rm - r dir命令可代替rmdir,但是有很大危险性.)删除某目录时也必须具有对父目录的写权限. 一.命令 ...

  9. --save与--save-dev的区别

    --save安装的包会在生产和开发环境中都使用: --save-dev的包只在开发环境中使用,在生产环境中就不需要这个包了,不再打包:

  10. Android开发(二十一)——自动更新

    参考: [1] Android应用自动更新功能的代码实现.http://www.cnblogs.com/coolszy/archive/2012/04/27/2474279.html