import altair as alt
import pandas as pd
import re
from pandas.api.types import is_string_dtype
from pandas.api.types import is_numeric_dtype
[docs]def scatterplot(df, x, y, c="", t="", o=0.5, s=50, xtitle="", ytitle="", ctitle="", xzero=False, yzero=False, shapes=True):
"""Plot a scatterplot on the dataframe with the magma color scheme.
Parameters
----------
df : dataframe
Dataframe containing the numerical features x and y
x : string
Column-name of the numerical variable to be plotted on the x-axis
y : string
Column-name of the numerical variable to be plotted on the y-axis
c : string
Column-name of the categorical variable to color-code the data points
Default value is blank for cases when there is no categorical column
t : string
Title of the plot. Default value is blank. If not provided,
title will be computed based on x, y and/or c
o : float
Opacity of the data points
Default value is 0.5
s : integer
Size of the data points
Default value is 50
xtitle : string
Title of the x-axis. Default value is blank. If not provided,
title will be proper case of the x axis column
ytitle : string
Title of the y-axis. Default value is blank. If not provided,
title will be proper case of the y axis column
ctitle : string
Title of the color legend. Default value is blank. If not provided,
title will be proper case of the color column
xzero : boolean
Scale the x-axis to start from 0 by specifying True
Default value is set to False
yzero : boolean
Scale the y-axis to start from 0 by specifying True
Default value is set to False
shapes : boolean
Assign the color column to the shape attribute of the plot if True
Default value is set to True
Returns
-------
altair.vegalite.v4.api.Chart
Scatterplot between the numerical variables x and y
Example
-------
>>> from magmaviz.scatterplot import scatterplot
>>> from vega_datasets import data
>>> scatterplot(data.iris(), "sepalLength", "sepalWidth", "species",
"Iris Sepal Length vs Sepal Width across Species",
1.0, 50, "Sepal Length", "Sepal Width", "", False, False, True)
"""
# Basic checks to see if parameters passed to function call contain expected values
# check if the dataframe is a pandas dataframe
if not isinstance(df, pd.core.frame.DataFrame):
raise TypeError("'df' should be of type 'pandas.core.frame.DataFrame', a pandas dataframe.")
# check if column name for x-axis is a string
if not isinstance(x, str):
raise TypeError("Invalid value passed to 'x' axis: Assign column name as a 'string'.")
# check if column name for y-axis is a string
if not isinstance(y, str):
raise TypeError("Invalid value passed to 'y' axis: Assign column name as a 'string'.")
# check if column name for color is a string
if not isinstance(c, str):
raise TypeError("Invalid value passed to 'color' variable: Assign column name as a 'string'.")
# check if title is a string
if not isinstance(t, str):
raise TypeError("Invalid value passed to 't' variable: Assign title as a 'string'.")
# check if opacity is a number
if not isinstance(o, float):
raise TypeError("Invalid value passed to 'o' variable: Assign opacity value as a decimal between 0 and 1.")
# check if size is an integer
if not isinstance(s, int):
raise TypeError("Invalid value passed to 's' variable: Assign size value as an integer between 1 and 100.")
# check if x-axis title is a string
if not isinstance(xtitle, str):
raise TypeError("Invalid value passed to 'xtitle' variable: Assign x-axis title as a 'string'.")
# check if y-axis title is a string
if not isinstance(ytitle, str):
raise TypeError("Invalid value passed to 'ytitle' variable: Assign y-axis title as a 'string'.")
# check if color legend title is a string
if not isinstance(ctitle, str):
raise TypeError("Invalid value passed to 'ctitle' variable: Assign color legend title as a 'string'.")
# check if xzero is a boolean
if not isinstance(xzero, bool):
raise TypeError("Invalid value passed to 'xzero' variable: Assign boolean True to begin x axis from zero.")
# check if yzero is a boolean
if not isinstance(yzero, bool):
raise TypeError("Invalid value passed to 'ctitle' variable: Assign boolean True to begin y axis from zero.")
# check if shapes is a boolean
if not isinstance(shapes, bool):
raise TypeError("Invalid value passed to 'shapes' variable: Assign boolean True to show different shapes for each color category.")
# Advanced checks to see whether columns exists, opacity and size are within expected range, etc.
# check if column name assigned to x-axis is present in the dataframe
assert x in list(
df.columns
), "The column specified for 'x' axis does not exist in the dataframe."
# check if column name assigned to y-axis is present in the dataframe
assert y in list(
df.columns
), "The column specified for 'y' axis does not exist in the dataframe."
# check if column name assigned to color is present in the dataframe
if c != "":
assert c in list(
df.columns
), "The column specified for 'color' does not exist in the dataframe."
# check if opacity value is in the range 0.1 to 1.0
if o < 0.1 or o > 1.0:
raise TypeError("Opacity value must be in the range [0.1, 1.0]")
# check if size value is in the range 1 to 100
if s < 1 or s > 100:
raise TypeError("Size of data points must be in the range [1, 100]")
# check if x-axis column is numeric or not
assert is_numeric_dtype(df[x]), "The column assigned to 'x' axis is not of type numeric."
# check if y-axis column is numeric or not
assert is_numeric_dtype(df[y]), "The column assigned to 'y' axis is not of type numeric."
# check if color column is string or not
if c != "":
assert is_string_dtype(df[c]), "The column assigned to 'color' is not of type string."
# add proper titles to axes, legend and plot
if xtitle == "":
xtitle = re.sub(r"[_.,-]", " ", x)
if ytitle == "":
ytitle = re.sub(r"[_.,-]", " ", y)
if ctitle == "":
ctitle = re.sub(r"[_.,-]", " ", c)
if t == "":
if c == "":
t = f"{xtitle.title()} vs {ytitle.title()}"
else:
t = f"{xtitle.title()} vs {ytitle.title()} by {ctitle.title()}"
# scatterplot
if c == "":
plot = alt.Chart(
data=df, title=t
).mark_point(
opacity=o, size=s, color="#721F81"
).encode(
alt.X(x, title=xtitle.capitalize(), scale=alt.Scale(zero=xzero)),
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero))
)
else:
if shapes is False:
plot = alt.Chart(
data=df, title=t
).mark_point(
opacity=o, size=s
).encode(
alt.X(x, title=xtitle.capitalize(), scale=alt.Scale(zero=xzero)),
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero)),
alt.Color(c, title=ctitle.capitalize(), scale=alt.Scale(scheme="magma"))
)
else:
plot = alt.Chart(
data=df, title=t
).mark_point(
opacity=o, size=s
).encode(
alt.X(x, title=xtitle.capitalize(), scale=alt.Scale(zero=xzero)),
alt.Y(y, title=ytitle.capitalize(), scale=alt.Scale(zero=yzero)),
alt.Color(c, title=ctitle.capitalize(), scale=alt.Scale(scheme="magma")),
shape=c
)
return plot