Save and Load a pytorch model to managed folder on s3 bucket

apointa
apointa Registered Posts: 1

I train a pytorch model via an python recipe. I have a ManagedFolder as In and Output configured.
The MangedFolder is not on the local file system but on a S3 storage.
Unfortunately the torch.save and torch.load function do not work with the existing API methods of the folder (get_writer() and get_download_stream() respectively).

Is there a option to make this work?

Answers

  • Zach
    Zach Dataiker, Dataiku DSS Core Designer, Dataiku DSS Adv Designer, Registered Posts: 153 Dataiker
    edited July 2024

    Hi @apointa
    ,

    You can save/load a PyTorch model to/from a managed folder by first copying it to a temporary file object that supports the operations needed by torch.save and torch.load.

    There are two ways you can do it depending on the size of the file.

    The first way is to create a file object that's loaded in memory. This is best suited for smaller models that can fit in memory:

    import io
    import shutil
    
    import dataiku
    import torch
    
    folder = dataiku.Folder("MY_FOLDER")
    
    # Download the model into memory, then load it
    with io.BytesIO() as stream:
        with folder.get_download_stream("my-model.pt") as folder_stream:
            shutil.copyfileobj(folder_stream, stream)
        stream.seek(0)
        my_model = torch.load(stream)
    
    # Save the model into memory, then upload it to the managed folder
    with io.BytesIO() as stream:
        torch.save(my_model, stream)
        stream.seek(0)
        folder.upload_stream("my-model.pt", stream)

    The second way is to download the model to a temporary file on the local filesystem. This is best suited for larger models that are too large to fit in memory:

    import shutil
    import tempfile
    
    import dataiku
    import torch
    
    folder = dataiku.Folder("MY_FOLDER")
    
    # Download the model to a temporary file, then load it
    with tempfile.TemporaryFile() as temp_file:
        with folder.get_download_stream("my-model.pt") as folder_stream:
            shutil.copyfileobj(folder_stream, temp_file)
        temp_file.seek(0)
        my_model = torch.load(temp_file)
    
    # Save the model to a temporary file, then upload it to the managed folder
    with tempfile.TemporaryFile() as temp_file:
        torch.save(my_model, temp_file)
        temp_file.seek(0)
        folder.upload_stream("my-model.pt", temp_file)

    Thanks,

    Zach

  • pbena64
    pbena64 Registered Posts: 11

    Hi @ZachM
    ,

    How this would work for other models (i.e., non pytorch) model? Thanks!

    Regards,

Setup Info
    Tags
      Help me…