From f817923a91ae97028fe71b31cc7b5e97bbea2426 Mon Sep 17 00:00:00 2001 From: Ishaaq Chandy Date: Wed, 5 Dec 2018 21:25:59 -0800 Subject: [PATCH 1/2] Add AugmentedManifestFile & ShuffleConfig support --- src/sagemaker/session.py | 36 ++++++++++++++++++++++++++++++------ tests/unit/test_estimator.py | 26 +++++++++++++++++++++++++- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/src/sagemaker/session.py b/src/sagemaker/session.py index f0fb48ce8d..badd69d6a2 100644 --- a/src/sagemaker/session.py +++ b/src/sagemaker/session.py @@ -1222,7 +1222,7 @@ class s3_input(object): def __init__(self, s3_data, distribution='FullyReplicated', compression=None, content_type=None, record_wrapping=None, s3_data_type='S3Prefix', - input_mode=None): + input_mode=None, attribute_names=None, shuffle_config=None): """Create a definition for input data used by an SageMaker training job. See AWS documentation on the ``CreateTrainingJob`` API for more details on the parameters. @@ -1234,17 +1234,23 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, compression (str): Valid values: 'Gzip', None (default: None). This is used only in Pipe input mode. content_type (str): MIME type of the input data (default: None). record_wrapping (str): Valid values: 'RecordIO' (default: None). - s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile'. If 'S3Prefix', ``s3_data`` defines - a prefix of s3 objects to train on. All objects with s3 keys beginning with ``s3_data`` will - be used to train. If 'ManifestFile', then ``s3_data`` defines a single s3 manifest file, listing - each s3 object to train on. The Manifest file format is described in the SageMaker API documentation: - https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html + s3_data_type (str): Valid values: 'S3Prefix', 'ManifestFile', 'AugmentedManifestFile'. If 'S3Prefix', + ``s3_data`` defines a prefix of s3 objects to train on. All objects with s3 keys beginning with + ``s3_data`` will be used to train. If 'ManifestFile' or 'AugmentedManifestFile', then ``s3_data`` + defines a single s3 manifest file or augmented manifest file (respectively), listing the s3 data to + train on. Both the ManifestFile and AugmentedManifestFile formats are described in the SageMaker API + documentation: https://docs.aws.amazon.com/sagemaker/latest/dg/API_S3DataSource.html input_mode (str): Optional override for this channel's input mode (default: None). By default, channels will use the input mode defined on ``sagemaker.estimator.EstimatorBase.input_mode``, but they will ignore that setting if this parameter is set. * None - Amazon SageMaker will use the input mode specified in the ``Estimator``. * 'File' - Amazon SageMaker copies the training dataset from the S3 location to a local directory. * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a Unix-named pipe. + attribute_names (list[str]): A list of one or more attribute names to use that are found in a specified + AugmentedManifestFile. + shuffle_config (ShuffleConfig): If specified this configuration enables shuffling on this channel. See the + SageMaker API documentation for more info: + https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html """ self.config = { 'DataSource': { @@ -1264,6 +1270,24 @@ def __init__(self, s3_data, distribution='FullyReplicated', compression=None, self.config['RecordWrapperType'] = record_wrapping if input_mode is not None: self.config['InputMode'] = input_mode + if attribute_names is not None: + self.config['DataSource']['S3DataSource']['AttributeNames'] = attribute_names + if shuffle_config is not None: + self.config['ShuffleConfig'] = {'Seed': shuffle_config.seed} + + +class ShuffleConfig(object): + """ + Used to configure channel shuffling using a seed. See SageMaker + documentation for more detail: https://docs.aws.amazon.com/sagemaker/latest/dg/API_ShuffleConfig.html + """ + def __init__(self, seed): + """ + Create a ShuffleConfig. + Args: + seed (long): the long value used to seed the shuffled sequence. + """ + self.seed = seed class ModelContainer(object): diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index 1916f9ca7c..cf726f03e8 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -24,7 +24,7 @@ from sagemaker.estimator import Estimator, EstimatorBase, Framework, _TrainingJob from sagemaker.model import FrameworkModel from sagemaker.predictor import RealTimePredictor -from sagemaker.session import s3_input +from sagemaker.session import s3_input, ShuffleConfig from sagemaker.transformer import Transformer MODEL_DATA = "s3://bucket/model.tar.gz" @@ -277,6 +277,30 @@ def test_invalid_custom_code_bucket(sagemaker_session): assert "Expecting 's3' scheme" in str(error) +def test_augmented_manifest(sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True) + fw.fit(inputs=s3_input('s3://mybucket/train_manifest', s3_data_type='AugmentedManifestFile', + attribute_names=["foo", "bar"])) + + _, _, train_kwargs = sagemaker_session.train.mock_calls[0] + s3_data_source = train_kwargs['input_config'][0]['DataSource']['S3DataSource'] + assert s3_data_source['S3Uri'] == 's3://mybucket/train_manifest' + assert s3_data_source['S3DataType'] == 'AugmentedManifestFile' + assert s3_data_source['AttributeNames'] == ["foo", "bar"] + + +def test_shuffle_config(sagemaker_session): + fw = DummyFramework(entry_point=SCRIPT_PATH, role='DummyRole', sagemaker_session=sagemaker_session, + train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, + enable_cloudwatch_metrics=True) + fw.fit(inputs=s3_input('s3://mybucket/train_manifest', shuffle_config=ShuffleConfig(100))) + _, _, train_kwargs = sagemaker_session.train.mock_calls[0] + channel = train_kwargs['input_config'][0] + assert channel['ShuffleConfig']['Seed'] == 100 + + BASE_HP = { 'sagemaker_program': json.dumps(SCRIPT_NAME), 'sagemaker_submit_directory': json.dumps('s3://mybucket/{}/source/sourcedir.tar.gz'.format(JOB_NAME)), From 8adbbb7d5f3847e1f3ce7ac37def7d70e2654aea Mon Sep 17 00:00:00 2001 From: Ishaaq Chandy Date: Thu, 6 Dec 2018 13:54:59 -0800 Subject: [PATCH 2/2] Update changelog and fix minor nit in test --- CHANGELOG.rst | 4 ++++ tests/unit/test_estimator.py | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index d21d3f80ee..44f9809d07 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,10 @@ CHANGELOG ========= +1.16.2 +====== +* feature: Add support for AugmentedManifestFile and ShuffleConfig + 1.16.1.post1 ============ diff --git a/tests/unit/test_estimator.py b/tests/unit/test_estimator.py index cf726f03e8..c33cbc4ece 100644 --- a/tests/unit/test_estimator.py +++ b/tests/unit/test_estimator.py @@ -282,13 +282,13 @@ def test_augmented_manifest(sagemaker_session): train_instance_count=INSTANCE_COUNT, train_instance_type=INSTANCE_TYPE, enable_cloudwatch_metrics=True) fw.fit(inputs=s3_input('s3://mybucket/train_manifest', s3_data_type='AugmentedManifestFile', - attribute_names=["foo", "bar"])) + attribute_names=['foo', 'bar'])) _, _, train_kwargs = sagemaker_session.train.mock_calls[0] s3_data_source = train_kwargs['input_config'][0]['DataSource']['S3DataSource'] assert s3_data_source['S3Uri'] == 's3://mybucket/train_manifest' assert s3_data_source['S3DataType'] == 'AugmentedManifestFile' - assert s3_data_source['AttributeNames'] == ["foo", "bar"] + assert s3_data_source['AttributeNames'] == ['foo', 'bar'] def test_shuffle_config(sagemaker_session):