你是否思考过TensorFlow的tutorial
TensorFlow的基本概念可以去查看TensorFlow官方文档。这里将帮你更好的理解TensorFlow Learn中estimator的工作原理,并指导你构建适合自己特定应用的estimator。
BaseEstimator和Estimator的理解
BaseEstimator是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit(),partial_fit(),evaluate()和predict()的基本功能,处理不同类型的输入数据批量拉取(Note:未来learn.DataFrame将替代DataFeeder)。它通过dtypes来检查输入数据的兼容,考虑输入数据是否稀疏需要使用estimators.tensor_signature。
BaseEstimator为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑。_get_train_ops(), _get_eval_ops()和_get_predict_ops()放在子类中实现,给Estimator自定义带来了更大的自由。BaseEstimator也是分布式的。
TensorFlow模块中Estimator的实现给我们重写BaseEstimator子类提供了很好的范本。
例如,Estimator中的_get_train_ops()载入features和targets作为输入,返回训练Operation和损失Tensor的一个tuple。如果你想完成自己的estimator,并且用于非监督机器学习训练,这时你就可以自由决定targets是否可忽略。
类似地,子类中的_get_eval_ops()可自定义metric来评估每步的训练。在TensorFlow的high-level模块中可发现一打适用的metric。它们会返回Tensor对象的字典,表示指定metric的评价ops。
_get_predict_ops()可实现自定义的prediction,例如 概率 v.s. 实际预测输出。它将返回一个Tensor或者Tensor对象的字典,表示预测ops。你可以很轻松的使用父类的predict()函数实现像transform()的功能。
Estimator示例
逻辑回归(LogisticRegressor)
Estimator已经提供了自定义estimator大部分实现。例如,LogisticRegressor仅需实现自己的metric即可,比如AUC,accuracy,precision和recall。开发者使用LogisticRegressor子类即可实现二值分类问题。
随机森林(TensorForestEstimator)
TensorForestEstimator已经增加到TensorFlow Learn。contrib.tensor_forest详细的实现了随机森林算法(Random Forests)评估器,并对外提供high-level API使得开发者构建随机森林评估器更简单。
例如,开发者只需传入params到构造器,params使用params.fill()来填充,而不用传入所有的超参数,Tensor Forest自己的RandomForestGraphs使用这些参数来构建整幅图。
|
|
随机森林算法的接口实现有许多细节,_get_predict_ops()利用tensor_forest.RandomForestGraphs来构建随机森林图,调用graph_builder.inference_graph来获取预测ops。
|
|
类似地,使用graph_builder.training_loss来实现_get_train_ops()。注意,TensorForestEstimator使用了tensor_forest.data.data_ops的模块功能,比如 ParseDataTensorOrDict和ParseLabelTensorOrDict解析输入特征和标签。
其它用例
K-means聚类的estimator刚加入项目,放在contrib.factorization.python.ops.kmeans。 更多的例子可以在learn.estimators中找到。
强烈推荐你先领悟代码整体结构,开始实现自己的estimator之旅!
参考:http://terrytangyuan.github.io/2016/07/08/understand-and-build-tensorflow-estimator
侠天,专注于大数据、机器学习和数学相关的内容,并有个人公众号:bigdata_ny分享相关技术文章。
若发现以上文章有任何不妥,请联系我。