python 实现可视化3D图形

liftword6个月前 (01-20)技术文章39

首先加载了鸢尾花数据集,将特征向量存储在X中,类别标签存储在y中。然后使用TSNE算法对数据进行降维,得到X_tsne。接下来定义了plot_embedding_3d函数,用于绘制3D图形。



鸢尾花数据集(Iris dataset)是一个经典的多分类问题数据集,常用于机器学习和模式识别中。这个数据集包含了150个样本,每个样本对应一朵鸢尾花。数据集中的每个样本都有四个特征测量值:花萼长度、花萼宽度、花瓣长度和花瓣宽度。

根据这四个特征,每朵鸢尾花被分为三个类别之一:山鸢尾(Setosa)、变色鸢尾(Versicolor)和维吉尼亚鸢尾(Virginica)。每个类别包含50个样本。这个数据集是从鸢尾花的实际测量值中获取的,由英国统计学家和生物学家Ronald Fisher于1936年创建。

鸢尾花数据集是一个经典的机器学习数据集,被广泛用于分类算法的训练和评估。因为它包含了多个特征和多个类别,并且具有较高的可分性,所以它成为了许多分类算法的基准数据集之


最后,调用plot_embedding_3d函数将降维后的数据可视化成3D图形,并添加了标题 "TSNE Visualization of Iris dataset"

1、效果图

2、代码展示

import numpy as np

import matplotlib.pyplot as plt

from mpl_toolkits.mplot3d import Axes3D

from sklearn.datasets import load_iris

from sklearn.manifold import TSNE

# 加载鸢尾花数据集

iris = load_iris()

X = iris.data # 特征向量

y = iris.target # 类别标签

# 使用TSNE进行降维

tsne = TSNE(n_components=3, random_state=0)

X_tsne = tsne.fit_transform(X)

# 定义绘制3D图形的函数

def plot_embedding_3d(X, y, title=None):

fig = plt.figure()

ax = fig.add_subplot(111, projection='3d')

for i in np.unique(y):

ax.scatter(X[y==i, 0], X[y==i, 1], X[y==i, 2], label=iris.target_names[i])

ax.set_xlabel('X')

ax.set_ylabel('Y')

ax.set_zlabel('Z')

ax.legend()

if title is not None:

plt.title(title)

plt.show()

# 绘制降维后的3D图形

plot_embedding_3d(X_tsne, y, title='TSNE Visualization of Iris dataset')

3、代码运行逻辑详解


我们使用load_iris函数加载了Iris数据集并将其存储在iris对象中:

iris = load_iris()

然后,我们从iris对象中提取特征向量和类别标签:

y = iris.target  # 类别标签

接下来,我们创建了一个TSNE对象,设置参数n_components=3表示希望降维到3维空间,并设置random_state=0作为随机种子:

tsne = TSNE(n_components=3, random_state=0)

然后,我们使用fit_transform方法对数据进行降维,并将结果存储在X_tsne中:

X_tsne = tsne.fit_transform(X)

接下来,我们定义了一个名为plot_embedding_3d的函数,用于绘制3D图形。该函数将降维后的数据X和对应的标签y作为输入参数:

def plot_embedding_3d(X, y, title=None):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    for i in np.unique(y):
        ax.scatter(X[y==i, 0], X[y==i, 1], X[y==i, 2], label=iris.target_names[i])
    ax.set_xlabel('X')
    ax.set_ylabel('Y')
    ax.set_zlabel('Z')
    ax.legend()

    if title is not None:
        plt.title(title)

    plt.show()

在函数中,我们首先创建了一个3D图形的Figure对象和一个Axes3D子图对象。然后,我们使用循环迭代每个类别,并通过scatter方法绘制相应类别的数据点。ax.scatter接受三个维度的坐标作为参数,分别是X、Y、Z。我们使用布尔索引筛选出属于当前类别的数据点,并通过iris.target_names获取类别的名称。最后,我们设置坐标轴的标签和图例,并根据需要添加标题。

最后,我们调用plot_embedding_3d函数来绘制降维后的3D图形:

plot_embedding_3d(X_tsne, y, title='TSNE Visualization of Iris dataset')


相关文章

Python超炫技巧!Tkinter神级应用打造震撼视觉的图形化界面设计

目录一、图形化界面设计的基本理解二、窗体控件布局2.1.根窗体显示实例2.2. tkinter 常用控件2.2.1 控件的共同属性2.3 控件布局2.3.1 pack()方法2.3.2 grid()方...

用Python制作一个带图形界面的计算器

大家好,今天我要带大家使用Python制作一个具有图形界面的计算器应用程序。这个项目不仅可以帮助你巩固Python编程基础,还可以让你初步体验图形化编程的乐趣。我们将使用Python的tkinter库...

Python高级技巧5:Python可视化的三种方法

Python可视化是经常在工作中需要用到的方法,比如看一个数据集的数据分布,呈现某几个变量间的关系,排查问题等。Python可视化也有多种库可以使用,每种库都有其特点和适用的场景。本文将为大家介绍常用...