【机器学习】在TensorFlow中构建自定义Estimator:深度解析TensorFlow组件Estimator

你是否思考过TensorFlow的tutorial其背后的“魔力”?希望这篇文章至少能给你思考的正确方向。

TensorFlow的基本概念可以去查看TensorFlow官方文档。这里将帮你更好的理解TensorFlow Learn中estimator的工作原理,并指导你构建适合自己特定应用的estimator。

BaseEstimator和Estimator的理解

BaseEstimator是TensorFlow训练和评估模块的抽象和基类。它利用graph_actions.py的隐藏逻辑,提供像fit(),partial_fit(),evaluate()和predict()的基本功能,处理不同类型的输入数据批量拉取(Note:未来learn.DataFrame将替代DataFeeder)。它通过dtypes来检查输入数据的兼容,考虑输入数据是否稀疏需要使用estimators.tensor_signature。

BaseEstimator为monitors,checkpointing等初始化设置,并提供了构建和评估自定义模块的大部分逻辑。_get_train_ops(), _get_eval_ops()和_get_predict_ops()放在子类中实现,给Estimator自定义带来了更大的自由。BaseEstimator也是分布式的。

TensorFlow模块中Estimator的实现给我们重写BaseEstimator子类提供了很好的范本。
例如,Estimator中的_get_train_ops()载入features和targets作为输入,返回训练Operation和损失Tensor的一个tuple。如果你想完成自己的estimator,并且用于非监督机器学习训练,这时你就可以自由决定targets是否可忽略。

类似地,子类中的_get_eval_ops()可自定义metric来评估每步的训练。在TensorFlow的high-level模块中可发现一打适用的metric。它们会返回Tensor对象的字典,表示指定metric的评价ops。

_get_predict_ops()可实现自定义的prediction,例如 概率 v.s. 实际预测输出。它将返回一个Tensor或者Tensor对象的字典,表示预测ops。你可以很轻松的使用父类的predict()函数实现像transform()的功能。

Estimator示例

逻辑回归(LogisticRegressor)

Estimator已经提供了自定义estimator大部分实现。例如,LogisticRegressor仅需实现自己的metric即可,比如AUC,accuracy,precision和recall。开发者使用LogisticRegressor子类即可实现二值分类问题。

随机森林(TensorForestEstimator)

TensorForestEstimator已经增加到TensorFlow Learn。contrib.tensor_forest详细的实现了随机森林算法(Random Forests)评估器,并对外提供high-level API使得开发者构建随机森林评估器更简单。

例如,开发者只需传入params到构造器,params使用params.fill()来填充,而不用传入所有的超参数,Tensor Forest自己的RandomForestGraphs使用这些参数来构建整幅图。

1
2
3
4
5
6
7
8
9
class TensorForestEstimator(estimator.BaseEstimator):
"""An estimator that can train and evaluate a random forest."""
def __init__(self, params, device_assigner=None, model_dir=None,
graph_builder_class=tensor_forest.RandomForestGraphs,
master='', accuracy_metric=None,
tf_random_seed=None, verbose=1,
config=None):
self.params = params.fill()

随机森林算法的接口实现有许多细节,_get_predict_ops()利用tensor_forest.RandomForestGraphs来构建随机森林图,调用graph_builder.inference_graph来获取预测ops。

1
2
3
4
5
6
def _get_predict_ops(self, features):
graph_builder = self.graph_builder_class(
self.params, device_assigner=self.device_assigner, training=False,
**self.construction_args)
features, spec = data_ops.ParseDataTensorOrDict(features)
return graph_builder.inference_graph(features, data_spec=spec)

类似地,使用graph_builder.training_loss来实现_get_train_ops()。注意,TensorForestEstimator使用了tensor_forest.data.data_ops的模块功能,比如 ParseDataTensorOrDict和ParseLabelTensorOrDict解析输入特征和标签。

其它用例

K-means聚类的estimator刚加入项目,放在contrib.factorization.python.ops.kmeans。 更多的例子可以在learn.estimators中找到。

强烈推荐你先领悟代码整体结构,开始实现自己的estimator之旅!

参考:http://terrytangyuan.github.io/2016/07/08/understand-and-build-tensorflow-estimator

侠天,专注于大数据、机器学习和数学相关的内容,并有个人公众号:bigdata_ny分享相关技术文章。

若发现以上文章有任何不妥,请联系我。

image