Skip to content

Commit 69311ca

Browse files
committed
Updated model list endpoint
1 parent 46d49f6 commit 69311ca

File tree

2 files changed

+18
-37
lines changed

2 files changed

+18
-37
lines changed

src/together/api.py

+9-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Any, Dict, List, Optional, cast
33

44
import requests
5+
import urllib.parse
56

67
from together.files import Files
78
from together.finetune import Finetune
@@ -39,21 +40,22 @@ def __init__(
3940
self.endpoint_url = endpoint_url
4041
self.supply_endpoint_url = supply_endpoint_url
4142

42-
def get_supply(self) -> Dict[str, Any]:
43+
def get_all_models(self) -> Dict[str, Any]:
44+
model_url = urllib.parse.urljoin(self.endpoint_url, "models/info?=")
45+
headers = {
46+
"Authorization": f"Bearer {self.together_api_key}",
47+
}
4348
try:
4449
response = requests.get(
45-
self.supply_endpoint_url,
46-
json={
47-
"method": "together_getDepth",
48-
"id": 1,
49-
},
50+
model_url,
51+
headers=headers,
5052
)
5153
except requests.exceptions.RequestException as e:
5254
self.logger.critical(f"Response error raised: {e}")
5355
exit_1(self.logger)
5456

5557
try:
56-
response_json = dict(response.json())
58+
response_json = list(response.json())
5759
except Exception as e:
5860
self.logger.critical(
5961
f"JSON Error raised: {e}\nResponse status code = {response.status_code}"
@@ -62,13 +64,6 @@ def get_supply(self) -> Dict[str, Any]:
6264

6365
return response_json
6466

65-
def get_all_models(self) -> List[str]:
66-
models = cast(List[str], self.get_supply()["result"].keys())
67-
68-
models = [str(sub[:-1]) for sub in models] # remove the ? after the model names
69-
70-
return models
71-
7267
def get_available_models(self) -> List[str]:
7368
res = self.get_supply()
7469
names = res["result"].keys()

src/together/commands/api.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ def add_parser(
1717
child_parsers = parser.add_subparsers(required=True)
1818

1919
_add_list(child_parsers, parents=parents)
20-
_add_raw(child_parsers, parents=parents)
2120

2221

2322
def _add_list(
@@ -26,35 +25,22 @@ def _add_list(
2625
) -> None:
2726
list_model_subparser = parser.add_parser("list", parents=parents)
2827
list_model_subparser.add_argument(
29-
"--all",
30-
"-a",
31-
help="List all models (available and unavailable)",
28+
"--raw",
29+
help="Raw details of all models",
3230
default=False,
3331
action="store_true",
3432
)
3533
list_model_subparser.set_defaults(func=_run_list)
3634

3735

38-
def _add_raw(
39-
parser: argparse._SubParsersAction[argparse.ArgumentParser],
40-
parents: List[argparse.ArgumentParser],
41-
) -> None:
42-
list_parser = parser.add_parser("raw-supply", parents=parents)
43-
list_parser.set_defaults(func=_run_raw)
44-
45-
4636
def _run_list(args: argparse.Namespace) -> None:
4737
api = API(endpoint_url=args.endpoint, log_level=args.log)
48-
49-
if args.all:
50-
response = api.get_all_models()
38+
response = api.get_all_models()
39+
if args.raw:
40+
print(json.dumps(response, indent=4))
5141
else:
52-
response = api.get_available_models()
53-
54-
print(json.dumps(response, indent=4))
42+
models = []
43+
for i in response:
44+
models.append(i["name"])
45+
print(json.dumps(models, indent=4))
5546

56-
57-
def _run_raw(args: argparse.Namespace) -> None:
58-
api = API(endpoint_url=args.endpoint, log_level=args.log)
59-
response = api.get_supply()
60-
print(json.dumps(response, indent=4))

0 commit comments

Comments
 (0)