sapthesh commited on
Commit
56e7b68
·
verified ·
1 Parent(s): 3c3a8dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -37
app.py CHANGED
@@ -1,50 +1,18 @@
1
  import warnings
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
4
- import torch
5
 
6
  # Suppress the FutureWarning
7
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
8
 
9
- # Load the model and tokenizer
10
- model_name = "deepseek-ai/DeepSeek-V3"
11
- revision = "4c1f24cc10a2a1894304c7ab52edd9710c047571"
12
-
13
- print(f"Loading tokenizer from {model_name}...")
14
- tokenizer = AutoTokenizer.from_pretrained(model_name, revision=revision, trust_remote_code=True)
15
-
16
- print(f"Loading configuration from {model_name}...")
17
- config = AutoConfig.from_pretrained(model_name, revision=revision, trust_remote_code=True)
18
-
19
- # Remove quantization configuration if it exists
20
- if hasattr(config, 'quantization_config'):
21
- del config.quantization_config
22
-
23
- print(f"Loading model from {model_name}...")
24
- model = AutoModelForCausalLM.from_pretrained(model_name, config=config, revision=revision, trust_remote_code=True)
25
-
26
- # Check if the model loaded successfully
27
- if model is None:
28
- print("Failed to load model. Exiting...")
29
- exit(1)
30
- else:
31
- print("Model loaded successfully.")
32
 
33
  # Define the text classification function
34
  def classify_text(text):
35
  try:
36
- # Tokenize the input text
37
- inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
38
- # Pass the inputs to the model
39
- logits = model(**inputs)
40
- # Get the probabilities
41
- probabilities = torch.softmax(logits, dim=-1).tolist()[0]
42
- # Get the predicted class
43
- predicted_class = torch.argmax(logits, dim=-1).item()
44
- return {
45
- "Predicted Class": predicted_class,
46
- "Probabilities": probabilities
47
- }
48
  except Exception as e:
49
  print(f"Error during text classification: {e}")
50
  return {
 
1
  import warnings
2
  import gradio as gr
3
+ from proxy_model import RemoteModelProxy
 
4
 
5
  # Suppress the FutureWarning
6
  warnings.filterwarnings("ignore", category=FutureWarning, module="torch")
7
 
8
+ # Load the model via the proxy
9
+ model_proxy = RemoteModelProxy("deepseek-ai/DeepSeek-V3")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # Define the text classification function
12
  def classify_text(text):
13
  try:
14
+ result = model_proxy.classify_text(text)
15
+ return result
 
 
 
 
 
 
 
 
 
 
16
  except Exception as e:
17
  print(f"Error during text classification: {e}")
18
  return {