Save and Load a pytorch model to managed folder on s3 bucket

apointa
Level 1
Save and Load a pytorch model to managed folder on s3 bucket

I train a pytorch model via an python recipe. I have a ManagedFolder as In and Output configured.
The MangedFolder is not on the local file system but on a S3 storage.
Unfortunately the torch.save and torch.load function do not work with the existing API methods of the folder (get_writer()  and get_download_stream() respectively).

Is there a option to make this work? 

2 Replies
ZachM
Dataiker

Hi @apointa,

You can save/load a PyTorch model to/from a managed folder by first copying it to a temporary file object that supports the operations needed by torch.save and torch.load.

There are two ways you can do it depending on the size of the file.

The first way is to create a file object that's loaded in memory. This is best suited for smaller models that can fit in memory:

import io
import shutil

import dataiku
import torch

folder = dataiku.Folder("MY_FOLDER")

# Download the model into memory, then load it
with io.BytesIO() as stream:
    with folder.get_download_stream("my-model.pt") as folder_stream:
        shutil.copyfileobj(folder_stream, stream)
    stream.seek(0)
    my_model = torch.load(stream)

# Save the model into memory, then upload it to the managed folder
with io.BytesIO() as stream:
    torch.save(my_model, stream)
    stream.seek(0)
    folder.upload_stream("my-model.pt", stream)

 

The second way is to download the model to a temporary file on the local filesystem. This is best suited for larger models that are too large to fit in memory:

import shutil
import tempfile

import dataiku
import torch

folder = dataiku.Folder("MY_FOLDER")

# Download the model to a temporary file, then load it
with tempfile.TemporaryFile() as temp_file:
    with folder.get_download_stream("my-model.pt") as folder_stream:
        shutil.copyfileobj(folder_stream, temp_file)
    temp_file.seek(0)
    my_model = torch.load(temp_file)

# Save the model to a temporary file, then upload it to the managed folder
with tempfile.TemporaryFile() as temp_file:
    torch.save(my_model, temp_file)
    temp_file.seek(0)
    folder.upload_stream("my-model.pt", temp_file)

 

Thanks,

Zach

pbena64
Level 2

Hi @ZachM,

How this would work for other models (i.e., non pytorch) model? Thanks!

Regards,