python 实现可视化3D图形
首先加载了鸢尾花数据集,将特征向量存储在X中,类别标签存储在y中。然后使用TSNE算法对数据进行降维,得到X_tsne。接下来定义了plot_embedding_3d函数,用于绘制3D图形。
鸢尾花数据集(Iris dataset)是一个经典的多分类问题数据集,常用于机器学习和模式识别中。这个数据集包含了150个样本,每个样本对应一朵鸢尾花。数据集中的每个样本都有四个特征测量值:花萼长度、花萼宽度、花瓣长度和花瓣宽度。
根据这四个特征,每朵鸢尾花被分为三个类别之一:山鸢尾(Setosa)、变色鸢尾(Versicolor)和维吉尼亚鸢尾(Virginica)。每个类别包含50个样本。这个数据集是从鸢尾花的实际测量值中获取的,由英国统计学家和生物学家Ronald Fisher于1936年创建。
鸢尾花数据集是一个经典的机器学习数据集,被广泛用于分类算法的训练和评估。因为它包含了多个特征和多个类别,并且具有较高的可分性,所以它成为了许多分类算法的基准数据集之
最后,调用plot_embedding_3d函数将降维后的数据可视化成3D图形,并添加了标题 "TSNE Visualization of Iris dataset"
1、效果图
2、代码展示
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.datasets import load_iris
from sklearn.manifold import TSNE
# 加载鸢尾花数据集
iris = load_iris()
X = iris.data # 特征向量
y = iris.target # 类别标签
# 使用TSNE进行降维
tsne = TSNE(n_components=3, random_state=0)
X_tsne = tsne.fit_transform(X)
# 定义绘制3D图形的函数
def plot_embedding_3d(X, y, title=None):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in np.unique(y):
ax.scatter(X[y==i, 0], X[y==i, 1], X[y==i, 2], label=iris.target_names[i])
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
if title is not None:
plt.title(title)
plt.show()
# 绘制降维后的3D图形
plot_embedding_3d(X_tsne, y, title='TSNE Visualization of Iris dataset')
3、代码运行逻辑详解
我们使用load_iris函数加载了Iris数据集并将其存储在iris对象中:
iris = load_iris()
然后,我们从iris对象中提取特征向量和类别标签:
y = iris.target # 类别标签
接下来,我们创建了一个TSNE对象,设置参数n_components=3表示希望降维到3维空间,并设置random_state=0作为随机种子:
tsne = TSNE(n_components=3, random_state=0)
然后,我们使用fit_transform方法对数据进行降维,并将结果存储在X_tsne中:
X_tsne = tsne.fit_transform(X)
接下来,我们定义了一个名为plot_embedding_3d的函数,用于绘制3D图形。该函数将降维后的数据X和对应的标签y作为输入参数:
def plot_embedding_3d(X, y, title=None):
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for i in np.unique(y):
ax.scatter(X[y==i, 0], X[y==i, 1], X[y==i, 2], label=iris.target_names[i])
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
if title is not None:
plt.title(title)
plt.show()
在函数中,我们首先创建了一个3D图形的Figure对象和一个Axes3D子图对象。然后,我们使用循环迭代每个类别,并通过scatter方法绘制相应类别的数据点。ax.scatter接受三个维度的坐标作为参数,分别是X、Y、Z。我们使用布尔索引筛选出属于当前类别的数据点,并通过iris.target_names获取类别的名称。最后,我们设置坐标轴的标签和图例,并根据需要添加标题。
最后,我们调用plot_embedding_3d函数来绘制降维后的3D图形:
plot_embedding_3d(X_tsne, y, title='TSNE Visualization of Iris dataset')