Skip to content

Commit

Permalink
Fix chatglm (#13)
Browse files Browse the repository at this point in the history
* Update chatglm2.py

* Update app.py

* parse_codeblock
  • Loading branch information
ypwhs authored Apr 1, 2023
1 parent 20d71c6 commit a9dbbe6
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 6 deletions.
4 changes: 2 additions & 2 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# model_name = 'BelleGroup/BELLE-LLAMA-7B-2M-gptq'

if 'chatglm' in model_name.lower():
from predictors.chatglm import ChatGLM
from predictors.chatglm2 import ChatGLM
predictor = ChatGLM(model_name)
elif 'gptq' in model_name.lower():
from predictors.llama_gptq import LLaMaGPTQ
Expand All @@ -26,7 +26,7 @@
from predictors.debug import Debug
predictor = Debug(model_name)
else:
from predictors.chatglm import ChatGLM
from predictors.chatglm2 import ChatGLM
predictor = ChatGLM(model_name)


Expand Down
14 changes: 14 additions & 0 deletions predictors/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,20 @@
from abc import ABC, abstractmethod


def parse_codeblock(text):
lines = text.split("\n")
for i, line in enumerate(lines):
if "```" in line:
if line != "```":
lines[i] = f'<pre><code class="{lines[i][3:]}">'
else:
lines[i] = '</code></pre>'
else:
if i > 0:
lines[i] = "<br/>" + line.replace("<", "&lt;").replace(">", "&gt;")
return "".join(lines)


class BasePredictor(ABC):

@abstractmethod
Expand Down
9 changes: 5 additions & 4 deletions predictors/chatglm.py → predictors/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from transformers import AutoModel, AutoTokenizer
from transformers import LogitsProcessor, LogitsProcessorList

from predictors.base import BasePredictor
from predictors.base import BasePredictor, parse_codeblock
from chatglm.modeling_chatglm import ChatGLMForConditionalGeneration


class InvalidScoreLogitsProcessor(LogitsProcessor):
Expand All @@ -27,7 +28,7 @@ def __init__(self, model_name):
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, trust_remote_code=True, resume_download=True)
if 'int4' not in model_name:
model = AutoModel.from_pretrained(
model = ChatGLMForConditionalGeneration.from_pretrained(
model_name,
trust_remote_code=True,
resume_download=True,
Expand All @@ -36,7 +37,7 @@ def __init__(self, model_name):
device_map={'': self.device}
)
else:
model = AutoModel.from_pretrained(
model = ChatGLMForConditionalGeneration.from_pretrained(
model_name,
trust_remote_code=True,
resume_download=True
Expand Down Expand Up @@ -105,4 +106,4 @@ def stream_chat_continue(self,
outputs = outputs.tolist()[0][input_length:]
response = tokenizer.decode(outputs)
response = model.process_response(response)
yield response
yield parse_codeblock(response)

0 comments on commit a9dbbe6

Please sign in to comment.