Multi Level Sankey Diagram

Level 2
Multi Level Sankey Diagram

Hi Team,

How can we create multilevel sankey diagram if we have multiple columns, I see in the plugin we do see only source and destination columns.



0 Kudos
1 Reply

Hi @AkshayArora1,

For more complex Sankey diagrams, I would suggest using a Python notebook. You can then save your Sankey chart as an insight to add to a dashboard or otherwise use. Here's an example using the Sankey Python method outlined in this blog post. I used this example and am attaching a Python notebook that you can import into DSS to test out this Python implementation.

You can try it out by going to Notebooks > New notebook > Upload and upload the attached notebook (you'll need to replace the input with your sankey input dataset). The notebook uses the insight api to save the sankey chart as an insight, which you can then add to dashboards, as you can with charts.

Here's an example, where my input dataset is "sankey_multi_test" and looks like this:

Screen Shot 2021-12-13 at 3.42.48 PM.png

Using the code from the blog post:, I create the following notebook code to load in my dataset "sankey_multi_test" and write out an insight of the sankey diagram.  Here is the full code: 


# -*- coding: utf-8 -*-
import dataiku
import pandas as pd, numpy as np
from dataiku import pandasutils as pdu
import plotly
from dataiku import insights

# Read recipe inputs
sankey = dataiku.Dataset("sankey_multi_test")
sankey_df = sankey.get_dataframe()

sankeychart_df = sankey_df 

# credit to 
def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
    # maximum of 6 value cols -> 6 colors
    colorPalette = ['#2ab1ac','#e038e3','#FFE873','#70dd1d','#671ddd']
    labelList = []
    colorNumList = []
    for catCol in cat_cols:
        labelListTemp =  list(set(df[catCol].values))
        labelList = labelList + labelListTemp
    # remove duplicates from labelList
    labelList = list(dict.fromkeys(labelList))
    # define colors based on number of levels
    colorList = []
    for idx, colorNum in enumerate(colorNumList):
        colorList = colorList + [colorPalette[idx]]*colorNum
    # transform df into a source-target pair
    for i in range(len(cat_cols)-1):
        if i==0:
            sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            sourceTargetDf.columns = ['source','target','count']
            tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
            tempDf.columns = ['source','target','count']
            sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
        sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
    # add index for source-target pair
    sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
    sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
    # creating the sankey diagram
    data = dict(
        node = dict(
          pad = 15,
          thickness = 20,
          line = dict(
            color = "black",
            width = 0.5
          label = labelList,
          color = colorList
        link = dict(
          source = sourceTargetDf['sourceID'],
          target = sourceTargetDf['targetID'],
          value = sourceTargetDf['count']
    layout =  dict(
        title = title,
        font = dict(
          size = 10
    fig = dict(data=[data], layout=layout)
    return fig

fig = genSankey(sankeychart_df,cat_cols=['lvl1', 'lvl2', 'lvl3', 'lvl4'],value_cols='count',title='My Sankey test')

insights.save_plotly("my-sankey-chart", fig)


Now if I navigate to Insights, I see my sankey test chart:

Screen Shot 2021-12-13 at 3.35.07 PM.png

I hope that information is helpful.

Thank you,