Fuad Hassan

Global Insights: Country Clustering Analysis Using Unsupervised Machine Learning¶

Objective:¶

To categorise the countries using socio-economic and health factors that determine the overall development of the country.

About organization:¶

HELP International is an international humanitarian NGO that is committed to fighting poverty and providing the people of backward countries with basic amenities and relief during the time of disasters and natural calamities.

Problem Statement:¶

HELP International have been able to raise around $ 10 million. Now the CEO of the NGO needs to decide how to use this money strategically and effectively. So, CEO has to make decision to choose the countries that are in the direst need of aid. Hence, your Job as a Data scientist is to categorise the countries using some socio-economic and health factors that determine the overall development of the country. Then you need to suggest the countries which the CEO needs to focus on the most.

Data¶

The dataset used for this analysis is available on Kaggle and is titled "Unsupervised Learning on Country Data". The dataset contains information about various countries and includes the following columns:

  • country: Name of the country
  • child_mort: Child mortality rate (per 1000 live births)
  • exports: Exports of goods and services (% of GDP)
  • health: Total health spending (% of GDP)
  • imports: Imports of goods and services (% of GDP)
  • income: Net income per person
  • inflation: The measurement of the annual growth rate of the Total GDP
  • life_expec: Life expectancy at birth (in years)
  • total_fer: The number of children that would be born to each woman if the current age-fertility rates remain the same
  • gdpp: Gross Domestic Product per capita

The dataset can be accessed through the following link: Unsupervised Learning on Country Data.

InĀ [1]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import plotly.express as px
from scipy.stats import boxcox
from sklearn.preprocessing import StandardScaler,MinMaxScaler
from scipy.stats import boxcox
import plotly.graph_objects as go
import warnings
warnings.filterwarnings("ignore")
InĀ [2]:
countries_data = pd.read_csv('https://raw.githubusercontent.com/fuadh246/Global-Insights-Clustering/main/Data/Country-data.csv')
countries_data.head(2)
Out[2]:
country child_mort exports health imports income inflation life_expec total_fer gdpp
0 Afghanistan 90.2 10.0 7.58 44.9 1610 9.44 56.2 5.82 553
1 Albania 16.6 28.0 6.55 48.6 9930 4.49 76.3 1.65 4090
InĀ [3]:
countries_data = countries_data.set_index('country')
countries_data.head()
Out[3]:
child_mort exports health imports income inflation life_expec total_fer gdpp
country
Afghanistan 90.2 10.0 7.58 44.9 1610 9.44 56.2 5.82 553
Albania 16.6 28.0 6.55 48.6 9930 4.49 76.3 1.65 4090
Algeria 27.3 38.4 4.17 31.4 12900 16.10 76.5 2.89 4460
Angola 119.0 62.3 2.85 42.9 5900 22.40 60.1 6.16 3530
Antigua and Barbuda 10.3 45.5 6.03 58.9 19100 1.44 76.8 2.13 12200
InĀ [4]:
Bangladesh_data = countries_data.loc['Bangladesh']
Bangladesh_data
Out[4]:
child_mort      49.40
exports         16.00
health           3.52
imports         21.80
income        2440.00
inflation        7.14
life_expec      70.40
total_fer        2.33
gdpp           758.00
Name: Bangladesh, dtype: float64

Learing about the data¶

