{"User":"Summarize this long article: <long article here...>","AI":"This article discusses the impact of..."}
时间步 3 (T3) - 用户发送提示"Are there any people mentioned in the article?"。将整个聊天历史记录发送到 LLM,以便为其提供适当的上下文以继续对话。
{"User":"Summarize this long article: <long article here...>","AI":"This article discusses the impact of...","User":"Are there any people mentioned in the article?"}
时间步 4 (T4) - LLM 引擎对输入(来自T1、T2和T3的所有标记)进行推理,计算模型权重和输入标记嵌入之间的较大矩阵乘法,以产生输出标记:"Yes, the article mentions several key figures, including..."
{"User":"Summarize this long article: <long article here...>","AI":"This article discusses the impact of...","User":"Are there any people mentioned in the article?","AI":"Yes, the article mentions several key figures, including..."}
KV 缓存
KV 缓存利用这样一个事实:到我们到达T3时,询问 LLM 关于"people mentioned in the article",我们已经在T1和T2中执行了与T3中需要计算的那些相同的矩阵计算。
{
# START OF PREVIOUSLY COMPUTED
"User":"Summarize this long article: <long article here...>","AI":"This article discusses the impact of..."
# END OF PREVIOUSLY COMPUTED
"User":"Are there any people mentioned in the article?"}
因此,如果我们将T1和T2中的计算结果保存到KV 缓存中,并在T3时让引擎访问KV 缓存,则引擎只需要对提示的新部分"Are there any people mentioned in the article?"进行计算。
{
KV CACHE,"User":"Are there any people mentioned in the article?"}
defprocess_prompt(self, prompt_tokens, cache_wrapper, generate_args) -> mx.array:
"""
This method processes the prompt and adds its tokens to the cache history
"""# --snip--# prefill cache with prompt_tokens, except those that need to have a repetition penalty applied# (repetition penalty not currently possible for cached tokens)if"repetition_context_size"notin generate_args:
generate_args["repetition_context_size"] = (
20# default value for mlx_lm.utils.generate_step
)
repetition_context_size = generate_args["repetition_context_size"]
cache_history, generate_step_input = cache_wrapper.update_cache(
prompt_tokens,
num_tokens_to_exclude=repetition_context_size
)
generate_args["cache_history"] = cache_history
return generate_step_input
# adapted from https://github.com/ml-explore/mlx-examples/blob/324184d670ec11916a5e92314171d497b312eefe/llms/mlx_lm/cache_prompt.py#L121-L137
step_size = 512
processed: int = 0while processed < len(tokens_to_process):
# Here we evaluate the input prompt chunk by chunk to fill the cache
chunk: mx.array = tokens_to_process[processed:processed+step_size]
self.model(chunk[None], cache=self.cache)
mx.eval([c.state for c in self.cache])
self.tokens: mx.array = mx.concatenate([self.tokens, chunk]) if self.tokens isnotNoneelse chunk
processed += chunk.size
# `process_prompt` function from above
generate_step_input = process_prompt(prompt_tokens, cache_wrapper, generate_args)
max_tokens = generate_args.pop("max_tokens")
for (token, _), n inzip(
# generate_step_input is now just the uncached repetition penalty tokens# generate_args has "cache_history" member, set in `process_prompt`
mlx_lm.utils.generate_step(generate_step_input, model, **generate_args),
range(max_tokens),
):