add
This commit is contained in:
parent
25e718d9a6
commit
a27b0aa2c6
49
20251014.md
49
20251014.md
@ -74,7 +74,7 @@ from pydantic import BaseModel
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.sampling_params import GuidedDecodingParams
|
from vllm.sampling_params import GuidedDecodingParams
|
||||||
|
|
||||||
# Guided decoding by JSON using Pydantic schema
|
# 定义结构化输出 schema
|
||||||
class CarType(str, Enum):
|
class CarType(str, Enum):
|
||||||
sedan = "sedan"
|
sedan = "sedan"
|
||||||
suv = "SUV"
|
suv = "SUV"
|
||||||
@ -86,10 +86,11 @@ class CarDescription(BaseModel):
|
|||||||
model: str
|
model: str
|
||||||
car_type: CarType
|
car_type: CarType
|
||||||
|
|
||||||
|
# 获取 JSON schema
|
||||||
json_schema = CarDescription.model_json_schema()
|
json_schema = CarDescription.model_json_schema()
|
||||||
# guided_decoding_params_json = GuidedDecodingParams(json=json_schema)
|
|
||||||
sampling_params_json = SamplingParams(guided_decoding={})
|
# 设置 prompt
|
||||||
prompt_json = (
|
prompt = (
|
||||||
"Generate a JSON with the brand, model and car_type of "
|
"Generate a JSON with the brand, model and car_type of "
|
||||||
"the most iconic car from the 90's"
|
"the most iconic car from the 90's"
|
||||||
)
|
)
|
||||||
@ -97,14 +98,40 @@ prompt_json = (
|
|||||||
def format_output(title: str, output: str):
|
def format_output(title: str, output: str):
|
||||||
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
|
print(f"{'-' * 50}\n{title}: {output}\n{'-' * 50}")
|
||||||
|
|
||||||
def generate_output(prompt: str, sampling_params: SamplingParams, llm: LLM):
|
|
||||||
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
|
||||||
return outputs[0].outputs[0].text
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
llm = LLM(model="qwen", max_model_len=100)
|
# 1. 初始化本地 LLM,加载本地模型文件
|
||||||
json_output = generate_output(prompt_json, sampling_params_json, llm)
|
llm = LLM(
|
||||||
format_output("Guided decoding by JSON", json_output)
|
model="/home/ss/vllm-py12/qwen3-06b", # 指向你的本地模型路径
|
||||||
|
max_model_len=1024,
|
||||||
|
enable_prefix_caching=True,
|
||||||
|
gpu_memory_utilization=0.9,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 2. 构造一个无效的 guided_decoding:没有任何有效字段
|
||||||
|
# 这将导致 get_structured_output_key() 中 raise ValueError
|
||||||
|
guided_decoding_invalid = GuidedDecodingParams(
|
||||||
|
json=None,
|
||||||
|
json_object=False,
|
||||||
|
regex=None,
|
||||||
|
choice=None,
|
||||||
|
grammar=None,
|
||||||
|
structural_tag=None
|
||||||
|
)
|
||||||
|
|
||||||
|
sampling_params = SamplingParams(
|
||||||
|
temperature=0.0,
|
||||||
|
max_tokens=512,
|
||||||
|
guided_decoding=guided_decoding_invalid # ✅ 传入但无有效字段
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. 生成输出(预期会触发 ValueError)
|
||||||
|
try:
|
||||||
|
outputs = llm.generate(prompts=prompt, sampling_params=sampling_params)
|
||||||
|
for output in outputs:
|
||||||
|
generated_text = output.outputs[0].text
|
||||||
|
format_output("Output", generated_text)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Caught expected error: {e}")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user