lightrag-comments/examples/lightrag_azure_openai_demo.py

126 lines
3.8 KiB
Python
Raw Permalink Normal View History

import os
import asyncio
from lightrag import LightRAG, QueryParam
from lightrag.utils import EmbeddingFunc
import numpy as np
from dotenv import load_dotenv
import aiohttp
import logging
logging.basicConfig(level=logging.INFO)
load_dotenv()
AZURE_OPENAI_API_VERSION = os.getenv("AZURE_OPENAI_API_VERSION")
AZURE_OPENAI_DEPLOYMENT = os.getenv("AZURE_OPENAI_DEPLOYMENT")
AZURE_OPENAI_API_KEY = os.getenv("AZURE_OPENAI_API_KEY")
AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
AZURE_EMBEDDING_DEPLOYMENT = os.getenv("AZURE_EMBEDDING_DEPLOYMENT")
AZURE_EMBEDDING_API_VERSION = os.getenv("AZURE_EMBEDDING_API_VERSION")
WORKING_DIR = "./dickens"
if os.path.exists(WORKING_DIR):
import shutil
shutil.rmtree(WORKING_DIR)
os.mkdir(WORKING_DIR)
async def llm_model_func(
prompt, system_prompt=None, history_messages=[], **kwargs
) -> str:
headers = {
"Content-Type": "application/json",
"api-key": AZURE_OPENAI_API_KEY,
}
endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_OPENAI_DEPLOYMENT}/chat/completions?api-version={AZURE_OPENAI_API_VERSION}"
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
if history_messages:
messages.extend(history_messages)
messages.append({"role": "user", "content": prompt})
payload = {
"messages": messages,
"temperature": kwargs.get("temperature", 0),
"top_p": kwargs.get("top_p", 1),
"n": kwargs.get("n", 1),
}
async with aiohttp.ClientSession() as session:
async with session.post(endpoint, headers=headers, json=payload) as response:
if response.status != 200:
raise ValueError(
f"Request failed with status {response.status}: {await response.text()}"
)
result = await response.json()
return result["choices"][0]["message"]["content"]
async def embedding_func(texts: list[str]) -> np.ndarray:
headers = {
"Content-Type": "application/json",
"api-key": AZURE_OPENAI_API_KEY,
}
endpoint = f"{AZURE_OPENAI_ENDPOINT}openai/deployments/{AZURE_EMBEDDING_DEPLOYMENT}/embeddings?api-version={AZURE_EMBEDDING_API_VERSION}"
payload = {"input": texts}
async with aiohttp.ClientSession() as session:
async with session.post(endpoint, headers=headers, json=payload) as response:
if response.status != 200:
raise ValueError(
f"Request failed with status {response.status}: {await response.text()}"
)
result = await response.json()
embeddings = [item["embedding"] for item in result["data"]]
return np.array(embeddings)
async def test_funcs():
result = await llm_model_func("How are you?")
print("Resposta do llm_model_func: ", result)
result = await embedding_func(["How are you?"])
print("Resultado do embedding_func: ", result.shape)
print("Dimensão da embedding: ", result.shape[1])
asyncio.run(test_funcs())
embedding_dimension = 3072
rag = LightRAG(
working_dir=WORKING_DIR,
llm_model_func=llm_model_func,
embedding_func=EmbeddingFunc(
embedding_dim=embedding_dimension,
max_token_size=8192,
func=embedding_func,
),
)
book1 = open("./book_1.txt", encoding="utf-8")
book2 = open("./book_2.txt", encoding="utf-8")
rag.insert([book1.read(), book2.read()])
query_text = "What are the main themes?"
print("Result (Naive):")
print(rag.query(query_text, param=QueryParam(mode="naive")))
print("\nResult (Local):")
print(rag.query(query_text, param=QueryParam(mode="local")))
print("\nResult (Global):")
print(rag.query(query_text, param=QueryParam(mode="global")))
print("\nResult (Hybrid):")
print(rag.query(query_text, param=QueryParam(mode="hybrid")))