Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 38 additions & 4 deletions src/together/cli/api/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
DownloadCheckpointType,
FinetuneEventType,
FinetuneTrainingLimits,
FullTrainingType,
LoRATrainingType,
)
from together.utils import (
finetune_price_to_dollars,
Expand All @@ -29,13 +31,21 @@

_CONFIRMATION_MESSAGE = (
"You are about to create a fine-tuning job. "
"The cost of your job will be determined by the model size, the number of tokens "
"The estimated price of this job is {price}. "
"The actual cost of your job will be determined by the model size, the number of tokens "
"in the training file, the number of tokens in the validation file, the number of epochs, and "
"the number of evaluations. Visit https://www.together.ai/pricing to get a price estimate.\n"
"the number of evaluations. Visit https://www.together.ai/pricing to learn more about fine-tuning pricing.\n"
"{warning}"
"You can pass `-y` or `--confirm` to your command to skip this message.\n\n"
"Do you want to proceed?"
)

_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of this job is significantly greater than your current credit limit and balance combined. "
"It will likely get cancelled due to insufficient funds. "
"Consider increasing your credit limit at https://api.together.xyz/settings/profile\n"
)


class DownloadCheckpointTypeChoice(click.Choice):
def __init__(self) -> None:
Expand Down Expand Up @@ -357,12 +367,36 @@ def create(
"You have specified a number of evaluation loops but no validation file."
)

if confirm or click.confirm(_CONFIRMATION_MESSAGE, default=True, show_default=True):
finetune_price_estimation_result = client.fine_tuning.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)

price = click.style(
f"${finetune_price_estimation_result.estimated_total_price:.2f}",
bold=True,
)

if not finetune_price_estimation_result.allowed_to_proceed:
warning = click.style(_WARNING_MESSAGE_INSUFFICIENT_FUNDS, fg="red", bold=True)
else:
warning = ""

confirmation_message = _CONFIRMATION_MESSAGE.format(
price=price,
warning=warning,
)

if confirm or click.confirm(confirmation_message, default=True, show_default=True):
response = client.fine_tuning.create(
**training_args,
verbose=True,
)

report_string = f"Successfully submitted a fine-tuning job {response.id}"
if response.created_at is not None:
created_time = datetime.strptime(
Expand Down
205 changes: 204 additions & 1 deletion src/together/resources/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
FinetuneLRScheduler,
FinetuneRequest,
FinetuneResponse,
FinetunePriceEstimationRequest,
FinetunePriceEstimationResponse,
FinetuneTrainingLimits,
FullTrainingType,
LinearLRScheduler,
Expand All @@ -31,7 +33,7 @@
TrainingMethodSFT,
TrainingType,
)
from together.types.finetune import DownloadCheckpointType
from together.types.finetune import DownloadCheckpointType, TrainingMethod
from together.utils import log_warn_once, normalize_key


Expand All @@ -42,6 +44,12 @@
TrainingMethodSFT().method,
TrainingMethodDPO().method,
}
_WARNING_MESSAGE_INSUFFICIENT_FUNDS = (
"The estimated price of the fine-tuning job is {} which is significantly "
"greater than your current credit limit and balance combined. "
"It will likely get cancelled due to insufficient funds. "
"Proceed at your own risk."
)


def create_finetune_request(
Expand Down Expand Up @@ -473,12 +481,34 @@ def create(
hf_api_token=hf_api_token,
hf_output_repo_name=hf_output_repo_name,
)
if from_checkpoint is None and from_hf_model is None:
price_estimation_result = self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model_name,
n_epochs=finetune_request.n_epochs,
n_evals=finetune_request.n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)
price_limit_passed = price_estimation_result.allowed_to_proceed
else:
# unsupported case
price_limit_passed = True

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_limit_passed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

