Split dataset by stratified sampling.

Options
Thomas_K
Thomas_K Registered Posts: 15 ✭✭✭✭

When I try to "split" a dataset randomly, I currently get the following options:

- Full random

- Random subset

Neither of those is what I often use to split into training/test data: Stratified sampling, to ensure that classes with very low presence (e.g. only a few dozen of 10000) are present in both sets. Is there something I overlooked, or is this not currently implemented?

Best Answer

  • Clément_Stenac
    Clément_Stenac Dataiker, Dataiku DSS Core Designer, Registered Posts: 753 Dataiker
    Answer ✓
    Options
    Hi,

    This feature does not yet exist, it is however in our backlog (but we don't yet have a target date for it).

    This can be done in a Python recipe with a bit of help from pandas and scikit-learn.

Answers

  • theoplatt
    theoplatt Registered Posts: 5 ✭✭✭✭
    Options
    Thanks - I implemented as you said with a python recipe. For others reading this later it's a one liner -

    df_train, df_test = train_test_split(df, test_size=0.2, stratify=df['label'])

    But I'd love to see this added as a feature!
  • MarkPundurs
    MarkPundurs Dataiku DSS Core Designer, Dataiku DSS ML Practitioner, Registered Posts: 26 ✭✭✭✭
    Options

    Well, a 2-liner, the other one being

    from sklearn.model_selection import train_test_split

  • MarkPundurs
    MarkPundurs Dataiku DSS Core Designer, Dataiku DSS ML Practitioner, Registered Posts: 26 ✭✭✭✭
    Options

    And if you have values of 'label' that appear in only a single record, and you want to make sure those records go to the training set, you need a few more lines:

    import numpy as np

    values, counts = np.unique(df['label'], return_counts=True)
    valseq1 = values[counts == 1]
    valsgt1 = values[counts > 1]
    counteq1_df = df[df['label'].isin(valseq1)]
    countgt1_df = df[df['label'].isin(valsgt1)]
    df_train, df_test = train_test_split(countgt1_df, test_size=0.2, stratify=countgt1_df['label'])
    df_train = pd.concat([df_train, counteq1_df], axis=0)

Setup Info
    Tags
      Help me…