Problem with passing column labels in a plugin custom prediction algorithm component

Antal Dataiku DSS Core Designer, Dataiku DSS ML Practitioner, Neuron, Dataiku DSS Adv Designer, Registered, Neuron 2023 Posts: 85 Neuron

I need the column labels of the preprocessed dataset in the fit function of a custom prediction algorithm that I'm building into a plugin. The classifier in the fit function only accepts a pd.DataFrame, so I have to convert the numpy array after the preprocessing steps back into a DataFrame with correct column labels.

I found these pieces of documentation that suggest column labels can be passed down to model functions though use of a set_column_labels() method.

Component: Prediction algorithm — Dataiku DSS 12 documentation

In-memory Python — Dataiku DSS 12 documentation

The documentation suggests that if this method is present in the function, it will be executed automatically.

I can't get this to work however.

I've implemented the code as follows in a custom prediction algorithm plugin component, but only a None is passed to the fit function

Has anyone tried this or can give a suggestion on how to get this functionality to work?

from import BaseCustomPredictionAlgorithmfrom some-package import SomeClassifierimport pandas as pdclass CustomPredictionAlgorithm(BaseCustomPredictionAlgorithm):def __init__(self, prediction_type=None, params=None, column_labels=None):self.params = paramsself.column_labels = column_labelsself.clf = SomeClassifier()super(CustomPredictionAlgorithm, self).__init__(prediction_type, params)def get_clf(self):return self.clfdef set_column_labels(self, column_labels):self.column_labels = column_labelsdef fit(self, X, y):X_pd = pd.DataFrame(X, columns=self.column_labels)return super(CustomPredictionAlgorithm, self).fit(X_pd, y)def predict(self, X):X_pd = pd.DataFrame(X, columns=self.column_labels)return super(CustomPredictionAlgorithm, self).predict(X_pd)

I've confirmed the column_labels attribute stays empty (None) by printing it out along the way and checking the logs.

Operating system used: AWS Linux

Best Answer

  • Gaspard
    Gaspard Dataiker, Dataiku DSS Core Designer, Dataiku DSS ML Practitioner Posts: 4 Dataiker
    Answer ✓


    The problem comes from the fact that the plugin's "CustomPredictionAlgorithm" class and the custom model's class should be two different things.

    For instance "CustomPredictionAlgorithm" should not implement "fit", or "predict", or "set_column_labels" functions.

    So, you must have two different classes:

    • The first one will be called something like "MyCustomClassifier". It will implement "fit", "predict", "set_column_labels", etc. This class must inherit "sklearn.base.BaseEstimator" and follow sklearn's guidelines.
    • The second one is "CustomPredictionAlgorithm". It will wrap "MyCustomClassifier" to make it work with the plugin. (i.e. it should follow the format described here and "self.clf"'s type should be "MyCustomClassifier").

    So, if I copy/paste the examples given in the doc, the file would look like that:

    class MyCustomRegressor(BaseEstimator):"""This model predicts random values between the mininimum and the maximum of y"""def fit(self, X, y):self.y_range = [np.min(y), np.max(y)]def predict(self, X):return np.random.uniform(self.y_range[0], self.y_range[1], size=X.shape[0])class CustomPredictionAlgorithm(BaseCustomPredictionAlgorithm):def __init__(self, prediction_type=None, params=None):self.clf = MyCustomRegressor()super(CustomPredictionAlgorithm, self).__init__(prediction_type, params)def get_clf(self):return self.clf

    Please tell me if that helps, or if you still have some questions.


  • Antal
    Antal Dataiku DSS Core Designer, Dataiku DSS ML Practitioner, Neuron, Dataiku DSS Adv Designer, Registered, Neuron 2023 Posts: 85 Neuron

    Thanks for that!

    Yes, I started out with everything split over two classes, but that didn't work, so I switched.

    But, I must have implemented something wrong, because I redid the code with split classes and now I got it to work!

    Thanks for pointing me in the right direction.

Setup Info
      Help me…