python 实现可视化3D图形

liftword4个月前 (01-20)技术文章23

首先加载了鸢尾花数据集,将特征向量存储在X中,类别标签存储在y中。然后使用TSNE算法对数据进行降维,得到X_tsne。接下来定义了plot_embedding_3d函数,用于绘制3D图形。



鸢尾花数据集(Iris dataset)是一个经典的多分类问题数据集,常用于机器学习和模式识别中。这个数据集包含了150个样本,每个样本对应一朵鸢尾花。数据集中的每个样本都有四个特征测量值:花萼长度、花萼宽度、花瓣长度和花瓣宽度。

根据这四个特征,每朵鸢尾花被分为三个类别之一:山鸢尾(Setosa)、变色鸢尾(Versicolor)和维吉尼亚鸢尾(Virginica)。每个类别包含50个样本。这个数据集是从鸢尾花的实际测量值中获取的,由英国统计学家和生物学家Ronald Fisher于1936年创建。

鸢尾花数据集是一个经典的机器学习数据集,被广泛用于分类算法的训练和评估。因为它包含了多个特征和多个类别,并且具有较高的可分性,所以它成为了许多分类算法的基准数据集之


最后,调用plot_embedding_3d函数将降维后的数据可视化成3D图形,并添加了标题 "TSNE Visualization of Iris dataset"

1、效果图

2、代码展示

import numpy as np

import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

from sklearn.datasets import load_iris

from sklearn.manifold import TSNE

# 加载鸢尾花数据集

iris = load_iris()

X = iris.data # 特征向量

y = iris.target # 类别标签

# 使用TSNE进行降维

tsne = TSNE(n_components=3, random_state=0)

X_tsne = tsne.fit_transform(X)

# 定义绘制3D图形的函数

def plot_embedding_3d(X, y, title=None):

fig = plt.figure()

ax = fig.add_subplot(111, projection='3d')

for i in np.unique(y):

ax.scatter(X[y==i, 0], X[y==i, 1], X[y==i, 2], label=iris.target_names[i])

ax.set_xlabel('X')

ax.set_ylabel('Y')

ax.set_zlabel('Z')

ax.legend()

if title is not None:

plt.title(title)

plt.show()

# 绘制降维后的3D图形

plot_embedding_3d(X_tsne, y, title='TSNE Visualization of Iris dataset')

3、代码运行逻辑详解


我们使用load_iris函数加载了Iris数据集并将其存储在iris对象中:

iris = load_iris()

然后,我们从iris对象中提取特征向量和类别标签:

y = iris.target  # 类别标签

接下来,我们创建了一个TSNE对象,设置参数n_components=3表示希望降维到3维空间,并设置random_state=0作为随机种子:

tsne = TSNE(n_components=3, random_state=0)

然后,我们使用fit_transform方法对数据进行降维,并将结果存储在X_tsne中:

X_tsne = tsne.fit_transform(X)

接下来,我们定义了一个名为plot_embedding_3d的函数,用于绘制3D图形。该函数将降维后的数据X和对应的标签y作为输入参数:

def plot_embedding_3d(X, y, title=None):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for i in np.unique(y):
        ax.scatter(X[y==i, 0], X[y==i, 1], X[y==i, 2], label=iris.target_names[i])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.legend()

    if title is not None:
        plt.title(title)

    plt.show()

在函数中,我们首先创建了一个3D图形的Figure对象和一个Axes3D子图对象。然后,我们使用循环迭代每个类别,并通过scatter方法绘制相应类别的数据点。ax.scatter接受三个维度的坐标作为参数,分别是X、Y、Z。我们使用布尔索引筛选出属于当前类别的数据点,并通过iris.target_names获取类别的名称。最后,我们设置坐标轴的标签和图例,并根据需要添加标题。

最后,我们调用plot_embedding_3d函数来绘制降维后的3D图形:

plot_embedding_3d(X_tsne, y, title='TSNE Visualization of Iris dataset')


相关文章

Python超炫技巧!Tkinter神级应用打造震撼视觉的图形化界面设计

目录一、图形化界面设计的基本理解二、窗体控件布局2.1.根窗体显示实例2.2. tkinter 常用控件2.2.1 控件的共同属性2.3 控件布局2.3.1 pack()方法2.3.2 grid()方...

Python | graphics、tkinter 画回归线

前言写 Python 题时遇到一道绘制回归线的题目。要求点击数据点输入进行绘制,我用 graphics 完成了。但是,这样输入并不精确,加上高中受到线性回归方程那庞大计算量的折磨,于是想写个能输数据的...

Python可视化很简单,一文学会绘制柱状图、条形图和直方图

matplotlib库作为Python数据化可视化的最经典和最常用库,掌握了它就相当于学会了Python的数据化可视化,通过前几次呢,咱们已经讨论了使用matplotlib库中的图表组成元素的几个重要...