response, _, _ = requestor.request(
Expand All @@ -493,6 +523,81 @@ def create(

return FinetuneResponse(**response.data)

def estimate_price(
self,
*,
training_file: str,
model: str,
validation_file: str | None = None,
n_epochs: int | None = 1,
n_evals: int | None = 0,
training_type: str = "lora",
training_method: str = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimates the price of a fine-tuning job

Args:
training_file (str): File-ID of a file uploaded to the Together API
model (str): Name of the base model to run fine-tune job on
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
training_type (str, optional): Training type. Defaults to "lora".
training_method (str, optional): Training method. Defaults to "sft".

Returns:
FinetunePriceEstimationResponse: Object containing the price estimation result.
"""
training_type_cls: TrainingType
training_method_cls: TrainingMethod

if training_method == "sft":
training_method_cls = TrainingMethodSFT(method="sft")
elif training_method == "dpo":
training_method_cls = TrainingMethodDPO(method="dpo")
else:
raise ValueError(f"Unknown training method: {training_method}")

if training_type.lower() == "lora":
# parameters of lora are unused in price estimation
# but we need to set them to valid values
training_type_cls = LoRATrainingType(
type="Lora",
lora_r=16,
Comment thread
newokaerinasai marked this conversation as resolved.
lora_alpha=16,
lora_dropout=0.0,
lora_trainable_modules="all-linear",
)
elif training_type.lower() == "full":
training_type_cls = FullTrainingType(type="Full")
else:
raise ValueError(f"Unknown training type: {training_type}")

request = FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)
parameter_payload = request.model_dump(exclude_none=True)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = requestor.request(
options=TogetherRequest(
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
),
stream=False,
)
assert isinstance(response, TogetherResponse)

return FinetunePriceEstimationResponse(**response.data)

def list(self) -> FinetuneList:
"""
Lists fine-tune job history
Expand Down Expand Up @@ -941,11 +1046,34 @@ async def create(
hf_output_repo_name=hf_output_repo_name,
)

if from_checkpoint is None and from_hf_model is None:
price_estimation_result = await self.estimate_price(
training_file=training_file,
validation_file=validation_file,
model=model_name,
n_epochs=finetune_request.n_epochs,
n_evals=finetune_request.n_evals,
training_type="lora" if lora else "full",
training_method=training_method,
)
price_limit_passed = price_estimation_result.allowed_to_proceed
else:
# unsupported case
price_limit_passed = True

if verbose:
rprint(
"Submitting a fine-tuning job with the following parameters:",
finetune_request,
)
if not price_limit_passed:
rprint(
"[red]"
+ _WARNING_MESSAGE_INSUFFICIENT_FUNDS.format(
price_estimation_result.estimated_total_price
)
+ "[/red]",
)
parameter_payload = finetune_request.model_dump(exclude_none=True)

response, _, _ = await requestor.arequest(
Expand All @@ -961,6 +1089,81 @@ async def create(

return FinetuneResponse(**response.data)

async def estimate_price(
self,
*,
training_file: str,
model: str,
validation_file: str | None = None,
n_epochs: int | None = 1,
n_evals: int | None = 0,
training_type: str = "lora",
training_method: str = "sft",
) -> FinetunePriceEstimationResponse:
"""
Estimates the price of a fine-tuning job

Args:
training_file (str): File-ID of a file uploaded to the Together API
model (str): Name of the base model to run fine-tune job on
validation_file (str, optional): File ID of a file uploaded to the Together API for validation.
n_epochs (int, optional): Number of epochs for fine-tuning. Defaults to 1.
n_evals (int, optional): Number of evaluation loops to run. Defaults to 0.
training_type (str, optional): Training type. Defaults to "lora".
training_method (str, optional): Training method. Defaults to "sft".

Returns:
FinetunePriceEstimationResponse: Object containing the price estimation result.
"""
training_type_cls: TrainingType
training_method_cls: TrainingMethod

if training_method == "sft":
training_method_cls = TrainingMethodSFT(method="sft")
elif training_method == "dpo":
training_method_cls = TrainingMethodDPO(method="dpo")
else:
raise ValueError(f"Unknown training method: {training_method}")

if training_type.lower() == "lora":
# parameters of lora are unused in price estimation
# but we need to set them to valid values
training_type_cls = LoRATrainingType(
type="Lora",
lora_r=16,
lora_alpha=16,
lora_dropout=0.0,
lora_trainable_modules="all-linear",
)
elif training_type.lower() == "full":
training_type_cls = FullTrainingType(type="Full")
else:
raise ValueError(f"Unknown training type: {training_type}")

request = FinetunePriceEstimationRequest(
training_file=training_file,
validation_file=validation_file,
model=model,
n_epochs=n_epochs,
n_evals=n_evals,
training_type=training_type_cls,
training_method=training_method_cls,
)
parameter_payload = request.model_dump(exclude_none=True)
requestor = api_requestor.APIRequestor(
client=self._client,
)

response, _, _ = await requestor.arequest(
options=TogetherRequest(
method="POST", url="fine-tunes/estimate-price", params=parameter_payload
),
stream=False,
)
assert isinstance(response, TogetherResponse)

return FinetunePriceEstimationResponse(**response.data)

async def list(self) -> FinetuneList:
"""
Async method to list fine-tune job history
Expand Down
4 changes: 4 additions & 0 deletions src/together/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
FinetuneListEvents,
FinetuneRequest,
FinetuneResponse,
FinetunePriceEstimationRequest,
FinetunePriceEstimationResponse,
FinetuneDeleteResponse,
FinetuneTrainingLimits,
FullTrainingType,
Expand Down Expand Up @@ -103,6 +105,8 @@
"FinetuneDeleteResponse",
"FinetuneDownloadResult",
"FinetuneLRScheduler",
"FinetunePriceEstimationRequest",
"FinetunePriceEstimationResponse",
"LinearLRScheduler",
"LinearLRSchedulerArgs",
"CosineLRScheduler",
Expand Down
26 changes: 26 additions & 0 deletions src/together/types/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,32 @@ def validate_training_type(cls, v: TrainingType) -> TrainingType:
raise ValueError("Unknown training type")


class FinetunePriceEstimationRequest(BaseModel):
"""
Fine-tune price estimation request type
"""

training_file: str
validation_file: str | None = None
model: str
n_epochs: int
n_evals: int
training_type: TrainingType
training_method: TrainingMethod


class FinetunePriceEstimationResponse(BaseModel):
"""
Fine-tune price estimation response type
"""

estimated_total_price: float
user_limit: float
estimated_train_token_count: int
estimated_eval_token_count: int
allowed_to_proceed: bool


class FinetuneList(BaseModel):
# object type
object: Literal["list"] | None = None
Expand Down
Loading