Save and Load a pytorch model to managed folder on s3 bucket
I train a pytorch model via an python recipe. I have a ManagedFolder as In and Output configured.
The MangedFolder is not on the local file system but on a S3 storage.
Unfortunately the torch.save and torch.load function do not work with the existing API methods of the folder (get_writer() and get_download_stream() respectively).
Is there a option to make this work?
Answers
-
Hi @apointa
,You can save/load a PyTorch model to/from a managed folder by first copying it to a temporary file object that supports the operations needed by torch.save and torch.load.
There are two ways you can do it depending on the size of the file.
The first way is to create a file object that's loaded in memory. This is best suited for smaller models that can fit in memory:
import io import shutil import dataiku import torch folder = dataiku.Folder("MY_FOLDER") # Download the model into memory, then load it with io.BytesIO() as stream: with folder.get_download_stream("my-model.pt") as folder_stream: shutil.copyfileobj(folder_stream, stream) stream.seek(0) my_model = torch.load(stream) # Save the model into memory, then upload it to the managed folder with io.BytesIO() as stream: torch.save(my_model, stream) stream.seek(0) folder.upload_stream("my-model.pt", stream)
The second way is to download the model to a temporary file on the local filesystem. This is best suited for larger models that are too large to fit in memory:
import shutil import tempfile import dataiku import torch folder = dataiku.Folder("MY_FOLDER") # Download the model to a temporary file, then load it with tempfile.TemporaryFile() as temp_file: with folder.get_download_stream("my-model.pt") as folder_stream: shutil.copyfileobj(folder_stream, temp_file) temp_file.seek(0) my_model = torch.load(temp_file) # Save the model to a temporary file, then upload it to the managed folder with tempfile.TemporaryFile() as temp_file: torch.save(my_model, temp_file) temp_file.seek(0) folder.upload_stream("my-model.pt", temp_file)
Thanks,
Zach
-