Multi Level Sankey Diagram

AkshayArora1
AkshayArora1 Partner, Registered Posts: 11 Partner

Hi Team,

How can we create multilevel sankey diagram if we have multiple columns, I see in the plugin we do see only source and destination columns.

Thanks

Akshay

Answers

  • Sarina
    Sarina Dataiker, Dataiku DSS Core Designer, Dataiku DSS Adv Designer, Registered Posts: 317 Dataiker
    edited July 2024

    Hi @AkshayArora1
    ,

    For more complex Sankey diagrams, I would suggest using a Python notebook. You can then save your Sankey chart as an insight to add to a dashboard or otherwise use. Here's an example using the Sankey Python method outlined in this blog post. I used this example and am attaching a Python notebook that you can import into DSS to test out this Python implementation.

    You can try it out by going to Notebooks > New notebook > Upload and upload the attached notebook (you'll need to replace the input with your sankey input dataset). The notebook uses the insight api to save the sankey chart as an insight, which you can then add to dashboards, as you can with charts.

    Here's an example, where my input dataset is "sankey_multi_test" and looks like this:

    Screen Shot 2021-12-13 at 3.42.48 PM.png

    Using the code from the blog post: https://medium.com/kenlok/how-to-create-sankey-diagrams-from-dataframes-in-python-e221c1b4d6b0, I create the following notebook code to load in my dataset "sankey_multi_test" and write out an insight of the sankey diagram. Here is the full code: 

    # -*- coding: utf-8 -*-
    import dataiku
    import pandas as pd, numpy as np
    from dataiku import pandasutils as pdu
    import plotly
    from dataiku import insights
    
    # Read recipe inputs
    sankey = dataiku.Dataset("sankey_multi_test")
    sankey_df = sankey.get_dataframe()
    
    sankeychart_df = sankey_df 
    
    # credit to https://medium.com/kenlok/how-to-create-sankey-diagrams-from-dataframes-in-python-e221c1b4d6b0 
    def genSankey(df,cat_cols=[],value_cols='',title='Sankey Diagram'):
        # maximum of 6 value cols -> 6 colors
        colorPalette = ['#2ab1ac','#e038e3','#FFE873','#70dd1d','#671ddd']
        labelList = []
        colorNumList = []
        for catCol in cat_cols:
            labelListTemp =  list(set(df[catCol].values))
            colorNumList.append(len(labelListTemp))
            labelList = labelList + labelListTemp
            
        # remove duplicates from labelList
        labelList = list(dict.fromkeys(labelList))
        
        # define colors based on number of levels
        colorList = []
        for idx, colorNum in enumerate(colorNumList):
            colorList = colorList + [colorPalette[idx]]*colorNum
            
        # transform df into a source-target pair
        for i in range(len(cat_cols)-1):
            if i==0:
                sourceTargetDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
                sourceTargetDf.columns = ['source','target','count']
            else:
                tempDf = df[[cat_cols[i],cat_cols[i+1],value_cols]]
                tempDf.columns = ['source','target','count']
                sourceTargetDf = pd.concat([sourceTargetDf,tempDf])
            sourceTargetDf = sourceTargetDf.groupby(['source','target']).agg({'count':'sum'}).reset_index()
            
        # add index for source-target pair
        sourceTargetDf['sourceID'] = sourceTargetDf['source'].apply(lambda x: labelList.index(x))
        sourceTargetDf['targetID'] = sourceTargetDf['target'].apply(lambda x: labelList.index(x))
        
        # creating the sankey diagram
        data = dict(
            type='sankey',
            node = dict(
              pad = 15,
              thickness = 20,
              line = dict(
                color = "black",
                width = 0.5
              ),
              label = labelList,
              color = colorList
            ),
            link = dict(
              source = sourceTargetDf['sourceID'],
              target = sourceTargetDf['targetID'],
              value = sourceTargetDf['count']
            )
          )
        
        layout =  dict(
            title = title,
            font = dict(
              size = 10
            )
        )
           
        fig = dict(data=[data], layout=layout)
        return fig
    
    fig = genSankey(sankeychart_df,cat_cols=['lvl1', 'lvl2', 'lvl3', 'lvl4'],value_cols='count',title='My Sankey test')
    
    insights.save_plotly("my-sankey-chart", fig)


    Now if I navigate to Insights, I see my sankey test chart:

    Screen Shot 2021-12-13 at 3.35.07 PM.png

    I hope that information is helpful.

    Thank you,
    Sarina 

Setup Info
    Tags
      Help me…