Fuad Hassan
Global Insights: Country Clustering Analysis Using Unsupervised Machine LearningĀ¶
Objective:Ā¶
To categorise the countries using socio-economic and health factors that determine the overall development of the country.
About organization:Ā¶
HELP International is an international humanitarian NGO that is committed to fighting poverty and providing the people of backward countries with basic amenities and relief during the time of disasters and natural calamities.
Problem Statement:Ā¶
HELP International have been able to raise around $ 10 million. Now the CEO of the NGO needs to decide how to use this money strategically and effectively. So, CEO has to make decision to choose the countries that are in the direst need of aid. Hence, your Job as a Data scientist is to categorise the countries using some socio-economic and health factors that determine the overall development of the country. Then you need to suggest the countries which the CEO needs to focus on the most.
DataĀ¶
The dataset used for this analysis is available on Kaggle and is titled "Unsupervised Learning on Country Data". The dataset contains information about various countries and includes the following columns:
country
: Name of the countrychild_mort
: Child mortality rate (per 1000 live births)exports
: Exports of goods and services (% of GDP)health
: Total health spending (% of GDP)imports
: Imports of goods and services (% of GDP)income
: Net income per personinflation
: The measurement of the annual growth rate of the Total GDPlife_expec
: Life expectancy at birth (in years)total_fer
: The number of children that would be born to each woman if the current age-fertility rates remain the samegdpp
: Gross Domestic Product per capita
The dataset can be accessed through the following link: Unsupervised Learning on Country Data.
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import plotly.express as px
from scipy.stats import boxcox
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from scipy.stats import boxcox
import plotly.graph_objects as go
import warnings
warnings.filterwarnings("ignore")
countries_data = pd.read_csv('https://raw.githubusercontent.com/fuadh246/Global-Insights-Clustering/main/Data/Country-data.csv')
countries_data.head(2)
country | child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|---|
0 | Afghanistan | 90.2 | 10.0 | 7.58 | 44.9 | 1610 | 9.44 | 56.2 | 5.82 | 553 |
1 | Albania | 16.6 | 28.0 | 6.55 | 48.6 | 9930 | 4.49 | 76.3 | 1.65 | 4090 |
countries_data = countries_data.set_index('country')
countries_data.head()
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|
country | |||||||||
Afghanistan | 90.2 | 10.0 | 7.58 | 44.9 | 1610 | 9.44 | 56.2 | 5.82 | 553 |
Albania | 16.6 | 28.0 | 6.55 | 48.6 | 9930 | 4.49 | 76.3 | 1.65 | 4090 |
Algeria | 27.3 | 38.4 | 4.17 | 31.4 | 12900 | 16.10 | 76.5 | 2.89 | 4460 |
Angola | 119.0 | 62.3 | 2.85 | 42.9 | 5900 | 22.40 | 60.1 | 6.16 | 3530 |
Antigua and Barbuda | 10.3 | 45.5 | 6.03 | 58.9 | 19100 | 1.44 | 76.8 | 2.13 | 12200 |
Bangladesh_data = countries_data.loc['Bangladesh']
Bangladesh_data
child_mort 49.40 exports 16.00 health 3.52 imports 21.80 income 2440.00 inflation 7.14 life_expec 70.40 total_fer 2.33 gdpp 758.00 Name: Bangladesh, dtype: float64
Learing about the dataĀ¶
countries_data.shape
(167, 9)
countries_data.info()
<class 'pandas.core.frame.DataFrame'> Index: 167 entries, Afghanistan to Zambia Data columns (total 9 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 child_mort 167 non-null float64 1 exports 167 non-null float64 2 health 167 non-null float64 3 imports 167 non-null float64 4 income 167 non-null int64 5 inflation 167 non-null float64 6 life_expec 167 non-null float64 7 total_fer 167 non-null float64 8 gdpp 167 non-null int64 dtypes: float64(7), int64(2) memory usage: 17.1+ KB
countries_data.describe()
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|
count | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 |
mean | 38.270060 | 41.108976 | 6.815689 | 46.890215 | 17144.688623 | 7.781832 | 70.555689 | 2.947964 | 12964.155689 |
std | 40.328931 | 27.412010 | 2.746837 | 24.209589 | 19278.067698 | 10.570704 | 8.893172 | 1.513848 | 18328.704809 |
min | 2.600000 | 0.109000 | 1.810000 | 0.065900 | 609.000000 | -4.210000 | 32.100000 | 1.150000 | 231.000000 |
25% | 8.250000 | 23.800000 | 4.920000 | 30.200000 | 3355.000000 | 1.810000 | 65.300000 | 1.795000 | 1330.000000 |
50% | 19.300000 | 35.000000 | 6.320000 | 43.300000 | 9960.000000 | 5.390000 | 73.100000 | 2.410000 | 4660.000000 |
75% | 62.100000 | 51.350000 | 8.600000 | 58.750000 | 22800.000000 | 10.750000 | 76.800000 | 3.880000 | 14050.000000 |
max | 208.000000 | 200.000000 | 17.900000 | 174.000000 | 125000.000000 | 104.000000 | 82.800000 | 7.490000 | 105000.000000 |
Data CleaningĀ¶
# num of missing value
countries_data.isna().sum()
child_mort 0 exports 0 health 0 imports 0 income 0 inflation 0 life_expec 0 total_fer 0 gdpp 0 dtype: int64
def find_outliers(df, col, q1, q3):
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
outliers = df[(df[col] < lower_bound) | (df[col] > upper_bound)]
return outliers
def box_plot_all(df):
for col in df.select_dtypes(exclude='object').columns:
if col != 'Price': # excluding price because it is our target
sns.boxplot(data=df[col])
# Calculate mean, Q1, and Q3
mean_val = df[col].mean()
q1 = np.percentile(df[col], 25)
q3 = np.percentile(df[col], 75)
# Plot mean, Q1, and Q3 levels
plt.axhline(mean_val, color='r', linestyle='--', label='Mean')
plt.axhline(q1, color='g', linestyle='--', label='Q1')
plt.axhline(q3, color='b', linestyle='--', label='Q3')
plt.title(f'Box Plot of {col}')
plt.legend()
plt.show()
# Find and display percentage of outliers
outliers = find_outliers(df, col, q1, q3)
print('{} - % of outliers: {}%'.format(col, round(outliers.shape[0] / df.shape[0] * 100, 2)))
print(outliers[col])
box_plot_all(countries_data)
child_mort - % of outliers: 2.4% country Central African Republic 149.0 Chad 150.0 Haiti 208.0 Sierra Leone 160.0 Name: child_mort, dtype: float64
exports - % of outliers: 2.99% country Ireland 103.0 Luxembourg 175.0 Malta 153.0 Seychelles 93.8 Singapore 200.0 Name: exports, dtype: float64
health - % of outliers: 1.2% country Micronesia, Fed. Sts. 14.2 United States 17.9 Name: health, dtype: float64
imports - % of outliers: 2.4% country Luxembourg 142.0 Malta 154.0 Seychelles 108.0 Singapore 174.0 Name: imports, dtype: float64
income - % of outliers: 4.79% country Brunei 80600 Kuwait 75200 Luxembourg 91700 Norway 62300 Qatar 125000 Singapore 72100 Switzerland 55500 United Arab Emirates 57600 Name: income, dtype: int64
inflation - % of outliers: 2.99% country Equatorial Guinea 24.9 Mongolia 39.2 Nigeria 104.0 Timor-Leste 26.5 Venezuela 45.9 Name: inflation, dtype: float64
life_expec - % of outliers: 1.8% country Central African Republic 47.5 Haiti 32.1 Lesotho 46.5 Name: life_expec, dtype: float64
total_fer - % of outliers: 0.6% country Niger 7.49 Name: total_fer, dtype: float64
gdpp - % of outliers: 14.97% country Australia 51900 Austria 46900 Belgium 44400 Brunei 35300 Canada 47400 Denmark 58000 Finland 46200 France 40600 Germany 41800 Iceland 41900 Ireland 48700 Italy 35800 Japan 44500 Kuwait 38500 Luxembourg 105000 Netherlands 50300 New Zealand 33700 Norway 87800 Qatar 70300 Singapore 46600 Sweden 52100 Switzerland 74600 United Arab Emirates 35000 United Kingdom 38900 United States 48400 Name: gdpp, dtype: int64
All data seems to be right, we do not need to fix any outlier
EDAĀ¶
def create_choropleth_maps(df):
for i in df.columns:
fig = px.choropleth(df.reset_index(),
locationmode='country names',
locations='country',
title=i.upper(),
color=i,
color_continuous_scale="Blues",
)
fig.update_geos(fitbounds="locations", visible=True)
fig.show(renderer="svg")
# fig.write_image(f"../Img/{i}.pdf")
create_choropleth_maps(countries_data)
sns.pairplot(countries_data)
<seaborn.axisgrid.PairGrid at 0x7f92f6f7dd50>
CorrelationĀ¶
countries_data.corr().style.background_gradient()
Ā | child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp |
---|---|---|---|---|---|---|---|---|---|
child_mort | 1.000000 | -0.318093 | -0.200402 | -0.127211 | -0.524315 | 0.288276 | -0.886676 | 0.848478 | -0.483032 |
exports | -0.318093 | 1.000000 | -0.114408 | 0.737381 | 0.516784 | -0.107294 | 0.316313 | -0.320011 | 0.418725 |
health | -0.200402 | -0.114408 | 1.000000 | 0.095717 | 0.129579 | -0.255376 | 0.210692 | -0.196674 | 0.345966 |
imports | -0.127211 | 0.737381 | 0.095717 | 1.000000 | 0.122406 | -0.246994 | 0.054391 | -0.159048 | 0.115498 |
income | -0.524315 | 0.516784 | 0.129579 | 0.122406 | 1.000000 | -0.147756 | 0.611962 | -0.501840 | 0.895571 |
inflation | 0.288276 | -0.107294 | -0.255376 | -0.246994 | -0.147756 | 1.000000 | -0.239705 | 0.316921 | -0.221631 |
life_expec | -0.886676 | 0.316313 | 0.210692 | 0.054391 | 0.611962 | -0.239705 | 1.000000 | -0.760875 | 0.600089 |
total_fer | 0.848478 | -0.320011 | -0.196674 | -0.159048 | -0.501840 | 0.316921 | -0.760875 | 1.000000 | -0.454910 |
gdpp | -0.483032 | 0.418725 | 0.345966 | 0.115498 | 0.895571 | -0.221631 | 0.600089 | -0.454910 | 1.000000 |
def norm_check(data):
for var in range(len(data.skew())):
if abs(data.skew()[var]) > 0.75:
print('"{}" skew: {}. The variable is NOT normally distributed.'.format(data.skew().index[var],
round(data.skew()[var],2)))
else:
print('"{}" skew: {}. The Variable is normally distributed'.format(data.skew().index[var],
round(data.skew()[var],2)))
norm_check(countries_data)
"child_mort" skew: 1.45. The variable is NOT normally distributed. "exports" skew: 2.45. The variable is NOT normally distributed. "health" skew: 0.71. The Variable is normally distributed "imports" skew: 1.91. The variable is NOT normally distributed. "income" skew: 2.23. The variable is NOT normally distributed. "inflation" skew: 5.15. The variable is NOT normally distributed. "life_expec" skew: -0.97. The variable is NOT normally distributed. "total_fer" skew: 0.97. The variable is NOT normally distributed. "gdpp" skew: 2.22. The variable is NOT normally distributed.
def corr_matrix(data):
corr_matrix = data.corr()
for var in corr_matrix.columns:
corr_matrix[var] = corr_matrix[var].apply(lambda x:x if abs(x) > 0.70 and abs(x) < 1 else '')
return corr_matrix
corr_matrix(countries_data)
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|
child_mort | -0.886676 | 0.848478 | |||||||
exports | 0.737381 | ||||||||
health | |||||||||
imports | 0.737381 | ||||||||
income | 0.895571 | ||||||||
inflation | |||||||||
life_expec | -0.886676 | -0.760875 | |||||||
total_fer | 0.848478 | -0.760875 | |||||||
gdpp | 0.895571 |
Observations and ExplanationsĀ¶
Close to 1: Strong positive relationship.
Close to -1: Strong negative relationship.
Close to 0: Weak or no relationship.
child_mort (Child Mortality):
- With income: Correlation of -0.524 means that higher income is linked to lower child mortality.
- With life_expec: Correlation of -0.886 shows that higher child mortality is linked to lower life expectancy.
- With total_fer: Correlation of 0.848 suggests that higher child mortality is linked to higher total fertility rates.
exports:
- With imports: Correlation of 0.737 means higher exports are linked to higher imports.
- With gdpp: Correlation of 0.418 means higher exports are linked to higher GDP per capita.
health:
- With gdpp: Correlation of 0.346 shows a positive relationship between health investment and GDP per capita.
income:
- With gdpp: Correlation of 0.896 indicates a strong positive link between income and GDP per capita.
- With life_expec: Correlation of 0.612 suggests a positive link between income and life expectancy.
inflation:
- With income: Correlation of -0.148 shows a weak negative relationship between inflation and income.
- With gdpp: Correlation of -0.222 shows a weak negative relationship between inflation and GDP per capita.
life_expec (Life Expectancy):
- With gdpp: Correlation of 0.600 suggests that higher life expectancy is linked to higher GDP per capita.
total_fer (Total Fertility Rate):
- With gdpp: Correlation of -0.455 suggests higher fertility rates are linked to lower GDP per capita.
Overall, the chart suggests strong relationships between child mortality, life expectancy, income, and GDP per capita. This means that higher income and GDP per capita often lead to better health outcomes like lower child mortality and higher life expectancy.
TransformationĀ¶
BoxCox TransformationĀ¶
We will use boxcox_transform to transform our data.
def boxcox_transform(df, constant=0.01):
transformed_data = pd.DataFrame(index=df.index) # Initialize DataFrame for transformed data with original index
for column in df.select_dtypes(include=[np.number]):
column_data = df[column].dropna() # Drop missing values
min_value = column_data.min()
if min_value <= 0:
column_data += abs(min_value) + constant # Add a constant to ensure all values are positive
transformed_values, _ = boxcox(column_data)
transformed_data[column] = transformed_values
return transformed_data
transformed_df = boxcox_transform(countries_data)
transformed_df
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|
country | |||||||||
Afghanistan | 4.596747 | 3.577908 | 2.941042 | 10.015629 | 10.833801 | 3.745020 | 8.586499e+05 | 1.189628 | 6.275043 |
Albania | 2.846095 | 6.419617 | 2.652941 | 10.457084 | 14.938781 | 2.906548 | 2.675528e+06 | 0.445233 | 8.246487 |
Algeria | 3.357802 | 7.523315 | 1.849401 | 8.205481 | 15.591688 | 4.571104 | 2.701690e+06 | 0.831924 | 8.331633 |
Angola | 4.885951 | 9.476151 | 1.263567 | 9.768596 | 13.688497 | 5.183649 | 1.101861e+06 | 1.213716 | 8.101683 |
Antigua and Barbuda | 2.357391 | 8.169779 | 2.496145 | 11.595983 | 16.602865 | 2.190407 | 2.741283e+06 | 0.634312 | 9.319889 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
Vanuatu | 3.427188 | 8.264001 | 2.243473 | 10.925065 | 12.119789 | 2.495148 | 1.312811e+06 | 0.941978 | 7.931759 |
Venezuela | 2.876553 | 6.478217 | 2.125639 | 5.824603 | 16.221298 | 6.799245 | 2.560084e+06 | 0.733832 | 9.419216 |
Vietnam | 3.194584 | 10.128990 | 2.736880 | 13.642586 | 13.057682 | 4.104579 | 2.281623e+06 | 0.571532 | 7.125736 |
Yemen | 4.106507 | 6.650154 | 2.219629 | 8.640211 | 13.052601 | 5.287811 | 1.696602e+06 | 1.089779 | 7.125736 |
Zambia | 4.511321 | 7.386964 | 2.452434 | 8.130824 | 12.352925 | 4.335625 | 6.433180e+05 | 1.156830 | 7.232565 |
167 rows Ć 9 columns
norm_check(transformed_df)
"child_mort" skew: -0.0. The Variable is normally distributed "exports" skew: 0.19. The Variable is normally distributed "health" skew: -0.01. The Variable is normally distributed "imports" skew: 0.27. The Variable is normally distributed "income" skew: -0.04. The Variable is normally distributed "inflation" skew: 0.34. The Variable is normally distributed "life_expec" skew: -0.19. The Variable is normally distributed "total_fer" skew: 0.09. The Variable is normally distributed "gdpp" skew: 0.0. The Variable is normally distributed
transformed_df.corr().style.background_gradient()
Ā | child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp |
---|---|---|---|---|---|---|---|---|---|
child_mort | 1.000000 | -0.418049 | -0.345278 | -0.159022 | -0.857661 | 0.429721 | -0.913018 | 0.861221 | -0.875941 |
exports | -0.418049 | 1.000000 | -0.076892 | 0.670786 | 0.528811 | -0.178139 | 0.339487 | -0.369051 | 0.482957 |
health | -0.345278 | -0.076892 | 1.000000 | 0.150612 | 0.135468 | -0.303367 | 0.273514 | -0.265663 | 0.251192 |
imports | -0.159022 | 0.670786 | 0.150612 | 1.000000 | 0.042087 | -0.322692 | 0.044210 | -0.160017 | 0.056133 |
income | -0.857661 | 0.528811 | 0.135468 | 0.042087 | 1.000000 | -0.253511 | 0.838362 | -0.742142 | 0.974390 |
inflation | 0.429721 | -0.178139 | -0.303367 | -0.322692 | -0.253511 | 1.000000 | -0.365269 | 0.391192 | -0.319505 |
life_expec | -0.913018 | 0.339487 | 0.273514 | 0.044210 | 0.838362 | -0.365269 | 1.000000 | -0.778634 | 0.849076 |
total_fer | 0.861221 | -0.369051 | -0.265663 | -0.160017 | -0.742142 | 0.391192 | -0.778634 | 1.000000 | -0.725514 |
gdpp | -0.875941 | 0.482957 | 0.251192 | 0.056133 | 0.974390 | -0.319505 | 0.849076 | -0.725514 | 1.000000 |
def corr_matrix(data):
corr_matrix = data.corr()
for var in corr_matrix.columns:
corr_matrix[var] = corr_matrix[var].apply(lambda x:x if abs(x) > 0.70 and abs(x) < 1 else '')
return corr_matrix
corr_matrix(transformed_df)
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|
child_mort | -0.857661 | -0.913018 | 0.861221 | -0.875941 | |||||
exports | |||||||||
health | |||||||||
imports | |||||||||
income | -0.857661 | 0.838362 | -0.742142 | 0.97439 | |||||
inflation | |||||||||
life_expec | -0.913018 | 0.838362 | -0.778634 | 0.849076 | |||||
total_fer | 0.861221 | -0.742142 | -0.778634 | -0.725514 | |||||
gdpp | -0.875941 | 0.97439 | 0.849076 | -0.725514 |
print('Pairplot - Data BoxCox Transformed')
sns.pairplot(transformed_df)
plt.show()
Pairplot - Data BoxCox Transformed
transformed_df.describe()
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | |
---|---|---|---|---|---|---|---|---|---|
count | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 167.000000 | 1.670000e+02 | 167.000000 | 167.000000 |
mean | 3.097442 | 7.317915 | 2.626042 | 9.880608 | 14.901508 | 3.165718 | 2.149662e+06 | 0.736725 | 8.424489 |
std | 1.193050 | 2.324811 | 0.790563 | 2.744099 | 2.980422 | 1.275490 | 8.480345e+05 | 0.296682 | 1.474337 |
min | 0.959732 | -1.529486 | 0.659003 | -1.573824 | 8.922385 | -2.683594 | 1.070752e+05 | 0.135186 | 5.412460 |
25% | 2.129960 | 5.898570 | 2.129177 | 8.025118 | 12.402926 | 2.290620 | 1.499989e+06 | 0.510173 | 7.140556 |
50% | 3.000858 | 7.186307 | 2.584629 | 9.818504 | 14.946213 | 3.081479 | 2.281623e+06 | 0.717791 | 8.374756 |
75% | 4.208314 | 8.655259 | 3.202136 | 11.580214 | 17.071952 | 3.927933 | 2.741283e+06 | 0.997138 | 9.457630 |
max | 5.471019 | 15.840537 | 4.967345 | 20.211924 | 22.019941 | 9.159315 | 3.625682e+06 | 1.291821 | 11.427076 |
scaled_transformed_df = StandardScaler().fit_transform(transformed_df)
scaled_transformed_df
array([[ 1.2604782 , -1.61357374, 0.39964884, ..., -1.52693606, 1.53115139, -1.46229157], [-0.21131003, -0.38755809, 0.03412817, ..., 0.62196501, -0.98546004, -0.121097 ], [ 0.21888667, 0.08861708, -0.98534321, ..., 0.65290794, 0.32184544, -0.063171 ], ..., [ 0.08166812, 1.21279898, 0.14062292, ..., 0.15607676, -0.5584737 , -0.88355629], [ 0.84832929, -0.2880961 , -0.51562558, ..., -0.53585296, 1.19358826, -0.88355629], [ 1.18865967, 0.0297901 , -0.220261 , ..., -1.78161855, 1.42027079, -0.81087922]])
Model DevelopmentĀ¶
# libraries
from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import DBSCAN
from scipy.cluster import hierarchy
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
KMeansĀ¶
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
def distortion_function(k, dataset, init='k-means++', random_state=98):
# Fit KMeans with the specified number of clusters
km = KMeans(init='k-means++', n_clusters=k, random_state=98)
km.fit(dataset)
# Create an empty DataFrame with the desired columns
dist_matrix = pd.DataFrame(columns=['Cluster', 'Distance'])
# Calculate the distance for each data point from its cluster center
for x in range(len(dataset)):
distance = ((dataset[x] - km.cluster_centers_[km.labels_[x]]) ** 2).sum()
new_row = pd.DataFrame({'Cluster': [km.labels_[x]], 'Distance': [distance]})
# Concatenate the new row to the existing DataFrame
dist_matrix = pd.concat([dist_matrix, new_row], ignore_index=True)
# Calculate the mean of the mean distances grouped by cluster
distortion = dist_matrix.groupby('Cluster').mean().mean()[0]
return distortion
# Initialize lists to store inertia, distortion, and silhouette scores
inertia = []
distortion = []
silhouette = []
# Iterate through different values of k (number of clusters)
for k in range(2, 11):
# Fit KMeans with the specified number of clusters
km = KMeans(init='k-means++', n_clusters=k, random_state=98)
km.fit(scaled_transformed_df) # Replace 'scaled_ds' with your dataset variable
# Append inertia, distortion, and silhouette scores to respective lists
inertia.append(km.inertia_)
distortion.append(distortion_function(k=k, dataset=scaled_transformed_df))
silhouette.append(silhouette_score(scaled_transformed_df, km.labels_))
# The lists 'inertia', 'distortion', and 'silhouette' now contain the respective scores
#plt.figure(figsize=(8,5.5))
plt.plot(range(2,11), inertia)
plt.scatter(range(2,11),inertia, color='red')
plt.title('Number of Clusters vs Inertia')
plt.xlabel('Number of Clusters')
plt.ylabel('Inertia')
plt.xticks(ticks=range(2,11), labels=range(2,11))
plt.grid(True)
plt.annotate('Inflection Point: K=3',
xy=(3.05,745),
xytext=(7,850),
arrowprops = dict(arrowstyle='->', connectionstyle='arc3', color='orange', lw=2))
plt.annotate('Inflection Point: K=5',
xy=(5.05,600),
xytext=(9,790),
arrowprops = dict(arrowstyle='->', connectionstyle='arc3', color='orange', lw=2))
plt.show()
km = KMeans(init='k-means++', n_clusters=4, random_state=98)
km.fit(scaled_transformed_df)
transformed_df['cluster'] = km.labels_
print('Pairplot - Number of K-Means clusters (using all data features): 4')
sns.pairplot(transformed_df, hue='cluster', palette={0:'blue', 1:'orange', 2:'green',3:'Black'})
plt.show()
Pairplot - Number of K-Means clusters (using all data features): 4
slct_ds = transformed_df.drop(['health', 'imports', 'inflation', 'exports', 'cluster'], axis=1)
slct_ds_scld = StandardScaler().fit_transform(slct_ds)
km = KMeans(init='k-means++', n_clusters=4, random_state=98)
km.fit(slct_ds_scld)
km_cluster = pd.DataFrame(slct_ds_scld, index=slct_ds.index)
km_cluster['cluster'] = km.labels_
km_cluster.tail()
0 | 1 | 2 | 3 | 4 | cluster | |
---|---|---|---|---|---|---|
country | ||||||
Vanuatu | 0.277220 | -0.936138 | -0.989779 | 0.693909 | -0.335210 | 2 |
Venezuela | -0.185703 | 0.444152 | 0.485425 | -0.009781 | 0.676724 | 3 |
Vietnam | 0.081668 | -0.620507 | 0.156077 | -0.558474 | -0.883556 | 2 |
Yemen | 0.848329 | -0.622217 | -0.535853 | 1.193588 | -0.883556 | 2 |
Zambia | 1.188660 | -0.857680 | -1.781619 | 1.420271 | -0.810879 | 1 |
print('Pairplot - Number of Cluster: 4 (Selected Features only)')
sns.pairplot(km_cluster, hue='cluster', palette={0:'blue', 1:'orange', 2:'green', 3:'red'})
plt.show()
Pairplot - Number of Cluster: 4 (Selected Features only)
cluster_size = km_cluster.groupby(['cluster'], as_index=False)['cluster'].count()
cluster_size.rename(columns={'cluster':'size'}, inplace=True)
cluster_size['cluster'] = [0,1,2,3]
cluster_size
size | cluster | |
---|---|---|
0 | 40 | 0 |
1 | 39 | 1 |
2 | 37 | 2 |
3 | 51 | 3 |
Extra:
PCAĀ¶
# Reduce the dimension to 2 components to visualize clusters on a 2d plot
pca = PCA(n_components=2)
pca_data = pca.fit_transform(km_cluster.iloc[:,:-1])
pca_data = pd.DataFrame(pca_data, index=slct_ds.index, columns=['1st Principal Component', '2nd Principal Component'])
pca_data['cluster'] = km.labels_
pca_data.head()
1st Principal Component | 2nd Principal Component | cluster | |
---|---|---|---|
country | |||
Afghanistan | -3.190472 | 0.132609 | 1 |
Albania | 0.739870 | -0.843691 | 3 |
Algeria | 0.133452 | 0.319914 | 3 |
Angola | -2.209307 | 1.244740 | 1 |
Antigua and Barbuda | 1.281350 | 0.136909 | 3 |
fig = plt.figure(figsize=(15,11))
clusters = fig.add_subplot(2,2,1)
sns.scatterplot(data=pca_data, x='1st Principal Component', y='2nd Principal Component', hue='cluster',
palette='tab10', s=75, ax=clusters)
clusters.set_title('Clusters - Using PCA for plotting points on 2 dimensions')
clusters.set_xlabel('1st Principal Component')
clusters.set_ylabel('2nd Principal Component')
Text(0, 0.5, '2nd Principal Component')
cluster_size = km_cluster.groupby(['cluster'], as_index=False)['cluster'].count()
cluster_size.rename(columns={'cluster':'size'}, inplace=True)
cluster_size['cluster'] = [0,1,2,3]
cluster_size
size | cluster | |
---|---|---|
0 | 40 | 0 |
1 | 39 | 1 |
2 | 37 | 2 |
3 | 51 | 3 |
Cluster 1: countries with the highest values in Income, Life Expectancy, GDP per capita, and lowest values in Child Mortality and Total Fertility. We can label these countries as "Highly developed".
Cluster 3: these countries have lower values than the Highly Developed observations in Income, Life Expectancy, GDP per capita but still above the overall average. Also have higher values than cluster 1 in Child Mortality and Total Fertility, however below the dataset average. We can label these countries as "Upper-middle developed".
Cluster 0: countries with Income, Life Expectancy, GDP per capita mean values below the average, and Child Mortality and Total Fertility mean values above the average. We can label these countries as "Lower-middle developed".
Cluster 2: countries with, by far, the highest values in Child Mortality and Total Fertility as well as the smallest values in Income, Life Expectancy, and GDP per capita. We can label these countries as the "Least developed".
countries_data['cluster'] = km.labels_
countries_data['label'] = None
countries_data.loc[countries_data['cluster'] == 0, 'label'] = 'Highly Developed'
countries_data.loc[countries_data['cluster'] == 1, 'label'] = 'Least Developed'
countries_data.loc[countries_data['cluster'] == 2, 'label'] = 'Lower-Middle Developed'
countries_data.loc[countries_data['cluster'] == 3, 'label'] = 'Upper-middle developed'
countries_data.head()
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | cluster | label | |
---|---|---|---|---|---|---|---|---|---|---|---|
country | |||||||||||
Afghanistan | 90.2 | 10.0 | 7.58 | 44.9 | 1610 | 9.44 | 56.2 | 5.82 | 553 | 1 | Least Developed |
Albania | 16.6 | 28.0 | 6.55 | 48.6 | 9930 | 4.49 | 76.3 | 1.65 | 4090 | 3 | Upper-middle developed |
Algeria | 27.3 | 38.4 | 4.17 | 31.4 | 12900 | 16.10 | 76.5 | 2.89 | 4460 | 3 | Upper-middle developed |
Angola | 119.0 | 62.3 | 2.85 | 42.9 | 5900 | 22.40 | 60.1 | 6.16 | 3530 | 1 | Least Developed |
Antigua and Barbuda | 10.3 | 45.5 | 6.03 | 58.9 | 19100 | 1.44 | 76.8 | 2.13 | 12200 | 3 | Upper-middle developed |
countries_data.groupby('label')['cluster'].count()
label Highly Developed 40 Least Developed 39 Lower-Middle Developed 37 Upper-middle developed 51 Name: cluster, dtype: int64
countries_data.groupby(['cluster','label'], as_index=True).mean()
child_mort | exports | health | imports | income | inflation | life_expec | total_fer | gdpp | ||
---|---|---|---|---|---|---|---|---|---|---|
cluster | label | |||||||||
0 | Highly Developed | 5.050000 | 58.730000 | 8.641500 | 51.977500 | 43277.500000 | 2.567900 | 79.692500 | 1.725750 | 39590.000000 |
1 | Least Developed | 98.094872 | 26.895128 | 6.400769 | 43.784615 | 2082.128205 | 11.166795 | 58.464103 | 5.196154 | 933.897436 |
2 | Lower-Middle Developed | 42.581081 | 37.929432 | 5.670270 | 47.504484 | 7065.675676 | 9.694081 | 67.772973 | 3.211892 | 3319.783784 |
3 | Upper-middle developed | 15.449020 | 40.464706 | 6.531961 | 44.829412 | 15479.019608 | 7.895373 | 74.654902 | 1.995882 | 8277.647059 |
def create_choropleth_map(df, as_svg=False):
fig = px.choropleth(df.reset_index(),
locationmode='country names',
locations='country',
color='label',
title='Countries',
hover_name='country',
labels={'label': 'Cluster Label'},
color_discrete_map={
'Least Developed': 'lightcoral',
'Upper-middle developed': 'lightblue',
}
)
fig.update_geos(fitbounds="locations", visible=True)
fig.update_layout(height=600, width=1000)
fig.show(fullscreen=True, renderer="svg" if as_svg else None)
# fig.write_image("../Img/choropleth_map_of_all_countries.png")
# Call the function with the provided data
create_choropleth_map(countries_data)
create_choropleth_map(countries_data, as_svg=True)
print(f"Highly developed countries are {countries_data[countries_data['cluster']==0].index.values}")
Highly developed countries are ['Australia' 'Austria' 'Bahrain' 'Belgium' 'Brunei' 'Canada' 'Croatia' 'Cyprus' 'Czech Republic' 'Denmark' 'Estonia' 'Finland' 'France' 'Germany' 'Greece' 'Hungary' 'Iceland' 'Ireland' 'Israel' 'Italy' 'Japan' 'Kuwait' 'Luxembourg' 'Malta' 'Netherlands' 'New Zealand' 'Norway' 'Poland' 'Portugal' 'Qatar' 'Singapore' 'Slovak Republic' 'Slovenia' 'South Korea' 'Spain' 'Sweden' 'Switzerland' 'United Arab Emirates' 'United Kingdom' 'United States']
print(f"Least developed countries are {countries_data[countries_data['cluster']==1].index.values}")
Least developed countries are ['Afghanistan' 'Angola' 'Benin' 'Burkina Faso' 'Burundi' 'Cameroon' 'Central African Republic' 'Chad' 'Comoros' 'Congo, Dem. Rep.' 'Congo, Rep.' "Cote d'Ivoire" 'Eritrea' 'Gambia' 'Ghana' 'Guinea' 'Guinea-Bissau' 'Haiti' 'Kenya' 'Kiribati' 'Lesotho' 'Liberia' 'Madagascar' 'Malawi' 'Mali' 'Mauritania' 'Mozambique' 'Niger' 'Nigeria' 'Pakistan' 'Rwanda' 'Senegal' 'Sierra Leone' 'Solomon Islands' 'Sudan' 'Tanzania' 'Togo' 'Uganda' 'Zambia']
print(f"Lower-Middle developed countries are {countries_data[countries_data['cluster']==2].index.values}")
Lower-Middle developed countries are ['Bangladesh' 'Belize' 'Bhutan' 'Bolivia' 'Botswana' 'Cambodia' 'Cape Verde' 'Egypt' 'Equatorial Guinea' 'Fiji' 'Gabon' 'Guatemala' 'Guyana' 'India' 'Indonesia' 'Iraq' 'Jordan' 'Kyrgyz Republic' 'Lao' 'Micronesia, Fed. Sts.' 'Mongolia' 'Morocco' 'Myanmar' 'Namibia' 'Nepal' 'Paraguay' 'Philippines' 'Samoa' 'South Africa' 'Tajikistan' 'Timor-Leste' 'Tonga' 'Turkmenistan' 'Uzbekistan' 'Vanuatu' 'Vietnam' 'Yemen']
print(f"Upper-Middle Developed countries are {countries_data[countries_data['cluster']==3].index.values}")
Upper-Middle Developed countries are ['Albania' 'Algeria' 'Antigua and Barbuda' 'Argentina' 'Armenia' 'Azerbaijan' 'Bahamas' 'Barbados' 'Belarus' 'Bosnia and Herzegovina' 'Brazil' 'Bulgaria' 'Chile' 'China' 'Colombia' 'Costa Rica' 'Dominican Republic' 'Ecuador' 'El Salvador' 'Georgia' 'Grenada' 'Iran' 'Jamaica' 'Kazakhstan' 'Latvia' 'Lebanon' 'Libya' 'Lithuania' 'Macedonia, FYR' 'Malaysia' 'Maldives' 'Mauritius' 'Moldova' 'Montenegro' 'Oman' 'Panama' 'Peru' 'Romania' 'Russia' 'Saudi Arabia' 'Serbia' 'Seychelles' 'Sri Lanka' 'St. Vincent and the Grenadines' 'Suriname' 'Thailand' 'Tunisia' 'Turkey' 'Ukraine' 'Uruguay' 'Venezuela']
cluster_ds = countries_data.groupby(['cluster','label'], as_index=True).mean()
cluster_ds.T
cluster | 0 | 1 | 2 | 3 |
---|---|---|---|---|
label | Highly Developed | Least Developed | Lower-Middle Developed | Upper-middle developed |
child_mort | 5.05000 | 98.094872 | 42.581081 | 15.449020 |
exports | 58.73000 | 26.895128 | 37.929432 | 40.464706 |
health | 8.64150 | 6.400769 | 5.670270 | 6.531961 |
imports | 51.97750 | 43.784615 | 47.504484 | 44.829412 |
income | 43277.50000 | 2082.128205 | 7065.675676 | 15479.019608 |
inflation | 2.56790 | 11.166795 | 9.694081 | 7.895373 |
life_expec | 79.69250 | 58.464103 | 67.772973 | 74.654902 |
total_fer | 1.72575 | 5.196154 | 3.211892 | 1.995882 |
gdpp | 39590.00000 | 933.897436 | 3319.783784 | 8277.647059 |
country_ds_norm = pd.DataFrame(MinMaxScaler().fit_transform(countries_data[countries_data.columns[:-2]]),
columns=countries_data.columns[:-2])
country_ds_norm['cluster'] = km.labels_
group_ds = country_ds_norm.groupby(['cluster'], as_index=True).mean()
labels = cluster_ds.columns
cluster_0 = np.ravel(group_ds.loc[0].values)
cluster_1 = np.ravel(group_ds.loc[1].values)
cluster_2 = np.ravel(group_ds.loc[2].values)
cluster_3 = np.ravel(group_ds.loc[3].values)
x = np.arange(len(labels)) # the label locations
width = 0.2 # the width of the bars
fig, ax = plt.subplots(figsize=(16,8))
rects1 = ax.bar(x - width*1.5, cluster_2, width, label='Lower-Middle Developed', color='green')
rects2 = ax.bar(x - width/2, cluster_0, width, label='Highly Developed', color='steelblue')
rects3 = ax.bar(x + width/2, cluster_3, width, label='Upper-middle developed', color='red')
rects4 = ax.bar(x + width*1.5, cluster_1, width, label='Least Developed', color='orange')
# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Normalized values (scale from 0 to 1)', fontsize=12)
ax.set_xlabel('Data Features', fontsize=12)
ax.set_title('Overview Cluster Characteristics: Feature mean value per cluster', fontsize=14)
ax.set_xticks(x, labels, fontsize=12)
ax.legend()
fig.tight_layout()
plt.show()
K-Means clustered the 167 countries included in the dataset into the following categories:
- Cluster 0 - Highly developed: countries with the highest values in Income, Life Expectancy, GDP per capita, and lowest values in Child Mortality and Total Fertility.
- Cluster 3 - Upper-middle developed: Income, Life Expectancy, GDP per capita mean values are smaller than the Highly developed countrie's average, but above the population average. Child Mortality and Total Fertility means values are greater then cluster 1 mean values, however still below the dataset average.
- Cluster 2 - Lower-middle developed: countries with Income, Life Expectancy, GDP per capita mean values below the average, and Child Mortality and Total Fertility mean values above the average.
- Cluster 1 - Least developed: countries with, by far, the highest values in Child Mortality and Total Fertility as well as the smallest values in Income, Life Expectancy, and GDP per capita.
df = cluster_ds.reset_index()
# Define the attributes (variables)
attributes = ['child_mort', 'exports', 'health', 'imports', 'income', 'inflation', 'life_expec', 'total_fer', 'gdpp']
# Define colors for each cluster
colors = ['steelblue', 'orange', 'green', 'red']
scaler = StandardScaler()
df[attributes] = scaler.fit_transform(df[attributes])
# Iterate through each cluster and create a separate radar plot
for i, cluster_id in enumerate(df['cluster']):
# Get the cluster label and values
cluster_label = df.loc[i, 'label']
cluster_values = df.loc[i, attributes].values
# Create a new radar plot for the current cluster
fig = go.Figure()
# Add the trace for the current cluster
fig.add_trace(go.Scatterpolar(r=cluster_values,
theta=attributes,
fill='toself',
fillcolor=colors[i],
opacity=.80,
name=cluster_label))
# Customize layout
fig.update_layout(polar=dict(radialaxis=dict(visible=True)),
showlegend=True,
title=f"Radar Plot - Cluster {cluster_id} ({cluster_label})")
# Show the plot
fig.show()