InĀ [5]:
countries_data.shape
Out[5]:
(167, 9)
InĀ [6]:
countries_data.info()
<class 'pandas.core.frame.DataFrame'>
Index: 167 entries, Afghanistan to Zambia
Data columns (total 9 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   child_mort  167 non-null    float64
 1   exports     167 non-null    float64
 2   health      167 non-null    float64
 3   imports     167 non-null    float64
 4   income      167 non-null    int64  
 5   inflation   167 non-null    float64
 6   life_expec  167 non-null    float64
 7   total_fer   167 non-null    float64
 8   gdpp        167 non-null    int64  
dtypes: float64(7), int64(2)
memory usage: 17.1+ KB
InĀ [7]:
countries_data.describe()
Out[7]:
child_mort exports health imports income inflation life_expec total_fer gdpp
count 167.000000 167.000000 167.000000 167.000000 167.000000 167.000000 167.000000 167.000000 167.000000
mean 38.270060 41.108976 6.815689 46.890215 17144.688623 7.781832 70.555689 2.947964 12964.155689
std 40.328931 27.412010 2.746837 24.209589 19278.067698 10.570704 8.893172 1.513848 18328.704809
min 2.600000 0.109000 1.810000 0.065900 609.000000 -4.210000 32.100000 1.150000 231.000000
25% 8.250000 23.800000 4.920000 30.200000 3355.000000 1.810000 65.300000 1.795000 1330.000000
50% 19.300000 35.000000 6.320000 43.300000 9960.000000 5.390000 73.100000 2.410000 4660.000000
75% 62.100000 51.350000 8.600000 58.750000 22800.000000 10.750000 76.800000 3.880000 14050.000000
max 208.000000 200.000000 17.900000 174.000000 125000.000000 104.000000 82.800000 7.490000 105000.000000

Data Cleaning¶

InĀ [8]:
# num of missing value
countries_data.isna().sum()
Out[8]:
child_mort    0
exports       0
health        0
imports       0
income        0
inflation     0
life_expec    0
total_fer     0
gdpp          0
dtype: int64
InĀ [9]:
def find_outliers(df, col, q1, q3):
    iqr = q3 - q1
    lower_bound = q1 - 1.5 * iqr
    upper_bound = q3 + 1.5 * iqr
    outliers = df[(df[col] < lower_bound) | (df[col] > upper_bound)]
    return outliers

def box_plot_all(df):
    for col in df.select_dtypes(exclude='object').columns:
        if col != 'Price':  # excluding price because it is our target
            sns.boxplot(data=df[col])

            # Calculate mean, Q1, and Q3
            mean_val = df[col].mean()
            q1 = np.percentile(df[col], 25)
            q3 = np.percentile(df[col], 75)

            # Plot mean, Q1, and Q3 levels
            plt.axhline(mean_val, color='r', linestyle='--', label='Mean')
            plt.axhline(q1, color='g', linestyle='--', label='Q1')
            plt.axhline(q3, color='b', linestyle='--', label='Q3')
            plt.title(f'Box Plot of {col}')
            plt.legend()
            plt.show()

            # Find and display percentage of outliers
            outliers = find_outliers(df, col, q1, q3)
            print('{} - % of outliers: {}%'.format(col, round(outliers.shape[0] / df.shape[0] * 100, 2)))
            print(outliers[col])
InĀ [10]:
box_plot_all(countries_data)
No description has been provided for this image
child_mort - % of outliers: 2.4%
country
Central African Republic    149.0
Chad                        150.0
Haiti                       208.0
Sierra Leone                160.0
Name: child_mort, dtype: float64
No description has been provided for this image
exports - % of outliers: 2.99%
country
Ireland       103.0
Luxembourg    175.0
Malta         153.0
Seychelles     93.8
Singapore     200.0
Name: exports, dtype: float64
No description has been provided for this image
health - % of outliers: 1.2%
country
Micronesia, Fed. Sts.    14.2
United States            17.9
Name: health, dtype: float64
No description has been provided for this image
imports - % of outliers: 2.4%
country
Luxembourg    142.0
Malta         154.0
Seychelles    108.0
Singapore     174.0
Name: imports, dtype: float64
No description has been provided for this image
income - % of outliers: 4.79%
country
Brunei                   80600
Kuwait                   75200
Luxembourg               91700
Norway                   62300
Qatar                   125000
Singapore                72100
Switzerland              55500
United Arab Emirates     57600
Name: income, dtype: int64
No description has been provided for this image
inflation - % of outliers: 2.99%
country
Equatorial Guinea     24.9
Mongolia              39.2
Nigeria              104.0
Timor-Leste           26.5
Venezuela             45.9
Name: inflation, dtype: float64
No description has been provided for this image
life_expec - % of outliers: 1.8%
country
Central African Republic    47.5
Haiti                       32.1
Lesotho                     46.5
Name: life_expec, dtype: float64
No description has been provided for this image
total_fer - % of outliers: 0.6%
country
Niger    7.49
Name: total_fer, dtype: float64
No description has been provided for this image
gdpp - % of outliers: 14.97%
country
Australia                51900
Austria                  46900
Belgium                  44400
Brunei                   35300
Canada                   47400
Denmark                  58000
Finland                  46200
France                   40600
Germany                  41800
Iceland                  41900
Ireland                  48700
Italy                    35800
Japan                    44500
Kuwait                   38500
Luxembourg              105000
Netherlands              50300
New Zealand              33700
Norway                   87800
Qatar                    70300
Singapore                46600
Sweden                   52100
Switzerland              74600
United Arab Emirates     35000
United Kingdom           38900
United States            48400
Name: gdpp, dtype: int64

All data seems to be right, we do not need to fix any outlier

EDA¶

InĀ [11]:
def create_choropleth_maps(df):
    for i in df.columns:
        fig = px.choropleth(df.reset_index(),
                            locationmode='country names',
                            locations='country',
                            title=i.upper(),
                            color=i,
                            color_continuous_scale="Blues",
                        )
        fig.update_geos(fitbounds="locations", visible=True)
        fig.show(renderer="svg")
        # fig.write_image(f"../Img/{i}.pdf")
create_choropleth_maps(countries_data)
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
No description has been provided for this image
InĀ [12]:
sns.pairplot(countries_data)
Out[12]:
<seaborn.axisgrid.PairGrid at 0x7f92f6f7dd50>
No description has been provided for this image

Correlation¶

InĀ [13]:
countries_data.corr().style.background_gradient()
Out[13]:
Ā  child_mort exports health imports income inflation life_expec total_fer gdpp
child_mort 1.000000 -0.318093 -0.200402 -0.127211 -0.524315 0.288276 -0.886676 0.848478 -0.483032
exports -0.318093 1.000000 -0.114408 0.737381 0.516784 -0.107294 0.316313 -0.320011 0.418725
health -0.200402 -0.114408 1.000000 0.095717 0.129579 -0.255376 0.210692 -0.196674 0.345966
imports -0.127211 0.737381 0.095717 1.000000 0.122406 -0.246994 0.054391 -0.159048 0.115498
income -0.524315 0.516784 0.129579 0.122406 1.000000 -0.147756 0.611962 -0.501840 0.895571
inflation 0.288276 -0.107294 -0.255376 -0.246994 -0.147756 1.000000 -0.239705 0.316921 -0.221631
life_expec -0.886676 0.316313 0.210692 0.054391 0.611962 -0.239705 1.000000 -0.760875 0.600089
total_fer 0.848478 -0.320011 -0.196674 -0.159048 -0.501840 0.316921 -0.760875 1.000000 -0.454910
gdpp -0.483032 0.418725 0.345966 0.115498 0.895571 -0.221631 0.600089 -0.454910 1.000000
InĀ [14]:
def norm_check(data):
    for var in range(len(data.skew())):
        if abs(data.skew()[var]) > 0.75:
            print('"{}" skew: {}. The variable is NOT normally distributed.'.format(data.skew().index[var],
                                                                                    round(data.skew()[var],2)))
        else:
            print('"{}" skew: {}. The Variable is normally distributed'.format(data.skew().index[var],
                                                                                round(data.skew()[var],2)))
                          

norm_check(countries_data)
"child_mort" skew: 1.45. The variable is NOT normally distributed.
"exports" skew: 2.45. The variable is NOT normally distributed.
"health" skew: 0.71. The Variable is normally distributed
"imports" skew: 1.91. The variable is NOT normally distributed.
"income" skew: 2.23. The variable is NOT normally distributed.
"inflation" skew: 5.15. The variable is NOT normally distributed.
"life_expec" skew: -0.97. The variable is NOT normally distributed.
"total_fer" skew: 0.97. The variable is NOT normally distributed.
"gdpp" skew: 2.22. The variable is NOT normally distributed.
InĀ [15]:
def corr_matrix(data):
    corr_matrix = data.corr()

    for var in corr_matrix.columns:
        corr_matrix[var] = corr_matrix[var].apply(lambda x:x if abs(x) > 0.70 and abs(x) < 1 else '')

    return corr_matrix

corr_matrix(countries_data)
Out[15]:
child_mort exports health imports income inflation life_expec total_fer gdpp
child_mort -0.886676 0.848478
exports 0.737381
health
imports 0.737381
income 0.895571
inflation
life_expec -0.886676 -0.760875
total_fer 0.848478 -0.760875
gdpp 0.895571

Observations and Explanations¶

  • Close to 1: Strong positive relationship.

  • Close to -1: Strong negative relationship.

  • Close to 0: Weak or no relationship.

  • child_mort (Child Mortality):

    • With income: Correlation of -0.524 means that higher income is linked to lower child mortality.
    • With life_expec: Correlation of -0.886 shows that higher child mortality is linked to lower life expectancy.
    • With total_fer: Correlation of 0.848 suggests that higher child mortality is linked to higher total fertility rates.
  • exports:

    • With imports: Correlation of 0.737 means higher exports are linked to higher imports.
    • With gdpp: Correlation of 0.418 means higher exports are linked to higher GDP per capita.
  • health:

    • With gdpp: Correlation of 0.346 shows a positive relationship between health investment and GDP per capita.
  • income:

    • With gdpp: Correlation of 0.896 indicates a strong positive link between income and GDP per capita.
    • With life_expec: Correlation of 0.612 suggests a positive link between income and life expectancy.
  • inflation:

    • With income: Correlation of -0.148 shows a weak negative relationship between inflation and income.
    • With gdpp: Correlation of -0.222 shows a weak negative relationship between inflation and GDP per capita.
  • life_expec (Life Expectancy):

    • With gdpp: Correlation of 0.600 suggests that higher life expectancy is linked to higher GDP per capita.
  • total_fer (Total Fertility Rate):

    • With gdpp: Correlation of -0.455 suggests higher fertility rates are linked to lower GDP per capita.

Overall, the chart suggests strong relationships between child mortality, life expectancy, income, and GDP per capita. This means that higher income and GDP per capita often lead to better health outcomes like lower child mortality and higher life expectancy.

Transformation¶

BoxCox Transformation¶

We will use boxcox_transform to transform our data.

InĀ [16]:
def boxcox_transform(df, constant=0.01):
    transformed_data = pd.DataFrame(index=df.index)  # Initialize DataFrame for transformed data with original index

    for column in df.select_dtypes(include=[np.number]):
        column_data = df[column].dropna()  # Drop missing values
        min_value = column_data.min()
        if min_value <= 0:
            column_data += abs(min_value) + constant  # Add a constant to ensure all values are positive
        transformed_values, _ = boxcox(column_data)
        transformed_data[column] = transformed_values

    return transformed_data

transformed_df = boxcox_transform(countries_data)
InĀ [17]:
transformed_df
Out[17]:
child_mort exports health imports income inflation life_expec total_fer gdpp
country
Afghanistan 4.596747 3.577908 2.941042 10.015629 10.833801 3.745020 8.586499e+05 1.189628 6.275043
Albania 2.846095 6.419617 2.652941 10.457084 14.938781 2.906548 2.675528e+06 0.445233 8.246487
Algeria 3.357802 7.523315 1.849401 8.205481 15.591688 4.571104 2.701690e+06 0.831924 8.331633
Angola 4.885951 9.476151 1.263567 9.768596 13.688497 5.183649 1.101861e+06 1.213716 8.101683
Antigua and Barbuda 2.357391 8.169779 2.496145 11.595983 16.602865 2.190407 2.741283e+06 0.634312 9.319889
... ... ... ... ... ... ... ... ... ...
Vanuatu 3.427188 8.264001 2.243473 10.925065 12.119789 2.495148 1.312811e+06 0.941978 7.931759
Venezuela 2.876553 6.478217 2.125639 5.824603 16.221298 6.799245 2.560084e+06 0.733832 9.419216
Vietnam 3.194584 10.128990 2.736880 13.642586 13.057682 4.104579 2.281623e+06 0.571532 7.125736
Yemen 4.106507 6.650154 2.219629 8.640211 13.052601 5.287811 1.696602e+06 1.089779 7.125736
Zambia 4.511321 7.386964 2.452434 8.130824 12.352925 4.335625 6.433180e+05 1.156830 7.232565

167 rows Ɨ 9 columns

InĀ [18]:
norm_check(transformed_df)
"child_mort" skew: -0.0. The Variable is normally distributed
"exports" skew: 0.19. The Variable is normally distributed
"health" skew: -0.01. The Variable is normally distributed
"imports" skew: 0.27. The Variable is normally distributed
"income" skew: -0.04. The Variable is normally distributed
"inflation" skew: 0.34. The Variable is normally distributed
"life_expec" skew: -0.19. The Variable is normally distributed
"total_fer" skew: 0.09. The Variable is normally distributed
"gdpp" skew: 0.0. The Variable is normally distributed
InĀ [19]:
transformed_df.corr().style.background_gradient()
Out[19]:
Ā  child_mort exports health imports income inflation life_expec total_fer gdpp
child_mort 1.000000 -0.418049 -0.345278 -0.159022 -0.857661 0.429721 -0.913018 0.861221 -0.875941
exports -0.418049 1.000000 -0.076892 0.670786 0.528811 -0.178139 0.339487 -0.369051 0.482957
health -0.345278 -0.076892 1.000000 0.150612 0.135468 -0.303367 0.273514 -0.265663 0.251192
imports -0.159022 0.670786 0.150612 1.000000 0.042087 -0.322692 0.044210 -0.160017 0.056133
income -0.857661 0.528811 0.135468 0.042087 1.000000 -0.253511 0.838362 -0.742142 0.974390
inflation 0.429721 -0.178139 -0.303367 -0.322692 -0.253511 1.000000 -0.365269 0.391192 -0.319505
life_expec -0.913018 0.339487 0.273514 0.044210 0.838362 -0.365269 1.000000 -0.778634 0.849076
total_fer 0.861221 -0.369051 -0.265663 -0.160017 -0.742142 0.391192 -0.778634 1.000000 -0.725514
gdpp -0.875941 0.482957 0.251192 0.056133 0.974390 -0.319505 0.849076 -0.725514 1.000000
InĀ [20]:
def corr_matrix(data):
    corr_matrix = data.corr()

    for var in corr_matrix.columns:
        corr_matrix[var] = corr_matrix[var].apply(lambda x:x if abs(x) > 0.70 and abs(x) < 1 else '')

    return corr_matrix

corr_matrix(transformed_df)
Out[20]:
child_mort exports health imports income inflation life_expec total_fer gdpp
child_mort -0.857661 -0.913018 0.861221 -0.875941
exports
health
imports
income -0.857661 0.838362 -0.742142 0.97439
inflation
life_expec -0.913018 0.838362 -0.778634 0.849076
total_fer 0.861221 -0.742142 -0.778634 -0.725514
gdpp -0.875941 0.97439 0.849076 -0.725514
InĀ [21]:
print('Pairplot - Data BoxCox Transformed')
sns.pairplot(transformed_df)
plt.show()
Pairplot - Data BoxCox Transformed
No description has been provided for this image
InĀ [22]:
transformed_df.describe()
Out[22]:
child_mort exports health imports income inflation life_expec total_fer gdpp
count 167.000000 167.000000 167.000000 167.000000 167.000000 167.000000 1.670000e+02 167.000000 167.000000
mean 3.097442 7.317915 2.626042 9.880608 14.901508 3.165718 2.149662e+06 0.736725 8.424489
std 1.193050 2.324811 0.790563 2.744099 2.980422 1.275490 8.480345e+05 0.296682 1.474337
min 0.959732 -1.529486 0.659003 -1.573824 8.922385 -2.683594 1.070752e+05 0.135186 5.412460
25% 2.129960 5.898570 2.129177 8.025118 12.402926 2.290620 1.499989e+06 0.510173 7.140556
50% 3.000858 7.186307 2.584629 9.818504 14.946213 3.081479 2.281623e+06 0.717791 8.374756
75% 4.208314 8.655259 3.202136 11.580214 17.071952 3.927933 2.741283e+06 0.997138 9.457630
max 5.471019 15.840537 4.967345 20.211924 22.019941 9.159315 3.625682e+06 1.291821 11.427076
InĀ [23]:
scaled_transformed_df = StandardScaler().fit_transform(transformed_df)
scaled_transformed_df
Out[23]:
array([[ 1.2604782 , -1.61357374,  0.39964884, ..., -1.52693606,
         1.53115139, -1.46229157],
       [-0.21131003, -0.38755809,  0.03412817, ...,  0.62196501,
        -0.98546004, -0.121097  ],
       [ 0.21888667,  0.08861708, -0.98534321, ...,  0.65290794,
         0.32184544, -0.063171  ],
       ...,
       [ 0.08166812,  1.21279898,  0.14062292, ...,  0.15607676,
        -0.5584737 , -0.88355629],
       [ 0.84832929, -0.2880961 , -0.51562558, ..., -0.53585296,
         1.19358826, -0.88355629],
       [ 1.18865967,  0.0297901 , -0.220261  , ..., -1.78161855,
         1.42027079, -0.81087922]])

Model Development¶

InĀ [24]:
# libraries
from sklearn.cluster import KMeans
from sklearn.cluster import AgglomerativeClustering
from sklearn.cluster import DBSCAN
from scipy.cluster import hierarchy
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA

KMeans¶

InĀ [25]:
import pandas as pd
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

def distortion_function(k, dataset, init='k-means++', random_state=98):
    # Fit KMeans with the specified number of clusters
    km = KMeans(init='k-means++', n_clusters=k, random_state=98)
    km.fit(dataset)

    # Create an empty DataFrame with the desired columns
    dist_matrix = pd.DataFrame(columns=['Cluster', 'Distance'])

    # Calculate the distance for each data point from its cluster center
    for x in range(len(dataset)):
        distance = ((dataset[x] - km.cluster_centers_[km.labels_[x]]) ** 2).sum()
        new_row = pd.DataFrame({'Cluster': [km.labels_[x]], 'Distance': [distance]})
        
        # Concatenate the new row to the existing DataFrame
        dist_matrix = pd.concat([dist_matrix, new_row], ignore_index=True)
    
    # Calculate the mean of the mean distances grouped by cluster
    distortion = dist_matrix.groupby('Cluster').mean().mean()[0]

    return distortion

# Initialize lists to store inertia, distortion, and silhouette scores
inertia = []
distortion = []
silhouette = []

# Iterate through different values of k (number of clusters)
for k in range(2, 11):
    # Fit KMeans with the specified number of clusters
    km = KMeans(init='k-means++', n_clusters=k, random_state=98)
    km.fit(scaled_transformed_df)  # Replace 'scaled_ds' with your dataset variable
    
    # Append inertia, distortion, and silhouette scores to respective lists
    inertia.append(km.inertia_)
    distortion.append(distortion_function(k=k, dataset=scaled_transformed_df))
    silhouette.append(silhouette_score(scaled_transformed_df, km.labels_))

# The lists 'inertia', 'distortion', and 'silhouette' now contain the respective scores
InĀ [26]:
#plt.figure(figsize=(8,5.5))
plt.plot(range(2,11), inertia)
plt.scatter(range(2,11),inertia, color='red')
plt.title('Number of Clusters vs Inertia')
plt.xlabel('Number of Clusters')
plt.ylabel('Inertia')
plt.xticks(ticks=range(2,11), labels=range(2,11))
plt.grid(True)
plt.annotate('Inflection Point: K=3',
             xy=(3.05,745),
             xytext=(7,850),
             arrowprops = dict(arrowstyle='->', connectionstyle='arc3', color='orange', lw=2))
plt.annotate('Inflection Point: K=5',
            xy=(5.05,600),
            xytext=(9,790),
            arrowprops = dict(arrowstyle='->', connectionstyle='arc3', color='orange', lw=2))
plt.show()
No description has been provided for this image
InĀ [27]:
km = KMeans(init='k-means++', n_clusters=4, random_state=98)
km.fit(scaled_transformed_df)
transformed_df['cluster'] = km.labels_
print('Pairplot - Number of K-Means clusters (using all data features): 4')
sns.pairplot(transformed_df, hue='cluster', palette={0:'blue', 1:'orange', 2:'green',3:'Black'})
plt.show()
Pairplot - Number of K-Means clusters (using all data features): 4
No description has been provided for this image
InĀ [28]:
slct_ds = transformed_df.drop(['health', 'imports', 'inflation', 'exports', 'cluster'], axis=1)
slct_ds_scld = StandardScaler().fit_transform(slct_ds)
InĀ [29]:
km = KMeans(init='k-means++', n_clusters=4, random_state=98)
km.fit(slct_ds_scld)
km_cluster = pd.DataFrame(slct_ds_scld, index=slct_ds.index)
km_cluster['cluster'] = km.labels_
km_cluster.tail()
Out[29]:
0 1 2 3 4 cluster
country
Vanuatu 0.277220 -0.936138 -0.989779 0.693909 -0.335210 2
Venezuela -0.185703 0.444152 0.485425 -0.009781 0.676724 3
Vietnam 0.081668 -0.620507 0.156077 -0.558474 -0.883556 2
Yemen 0.848329 -0.622217 -0.535853 1.193588 -0.883556 2
Zambia 1.188660 -0.857680 -1.781619 1.420271 -0.810879 1
InĀ [30]:
print('Pairplot - Number of Cluster: 4 (Selected Features only)')
sns.pairplot(km_cluster, hue='cluster', palette={0:'blue', 1:'orange', 2:'green', 3:'red'})
plt.show()
Pairplot - Number of Cluster: 4 (Selected Features only)
No description has been provided for this image
InĀ [31]:
cluster_size = km_cluster.groupby(['cluster'], as_index=False)['cluster'].count()
cluster_size.rename(columns={'cluster':'size'}, inplace=True)
cluster_size['cluster'] = [0,1,2,3]
cluster_size
Out[31]:
size cluster
0 40 0
1 39 1
2 37 2
3 51 3

Extra:

PCA¶

InĀ [32]:
# Reduce the dimension to 2 components to visualize clusters on a 2d plot
pca = PCA(n_components=2)
pca_data = pca.fit_transform(km_cluster.iloc[:,:-1])
pca_data = pd.DataFrame(pca_data, index=slct_ds.index, columns=['1st Principal Component', '2nd Principal Component'])
pca_data['cluster'] = km.labels_
pca_data.head()
Out[32]:
1st Principal Component 2nd Principal Component cluster
country
Afghanistan -3.190472 0.132609 1
Albania 0.739870 -0.843691 3
Algeria 0.133452 0.319914 3
Angola -2.209307 1.244740 1
Antigua and Barbuda 1.281350 0.136909 3
InĀ [33]:
fig = plt.figure(figsize=(15,11))
clusters = fig.add_subplot(2,2,1)
sns.scatterplot(data=pca_data, x='1st Principal Component', y='2nd Principal Component', hue='cluster',
                 palette='tab10', s=75, ax=clusters)
clusters.set_title('Clusters - Using PCA for plotting points on 2 dimensions')
clusters.set_xlabel('1st Principal Component')
clusters.set_ylabel('2nd Principal Component')
Out[33]:
Text(0, 0.5, '2nd Principal Component')
No description has been provided for this image
InĀ [34]:
cluster_size = km_cluster.groupby(['cluster'], as_index=False)['cluster'].count()
cluster_size.rename(columns={'cluster':'size'}, inplace=True)
cluster_size['cluster'] = [0,1,2,3]
cluster_size
Out[34]:
size cluster
0 40 0
1 39 1
2 37 2
3 51 3

Cluster 1: countries with the highest values in Income, Life Expectancy, GDP per capita, and lowest values in Child Mortality and Total Fertility. We can label these countries as "Highly developed".

Cluster 3: these countries have lower values than the Highly Developed observations in Income, Life Expectancy, GDP per capita but still above the overall average. Also have higher values than cluster 1 in Child Mortality and Total Fertility, however below the dataset average. We can label these countries as "Upper-middle developed".

Cluster 0: countries with Income, Life Expectancy, GDP per capita mean values below the average, and Child Mortality and Total Fertility mean values above the average. We can label these countries as "Lower-middle developed".

Cluster 2: countries with, by far, the highest values in Child Mortality and Total Fertility as well as the smallest values in Income, Life Expectancy, and GDP per capita. We can label these countries as the "Least developed".

InĀ [35]:
countries_data['cluster'] = km.labels_
countries_data['label'] = None

countries_data.loc[countries_data['cluster'] == 0, 'label'] = 'Highly Developed'
countries_data.loc[countries_data['cluster'] == 1, 'label'] = 'Least Developed'
countries_data.loc[countries_data['cluster'] == 2, 'label'] = 'Lower-Middle Developed'
countries_data.loc[countries_data['cluster'] == 3, 'label'] = 'Upper-middle developed'
countries_data.head()
Out[35]:
child_mort exports health imports income inflation life_expec total_fer gdpp cluster label
country
Afghanistan 90.2 10.0 7.58 44.9 1610 9.44 56.2 5.82 553 1 Least Developed
Albania 16.6 28.0 6.55 48.6 9930 4.49 76.3 1.65 4090 3 Upper-middle developed
Algeria 27.3 38.4 4.17 31.4 12900 16.10 76.5 2.89 4460 3 Upper-middle developed
Angola 119.0 62.3 2.85 42.9 5900 22.40 60.1 6.16 3530 1 Least Developed
Antigua and Barbuda 10.3 45.5 6.03 58.9 19100 1.44 76.8 2.13 12200 3 Upper-middle developed
InĀ [36]:
countries_data.groupby('label')['cluster'].count()
Out[36]:
label
Highly Developed          40
Least Developed           39
Lower-Middle Developed    37
Upper-middle developed    51
Name: cluster, dtype: int64
InĀ [37]:
countries_data.groupby(['cluster','label'], as_index=True).mean()
Out[37]:
child_mort exports health imports income inflation life_expec total_fer gdpp
cluster label
0 Highly Developed 5.050000 58.730000 8.641500 51.977500 43277.500000 2.567900 79.692500 1.725750 39590.000000
1 Least Developed 98.094872 26.895128 6.400769 43.784615 2082.128205 11.166795 58.464103 5.196154 933.897436
2 Lower-Middle Developed 42.581081 37.929432 5.670270 47.504484 7065.675676 9.694081 67.772973 3.211892 3319.783784
3 Upper-middle developed 15.449020 40.464706 6.531961 44.829412 15479.019608 7.895373 74.654902 1.995882 8277.647059
InĀ [38]:
def create_choropleth_map(df, as_svg=False):
    fig = px.choropleth(df.reset_index(),
                        locationmode='country names',
                        locations='country',
                        color='label',
                        title='Countries',
                        hover_name='country',
                        labels={'label': 'Cluster Label'},
                        color_discrete_map={
                            'Least Developed': 'lightcoral',
                            'Upper-middle developed': 'lightblue',
                        }
                    )
    fig.update_geos(fitbounds="locations", visible=True)
    fig.update_layout(height=600, width=1000)
    fig.show(fullscreen=True, renderer="svg" if as_svg else None)
    # fig.write_image("../Img/choropleth_map_of_all_countries.png")

# Call the function with the provided data
create_choropleth_map(countries_data)
InĀ [39]:
create_choropleth_map(countries_data, as_svg=True)
No description has been provided for this image
InĀ [40]:
print(f"Highly developed countries are {countries_data[countries_data['cluster']==0].index.values}")
Highly developed countries are ['Australia' 'Austria' 'Bahrain' 'Belgium' 'Brunei' 'Canada' 'Croatia'
 'Cyprus' 'Czech Republic' 'Denmark' 'Estonia' 'Finland' 'France'
 'Germany' 'Greece' 'Hungary' 'Iceland' 'Ireland' 'Israel' 'Italy' 'Japan'
 'Kuwait' 'Luxembourg' 'Malta' 'Netherlands' 'New Zealand' 'Norway'
 'Poland' 'Portugal' 'Qatar' 'Singapore' 'Slovak Republic' 'Slovenia'
 'South Korea' 'Spain' 'Sweden' 'Switzerland' 'United Arab Emirates'
 'United Kingdom' 'United States']
InĀ [41]:
print(f"Least developed countries are {countries_data[countries_data['cluster']==1].index.values}")
Least developed countries are ['Afghanistan' 'Angola' 'Benin' 'Burkina Faso' 'Burundi' 'Cameroon'
 'Central African Republic' 'Chad' 'Comoros' 'Congo, Dem. Rep.'
 'Congo, Rep.' "Cote d'Ivoire" 'Eritrea' 'Gambia' 'Ghana' 'Guinea'
 'Guinea-Bissau' 'Haiti' 'Kenya' 'Kiribati' 'Lesotho' 'Liberia'
 'Madagascar' 'Malawi' 'Mali' 'Mauritania' 'Mozambique' 'Niger' 'Nigeria'
 'Pakistan' 'Rwanda' 'Senegal' 'Sierra Leone' 'Solomon Islands' 'Sudan'
 'Tanzania' 'Togo' 'Uganda' 'Zambia']
InĀ [42]:
print(f"Lower-Middle developed countries are {countries_data[countries_data['cluster']==2].index.values}")
Lower-Middle developed countries are ['Bangladesh' 'Belize' 'Bhutan' 'Bolivia' 'Botswana' 'Cambodia'
 'Cape Verde' 'Egypt' 'Equatorial Guinea' 'Fiji' 'Gabon' 'Guatemala'
 'Guyana' 'India' 'Indonesia' 'Iraq' 'Jordan' 'Kyrgyz Republic' 'Lao'
 'Micronesia, Fed. Sts.' 'Mongolia' 'Morocco' 'Myanmar' 'Namibia' 'Nepal'
 'Paraguay' 'Philippines' 'Samoa' 'South Africa' 'Tajikistan'
 'Timor-Leste' 'Tonga' 'Turkmenistan' 'Uzbekistan' 'Vanuatu' 'Vietnam'
 'Yemen']
InĀ [43]:
print(f"Upper-Middle Developed countries are {countries_data[countries_data['cluster']==3].index.values}")
Upper-Middle Developed countries are ['Albania' 'Algeria' 'Antigua and Barbuda' 'Argentina' 'Armenia'
 'Azerbaijan' 'Bahamas' 'Barbados' 'Belarus' 'Bosnia and Herzegovina'
 'Brazil' 'Bulgaria' 'Chile' 'China' 'Colombia' 'Costa Rica'
 'Dominican Republic' 'Ecuador' 'El Salvador' 'Georgia' 'Grenada' 'Iran'
 'Jamaica' 'Kazakhstan' 'Latvia' 'Lebanon' 'Libya' 'Lithuania'
 'Macedonia, FYR' 'Malaysia' 'Maldives' 'Mauritius' 'Moldova' 'Montenegro'
 'Oman' 'Panama' 'Peru' 'Romania' 'Russia' 'Saudi Arabia' 'Serbia'
 'Seychelles' 'Sri Lanka' 'St. Vincent and the Grenadines' 'Suriname'
 'Thailand' 'Tunisia' 'Turkey' 'Ukraine' 'Uruguay' 'Venezuela']
InĀ [44]:
cluster_ds = countries_data.groupby(['cluster','label'], as_index=True).mean()
cluster_ds.T
Out[44]:
cluster 0 1 2 3
label Highly Developed Least Developed Lower-Middle Developed Upper-middle developed
child_mort 5.05000 98.094872 42.581081 15.449020
exports 58.73000 26.895128 37.929432 40.464706
health 8.64150 6.400769 5.670270 6.531961
imports 51.97750 43.784615 47.504484 44.829412
income 43277.50000 2082.128205 7065.675676 15479.019608
inflation 2.56790 11.166795 9.694081 7.895373
life_expec 79.69250 58.464103 67.772973 74.654902
total_fer 1.72575 5.196154 3.211892 1.995882
gdpp 39590.00000 933.897436 3319.783784 8277.647059
InĀ [45]:
country_ds_norm = pd.DataFrame(MinMaxScaler().fit_transform(countries_data[countries_data.columns[:-2]]),
                               columns=countries_data.columns[:-2])
country_ds_norm['cluster'] = km.labels_

group_ds = country_ds_norm.groupby(['cluster'], as_index=True).mean()
labels = cluster_ds.columns
cluster_0 = np.ravel(group_ds.loc[0].values)
cluster_1 = np.ravel(group_ds.loc[1].values)
cluster_2 = np.ravel(group_ds.loc[2].values)
cluster_3 = np.ravel(group_ds.loc[3].values)

x = np.arange(len(labels))  # the label locations
width = 0.2  # the width of the bars

fig, ax = plt.subplots(figsize=(16,8))
rects1 = ax.bar(x - width*1.5, cluster_2, width, label='Lower-Middle Developed', color='green')
rects2 = ax.bar(x - width/2, cluster_0, width, label='Highly Developed', color='steelblue')
rects3 = ax.bar(x + width/2, cluster_3, width, label='Upper-middle developed', color='red')
rects4 = ax.bar(x + width*1.5, cluster_1, width, label='Least Developed', color='orange')

# Add some text for labels, title and custom x-axis tick labels, etc.
ax.set_ylabel('Normalized values (scale from 0 to 1)', fontsize=12)
ax.set_xlabel('Data Features', fontsize=12)
ax.set_title('Overview Cluster Characteristics: Feature mean value per cluster', fontsize=14)
ax.set_xticks(x, labels, fontsize=12)
ax.legend()

fig.tight_layout()
plt.show()
No description has been provided for this image

K-Means clustered the 167 countries included in the dataset into the following categories:

  1. Cluster 0 - Highly developed: countries with the highest values in Income, Life Expectancy, GDP per capita, and lowest values in Child Mortality and Total Fertility.
  2. Cluster 3 - Upper-middle developed: Income, Life Expectancy, GDP per capita mean values are smaller than the Highly developed countrie's average, but above the population average. Child Mortality and Total Fertility means values are greater then cluster 1 mean values, however still below the dataset average.
  3. Cluster 2 - Lower-middle developed: countries with Income, Life Expectancy, GDP per capita mean values below the average, and Child Mortality and Total Fertility mean values above the average.
  4. Cluster 1 - Least developed: countries with, by far, the highest values in Child Mortality and Total Fertility as well as the smallest values in Income, Life Expectancy, and GDP per capita.
InĀ [46]:
df = cluster_ds.reset_index()

# Define the attributes (variables)
attributes = ['child_mort', 'exports', 'health', 'imports', 'income', 'inflation', 'life_expec', 'total_fer', 'gdpp']

# Define colors for each cluster
colors = ['steelblue', 'orange', 'green', 'red']

scaler = StandardScaler()
df[attributes] = scaler.fit_transform(df[attributes])

# Iterate through each cluster and create a separate radar plot
for i, cluster_id in enumerate(df['cluster']):
    # Get the cluster label and values
    cluster_label = df.loc[i, 'label']
    cluster_values = df.loc[i, attributes].values
    
    # Create a new radar plot for the current cluster
    fig = go.Figure()
    
    # Add the trace for the current cluster
    fig.add_trace(go.Scatterpolar(r=cluster_values,
                                  theta=attributes,
                                  fill='toself',
                                  fillcolor=colors[i],
                                  opacity=.80,
                                  name=cluster_label))
    
    # Customize layout
    fig.update_layout(polar=dict(radialaxis=dict(visible=True)),
                      showlegend=True,
                      title=f"Radar Plot - Cluster {cluster_id} ({cluster_label})")
    
    # Show the plot
    fig.show()