11import json
2+ import logging
23import os
34import posixpath
45import sys
56import urllib .parse
6- from typing import Any , Dict , List , Optional , Union
7+ from typing import Dict , List , Optional , Union
78
89import requests
910from tqdm import tqdm
1213DEFAULT_ENDPOINT = "https://api.together.xyz/"
1314
1415
15- def exception_handler (exception_type : Any , exception : Any , traceback : Any ) -> None :
16- # All your trace are belong to us!
17- # your format
18- # print("%s: %s" % (exception_type.__name__, exception))
19- pass
20-
21-
2216class JSONException (Exception ):
2317 pass
2418
@@ -39,6 +33,7 @@ class Files:
3933 def __init__ (
4034 self ,
4135 endpoint_url : Optional [str ] = None ,
36+ log_level : str = "WARNING" ,
4237 ) -> None :
4338 self .together_api_key = os .environ .get ("TOGETHER_API_KEY" , None )
4439 if self .together_api_key is None :
@@ -51,6 +46,17 @@ def __init__(
5146
5247 self .endpoint_url = urllib .parse .urljoin (endpoint_url , "/v1/files/" )
5348
49+ self .logger = logging .getLogger (__name__ )
50+
51+ # Setup logging
52+ logging .basicConfig (
53+ format = "%(asctime)s - %(levelname)s - %(name)s - %(message)s" ,
54+ datefmt = "%m/%d/%Y %H:%M:%S" ,
55+ handlers = [logging .StreamHandler (sys .stdout )],
56+ )
57+
58+ self .logger .setLevel (log_level )
59+
5460 def list_files (self ) -> Dict [str , List [Dict [str , Union [str , int ]]]]:
5561 headers = {
5662 "Authorization" : f"Bearer { self .together_api_key } " ,
@@ -72,40 +78,66 @@ def list_files(self) -> Dict[str, List[Dict[str, Union[str, int]]]]:
7278 return response_json
7379
7480 def upload_file (self , file : str ) -> Dict [str , Union [str , int ]]:
75- data = {"purpose" : "fine-tune" }
76- headers = {
77- "Authorization" : f"Bearer { self .together_api_key } " ,
78- }
81+ try :
82+ data = {"purpose" : "fine-tune" , "file_name" : os .path .basename (file )}
7983
80- if not validate_json (file = file ):
81- raise ValueError ("Could not load file: invalid .jsonl file detected." )
84+ headers = {
85+ "Authorization" : f"Bearer { self .together_api_key } " ,
86+ }
87+
88+ if not validate_json (file = file ):
89+ raise ValueError ("Could not load file: invalid .jsonl file detected." )
90+
91+ session = requests .Session ()
92+ init_endpoint = self .endpoint_url [:- 1 ]
93+
94+ self .logger .debug (
95+ f"Upload file POST request: data={ data } , headers={ headers } , URL={ init_endpoint } , allow_redirects=False"
96+ )
97+
98+ response = session .post (
99+ init_endpoint ,
100+ data = data ,
101+ headers = headers ,
102+ allow_redirects = False ,
103+ )
104+ self .logger .debug (f"Response: { response .text } " )
105+ r2_signed_url = response .headers ["Location" ]
106+ file_id = response .headers ["X-Together-File-Id" ]
107+
108+ self .logger .info (f"R2 Signed URL: { r2_signed_url } " )
109+ self .logger .info ("File-ID" )
110+
111+ self .logger .info ("Uploading file..." )
112+ # print("> Uploading file...")
82113
83- sys .excepthook = exception_handler # send request
84- try :
85114 with open (file , "rb" ) as f :
86- response = requests .post (
87- self .endpoint_url ,
88- headers = headers ,
89- files = {"file" : f },
90- data = data ,
91- )
115+ response = session .put (r2_signed_url , files = {"file" : f })
92116
93- except Exception :
94- print (
95- "ERROR: An exception occurred during file upload, likely due to trying to upload a large file. Please note, that we have a 100MB file upload limit at the moment. Up to 5GB uploads are coming soon, stay tuned!"
117+ self .logger .info ("> File uploaded." )
118+ self .logger .debug (f"status code: { response .status_code } " )
119+ self .logger .info ("> Processing file..." )
120+ preprocess_url = urllib .parse .urljoin (
121+ self .endpoint_url , f"{ file_id } /preprocess"
96122 )
97- sys .exit (0 )
98- # print(response.text)
99- # raise ValueError(f"Error raised by files endpoint: {e}")
100123
101- try :
102- response_json = dict (response .json ())
103- except Exception :
104- raise ValueError (
105- f"JSON Error raised. \n Response status code: { str (response .status_code )} "
124+ response = session .post (
125+ preprocess_url ,
126+ headers = headers ,
106127 )
107128
108- return response_json
129+ self .logger .info ("> File processed" )
130+ self .logger .debug (f"Status code: { response .status_code } " )
131+
132+ except Exception :
133+ self .logger .critical ("Response error raised." )
134+ sys .exit (1 )
135+
136+ return {
137+ "filename" : os .path .basename (file ),
138+ "id" : str (file_id ),
139+ "object" : "file" ,
140+ }
109141
110142 def delete_file (self , file_id : str ) -> Dict [str , str ]:
111143 delete_url = urllib .parse .urljoin (self .endpoint_url , file_id )
@@ -131,6 +163,7 @@ def delete_file(self, file_id: str) -> Dict[str, str]:
131163
132164 def retrieve_file (self , file_id : str ) -> Dict [str , Union [str , int ]]:
133165 retrieve_url = urllib .parse .urljoin (self .endpoint_url , file_id )
166+ print (retrieve_url )
134167
135168 headers = {
136169 "Authorization" : f"Bearer { self .together_api_key } " ,
0 commit comments