Source code for magmaviz.corrplot

import altair as alt
import pandas as pd

[docs]def corrplot(df, print_corr=True, shape="circle"): """Generates a correlation plot for the numeric variables in a dataframe, and prints the correlation values. Parameters ---------- df : DataFrame The dataframe containing the data to be plotted. print_corr : boolean Determines whether the function will print the correlation values after the graph or not. shape : string Determines the shape to use in the graph to represent the value of the correlation. Its possible values are "circle" and "square". Returns ------- altair.vegalite.v4.api.Chart Correlation plot between the numerical variables of a dataframe. Examples -------- >>> from magmaviz.magmaviz import corrplot >>> corrplot(movies_data, print_corr=True, shape="square") """ # Checking that the dataframe is a pandas dataframe type if not isinstance(df, pd.core.frame.DataFrame): raise TypeError( "'df' should be of type 'pandas.core.frame.DataFrame', a pandas dataframe." ) # Checking that type for print_corr is a boolean if not isinstance(print_corr, bool): raise TypeError("'print_corr' should be of type 'boolean'.") # Checking that type of shape is string if not isinstance(shape, str): raise TypeError("'shape' should be of type 'str'.") # Checking if shape has a valid value if shape not in ["circle", "square"]: raise ValueError("The value specified for 'shape' is not valid. It should be 'circle' or 'square'") # Create the correlation dataframe corr_df = ( df .select_dtypes('number') .corr('spearman') .stack() # Get df into long format for altair .reset_index(name='Correlation')) # Name the index that is reset to "Correlation" # Set cases where the variable is compared vs itself to zero corr_df.loc[corr_df['Correlation'] == 1, 'Correlation'] = 0 # Add a column with the absolute value corr_df['abs'] = corr_df['Correlation'].abs() # Rename the columns corr_df.rename(columns={'level_0': 'Variable1', 'level_1': 'Variable2'}, inplace=True) # Creating the plot with circles if shape == "circle": chart = alt.Chart(corr_df).mark_circle().encode( x='Variable1', y='Variable2', size='abs', color=alt.Color('Correlation', scale=alt.Scale(scheme='purpleorange', domain=(-1, 1)))) # Creating the plot with squares if shape == "square": chart = alt.Chart(corr_df).mark_square().encode( x='Variable1', y='Variable2', size='abs', color=alt.Color('Correlation', scale=alt.Scale(scheme='purpleorange', domain=(-1, 1)))) # Printing the correlations without repeating the combination of variables if print_corr: print_corr = corr_df.copy() print_corr['Variable12'] = print_corr['Variable1'] + print_corr['Variable2'] print_corr['Variable21'] = print_corr['Variable2'] + print_corr['Variable1'] for value in print_corr['Variable21']: if value in list(print_corr['Variable12']): print_corr.drop(print_corr[print_corr.Variable21 == value].index, inplace=True) print_corr.reset_index(drop=True, inplace=True) print(print_corr[['Variable1', 'Variable2', 'Correlation']]) return chart