이번에는 차원축소방법중 하나인 TSNE에 대해서 알아보겠습니다.
간단한 설명 후 실제로 Python 에서 어떻게 T-SNE를 이용하는지에 대한 예를 들어 보겠습니다.
차원 축소는 기계학습(Machine Learning)에서 매우 중요합니다. 왜냐하면 고차원의 데이터를 가지고 모델을 만들면 Under fitting 되기 쉽기 때문입니다. 즉, 쓸모 없는 데이터가 너무 많아서 학습이 되지 않는 현상이지요. 고차원, 즉 여러 데이터 중에서 가장 중요한 데이터만 골라서 (저차원으로 만들어서) 모델에 사용할 수도 있고, 여러 데이터를 이용하여 새로운 데이터를 만들어서 저차원으로 만들수도 있습니다. 어째튼 이처럼 고차원의 데이터를 저차원으로 변환하는 것이 필요합니다. 이것이 바로 차원 축소(Dimension Reduction) 방법입니다. (이외에도 Feature를 만드는 방법에는 Feature Elimination, Feature Selection 등의 방법이 있습니다.) 차원 축소 방법은 선형 방법(Principal Component Analysis (PCA), Independent Component Analysis, Linear Discriminant Analysis, 등)과 비선형 방법(Manifold, Auto-encoder 등)이 있습니다. TSNE는 Manifold 방법중 하나입니다.
SNE (Stochastic Neighbor Embedding)에서 t-SNE (t-distributed Stochastic Neighbor Embedding)로 발전했고, 이후 다시 UMAP (Uniform Manifold Approximation and Projection)으로 발전 하였습니다.
간단한 코드와 설명은 아래와 같습니다.
# python 3.8.6
# 필요한 패키지를 설치 합니다.
#! pip install sklearn
#! pip install seaborn
#! pip install matplotlib
# 필요한 패키지를 로딩 합니다.
from sklearn.datasets import load_digits
from sklearn.manifold import TSNE
import seaborn as sns
from matplotlib import pyplot as plt
# 필요한 데이터를 로드합니다. 여기서는 0부터 9까지의 숫자 데이터 입니다.
data = load_digits()
# 설명을 위한 참고 부분
# 로딩한 데이터의 첫번째 샘플을 보면 아래와 같습니다. 0은 하얀색이고 높은 숫자일 수록 검은 색에 가까움을 나타냅니다.
# 0이 아닌 숫자들을 연결해보면 중앙부분에 하얀색(0)이 있는 숫자 0을 나타내고 있음을 알 수 있습니다.
# >>> data.data[0]
# [ 0., 0., 5., 13., 9., 1., 0., 0.,
# 0., 0., 13., 15., 10., 15., 5., 0.,
# 0., 3., 15., 2., 0., 11., 8., 0.,
# 0., 4., 12., 0., 0., 8., 8., 0.,
# 0., 5., 8., 0., 0., 9., 8., 0.,
# 0., 4., 11., 0., 1., 12., 7., 0.,
# 0., 2., 14., 5., 10., 12., 0., 0.,
# 0., 0., 6., 13., 10., 0., 0., 0. ]
# 실제로 타겟의 첫번째에는 첫번째 샘플의 정답인 0이 들어잇습니다.
# >>> data.target[0]
# 0
# 축소한 차원의 수를 정합니다.
n_components = 2
# TSNE 모델의 인스턴스를 만듭니다.
model = TSNE(n_components=n_components)
# data를 가지고 TSNE 모델을 훈련(적용) 합니다.
X_embedded = model.fit_transform(data.data)
# 훈련된(차원 축소된) 데이터의 첫번째 값을 출력해 봅니다.
print(X_embedded[0])
# [65.49378 -7.3817754]
# 차원 축소된 데이터를 그래프로 만들어서 화면에 출력해 봅니다.
palette = sns.color_palette("bright", 10)
sns.scatterplot(X_embedded[:,0], X_embedded[:,1], hue=data.target, legend='full', palette=palette)
plt.show()
아래는 화면에 출력된 결과 입니다.