DataLearner 标志DataLearnerAI
AI技术博客
大模型评测排行
大模型评测基准
AI大模型大全
AI资源仓库
AI工具导航

加载中...

DataLearner 标志DataLearner AI

专注大模型评测、数据资源与实践教学的知识平台,持续更新可落地的 AI 能力图谱。

产品

  • 评测榜单
  • 模型对比
  • 数据资源

资源

  • 部署教程
  • 原创内容
  • 工具导航

关于

  • 关于我们
  • 隐私政策
  • 数据收集方法
  • 联系我们

© 2026 DataLearner AI. DataLearner 持续整合行业数据与案例,为科研、企业与开发者提供可靠的大模型情报与实践指南。

隐私政策服务条款
目录
目录
  1. 首页/
  2. 博客列表/
  3. 博客详情

TFboys:使用Tensorflow搭建深层网络分类器

2017/03/08 09:53:51
5,267 阅读
DNNTensorflowtf.contrib.learn神经网络

前言

根据官方文档整理而来的,主要是对Iris数据集进行分类。使用tf.contrib.learn.tf.contrib.learn快速搭建一个深层网络分类器,

步骤

  1. 导入csv数据
  2. 搭建网络分类器
  3. 训练网络
  4. 计算测试集正确率
  5. 对新样本进行分类

数据

Iris数据集包含150行数据,有三种不同的Iris品种分类。每一行数据给出了四个特征信息和一个分类信息。 现在已经将数据分为训练集和测试集

  • A training set of 120 samples (iris_training.csv)
  • A test set of 30 samples (iris_test.csv)

网络搭建

1. 首先,导入tensorflow 和 numpy

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

2. 导入数据

# 定义数据地址
IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"
# 导入数据
training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

load_csv_with_header() 有三个参数

  • filename, 数据地址
  • target_dtype, 目标值的numpy datatype(iris的目标值是0,1,2,所以是np.int)
  • features_dtype, 特征值的numpy datatype .

3. 搭建网络结构

# 每行数据4个特征,都是real-value的
feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

# 3层DNN,3分类问题
classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="iris_model")

参数解释

  • feature_columns 特征值
  • hidden_units=[10, 20, 10]. 3个隐藏层,包含的隐藏神经元依次是10, 20, 10
  • n_classes 类别个数
  • model_dir 模型保存地址

4. 训练数据

classifier.fit(x=training_set.data, y=training_set.target, steps=2000)

steps 为训练次数

5. 计算准确率

accuracy_score = classifier.evaluate(x=test_set.data, y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

运行结果是

Accuracy: 0.966667

6. 对新样本进行预测

# Classify two new flower samples.
new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))

运行结果为:

Prediction: [1 2]

完整代码

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
import numpy as np

IRIS_TRAINING = "iris_training.csv"
IRIS_TEST = "iris_test.csv"

training_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TRAINING,
    target_dtype=np.int,
    features_dtype=np.float32)
test_set = tf.contrib.learn.datasets.base.load_csv_with_header(
    filename=IRIS_TEST,
    target_dtype=np.int,
    features_dtype=np.float32)

feature_columns = [tf.contrib.layers.real_valued_column("", dimension=4)]

classifier = tf.contrib.learn.DNNClassifier(feature_columns=feature_columns,
                                            hidden_units=[10, 20, 10],
                                            n_classes=3,
                                            model_dir="iris_model")

classifier.fit(x=training_set.data,
               y=training_set.target,
               steps=2000)

accuracy_score = classifier.evaluate(x=test_set.data,
                                     y=test_set.target)["accuracy"]
print('Accuracy: {0:f}'.format(accuracy_score))

new_samples = np.array(
    [[6.4, 3.2, 4.5, 1.5], [5.8, 3.1, 5.0, 1.7]], dtype=float)
y = list(classifier.predict(new_samples, as_iterable=True))
print('Predictions: {}'.format(str(y)))

参考

  • tf.contrib.learn Quickstart
  • tf.contrib.learn API

DataLearner 官方微信

欢迎关注 DataLearner 官方微信,获得最新 AI 技术推送

DataLearner 官方微信二维码
返回博客列表

相关博客

  • 谷歌官方高性能大规模高维数据处理库TensorStore发布!
  • Stable Diffusion的Tensorflow/Keras实现及使用
  • TensorFlow与PyTorch近几年发展对比
  • TensorFlow中常见的错误解释及解决方法
  • Tensorflow中数据集的使用方法(tf.data.Dataset)
  • Tensorflow自定义训练模型的样例写法
  • Tensorflow中关于tf.metrics的返回值详解及其用法以及它和tf.losses的区别
  • 自定义Tensorflow模型的训练

热门博客

  • 1Dirichlet Distribution(狄利克雷分布)与Dirichlet Process(狄利克雷过程)
  • 2回归模型中的交互项简介(Interactions in Regression)
  • 3贝塔分布(Beta Distribution)简介及其应用
  • 4矩母函数简介(Moment-generating function)
  • 5普通最小二乘法(Ordinary Least Squares,OLS)的详细推导过程
  • 6使用R语言进行K-means聚类并分析结果
  • 7深度学习技巧之Early Stopping(早停法)
  • 8H5文件简介和使用