From 56115122a4ce7e96e24734335363bd58c44efe81 Mon Sep 17 00:00:00 2001 From: Nadia Yakimakha <32335935+nadiaya@users.noreply.github.com> Date: Thu, 8 Nov 2018 11:40:31 -0800 Subject: [PATCH] Support optional input channels in local mode. --- CHANGELOG.rst | 5 +++++ src/sagemaker/local/image.py | 1 - src/sagemaker/local/local_session.py | 6 +++--- tests/unit/test_local_session.py | 12 ++++++------ 4 files changed, 14 insertions(+), 10 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 3d6093f49e..14fe52a20a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -2,6 +2,11 @@ CHANGELOG ========= +1.14.1.dev +========== + +* enhancement: Local Mode: support optional input channels + 1.14.0 ====== diff --git a/src/sagemaker/local/image.py b/src/sagemaker/local/image.py index e950893e7b..eafc6f1262 100644 --- a/src/sagemaker/local/image.py +++ b/src/sagemaker/local/image.py @@ -262,7 +262,6 @@ def write_config_files(self, host, hyperparameters, input_data_config): 'hosts': self.hosts } - print(input_data_config) json_input_data_config = {} for c in input_data_config: channel_name = c['ChannelName'] diff --git a/src/sagemaker/local/local_session.py b/src/sagemaker/local/local_session.py index 134d90ffe3..d935c7a97d 100644 --- a/src/sagemaker/local/local_session.py +++ b/src/sagemaker/local/local_session.py @@ -53,8 +53,8 @@ def __init__(self, sagemaker_session=None): """ self.sagemaker_session = sagemaker_session or LocalSession() - def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputDataConfig, OutputDataConfig, - ResourceConfig, **kwargs): + def create_training_job(self, TrainingJobName, AlgorithmSpecification, OutputDataConfig, + ResourceConfig, InputDataConfig=None, **kwargs): """ Create a training job in Local Mode Args: @@ -66,7 +66,7 @@ def create_training_job(self, TrainingJobName, AlgorithmSpecification, InputData HyperParameters (dict) [optional]: Specifies these algorithm-specific parameters to influence the quality of the final model. """ - + InputDataConfig = InputDataConfig or {} container = _SageMakerContainer(ResourceConfig['InstanceType'], ResourceConfig['InstanceCount'], AlgorithmSpecification['TrainingImage'], self.sagemaker_session) training_job = _LocalTrainingJob(container) diff --git a/tests/unit/test_local_session.py b/tests/unit/test_local_session.py index 78d9b49e77..882f86124d 100644 --- a/tests/unit/test_local_session.py +++ b/tests/unit/test_local_session.py @@ -61,8 +61,8 @@ def test_create_training_job(train, LocalSession): resource_config = {'InstanceType': 'local', 'InstanceCount': instance_count} hyperparameters = {'a': 1, 'b': 'bee'} - local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config, - output_data_config, resource_config, HyperParameters=hyperparameters) + local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config, + InputDataConfig=input_data_config, HyperParameters=hyperparameters) expected = { 'ResourceConfig': {'InstanceCount': instance_count}, @@ -111,8 +111,8 @@ def test_create_training_job_invalid_data_source(train, LocalSession): hyperparameters = {'a': 1, 'b': 'bee'} with pytest.raises(ValueError): - local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config, - output_data_config, resource_config, HyperParameters=hyperparameters) + local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config, + InputDataConfig=input_data_config, HyperParameters=hyperparameters) @patch('sagemaker.local.image._SageMakerContainer.train', return_value="/some/path/to/model") @@ -141,8 +141,8 @@ def test_create_training_job_not_fully_replicated(train, LocalSession): hyperparameters = {'a': 1, 'b': 'bee'} with pytest.raises(RuntimeError): - local_sagemaker_client.create_training_job('my-training-job', algo_spec, input_data_config, - output_data_config, resource_config, HyperParameters=hyperparameters) + local_sagemaker_client.create_training_job('my-training-job', algo_spec, output_data_config, resource_config, + InputDataConfig=input_data_config, HyperParameters=hyperparameters) @patch('sagemaker.local.local_session.LocalSession')