diff --git a/docs/api/trainingclient.md b/docs/api/trainingclient.md index b2c00d0..a4651da 100644 --- a/docs/api/trainingclient.md +++ b/docs/api/trainingclient.md @@ -57,7 +57,7 @@ data = [types.Datum( )] future = training_client.forward(data, "cross_entropy") result = await future -print(f"Loss: {result.loss}") +print(f"Loss: {result.metrics['loss:sum']}") ``` #### `forward_async` @@ -108,7 +108,7 @@ optim_future = training_client.optim_step( ) fwdbwd_result = await fwdbwd_future -print(f"Loss: {fwdbwd_result.loss}") +print(f"Loss: {fwdbwd_result.metrics['loss:sum']}") ``` #### `forward_backward_async` @@ -157,7 +157,6 @@ def custom_loss(data, logprobs_list): future = training_client.forward_backward_custom(data, custom_loss) result = future.result() -print(f"Custom loss: {result.loss}") print(f"Metrics: {result.metrics}") ``` diff --git a/src/tinker/lib/public_interfaces/training_client.py b/src/tinker/lib/public_interfaces/training_client.py index cc812f8..90d55ce 100644 --- a/src/tinker/lib/public_interfaces/training_client.py +++ b/src/tinker/lib/public_interfaces/training_client.py @@ -206,7 +206,7 @@ def forward( )] future = training_client.forward(data, "cross_entropy") result = await future - print(f"Loss: {result.loss}") + print(f"Loss: {result.metrics['loss:sum']}") ``` """ requests = self._chunked_requests(data) @@ -294,7 +294,7 @@ def forward_backward( ) fwdbwd_result = await fwdbwd_future - print(f"Loss: {fwdbwd_result.loss}") + print(f"Loss: {fwdbwd_result.metrics['loss:sum']}") ``` """ requests = self._chunked_requests(data) @@ -368,7 +368,6 @@ def custom_loss(data, logprobs_list): future = training_client.forward_backward_custom(data, custom_loss) result = future.result() - print(f"Custom loss: {result.loss}") print(f"Metrics: {result.metrics}") ``` """